AAAI Version
This commit is contained in:
49
AAAI Supplementary Material/ForNet Creation Code/utils.py
Normal file
49
AAAI Supplementary Material/ForNet Creation Code/utils.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def save_img(img: Image, img_name: str, base_dir: str, img_class: str = None, format="PNG", img_version=None):
|
||||
"""Save an image to a directory.
|
||||
|
||||
Args:
|
||||
img (PIL.Image): Image to save.
|
||||
img_name (str): Relative path to the image.
|
||||
base_dir (str): Base directory to save images in.
|
||||
img_class (str, optional): Image class, if not given try to extract it from the image name in ImageNet train format. Defaults to None.
|
||||
format (str, optional): Format to save the image in. Defaults to "PNG".
|
||||
img_version (int, optional): Version of the image. Will be appended to the path. Defaults to None.
|
||||
|
||||
"""
|
||||
if not img_name.endswith(f".{format}"):
|
||||
img_name = f"{img_name.split('.')[0]}.{format}"
|
||||
if img_class is None:
|
||||
img_class = img_name.split("_")[0]
|
||||
if not os.path.exists(os.path.join(base_dir, img_class)):
|
||||
os.makedirs(os.path.join(base_dir, img_class), exist_ok=True)
|
||||
if img_version is not None:
|
||||
img_name = f"{img_name.split('.')[0]}_v{img_version}.{format}"
|
||||
img.save(os.path.join(base_dir, img_class, img_name), format.lower())
|
||||
|
||||
|
||||
def already_segmented(img_name: str, base_dir: str, img_class: str = None):
|
||||
"""Check if an image was already segmented.
|
||||
|
||||
Args:
|
||||
img_name (str): Relative path to the image.
|
||||
base_dir (str): Base directory to save images in.
|
||||
img_class (str, optional): Image class, if not given try to extract it from the image name in ImageNet train format. Defaults to None.
|
||||
|
||||
Returns:
|
||||
bool: Image was segmented already.
|
||||
|
||||
"""
|
||||
img_base_name = ".".join(img_name.split(".")[:-1]) if "." in img_name else img_name
|
||||
if img_class is None:
|
||||
img_class = img_name.split("_")[0]
|
||||
if not os.path.exists(os.path.join(base_dir, img_class)):
|
||||
return False
|
||||
return any(
|
||||
file.startswith(img_base_name + "_v") or file.startswith(img_base_name + ".")
|
||||
for file in os.listdir(os.path.join(base_dir, img_class))
|
||||
)
|
||||
Reference in New Issue
Block a user