AAAI Version
73
AAAI Supplementary Material/ForNet Creation Code/.ruff.toml
Normal file
@@ -0,0 +1,73 @@
|
||||
# Exclude a variety of commonly ignored directories.
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".ipynb_checkpoints",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pyenv",
|
||||
".pytest_cache",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
".vscode",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"site-packages",
|
||||
"venv",
|
||||
]
|
||||
|
||||
# Same as Black.
|
||||
line-length = 120
|
||||
indent-width = 4
|
||||
|
||||
[lint]
|
||||
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
|
||||
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
|
||||
# McCabe complexity (`C901`) by default.
|
||||
select = ["E", "F", "D", "B", "NPY", "PD", "TD005", "TD006", "TD007", "SIM", "RET", "Q", "ICN", "I"]
|
||||
ignore = ["D203", "D213", "E501", "D100", "NPY002", "D102", "B008", "PD011", "D105", "SIM118"]
|
||||
|
||||
# Allow fix for all enabled rules (when `--fix`) is provided.
|
||||
unfixable = ["B"]
|
||||
|
||||
# Allow unused variables when underscore-prefixed.
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
[format]
|
||||
# Like Black, use double quotes for strings.
|
||||
quote-style = "double"
|
||||
|
||||
# Like Black, indent with spaces, rather than tabs.
|
||||
indent-style = "space"
|
||||
|
||||
# Like Black, respect magic trailing commas.
|
||||
skip-magic-trailing-comma = true
|
||||
|
||||
# Like Black, automatically detect the appropriate line ending.
|
||||
line-ending = "auto"
|
||||
|
||||
# Enable auto-formatting of code examples in docstrings. Markdown,
|
||||
# reStructuredText code/literal blocks and doctests are all supported.
|
||||
#
|
||||
# This is currently disabled by default, but it is planned for this
|
||||
# to be opt-out in the future.
|
||||
docstring-code-format = true
|
||||
|
||||
# Set the line length limit used when formatting code snippets in
|
||||
# docstrings.
|
||||
#
|
||||
# This only has an effect when the `docstring-code-format` setting is
|
||||
# enabled.
|
||||
docstring-code-line-length = "dynamic"
|
||||
61
AAAI Supplementary Material/ForNet Creation Code/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Creating the ForNet Dataset
|
||||
|
||||
We can't just provide the ForNet dataset here, as it's too large to be part of the appendix and using a link will go against double-blind review rules.
|
||||
After acceptance, the dataset will be downloadable online.
|
||||
For now, we provide the scripts and steps to recreate the dataset.
|
||||
In general, if you are unsure what arguments each script allows, run it using the `--help` flag.
|
||||
|
||||
## 1. Setup paths
|
||||
|
||||
Fill in the paths in `experiments/general_srun` in the `Model Training Code` folder, as well as in `srun-general.sh`, `slurm-segment-imnet.sh` and all the `sbatch-segment-...` files.
|
||||
In particular the `--container-image`, `--container-mounts`, `--output` and `NLTK_DATA` and `HF_HOME` paths in `--export`.
|
||||
|
||||
## 2. Pretrain Filtering Models
|
||||
|
||||
Use the `Model Trainig Code` to pretrain an ensemble of models to use for filtering in a later step.
|
||||
Train those models on either `TinyImageNet` or `ImageNet`, depending on if you want to create `TinyForNet` or `ForNet`.
|
||||
The fill in the relevant paths to the pretrained weights in `experiments/filter_segmentation_versions.py` lines 96/98.
|
||||
|
||||
## 3. Create the dataset
|
||||
|
||||
### Automatically: using slurm
|
||||
|
||||
You may just run the `create_dataset.py` file (on a slurm head node). That file will automatically run all the necessary steps one after another.
|
||||
|
||||
### Manually and step-by-step
|
||||
|
||||
If you want to run each step of the pipeline manually, follow these steps.
|
||||
For default arguments and settings, see the `create_dataset.py` script, even though you may not want to run it directly, it can tell you how to run all the other scripts.
|
||||
|
||||
#### 3.1 Segment Objects and Backgrounds
|
||||
|
||||
Use the segementation script (`segment_imagenet.py`) to segment each of the dataset images.
|
||||
Watch out, as this script uses `datadings` for image loading, so you need to provide a `datadings` variant of your dataset.
|
||||
You need to provide the root folder of the dataset.
|
||||
Choose your segmentation model using the `-model` argument (LaMa or AttErase).
|
||||
If you want to use the >general< prompting strategy, set the `--parent_in_promt` flag.
|
||||
Use `--output`/`-o` to set the output directory.
|
||||
Use `--processes` and `-id` for splitting the task up into multiple parallelizable processes.
|
||||
|
||||
#### 3.2 Filter the segmented images
|
||||
|
||||
In this step, you use the pretrained ensemble of models (from step 2) for filtering the segmented images.
|
||||
As this step is based on the training and model code, it's in the `Model Training Code` directory.
|
||||
After setting the relevant paths to the pretrained weights (see step 2), you may run the `experiments/filter_segmentation_versions.py` script using that directory as the PWD.
|
||||
|
||||
#### 3.3 Zip the dataset
|
||||
|
||||
In distributed storage settings it might be useful to read from one large (unclompressed) zip file instead of reading millions of small single files.
|
||||
To do this, run
|
||||
|
||||
```commandline
|
||||
zip -r -0 backgrounds_train.zip train/backgrounds > /dev/null 2>&1
|
||||
```
|
||||
|
||||
for the train and val backgrounds and foregrounds
|
||||
|
||||
#### 3.4 Compute the foreground size ratios
|
||||
|
||||
For the resizing step during recombination, the relative size of each object in each image is needed.
|
||||
To compute it, run the `foreground_size_ratio.py` script on your filtered dataset.
|
||||
It expects the zipfiled in the folder you provide as `-ds`.
|
||||
@@ -0,0 +1,98 @@
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision.transforms.functional import gaussian_blur, to_tensor
|
||||
|
||||
|
||||
def _preprocess_image(image, device, dtype=torch.float16):
|
||||
image = to_tensor(image)
|
||||
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
|
||||
if image.shape[1] != 3:
|
||||
image = image.expand(-1, 3, -1, -1)
|
||||
image = F.interpolate(image, (1024, 1024))
|
||||
return image.to(dtype).to(device)
|
||||
|
||||
|
||||
def _preprocess_mask(mask, device, dtype=torch.float16):
|
||||
mask = to_tensor(mask.convert("L"))
|
||||
mask = mask.unsqueeze_(0).float() # 0 or 1
|
||||
mask = F.interpolate(mask, (1024, 1024))
|
||||
mask = gaussian_blur(mask, kernel_size=(77, 77))
|
||||
mask[mask < 0.1] = 0
|
||||
mask[mask >= 0.1] = 1
|
||||
return mask.to(dtype).to(device)
|
||||
|
||||
|
||||
class AttentiveEraser:
|
||||
"""Attentive Eraser Pipeline + Pre- and Post-Processing."""
|
||||
|
||||
prompt = ""
|
||||
dtype = torch.float16
|
||||
|
||||
def __init__(self, num_steps=50, device=None):
|
||||
"""Create the attentive eraser.
|
||||
|
||||
Args:
|
||||
num_steps (int, optional): Number of steps in the diffusion process. Will start at 20%. Defaults to 100.
|
||||
device (_type_, optional): Device to run on. Defaults to 'cuda' if available, else 'cpu'.
|
||||
|
||||
"""
|
||||
self.num_steps = num_steps
|
||||
|
||||
self.device = device if device else torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
|
||||
)
|
||||
self.model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
self.pipeline = DiffusionPipeline.from_pretrained(
|
||||
self.model_path,
|
||||
custom_pipeline="./pipelines/pipeline_stable_diffusion_xl_attentive_eraser.py",
|
||||
scheduler=scheduler,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
torch_dtype=self.dtype,
|
||||
).to(self.device)
|
||||
self.pipeline.enable_attention_slicing()
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
|
||||
def __call__(self, image, mask):
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if isinstance(mask, np.ndarray):
|
||||
mask = Image.fromarray(mask)
|
||||
|
||||
prep_img = _preprocess_image(image, self.device, self.dtype)
|
||||
prep_mask = _preprocess_mask(mask, self.device, self.dtype)
|
||||
orig_shape = image.size
|
||||
|
||||
diff_img = (
|
||||
self.pipeline(
|
||||
prompt=self.prompt,
|
||||
image=prep_img,
|
||||
mask_image=prep_mask,
|
||||
height=1024,
|
||||
width=1024,
|
||||
AAS=True, # enable AAS
|
||||
strength=0.8, # inpainting strength
|
||||
rm_guidance_scale=9, # removal guidance scale
|
||||
ss_steps=9, # similarity suppression steps
|
||||
ss_scale=0.3, # similarity suppression scale
|
||||
AAS_start_step=0, # AAS start step
|
||||
AAS_start_layer=34, # AAS start layer
|
||||
AAS_end_layer=70, # AAS end layer
|
||||
num_inference_steps=self.num_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
|
||||
guidance_scale=1,
|
||||
)
|
||||
.images[0]
|
||||
.resize(orig_shape)
|
||||
)
|
||||
|
||||
paste_mask = Image.fromarray(np.array(mask) > 0 * 255)
|
||||
paste_mask = paste_mask.convert("RGB").filter(ImageFilter.GaussianBlur(radius=10)).convert("L")
|
||||
out_img = copy(image)
|
||||
out_img.paste(diff_img, (0, 0), paste_mask)
|
||||
return out_img
|
||||
@@ -0,0 +1,195 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Create Segment & Recombine dataset")
|
||||
parser.add_argument("-d", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
|
||||
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
|
||||
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
|
||||
parser.add_argument(
|
||||
"--out_root", type=str, required=True, help="Root directory where the output directory will be created."
|
||||
)
|
||||
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
|
||||
parser.add_argument(
|
||||
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
|
||||
)
|
||||
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
|
||||
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
|
||||
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
|
||||
parser.add_argument("-infill_model", choices=["LaMa", "AttErase"], default="LaMa", help="Infilling model to use")
|
||||
parser.add_argument("--continue", dest="continue_", action="store_true", help="Continue from previous run")
|
||||
|
||||
args = parser.parse_args()
|
||||
out_root = args.out_root
|
||||
|
||||
base_folder = os.path.dirname(__file__)
|
||||
training_code_folder = os.path.join(base_folder, os.pardir, "Model Training Code")
|
||||
|
||||
ds_name_re = re.compile(r"(Tiny)?INSegment_v(\d*)(_f\d*)?")
|
||||
name_match = ds_name_re.match(args.output)
|
||||
assert name_match, f"Output name {args.output} does not match the expected format: {ds_name_re.pattern}"
|
||||
assert (
|
||||
args.output_ims == "best" or name_match.group(3) is None
|
||||
), "For output_ims == 'all', the filter subversions will be automatically created."
|
||||
assert args.continue_ or not os.path.exists(out_root + args.output), f"Output directory {args.output} already exists."
|
||||
|
||||
settings_file = {
|
||||
"dataset": args.dataset,
|
||||
"threshold": args.threshold,
|
||||
"output_ims": args.output_ims,
|
||||
"mask_merge_threshold": args.mask_merge_threshold,
|
||||
"parent_labels": args.parent_labels,
|
||||
"parent_in_prompt": args.parent_in_prompt,
|
||||
"infill_model": args.infill_model,
|
||||
}
|
||||
settings_file = [f"{k} = {str(settings_file[k])}" for k in sorted(list(settings_file.keys()))]
|
||||
if os.path.exists(out_root + args.output) and os.path.exists(out_root + args.output + "/settings.txt"):
|
||||
with open(out_root + args.output + "/settings.txt", "r") as f:
|
||||
old_settings = f.read().split("\n")
|
||||
old_settings = [line.strip() for line in old_settings if line if len(line.strip()) > 0]
|
||||
assert old_settings == settings_file, (
|
||||
f"Settings file {out_root + args.output}/settings.txt does not match current settings: old: {old_settings} vs"
|
||||
f" new: {settings_file}"
|
||||
)
|
||||
else:
|
||||
os.makedirs(out_root + args.output, exist_ok=True)
|
||||
with open(out_root + args.output + "/settings.txt", "w") as f:
|
||||
f.write("\n".join(settings_file) + "\n")
|
||||
|
||||
general_args = [
|
||||
"sbatch",
|
||||
"sbatch-segment-tinyimnet-wait" if args.dataset == "tinyimagenet" else "sbatch-segment-imagenet-wait",
|
||||
"-r",
|
||||
args.dataset_root,
|
||||
"-o",
|
||||
out_root + args.output,
|
||||
"--parent_labels",
|
||||
str(args.parent_labels),
|
||||
"--output_ims",
|
||||
args.output_ims,
|
||||
"--mask_merge_threshold",
|
||||
str(args.mask_merge_threshold),
|
||||
"-t",
|
||||
str(args.threshold),
|
||||
"-model",
|
||||
args.infill_model,
|
||||
]
|
||||
|
||||
if args.parent_in_prompt:
|
||||
general_args.append("--parent_in_prompt")
|
||||
|
||||
print(f"Starting segmentation: {' '.join(general_args)} for {args.dataset}-val and {args.dataset}")
|
||||
p_train = subprocess.Popen(general_args + ["-d", args.dataset], cwd=base_folder)
|
||||
if args.dataset == "imagenet":
|
||||
general_args[1] = "sbatch-segment-imagenet-val-wait"
|
||||
p_val = subprocess.Popen(general_args + ["-d", args.dataset + "-val"], cwd=base_folder)
|
||||
|
||||
# detect if exit in error
|
||||
p_val.wait()
|
||||
p_train.wait()
|
||||
rcodes = (p_val.returncode, p_train.returncode)
|
||||
if any(rcode != 0 for rcode in rcodes):
|
||||
print(f"Error in segmentation (val, train): {rcodes}")
|
||||
exit(1)
|
||||
print("Segmentation done.")
|
||||
|
||||
if args.output_ims == "all":
|
||||
print("copy to subversions for filtering")
|
||||
p_1 = subprocess.Popen(
|
||||
[
|
||||
"cp",
|
||||
"-rl",
|
||||
os.path.join(out_root, args.output),
|
||||
out_root + f"{args.output}_f1/",
|
||||
]
|
||||
)
|
||||
p_2 = subprocess.Popen(
|
||||
[
|
||||
"cp",
|
||||
"-rl",
|
||||
os.path.join(out_root, args.output),
|
||||
out_root + f"{args.output}_f2/",
|
||||
]
|
||||
)
|
||||
p_1.wait()
|
||||
p_2.wait()
|
||||
print("Filtering subversions copied over.")
|
||||
|
||||
print("Starting filtering")
|
||||
filtering_args = ["./experiments/general_srun.sh", "experiments/filter_segmentation_versions.py"]
|
||||
p_val_f1 = subprocess.Popen(
|
||||
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/val", "-d", args.dataset], cwd=training_code_folder
|
||||
)
|
||||
p_train_f1 = subprocess.Popen(
|
||||
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/train", "-d", args.dataset], cwd=training_code_folder
|
||||
)
|
||||
p_val_f2 = subprocess.Popen(
|
||||
filtering_args
|
||||
+ ["-f", out_root + f"{args.output}_f2" + "/val", "-score_f_weights", "automatic", "-d", args.dataset],
|
||||
cwd=training_code_folder,
|
||||
)
|
||||
p_train_f2 = subprocess.Popen(
|
||||
filtering_args
|
||||
+ ["-f", out_root + f"{args.output}_f2" + "/train", "-score_f_weights", "automatic", "-d", args.dataset],
|
||||
cwd=training_code_folder,
|
||||
)
|
||||
|
||||
p_val_f1.wait()
|
||||
p_train_f1.wait()
|
||||
p_val_f2.wait()
|
||||
p_train_f2.wait()
|
||||
|
||||
ds_folders = [f"{args.output}_f1", f"{args.output}_f2"]
|
||||
else:
|
||||
if name_match.group(2) is None:
|
||||
new_name = args.output + "_f1"
|
||||
print(f"Renaming output folder: {args.output} -> {new_name}")
|
||||
os.rename(out_root + args.output, out_root + new_name)
|
||||
else:
|
||||
new_name = args.output
|
||||
ds_folders = [new_name]
|
||||
|
||||
for folder in ds_folders:
|
||||
print(f"Zipping up {folder}")
|
||||
p_val_fg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "foregrounds_val.zip", "val/foregrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
|
||||
)
|
||||
p_train_fg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "foregrounds_train.zip", "train/foregrounds", ">", "/dev/null", "2>&1"],
|
||||
cwd=out_root + folder,
|
||||
)
|
||||
p_val_bg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "backgrounds_val.zip", "val/backgrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
|
||||
)
|
||||
p_train_bg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "backgrounds_train.zip", "train/backgrounds", ">", "/dev/null", "2>&1"],
|
||||
cwd=out_root + folder,
|
||||
)
|
||||
|
||||
p_val_fg.wait()
|
||||
p_train_fg.wait()
|
||||
p_val_bg.wait()
|
||||
p_train_bg.wait()
|
||||
|
||||
print(f"Gathering foreground size ratios for {folder}")
|
||||
p_val = subprocess.Popen(
|
||||
[
|
||||
"./srun-general.sh",
|
||||
"python",
|
||||
"foreground_size_ratio.py",
|
||||
"--root",
|
||||
args.dataset_root,
|
||||
"-ds",
|
||||
folder,
|
||||
"-mode",
|
||||
"val",
|
||||
],
|
||||
cwd=base_folder,
|
||||
)
|
||||
p_train = subprocess.Popen(
|
||||
["./srun-general.sh", "python", "foreground_size_ratio.py", "--root", args.dataset_root, "-ds", folder],
|
||||
cwd=base_folder,
|
||||
)
|
||||
p_val.wait()
|
||||
p_train.wait()
|
||||
|
After Width: | Height: | Size: 402 KiB |
|
After Width: | Height: | Size: 46 KiB |
|
After Width: | Height: | Size: 36 KiB |
|
After Width: | Height: | Size: 49 KiB |
|
After Width: | Height: | Size: 148 KiB |
|
After Width: | Height: | Size: 136 KiB |
|
After Width: | Height: | Size: 46 KiB |
@@ -0,0 +1,69 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser(description="Foreground size ratio")
|
||||
parser.add_argument("-mode", choices=["train", "val"], default="train", help="Train or val data?")
|
||||
parser.add_argument("-ds", "--dataset", type=str, required=True, help="Dataset to use")
|
||||
parser.add_argument("--root", type=str, required=True, help="Root folder of the dataset")
|
||||
args = parser.parse_args()
|
||||
|
||||
# if not os.path.exists("foreground_size_ratios.json"):
|
||||
train = args.mode == "train"
|
||||
root = os.path.join(args.root, args.dataset) if not args.dataset.startswith("/") else args.dataset
|
||||
|
||||
cls_bg_ratios = {}
|
||||
fg_bg_ratio_map = {}
|
||||
|
||||
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip, zipfile.ZipFile(
|
||||
f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r"
|
||||
) as fg_zip:
|
||||
backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
|
||||
foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
|
||||
|
||||
print(f"Bgs: {backgrounds[:5]}, ...\nFgs: {foregrounds[:5]}, ...")
|
||||
|
||||
for bg_name in tqdm(backgrounds):
|
||||
|
||||
fg_name = bg_name.replace("backgrounds", "foregrounds").replace("JPEG", "WEBP")
|
||||
if fg_name not in foregrounds:
|
||||
tqdm.write(f"Skipping {bg_name} as it has no corresponding foreground")
|
||||
fg_bg_ratio_map[bg_name] = 0.0
|
||||
continue
|
||||
|
||||
with bg_zip.open(bg_name) as f:
|
||||
bg_data = BytesIO(f.read())
|
||||
try:
|
||||
bg_img = Image.open(bg_data)
|
||||
except PIL.UnidentifiedImageError as e:
|
||||
print(f"Error with file={bg_name}")
|
||||
raise e
|
||||
bg_img_size = bg_img.size
|
||||
|
||||
with fg_zip.open(fg_name) as f:
|
||||
fg_data = BytesIO(f.read())
|
||||
try:
|
||||
fg_img = Image.open(fg_data)
|
||||
except PIL.UnidentifiedImageError as e:
|
||||
print(f"Error with file={fg_name}")
|
||||
raise e
|
||||
fg_img_size = fg_img.size
|
||||
fg_img_pixel_size = int(np.sum(fg_img.split()[-1]) / 255)
|
||||
|
||||
img_cls = bg_name.split("/")[-2]
|
||||
if img_cls not in cls_bg_ratios:
|
||||
cls_bg_ratios[img_cls] = []
|
||||
|
||||
cls_bg_ratios[img_cls].append(fg_img_pixel_size / (bg_img_size[0] * bg_img_size[1]))
|
||||
fg_bg_ratio_map[bg_name] = fg_img_pixel_size / (bg_img_size[0] * bg_img_size[1])
|
||||
|
||||
with open(f"{root}/fg_bg_ratios_{args.mode}.json", "w") as f:
|
||||
json.dump(fg_bg_ratio_map, f)
|
||||
print(f"Saved fg_bg_ratios_{args.mode} to disk. Exiting.")
|
||||
@@ -0,0 +1,42 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from datadings.reader import MsgpackReader
|
||||
from datadings.torch import Dataset
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser(description="Foreground size ratio")
|
||||
parser.add_argument("-mode", choices=["train", "val"], default="train", help="Train or val data?")
|
||||
parser.add_argument("-ds", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
|
||||
parser.add_argument("-r", "--dataset_root", required=True, type=str, help="Root directory of the dataset")
|
||||
parser.add_argument("-s", "--segment_root", required=True, type=str, help="Root directory of the segmentation dataset")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.dataset == "imagenet":
|
||||
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/{args.mode}.msgpack")
|
||||
dataset = Dataset(reader)
|
||||
elif args.dataset == "tinyimagenet":
|
||||
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_{args.mode}.msgpack")
|
||||
dataset = Dataset(reader)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset}")
|
||||
|
||||
if args.dataset.startswith("imagenet"):
|
||||
with open("wordnet_data/imagenet1k_synsets.json", "r") as f:
|
||||
id_to_synset = json.load(f)
|
||||
id_to_synset = {int(k): v for k, v in id_to_synset.items()}
|
||||
elif args.dataset.startswith("tinyimagenet"):
|
||||
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
|
||||
synsets = f.readlines()
|
||||
id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets]
|
||||
id_to_synset = sorted(id_to_synset)
|
||||
|
||||
key_to_idx = {
|
||||
f"n{id_to_synset[data['label']]:08d}/{data['key'].split('.')[0]}": i
|
||||
for i, data in enumerate(tqdm(dataset, leave=True))
|
||||
}
|
||||
|
||||
print(len(key_to_idx))
|
||||
with open(f"{args.segment_root}/{args.mode}_indices.json", "w") as f:
|
||||
json.dump(key_to_idx, f)
|
||||
@@ -0,0 +1,223 @@
|
||||
"""Use grounding DINO + Segment Anything (SAM) to perform grounded segmentation on an image.
|
||||
|
||||
Based on: https://github.com/IDEA-Research/Grounded-Segment-Anything
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToTensor
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
"""Bounding box representation."""
|
||||
|
||||
xmin: int
|
||||
ymin: int
|
||||
xmax: int
|
||||
ymax: int
|
||||
|
||||
@property
|
||||
def xyxy(self) -> List[float]:
|
||||
"""Return bounding box coordinates.
|
||||
|
||||
Returns:
|
||||
List[float]: coodinates: [xmin, ymin, xmax, ymax]
|
||||
|
||||
"""
|
||||
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Detection result from Grounding DINO + Mask from SAM."""
|
||||
|
||||
score: float
|
||||
label: str
|
||||
box: BoundingBox
|
||||
mask: Optional[np.array] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
|
||||
"""Create a DetectionResult from a dictionary.
|
||||
|
||||
Args:
|
||||
detection_dict (Dict): Detection result dictionary.
|
||||
|
||||
Returns:
|
||||
DetectionResult: Detection result object.
|
||||
|
||||
"""
|
||||
return cls(
|
||||
score=detection_dict["score"],
|
||||
label=detection_dict["label"],
|
||||
box=BoundingBox(
|
||||
xmin=detection_dict["box"]["xmin"],
|
||||
ymin=detection_dict["box"]["ymin"],
|
||||
xmax=detection_dict["box"]["xmax"],
|
||||
ymax=detection_dict["box"]["ymax"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
|
||||
"""Use OpenCV to refine a mask by turning it into a polygon.
|
||||
|
||||
Args:
|
||||
mask (np.ndarray): Segmentation mask.
|
||||
|
||||
Returns:
|
||||
List[List[int]]: List of (x, y) coordinates representing the vertices of the polygon.
|
||||
|
||||
"""
|
||||
# Find contours in the binary mask
|
||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Find the contour with the largest area
|
||||
largest_contour = max(contours, key=cv2.contourArea)
|
||||
|
||||
# Extract the vertices of the contour
|
||||
return largest_contour.reshape(-1, 2).tolist()
|
||||
|
||||
|
||||
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
|
||||
"""Convert a polygon to a segmentation mask.
|
||||
|
||||
Args:
|
||||
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
|
||||
image_shape (tuple): Shape of the image (height, width) for the mask.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Segmentation mask with the polygon filled.
|
||||
|
||||
"""
|
||||
# Create an empty mask
|
||||
mask = np.zeros(image_shape, dtype=np.uint8)
|
||||
|
||||
# Convert polygon to an array of points
|
||||
pts = np.array(polygon, dtype=np.int32)
|
||||
|
||||
# Fill the polygon with white color (255)
|
||||
cv2.fillPoly(mask, [pts], color=(255,))
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def load_image(image_str: str) -> Image.Image:
|
||||
"""Load an image from a URL or file path.
|
||||
|
||||
Args:
|
||||
image_str (str): URL or file path to the image.
|
||||
|
||||
Returns:
|
||||
PIL.Image: Image object.
|
||||
|
||||
"""
|
||||
if image_str.startswith("http"):
|
||||
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
|
||||
else:
|
||||
image = Image.open(image_str).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _get_boxes(results: DetectionResult) -> List[List[List[float]]]:
|
||||
boxes = []
|
||||
for result in results:
|
||||
xyxy = result.box.xyxy
|
||||
boxes.append(xyxy)
|
||||
|
||||
return [boxes]
|
||||
|
||||
|
||||
def _refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
|
||||
masks = masks.cpu().float()
|
||||
masks = masks.permute(0, 2, 3, 1)
|
||||
masks = masks.mean(axis=-1)
|
||||
masks = (masks > 0).int()
|
||||
masks = masks.numpy().astype(np.uint8)
|
||||
masks = list(masks)
|
||||
|
||||
if polygon_refinement:
|
||||
for idx, mask in enumerate(masks):
|
||||
shape = mask.shape
|
||||
polygon = mask_to_polygon(mask)
|
||||
mask = polygon_to_mask(polygon, shape)
|
||||
masks[idx] = mask
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
detector_id = "IDEA-Research/grounding-dino-tiny"
|
||||
print(f"load object detector pipeline: {detector_id}")
|
||||
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
|
||||
|
||||
segmenter_id = "facebook/sam-vit-base"
|
||||
print(f"load segmentator: {segmenter_id}")
|
||||
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
|
||||
print(f"load processor: {segmenter_id}")
|
||||
processor = AutoProcessor.from_pretrained(segmenter_id)
|
||||
|
||||
|
||||
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3) -> List[Dict[str, Any]]:
|
||||
"""Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion."""
|
||||
global object_detector, device
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
results = object_detector(image, candidate_labels=labels, threshold=threshold)
|
||||
return [DetectionResult.from_dict(result) for result in results]
|
||||
|
||||
|
||||
def segment(
|
||||
image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False
|
||||
) -> List[DetectionResult]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
global segmentator, processor, device
|
||||
boxes = _get_boxes(detection_results)
|
||||
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
|
||||
|
||||
outputs = segmentator(**inputs)
|
||||
masks = processor.post_process_masks(
|
||||
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes
|
||||
)[0]
|
||||
|
||||
masks = _refine_masks(masks, polygon_refinement)
|
||||
|
||||
for detection_result, mask in zip(detection_results, masks):
|
||||
detection_result.mask = mask
|
||||
|
||||
return detection_results
|
||||
|
||||
|
||||
def grounded_segmentation(
|
||||
image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False
|
||||
) -> Tuple[torch.Tensor, List[DetectionResult]]:
|
||||
"""Segment out the objects in an image given a set of labels.
|
||||
|
||||
Args:
|
||||
image (Union[Image.Image, str]): Image to load/work on.
|
||||
labels (List[str]): Object labels to segment.
|
||||
threshold (float, optional): Segmentation threshold. Defaults to 0.3.
|
||||
polygon_refinement (bool, optional): Use polygon refinement on the segmented mask? Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, List[DetectionResult]]: Image tensor and list of detection results.
|
||||
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
image = load_image(image)
|
||||
|
||||
detections = detect(image, labels, threshold)
|
||||
if len(detections) == 0:
|
||||
return ToTensor()(image), []
|
||||
detections = segment(image, detections, polygon_refinement)
|
||||
|
||||
return ToTensor()(image), detections
|
||||
271
AAAI Supplementary Material/ForNet Creation Code/infill_lama.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Use the LaMa (*LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions*) model to infill the background of an image.
|
||||
|
||||
Based on: https://github.com/Sanster/IOPaint/blob/main/iopaint/model/lama.py
|
||||
"""
|
||||
|
||||
import abc
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.hub import download_url_to_file
|
||||
|
||||
LAMA_MODEL_URL = os.environ.get(
|
||||
"LAMA_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
)
|
||||
LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
|
||||
|
||||
|
||||
class LDMSampler(str, Enum):
|
||||
ddim = "ddim"
|
||||
plms = "plms"
|
||||
|
||||
|
||||
class SDSampler(str, Enum):
|
||||
ddim = "ddim"
|
||||
pndm = "pndm"
|
||||
k_lms = "k_lms"
|
||||
k_euler = "k_euler"
|
||||
k_euler_a = "k_euler_a"
|
||||
dpm_plus_plus = "dpm++"
|
||||
uni_pc = "uni_pc"
|
||||
|
||||
|
||||
def ceil_modulo(x, mod):
|
||||
if x % mod == 0:
|
||||
return x
|
||||
return (x // mod + 1) * mod
|
||||
|
||||
|
||||
def pad_img_to_modulo(img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
img: [H, W, C]
|
||||
mod:
|
||||
square: 是否为正方形
|
||||
min_size:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if len(img.shape) == 2:
|
||||
img = img[:, :, np.newaxis]
|
||||
height, width = img.shape[:2]
|
||||
out_height = ceil_modulo(height, mod)
|
||||
out_width = ceil_modulo(width, mod)
|
||||
|
||||
if min_size is not None:
|
||||
assert min_size % mod == 0
|
||||
out_width = max(min_size, out_width)
|
||||
out_height = max(min_size, out_height)
|
||||
|
||||
if square:
|
||||
max_size = max(out_height, out_width)
|
||||
out_height = max_size
|
||||
out_width = max_size
|
||||
|
||||
return np.pad(
|
||||
img,
|
||||
((0, out_height - height), (0, out_width - width), (0, 0)),
|
||||
mode="symmetric",
|
||||
)
|
||||
|
||||
|
||||
def undo_pad_to_mod(img: np.ndarray, height: int, width: int):
|
||||
return img[:height, :width, :]
|
||||
|
||||
|
||||
class InpaintModel:
|
||||
name = "base"
|
||||
min_size: Optional[int] = None
|
||||
pad_mod = 8
|
||||
pad_to_square = False
|
||||
|
||||
def __init__(self, device, **kwargs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
self.device = device
|
||||
self.init_model(device, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_model(self, device, **kwargs): ...
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def is_downloaded() -> bool: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, image, mask):
|
||||
"""Input images and output images have same size
|
||||
images: [H, W, C] RGB
|
||||
masks: [H, W, 1] 255 为 masks 区域
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
...
|
||||
|
||||
def _pad_forward(self, image, mask):
|
||||
origin_height, origin_width = image.shape[:2]
|
||||
pad_image = pad_img_to_modulo(image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
|
||||
pad_mask = pad_img_to_modulo(mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
|
||||
|
||||
result = self.forward(pad_image, pad_mask)
|
||||
# result = result[0:origin_height, 0:origin_width, :]
|
||||
|
||||
# result, image, mask = self.forward_post_process(result, image, mask)
|
||||
|
||||
mask = mask[:, :, np.newaxis]
|
||||
result = result[0].permute(1, 2, 0).cpu().numpy() # 400 x 504 x 3
|
||||
result = undo_pad_to_mod(result, origin_height, origin_width)
|
||||
assert result.shape[:2] == image.shape[:2], f"{result.shape[:2]} != {image.shape[:2]}"
|
||||
# result = result * 255
|
||||
mask = (mask > 0) * 255
|
||||
result = result * mask + image * (1 - (mask / 255))
|
||||
result = np.clip(result, 0, 255).astype("uint8")
|
||||
return result
|
||||
|
||||
def forward_post_process(self, result, image, mask):
|
||||
return result, image, mask
|
||||
|
||||
|
||||
def resize_np_img(np_img, size, interpolation="bicubic"):
|
||||
assert interpolation in [
|
||||
"nearest",
|
||||
"bilinear",
|
||||
"bicubic",
|
||||
], f"Unsupported interpolation: {interpolation}, use nearest, bilinear or bicubic."
|
||||
torch_img = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0).float() / 255
|
||||
interp_img = torch.nn.functional.interpolate(torch_img, size=size, mode=interpolation, align_corners=True)
|
||||
return (interp_img.squeeze(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
||||
|
||||
|
||||
def resize_max_size(np_img, size_limit: int, interpolation="bicubic") -> np.ndarray:
|
||||
# Resize image's longer size to size_limit if longer size larger than size_limit
|
||||
h, w = np_img.shape[:2]
|
||||
if max(h, w) > size_limit:
|
||||
ratio = size_limit / max(h, w)
|
||||
new_w = int(w * ratio + 0.5)
|
||||
new_h = int(h * ratio + 0.5)
|
||||
return resize_np_img(np_img, size=(new_w, new_h), interpolation=interpolation)
|
||||
else:
|
||||
return np_img
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
if len(np_img.shape) == 2:
|
||||
np_img = np_img[:, :, np.newaxis]
|
||||
np_img = np.transpose(np_img, (2, 0, 1))
|
||||
np_img = np_img.astype("float32") / 255
|
||||
return np_img
|
||||
|
||||
|
||||
def get_cache_path_by_url(url):
|
||||
parts = urlparse(url)
|
||||
hub_dir = "~/.TORCH_HUB_CACHE"
|
||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||
if not os.path.isdir(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
return cached_file
|
||||
|
||||
|
||||
def md5sum(filename):
|
||||
md5 = hashlib.md5()
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def download_model(url, model_md5: str = None):
|
||||
cached_file = get_cache_path_by_url(url)
|
||||
if not os.path.exists(cached_file):
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
||||
if model_md5:
|
||||
_md5 = md5sum(cached_file)
|
||||
if model_md5 == _md5:
|
||||
print(f"Download model success, md5: {_md5}")
|
||||
else:
|
||||
try:
|
||||
os.remove(cached_file)
|
||||
print(
|
||||
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart"
|
||||
" lama-cleaner.If you still have errors, please try download model manually first"
|
||||
" https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
||||
)
|
||||
except:
|
||||
print(
|
||||
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart"
|
||||
" lama-cleaner."
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
return cached_file
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device, model_md5: str):
|
||||
if os.path.exists(url_or_path):
|
||||
model_path = url_or_path
|
||||
else:
|
||||
model_path = download_model(url_or_path, model_md5)
|
||||
|
||||
print(f"Loading model from: {model_path}")
|
||||
try:
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
raise
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LaMa(InpaintModel):
|
||||
name = "lama"
|
||||
pad_mod = 8
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask):
|
||||
"""
|
||||
images: [H, W, C] RGB, not normalized
|
||||
masks: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
# boxes = boxes_from_mask(mask)
|
||||
inpaint_result = self._pad_forward(image, mask)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
||||
|
||||
def forward(self, image, mask):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
image = norm_img(image)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
|
||||
inpainted_image = self.model(image, mask)
|
||||
return inpainted_image
|
||||
@@ -0,0 +1,65 @@
|
||||
import argparse
|
||||
import os
|
||||
from random import shuffle
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
|
||||
parser = argparse.ArgumentParser(description="Inspect image versions")
|
||||
parser.add_argument("-f", "--base_folder", type=str, required=True, help="Base folder to inspect")
|
||||
parser.add_argument("-opt_fg_ratio", type=float, default=0.3, help="Optimal foreground ratio")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
bg_folder = os.path.join(args.base_folder, "backgrounds")
|
||||
fg_folder = os.path.join(args.base_folder, "foregrounds")
|
||||
|
||||
classes = os.listdir(fg_folder)
|
||||
classes = sorted(classes, key=lambda x: int(x[1:]))
|
||||
assert len(classes) in [200, 1_000], f"Expected 200 or 1_000 classes, got {len(classes)}"
|
||||
|
||||
total_images = set()
|
||||
for in_cls in classes:
|
||||
cls_images = {
|
||||
os.path.join(in_cls, "_".join(img.split(".")[0].split("_")[:3]))
|
||||
for img in os.listdir(os.path.join(fg_folder, in_cls))
|
||||
}
|
||||
total_images.update(cls_images)
|
||||
total_images = list(total_images)
|
||||
shuffle(total_images)
|
||||
print(total_images[:5], "...")
|
||||
|
||||
# in_cls to print name/lemma
|
||||
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
|
||||
in_cls_to_name = {line.split(":")[0].strip(): line.split(":")[1].strip() for line in f.readlines() if len(line) > 2}
|
||||
|
||||
for image_name in total_images:
|
||||
in_cls, img_name = image_name.split("/")
|
||||
versions = set()
|
||||
for img in os.listdir(os.path.join(fg_folder, in_cls)):
|
||||
if img.startswith(img_name):
|
||||
versions.add(img)
|
||||
if len(versions) <= 1:
|
||||
print(f"Image {image_name} has only one version")
|
||||
continue
|
||||
versions = sorted(list(versions))
|
||||
|
||||
bg_img = Image.open(os.path.join(bg_folder, in_cls, f"{versions[0].split('.')[0]}.JPEG"))
|
||||
|
||||
# plot all versions at once
|
||||
fig, axs = plt.subplots(1, len(versions) + 1, figsize=(15, 5))
|
||||
axs[0].imshow(bg_img)
|
||||
axs[0].axis("off")
|
||||
axs[0].set_title("Background")
|
||||
for version, ax in zip(versions, axs[1:]):
|
||||
img = Image.open(os.path.join(fg_folder, in_cls, version))
|
||||
img_mask = np.array(img.convert("RGBA").split()[-1])
|
||||
fg_dens_fact = np.sum(img_mask) / (255 * img_mask.size)
|
||||
fg_ratio = np.sum(img_mask) / (255 * bg_img.size[0] * bg_img.size[1])
|
||||
ax.imshow(img)
|
||||
ax.axis("off")
|
||||
ax.set_title(f"{version}\nfg_rat: {fg_ratio:.2f} dist to optimal: {abs(fg_ratio - args.opt_fg_ratio):.2f}")
|
||||
fig.suptitle(in_cls_to_name[in_cls])
|
||||
plt.show()
|
||||
plt.close()
|
||||
@@ -0,0 +1,11 @@
|
||||
tqdm
|
||||
einops
|
||||
omegaconf
|
||||
diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
accelerate
|
||||
torchvision
|
||||
datadings
|
||||
numpy
|
||||
nltk
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --array=0-29%30
|
||||
#SBATCH --time=1-0
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --gpus=1
|
||||
#SBATCH --partition=H200,H100,H100-PCI,A100-40GB,A100-80GB
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
|
||||
#SBATCH --job-name="Segment ImageNet (val)"
|
||||
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
|
||||
#SBATCH --wait
|
||||
|
||||
srun -K \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py -p 30 -id $SLURM_ARRAY_TASK_ID -o /OUTPUT/FOLDER/PATH/ "$@"
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --array=0-319%60
|
||||
#SBATCH --time=1-0
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --gpus=1
|
||||
#SBATCH --partition=H200,H100,H100-PCI,A100-PCI,A100-40GB,A100-80GB
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
|
||||
#SBATCH --job-name="Segment ImageNet (train)"
|
||||
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
|
||||
#SBATCH --wait
|
||||
|
||||
srun -K \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py -p 320 -id $SLURM_ARRAY_TASK_ID -o /OUTPUT_PATH/INSegment/ "$@"
|
||||
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --array=0-4%10
|
||||
#SBATCH --time=1-0
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --gpus=1
|
||||
#SBATCH --partition=RTX3090,RTXA6000,A100-PCI,A100-40GB,H100-PCI
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
|
||||
#SBATCH --job-name="Segment ImageNet"
|
||||
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
|
||||
|
||||
srun -K \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py -p 5 -id $SLURM_ARRAY_TASK_ID -o /PATH/TO/OUT/FOLDER/INSegment/ "$@"
|
||||
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --array=0-19%20
|
||||
#SBATCH --time=1-0
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --gpus=1
|
||||
#SBATCH --partition=RTX3090,RTXA6000,A100-PCI,A100-40GB,H100-PCI
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
|
||||
#SBATCH --job-name="Segment TinyImageNet"
|
||||
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
|
||||
|
||||
srun -K \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py -p 20 -id $SLURM_ARRAY_TASK_ID "$@"
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --array=0-29%30
|
||||
#SBATCH --time=1-0
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --gpus=1
|
||||
#SBATCH --partition=H200,H100,H100-PCI,A100-PCI,A100-40GB,A100-80GB
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
|
||||
#SBATCH --job-name="Segment TinyImageNet"
|
||||
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
|
||||
#SBATCH --wait
|
||||
|
||||
srun -K \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py -p 30 -id $SLURM_ARRAY_TASK_ID "$@"
|
||||
@@ -0,0 +1,338 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from math import ceil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datadings.reader import MsgpackReader
|
||||
from datadings.torch import Compose, CompressedToPIL, Dataset
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision.transforms import ToPILImage, ToTensor
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from attentive_eraser import AttentiveEraser
|
||||
from grounded_segmentation import grounded_segmentation
|
||||
from infill_lama import LaMa
|
||||
from utils import already_segmented, save_img
|
||||
from wordnet_tree import WNTree
|
||||
|
||||
|
||||
def _collate_single(data):
|
||||
return data[0]
|
||||
|
||||
|
||||
def _collate_multiple(datas):
|
||||
return {
|
||||
"images": [data["image"] for data in datas],
|
||||
"keys": [data["key"] for data in datas],
|
||||
"labels": [data["label"] for data in datas],
|
||||
}
|
||||
|
||||
|
||||
def smallest_crop(image: torch.Tensor, mask: torch.Tensor):
|
||||
"""Crops the image to just so fit the mask given.
|
||||
|
||||
Mask and image have to be of the same size.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): image to crop
|
||||
mask (torch.Tensor): cropping to mask
|
||||
|
||||
Returns:
|
||||
torch.Tensor: cropped image
|
||||
|
||||
"""
|
||||
if len(mask.shape) == 3:
|
||||
assert mask.shape[0] == 1
|
||||
mask = mask[0]
|
||||
assert (
|
||||
len(image.shape) == 3 and len(mask.shape) == 2 and image.shape[1:] == mask.shape
|
||||
), f"Invalid shapes: {image.shape}, {mask.shape}"
|
||||
dim0_indices = mask.sum(dim=0).nonzero()
|
||||
dim1_indices = mask.sum(dim=1).nonzero()
|
||||
|
||||
return image[
|
||||
:,
|
||||
dim1_indices.min().item() : dim1_indices.max().item() + 1,
|
||||
dim0_indices.min().item() : dim0_indices.max().item() + 1,
|
||||
]
|
||||
|
||||
|
||||
def _synset_to_prompt(synset, tree, parent_in_prompt):
|
||||
prompt = f"an {synset.print_name}." if synset.print_name[0] in "aeiou" else f"a {synset.print_name}."
|
||||
if synset.parent_id is None or not parent_in_prompt:
|
||||
return prompt
|
||||
parent = tree[synset.parent_id]
|
||||
return prompt[:-1] + f", a type of {parent.print_name}."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Grounded Segmentation")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
choices=["imagenet", "imagenet-val", "imagenet21k", "tinyimagenet", "tinyimagenet-val"],
|
||||
default="imagenet",
|
||||
help="Dataset to use",
|
||||
)
|
||||
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
|
||||
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
|
||||
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
||||
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
|
||||
parser.add_argument(
|
||||
"-p", "--processes", type=int, default=1, help="Number of processes that are used to process the data"
|
||||
)
|
||||
parser.add_argument("-id", type=int, default=0, help="ID of this process)")
|
||||
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files, instead of skipping them")
|
||||
parser.add_argument(
|
||||
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
|
||||
)
|
||||
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
|
||||
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
|
||||
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
|
||||
parser.add_argument(
|
||||
"-model",
|
||||
choices=["LaMa", "AttErase"],
|
||||
default="LaMa",
|
||||
help="Model to use for erasing/infilling. Defaults to LaMa.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = args.dataset.lower()
|
||||
part = "train"
|
||||
if dataset == "imagenet21k":
|
||||
reader = MsgpackReader(f"{args.dataset_root}imagenet21k/train.msgpack")
|
||||
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
||||
elif dataset == "imagenet-val":
|
||||
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/val.msgpack")
|
||||
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
||||
part = "val"
|
||||
elif dataset == "imagenet":
|
||||
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/train.msgpack")
|
||||
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
||||
elif dataset == "tinyimagenet":
|
||||
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_train.msgpack")
|
||||
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
||||
elif dataset == "tinyimagenet-val":
|
||||
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_val.msgpack")
|
||||
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
||||
part = "val"
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {dataset}")
|
||||
|
||||
assert 0 <= args.id < args.processes, "ID must be in the range [0, processes)"
|
||||
if args.processes > 1:
|
||||
partlen = ceil(len(dataset) / args.processes)
|
||||
start_id = args.id * partlen
|
||||
end_id = min((args.id + 1) * partlen, len(dataset))
|
||||
dataset = torch.utils.data.Subset(dataset, list(range(start_id, end_id)))
|
||||
|
||||
assert args.batch_size == 1, "Batch size must be 1 for grounded segmentation (for now)"
|
||||
dataset = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=args.debug,
|
||||
drop_last=False,
|
||||
num_workers=10,
|
||||
collate_fn=_collate_single if args.batch_size == 1 else _collate_multiple,
|
||||
)
|
||||
|
||||
infill_model = (
|
||||
LaMa(device="cuda" if torch.cuda.is_available() else "cpu") if args.model == "LaMa" else AttentiveEraser()
|
||||
)
|
||||
|
||||
if args.dataset.startswith("imagenet"):
|
||||
with open("wordnet_data/imagenet1k_synsets.json", "r") as f:
|
||||
id_to_synset = json.load(f)
|
||||
id_to_synset = {int(k): v for k, v in id_to_synset.items()}
|
||||
elif args.dataset.startswith("tinyimagenet"):
|
||||
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
|
||||
synsets = f.readlines()
|
||||
id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets]
|
||||
id_to_synset = sorted(id_to_synset)
|
||||
|
||||
wordnet = WNTree.load("wordnet_data/imagenet21k+1k_masses_tree.json")
|
||||
|
||||
# create the necessary folders
|
||||
os.makedirs(os.path.join(args.output, part, "foregrounds"), exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output, part, "backgrounds"), exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output, part, "no_detect"), exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output, part, "error"), exist_ok=True)
|
||||
|
||||
man_print_progress = int(os.environ.get("TQDM_DISABLE", "0")) == 1
|
||||
|
||||
for i, data in tqdm(enumerate(dataset), total=len(dataset)):
|
||||
if args.debug and i > 10:
|
||||
break
|
||||
|
||||
if man_print_progress and i % 200 == 0:
|
||||
print(f"{datetime.now()} \t process {args.id}/{args.processes}: \t sample: {i}/{len(dataset)}", flush=True)
|
||||
|
||||
image = data["image"]
|
||||
key = data["key"]
|
||||
|
||||
if args.dataset.endswith("-val"):
|
||||
img_class = f"n{id_to_synset[data['label']]:08d}" # overwrite the synset id on the val set
|
||||
else:
|
||||
img_class = key.split("/")[-1].split("_")[0]
|
||||
|
||||
if (
|
||||
already_segmented(key, os.path.join(args.output, part, "foregrounds"), img_class=img_class)
|
||||
and not args.overwrite
|
||||
):
|
||||
continue
|
||||
|
||||
# get the next 3 labels up in the wordnet tree
|
||||
label_id = data["label"]
|
||||
offset = id_to_synset[label_id] # int(data["key"].split("_")[0][1:])
|
||||
synset = wordnet[offset]
|
||||
labels = [_synset_to_prompt(synset, wordnet, args.parent_in_prompt)]
|
||||
while len(labels) < 1 + args.parent_labels and synset.parent_id is not None:
|
||||
synset = wordnet[synset.parent_id]
|
||||
_label = _synset_to_prompt(synset, wordnet, args.parent_in_prompt)
|
||||
_label = _label.replace("_", " ")
|
||||
labels.append(_label)
|
||||
|
||||
assert len(labels) > 0, f"No labels for {key}"
|
||||
|
||||
# 1. detect and segment
|
||||
image_tensor, detections = grounded_segmentation(
|
||||
image, labels, threshold=args.threshold, polygon_refinement=True
|
||||
)
|
||||
|
||||
# merge all detections for the same label
|
||||
masks = {}
|
||||
for detection in detections:
|
||||
if detection.label not in masks:
|
||||
masks[detection.label] = detection.mask
|
||||
else:
|
||||
masks[detection.label] |= detection.mask
|
||||
labels = list(masks.keys())
|
||||
masks = [masks[lbl] for lbl in labels]
|
||||
assert len(masks) == len(
|
||||
labels
|
||||
), f"Number of masks ({len(masks)}) != number of labels ({len(labels)}); detections: {detections}"
|
||||
|
||||
if args.debug:
|
||||
for detected_mask, label in zip(masks, labels):
|
||||
detected_mask = torch.from_numpy(detected_mask).unsqueeze(0) / 255
|
||||
mask_foreground = torch.cat((image_tensor * detected_mask, detected_mask), dim=0)
|
||||
ToPILImage()(mask_foreground).save(f"example_images/{key}_{label}_masked.png")
|
||||
|
||||
if len(masks) == 0:
|
||||
tqdm.write(f"No detections for {key}; skipping")
|
||||
save_img(image, key, os.path.join(args.output, part, "no_detect"), img_class=img_class, format="JPEG")
|
||||
continue
|
||||
elif len(masks) == 1:
|
||||
# max_overlap_mask = masks[0]
|
||||
masks = [masks[0]]
|
||||
elif args.output_ims == "best":
|
||||
# find the 2 masks with the largest overlap
|
||||
max_overlap = 0
|
||||
max_overlap_mask = None
|
||||
using_labels = None
|
||||
for i, mask1 in enumerate(masks):
|
||||
for j, mask2 in enumerate(masks):
|
||||
if i >= j:
|
||||
continue
|
||||
iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
|
||||
if iou > max_overlap:
|
||||
max_overlap = iou
|
||||
max_overlap_mask = mask1 | mask2
|
||||
using_labels = f"{labels[i]} & {labels[j]}"
|
||||
if args.debug:
|
||||
tqdm.write(f"{key}:\tMax overlap: {max_overlap}, using {using_labels}")
|
||||
if max_overlap_mask is None:
|
||||
# assert len(masks) > 0 and len(masks) == len(labels), f"No detections for {key}"
|
||||
max_overlap_mask = masks[0]
|
||||
masks = [max_overlap_mask]
|
||||
else:
|
||||
# merge masks that are too similar
|
||||
has_changed = True
|
||||
while has_changed:
|
||||
has_changed = False
|
||||
for i, mask1 in enumerate(masks):
|
||||
for j, mask2 in enumerate(masks):
|
||||
if i >= j:
|
||||
continue
|
||||
iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
|
||||
if iou > args.mask_merge_threshold:
|
||||
masks[i] |= masks[j]
|
||||
masks.pop(j)
|
||||
labels.pop(j)
|
||||
has_changed = True
|
||||
break
|
||||
if has_changed:
|
||||
break
|
||||
|
||||
for mask_idx, mask_array in enumerate(masks):
|
||||
mask = torch.from_numpy(mask_array).unsqueeze(0) / 255
|
||||
|
||||
if args.debug:
|
||||
mask_image = ToPILImage()(mask)
|
||||
mask_image.save(f"example_images/{key}_mask.png")
|
||||
|
||||
# foreground = image_tensor
|
||||
foreground = torch.cat((image_tensor * mask, mask), dim=0)
|
||||
foreground = ToPILImage()(foreground)
|
||||
mask_img = foreground.split()[-1]
|
||||
foreground = smallest_crop(ToTensor()(foreground), ToTensor()(mask_img)) # TODO: fix smallest crop
|
||||
fg_img = ToPILImage()(foreground)
|
||||
|
||||
background = (1 - mask) * image_tensor
|
||||
bg_image = ToPILImage()(background)
|
||||
|
||||
# 2. infill background
|
||||
mask_image = Image.fromarray(mask_array)
|
||||
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=7))
|
||||
try:
|
||||
infilled_bg = infill_model(np.array(bg_image), np.array(mask_image))
|
||||
infilled_bg = Image.fromarray(np.uint8(infilled_bg))
|
||||
except RuntimeError as e:
|
||||
tqdm.write(
|
||||
f"Error infilling {key}: bg_image.shape={np.array(bg_image).shape},"
|
||||
f" mask_image.shape={np.array(mask_image).shape}\n{e}"
|
||||
)
|
||||
save_img(
|
||||
image,
|
||||
key,
|
||||
os.path.join(args.output, part, "error"),
|
||||
img_class=img_class,
|
||||
img_version=mask_idx if len(masks) > 1 else None,
|
||||
format="JPEG",
|
||||
)
|
||||
infilled_bg = None
|
||||
|
||||
if args.debug:
|
||||
data["image"].save(f"example_images/{key}_orig.png")
|
||||
fg_img.save(f"example_images/{key}_fg.png")
|
||||
infilled_bg.save(f"example_images/{key}_bg.png")
|
||||
else:
|
||||
# save files
|
||||
class_name = key.split("_")[0]
|
||||
save_img(
|
||||
fg_img,
|
||||
key,
|
||||
os.path.join(args.output, part, "foregrounds"),
|
||||
img_class=img_class,
|
||||
img_version=mask_idx if len(masks) > 1 else None,
|
||||
format="WEBP",
|
||||
)
|
||||
if infilled_bg is not None:
|
||||
save_img(
|
||||
infilled_bg,
|
||||
key,
|
||||
os.path.join(args.output, part, "backgrounds"),
|
||||
img_class=img_class,
|
||||
img_version=mask_idx if len(masks) > 1 else None,
|
||||
format="JPEG",
|
||||
)
|
||||
if man_print_progress:
|
||||
print(
|
||||
f"{datetime.now()} \t process {args.id}/{args.processes}: \t done with all {len(dataset)} samples",
|
||||
flush=True,
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
srun -K \
|
||||
--partition=H200,H100,A100-40GB,A100-80GB \
|
||||
--job-name="Segment ImageNet" \
|
||||
--nodes=1 \
|
||||
--gpus=1 \
|
||||
--cpus-per-task=16 \
|
||||
--mem=64G \
|
||||
--time=1-0 \
|
||||
--export=HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/ \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py "$@"
|
||||
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
srun -K \
|
||||
--partition=RTX3090,A100-40GB,H100 \
|
||||
--job-name="Segment ImageNet" \
|
||||
--nodes=1 \
|
||||
--mem=64G \
|
||||
--time=1-0 \
|
||||
--export=HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,MAX_WORKERS=16 \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
"$@"
|
||||
49
AAAI Supplementary Material/ForNet Creation Code/utils.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def save_img(img: Image, img_name: str, base_dir: str, img_class: str = None, format="PNG", img_version=None):
|
||||
"""Save an image to a directory.
|
||||
|
||||
Args:
|
||||
img (PIL.Image): Image to save.
|
||||
img_name (str): Relative path to the image.
|
||||
base_dir (str): Base directory to save images in.
|
||||
img_class (str, optional): Image class, if not given try to extract it from the image name in ImageNet train format. Defaults to None.
|
||||
format (str, optional): Format to save the image in. Defaults to "PNG".
|
||||
img_version (int, optional): Version of the image. Will be appended to the path. Defaults to None.
|
||||
|
||||
"""
|
||||
if not img_name.endswith(f".{format}"):
|
||||
img_name = f"{img_name.split('.')[0]}.{format}"
|
||||
if img_class is None:
|
||||
img_class = img_name.split("_")[0]
|
||||
if not os.path.exists(os.path.join(base_dir, img_class)):
|
||||
os.makedirs(os.path.join(base_dir, img_class), exist_ok=True)
|
||||
if img_version is not None:
|
||||
img_name = f"{img_name.split('.')[0]}_v{img_version}.{format}"
|
||||
img.save(os.path.join(base_dir, img_class, img_name), format.lower())
|
||||
|
||||
|
||||
def already_segmented(img_name: str, base_dir: str, img_class: str = None):
|
||||
"""Check if an image was already segmented.
|
||||
|
||||
Args:
|
||||
img_name (str): Relative path to the image.
|
||||
base_dir (str): Base directory to save images in.
|
||||
img_class (str, optional): Image class, if not given try to extract it from the image name in ImageNet train format. Defaults to None.
|
||||
|
||||
Returns:
|
||||
bool: Image was segmented already.
|
||||
|
||||
"""
|
||||
img_base_name = ".".join(img_name.split(".")[:-1]) if "." in img_name else img_name
|
||||
if img_class is None:
|
||||
img_class = img_name.split("_")[0]
|
||||
if not os.path.exists(os.path.join(base_dir, img_class)):
|
||||
return False
|
||||
return any(
|
||||
file.startswith(img_base_name + "_v") or file.startswith(img_base_name + ".")
|
||||
for file in os.listdir(os.path.join(base_dir, img_class))
|
||||
)
|
||||
@@ -0,0 +1,200 @@
|
||||
n07695742: pretzel
|
||||
n03902125: pay-phone, pay-station
|
||||
n03980874: poncho
|
||||
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
|
||||
n02730930: apron
|
||||
n02699494: altar
|
||||
n03201208: dining table, board
|
||||
n02056570: king penguin, Aptenodytes patagonica
|
||||
n04099969: rocking chair, rocker
|
||||
n04366367: suspension bridge
|
||||
n04067472: reel
|
||||
n02808440: bathtub, bathing tub, bath, tub
|
||||
n04540053: volleyball
|
||||
n02403003: ox
|
||||
n03100240: convertible
|
||||
n04562935: water tower
|
||||
n02788148: bannister, banister, balustrade, balusters, handrail
|
||||
n02988304: CD player
|
||||
n02423022: gazelle
|
||||
n03637318: lampshade, lamp shade
|
||||
n01774384: black widow, Latrodectus mactans
|
||||
n01768244: trilobite
|
||||
n07614500: ice cream, icecream
|
||||
n04254777: sock
|
||||
n02085620: Chihuahua
|
||||
n01443537: goldfish, Carassius auratus
|
||||
n01629819: European fire salamander, Salamandra salamandra
|
||||
n02099601: golden retriever
|
||||
n02321529: sea cucumber, holothurian
|
||||
n03837869: obelisk
|
||||
n02002724: black stork, Ciconia nigra
|
||||
n02841315: binoculars, field glasses, opera glasses
|
||||
n04560804: water jug
|
||||
n02364673: guinea pig, Cavia cobaya
|
||||
n03706229: magnetic compass
|
||||
n09256479: coral reef
|
||||
n09332890: lakeside, lakeshore
|
||||
n03544143: hourglass
|
||||
n02124075: Egyptian cat
|
||||
n02948072: candle, taper, wax light
|
||||
n01950731: sea slug, nudibranch
|
||||
n02791270: barbershop
|
||||
n03179701: desk
|
||||
n02190166: fly
|
||||
n04275548: spider web, spider's web
|
||||
n04417672: thatch, thatched roof
|
||||
n03930313: picket fence, paling
|
||||
n02236044: mantis, mantid
|
||||
n03976657: pole
|
||||
n01774750: tarantula
|
||||
n04376876: syringe
|
||||
n04133789: sandal
|
||||
n02099712: Labrador retriever
|
||||
n04532670: viaduct
|
||||
n04487081: trolleybus, trolley coach, trackless trolley
|
||||
n09428293: seashore, coast, seacoast, sea-coast
|
||||
n03160309: dam, dike, dyke
|
||||
n03250847: drumstick
|
||||
n02843684: birdhouse
|
||||
n07768694: pomegranate
|
||||
n03670208: limousine, limo
|
||||
n03085013: computer keyboard, keypad
|
||||
n02892201: brass, memorial tablet, plaque
|
||||
n02233338: cockroach, roach
|
||||
n03649909: lawn mower, mower
|
||||
n03388043: fountain
|
||||
n02917067: bullet train, bullet
|
||||
n02486410: baboon
|
||||
n04596742: wok
|
||||
n03255030: dumbbell
|
||||
n03937543: pill bottle
|
||||
n02113799: standard poodle
|
||||
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
|
||||
n02906734: broom
|
||||
n07920052: espresso
|
||||
n01698640: American alligator, Alligator mississipiensis
|
||||
n02123394: Persian cat
|
||||
n03424325: gasmask, respirator, gas helmet
|
||||
n02129165: lion, king of beasts, Panthera leo
|
||||
n04008634: projectile, missile
|
||||
n03042490: cliff dwelling
|
||||
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
|
||||
n02815834: beaker
|
||||
n02395406: hog, pig, grunter, squealer, Sus scrofa
|
||||
n01784675: centipede
|
||||
n03126707: crane
|
||||
n04399382: teddy, teddy bear
|
||||
n07875152: potpie
|
||||
n03733131: maypole
|
||||
n02802426: basketball
|
||||
n03891332: parking meter
|
||||
n01910747: jellyfish
|
||||
n03838899: oboe, hautboy, hautbois
|
||||
n03770439: miniskirt, mini
|
||||
n02281406: sulphur butterfly, sulfur butterfly
|
||||
n03970156: plunger, plumber's helper
|
||||
n09246464: cliff, drop, drop-off
|
||||
n02206856: bee
|
||||
n02074367: dugong, Dugong dugon
|
||||
n03584254: iPod
|
||||
n04179913: sewing machine
|
||||
n04328186: stopwatch, stop watch
|
||||
n07583066: guacamole
|
||||
n01917289: brain coral
|
||||
n03447447: gondola
|
||||
n02823428: beer bottle
|
||||
n03854065: organ, pipe organ
|
||||
n02793495: barn
|
||||
n04285008: sports car, sport car
|
||||
n02231487: walking stick, walkingstick, stick insect
|
||||
n04465501: tractor
|
||||
n02814860: beacon, lighthouse, beacon light, pharos
|
||||
n02883205: bow tie, bow-tie, bowtie
|
||||
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
|
||||
n04149813: scoreboard
|
||||
n04023962: punching bag, punch bag, punching ball, punchball
|
||||
n02226429: grasshopper, hopper
|
||||
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
|
||||
n02669723: academic gown, academic robe, judge's robe
|
||||
n04486054: triumphal arch
|
||||
n04070727: refrigerator, icebox
|
||||
n03444034: go-kart
|
||||
n02666196: abacus
|
||||
n01945685: slug
|
||||
n04251144: snorkel
|
||||
n03617480: kimono
|
||||
n03599486: jinrikisha, ricksha, rickshaw
|
||||
n02437312: Arabian camel, dromedary, Camelus dromedarius
|
||||
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
|
||||
n04118538: rugby ball
|
||||
n01770393: scorpion
|
||||
n04356056: sunglasses, dark glasses, shades
|
||||
n03804744: nail
|
||||
n02132136: brown bear, bruin, Ursus arctos
|
||||
n03400231: frying pan, frypan, skillet
|
||||
n03983396: pop bottle, soda bottle
|
||||
n07734744: mushroom
|
||||
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
|
||||
n02410509: bison
|
||||
n03404251: fur coat
|
||||
n04456115: torch
|
||||
n02123045: tabby, tabby cat
|
||||
n03026506: Christmas stocking
|
||||
n07715103: cauliflower
|
||||
n04398044: teapot
|
||||
n02927161: butcher shop, meat market
|
||||
n07749582: lemon
|
||||
n07615774: ice lolly, lolly, lollipop, popsicle
|
||||
n02795169: barrel, cask
|
||||
n04532106: vestment
|
||||
n02837789: bikini, two-piece
|
||||
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
|
||||
n04265275: space heater
|
||||
n02481823: chimpanzee, chimp, Pan troglodytes
|
||||
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
|
||||
n06596364: comic book
|
||||
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
|
||||
n02504458: African elephant, Loxodonta africana
|
||||
n03014705: chest
|
||||
n01944390: snail
|
||||
n04146614: school bus
|
||||
n01641577: bullfrog, Rana catesbeiana
|
||||
n07720875: bell pepper
|
||||
n02999410: chain
|
||||
n01855672: goose
|
||||
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
|
||||
n07753592: banana
|
||||
n07871810: meat loaf, meatloaf
|
||||
n04501370: turnstile
|
||||
n04311004: steel arch bridge
|
||||
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
|
||||
n04074963: remote control, remote
|
||||
n03662601: lifeboat
|
||||
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
|
||||
n03089624: confectionery, confectionary, candy store
|
||||
n04259630: sombrero
|
||||
n03393912: freight car
|
||||
n04597913: wooden spoon
|
||||
n07711569: mashed potato
|
||||
n03355925: flagpole, flagstaff
|
||||
n02963159: cardigan
|
||||
n07579787: plate
|
||||
n02950826: cannon
|
||||
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
|
||||
n02094433: Yorkshire terrier
|
||||
n02909870: bucket, pail
|
||||
n02058221: albatross, mollymawk
|
||||
n01742172: boa constrictor, Constrictor constrictor
|
||||
n09193705: alp
|
||||
n04371430: swimming trunks, bathing trunks
|
||||
n07747607: orange
|
||||
n03814639: neck brace
|
||||
n04507155: umbrella
|
||||
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
|
||||
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
|
||||
n03763968: military uniform
|
||||
n07873807: pizza, pizza pie
|
||||
n03992509: potter's wheel
|
||||
n03796401: moving van
|
||||
n12267677: acorn
|
||||
400
AAAI Supplementary Material/ForNet Creation Code/wordnet_tree.py
Normal file
@@ -0,0 +1,400 @@
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
from copy import copy
|
||||
|
||||
from nltk.corpus import wordnet as wn
|
||||
|
||||
|
||||
class bcolors:
|
||||
"""Colors for terminal output."""
|
||||
|
||||
HEADER = "\033[95m"
|
||||
OKBLUE = "\033[94m"
|
||||
OKCYAN = "\033[96m"
|
||||
OKGREEN = "\033[92m"
|
||||
WARNING = "\033[93m"
|
||||
FAIL = "\033[91m"
|
||||
ENDC = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
UNDERLINE = "\033[4m"
|
||||
|
||||
|
||||
def _lemmas_str(synset):
|
||||
return ", ".join([lemma.name() for lemma in synset.lemmas()])
|
||||
|
||||
|
||||
class WNEntry:
|
||||
"""One wordnet synset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
id: int,
|
||||
lemmas: str,
|
||||
parent_id: int,
|
||||
depth: int = None,
|
||||
in_image_net: bool = False,
|
||||
child_ids: list = None,
|
||||
in_main_tree: bool = True,
|
||||
_n_images: int = 0,
|
||||
_description: str = None,
|
||||
_name: str = None,
|
||||
_pruned: bool = False,
|
||||
):
|
||||
self.name = name
|
||||
self.id = id
|
||||
self.lemmas = lemmas
|
||||
self.parent_id = parent_id
|
||||
self.depth = depth
|
||||
self.in_image_net = in_image_net
|
||||
self.child_ids = child_ids
|
||||
self.in_main_tree = in_main_tree
|
||||
self._n_images = _n_images
|
||||
self._description = _description
|
||||
self._name = _name
|
||||
self._pruned = _pruned
|
||||
|
||||
def __str__(self, tree=None, accumulate=True, colors=True, max_depth=0, max_children=None):
|
||||
green = f"{bcolors.OKGREEN}" if colors else ""
|
||||
red = f"{bcolors.FAIL}" if colors else ""
|
||||
end = f"{bcolors.ENDC}" if colors else ""
|
||||
start_symb = f"{green}+{end}" if self.in_image_net else f"{red}-{end}"
|
||||
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
|
||||
if self.child_ids is None or tree is None or max_depth == 0:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
||||
|
||||
children = self.child_ids
|
||||
if max_children is not None and len(children) > max_children:
|
||||
children = children[:max_children]
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
||||
[
|
||||
"\n ".join(
|
||||
tree.nodes[child_id]
|
||||
.__str__(tree=tree, accumulate=accumulate, colors=colors, max_depth=max_depth - 1)
|
||||
.split("\n")
|
||||
)
|
||||
for child_id in children
|
||||
]
|
||||
)
|
||||
|
||||
def tree_diff(self, tree_1, tree_2):
|
||||
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
|
||||
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
|
||||
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
|
||||
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
|
||||
else:
|
||||
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
|
||||
n_ims = (
|
||||
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
|
||||
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
|
||||
)
|
||||
|
||||
if self.child_ids is None:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
||||
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
||||
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
|
||||
)
|
||||
|
||||
def prune(self, tree):
|
||||
if self._pruned or self.parent_id is None:
|
||||
return
|
||||
|
||||
if self.child_ids is not None:
|
||||
for child_id in self.child_ids:
|
||||
tree[child_id].prune(tree)
|
||||
|
||||
self._pruned = True
|
||||
parent_node = tree.nodes[self.parent_id]
|
||||
try:
|
||||
parent_node.child_ids.remove(self.id)
|
||||
except ValueError as e:
|
||||
print(
|
||||
f"Error removing {self.name} from"
|
||||
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
|
||||
)
|
||||
while parent_node._pruned:
|
||||
parent_node = tree.nodes[parent_node.parent_id]
|
||||
parent_node._n_images += self._n_images
|
||||
self._n_images = 0
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
if not self._description:
|
||||
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def print_name(self):
|
||||
return self.name.split(".")[0]
|
||||
|
||||
@property
|
||||
def is_leaf(self):
|
||||
return self.child_ids is None or len(self.child_ids) == 0
|
||||
|
||||
def get_branch(self, tree=None):
|
||||
if self.parent_id is None or tree is None:
|
||||
return self.print_name
|
||||
|
||||
parent = tree.nodes[self.parent_id]
|
||||
return parent.get_branch(tree) + " > " + self.print_name
|
||||
|
||||
def get_branch_list(self, tree):
|
||||
if self.parent_id is None:
|
||||
return [self]
|
||||
parent = tree.nodes[self.parent_id]
|
||||
return parent.get_branch_list(tree) + [self]
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"id": self.id,
|
||||
"lemmas": self.lemmas,
|
||||
"parent_id": self.parent_id,
|
||||
"depth": self.depth,
|
||||
"in_image_net": self.in_image_net,
|
||||
"child_ids": self.child_ids,
|
||||
"in_main_tree": self.in_main_tree,
|
||||
"_n_images": self._n_images,
|
||||
"_description": self._description,
|
||||
"_name": self._name,
|
||||
"_pruned": self._pruned,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(**d)
|
||||
|
||||
def n_images(self, tree=None):
|
||||
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
|
||||
return self._n_images
|
||||
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
|
||||
|
||||
def n_children(self, tree=None):
|
||||
if self.child_ids is None:
|
||||
return 0
|
||||
if tree is None or len(self.child_ids) == 0:
|
||||
return len(self.child_ids)
|
||||
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
|
||||
|
||||
def get_examples(self, tree, n_examples=3):
|
||||
if self.child_ids is None or len(self.child_ids) == 0:
|
||||
return ""
|
||||
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
|
||||
max_images = max(child_images.values())
|
||||
if max_images == 0:
|
||||
# go on number of child nodes
|
||||
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
|
||||
# sorted childids by number of images
|
||||
top_children = [
|
||||
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
top_children = top_children[: min(n_examples, len(top_children))]
|
||||
return ", ".join(
|
||||
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
|
||||
)
|
||||
|
||||
|
||||
class WNTree:
|
||||
def __init__(self, root=1740, nodes=None):
|
||||
if isinstance(root, int):
|
||||
root_id = root
|
||||
root_synset = wn.synset_from_pos_and_offset("n", root)
|
||||
root_node = WNEntry(
|
||||
root_synset.name(),
|
||||
root_id,
|
||||
_lemmas_str(root_synset),
|
||||
parent_id=None,
|
||||
depth=0,
|
||||
)
|
||||
else:
|
||||
assert isinstance(root, WNEntry)
|
||||
root_id = root.id
|
||||
root_node = root
|
||||
|
||||
self.root = root_node
|
||||
self.nodes = {root_id: self.root} if nodes is None else nodes
|
||||
self.parentless = []
|
||||
self.label_index = None
|
||||
self.pruned = set()
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"root": self.root.to_dict(),
|
||||
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
||||
"parentless": self.parentless,
|
||||
"pruned": list(self.pruned),
|
||||
}
|
||||
|
||||
def prune(self, min_images):
|
||||
pruned_nodes = set()
|
||||
|
||||
# prune all nodes that have fewer than min_images below them
|
||||
for node_id, node in self.nodes.items():
|
||||
if node.n_images(self) < min_images:
|
||||
pruned_nodes.add(node_id)
|
||||
node.prune(self)
|
||||
|
||||
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
|
||||
node_stack = [self.root]
|
||||
node_idx = 0
|
||||
while node_idx < len(node_stack):
|
||||
node = node_stack[node_idx]
|
||||
if node.child_ids is not None:
|
||||
for child_id in node.child_ids:
|
||||
child = self.nodes[child_id]
|
||||
node_stack.append(child)
|
||||
node_idx += 1
|
||||
|
||||
# now prune the stack from the bottom up
|
||||
for node in node_stack[::-1]:
|
||||
# only look at images of that class, not of additional children
|
||||
if node.n_images() < min_images:
|
||||
pruned_nodes.add(node.id)
|
||||
node.prune(self)
|
||||
|
||||
self.pruned = pruned_nodes
|
||||
return pruned_nodes
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
tree = cls()
|
||||
tree.root = WNEntry.from_dict(d["root"])
|
||||
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
|
||||
tree.parentless = d["parentless"]
|
||||
if "pruned" in d:
|
||||
tree.pruned = set(d["pruned"])
|
||||
return tree
|
||||
|
||||
def add_node(self, node_id, in_in=True):
|
||||
if node_id in self.nodes:
|
||||
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
|
||||
return
|
||||
|
||||
synset = wn.synset_from_pos_and_offset("n", node_id)
|
||||
|
||||
# print(f"adding node {synset.name()} with id {node_id}")
|
||||
|
||||
hypernyms = synset.hypernyms()
|
||||
if len(hypernyms) == 0:
|
||||
parent_id = None
|
||||
self.parentless.append(node_id)
|
||||
main_tree = False
|
||||
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
|
||||
else:
|
||||
parent_id = synset.hypernyms()[0].offset()
|
||||
if parent_id not in self.nodes:
|
||||
self.add_node(parent_id, in_in=False)
|
||||
parent = self.nodes[parent_id]
|
||||
|
||||
if parent.child_ids is None:
|
||||
parent.child_ids = []
|
||||
parent.child_ids.append(node_id)
|
||||
main_tree = parent.in_main_tree
|
||||
|
||||
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
|
||||
node = WNEntry(
|
||||
synset.name(),
|
||||
node_id,
|
||||
_lemmas_str(synset),
|
||||
parent_id=parent_id,
|
||||
in_image_net=in_in,
|
||||
depth=depth,
|
||||
in_main_tree=main_tree,
|
||||
)
|
||||
|
||||
self.nodes[node_id] = node
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes)
|
||||
|
||||
def image_net_len(self, only_main_tree=False):
|
||||
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
||||
|
||||
def max_depth(self, only_main_tree=False):
|
||||
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
||||
|
||||
def __str__(self, colors=True):
|
||||
return (
|
||||
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
|
||||
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self, colors=colors)}\nParentless:\n"
|
||||
+ "\n".join([self.nodes[node_id].__str__(tree=self, colors=colors) for node_id in self.parentless])
|
||||
)
|
||||
|
||||
def save(self, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(self.to_dict(), f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
with open(path, "r") as f:
|
||||
tree_dict = json.load(f)
|
||||
return cls.from_dict(tree_dict)
|
||||
|
||||
def subtree(self, node):
|
||||
node_id = self[node].id
|
||||
if node_id not in self.nodes:
|
||||
return None
|
||||
node_queue = [self.nodes[node_id]]
|
||||
subtree_ids = set()
|
||||
while len(node_queue) > 0:
|
||||
node = node_queue.pop(0)
|
||||
subtree_ids.add(node.id)
|
||||
if node.child_ids is not None:
|
||||
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
|
||||
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
|
||||
subtree_root = subtree_nodes[node_id]
|
||||
subtree_root.parent_id = None
|
||||
depth_diff = subtree_root.depth
|
||||
for node in subtree_nodes.values():
|
||||
node.depth -= depth_diff
|
||||
return WNTree(root=subtree_root, nodes=subtree_nodes)
|
||||
|
||||
def _make_label_index(self):
|
||||
self.label_index = sorted(
|
||||
[node_id for node_id, node in self.nodes.items() if node.n_images(self) > 0 and not node._pruned]
|
||||
)
|
||||
|
||||
def get_label(self, node_id):
|
||||
if self.label_index is None:
|
||||
self._make_label_index()
|
||||
while self.nodes[node_id]._pruned:
|
||||
node_id = self.nodes[node_id].parent_id
|
||||
return self.label_index.index(node_id)
|
||||
|
||||
def n_labels(self):
|
||||
if self.label_index is None:
|
||||
self._make_label_index()
|
||||
return len(self.label_index)
|
||||
|
||||
def __contains__(self, item):
|
||||
if isinstance(item, str):
|
||||
if item[0] == "n":
|
||||
item = int(item[1:])
|
||||
else:
|
||||
return False
|
||||
if isinstance(item, int):
|
||||
return item in self.nodes
|
||||
if isinstance(item, WNEntry):
|
||||
return item.id in self.nodes
|
||||
return False
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, str) and item[0].startswith("n"):
|
||||
with contextlib.suppress(ValueError):
|
||||
item = int(item[1:])
|
||||
if isinstance(item, str) and ".n." in item:
|
||||
for node in self.nodes.values():
|
||||
if item == node.name:
|
||||
return node
|
||||
raise KeyError(f"Item {item} not found in tree")
|
||||
if isinstance(item, int):
|
||||
return self.nodes[item]
|
||||
if isinstance(item, WNEntry):
|
||||
return self.nodes[item.id]
|
||||
raise KeyError(f"Item {item} not found in tree")
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.nodes.keys())
|
||||