AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,73 @@
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]
# Same as Black.
line-length = 120
indent-width = 4
[lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["E", "F", "D", "B", "NPY", "PD", "TD005", "TD006", "TD007", "SIM", "RET", "Q", "ICN", "I"]
ignore = ["D203", "D213", "E501", "D100", "NPY002", "D102", "B008", "PD011", "D105", "SIM118"]
# Allow fix for all enabled rules (when `--fix`) is provided.
unfixable = ["B"]
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = true
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
# Enable auto-formatting of code examples in docstrings. Markdown,
# reStructuredText code/literal blocks and doctests are all supported.
#
# This is currently disabled by default, but it is planned for this
# to be opt-out in the future.
docstring-code-format = true
# Set the line length limit used when formatting code snippets in
# docstrings.
#
# This only has an effect when the `docstring-code-format` setting is
# enabled.
docstring-code-line-length = "dynamic"

View File

@@ -0,0 +1,61 @@
# Creating the ForNet Dataset
We can't just provide the ForNet dataset here, as it's too large to be part of the appendix and using a link will go against double-blind review rules.
After acceptance, the dataset will be downloadable online.
For now, we provide the scripts and steps to recreate the dataset.
In general, if you are unsure what arguments each script allows, run it using the `--help` flag.
## 1. Setup paths
Fill in the paths in `experiments/general_srun` in the `Model Training Code` folder, as well as in `srun-general.sh`, `slurm-segment-imnet.sh` and all the `sbatch-segment-...` files.
In particular the `--container-image`, `--container-mounts`, `--output` and `NLTK_DATA` and `HF_HOME` paths in `--export`.
## 2. Pretrain Filtering Models
Use the `Model Trainig Code` to pretrain an ensemble of models to use for filtering in a later step.
Train those models on either `TinyImageNet` or `ImageNet`, depending on if you want to create `TinyForNet` or `ForNet`.
The fill in the relevant paths to the pretrained weights in `experiments/filter_segmentation_versions.py` lines 96/98.
## 3. Create the dataset
### Automatically: using slurm
You may just run the `create_dataset.py` file (on a slurm head node). That file will automatically run all the necessary steps one after another.
### Manually and step-by-step
If you want to run each step of the pipeline manually, follow these steps.
For default arguments and settings, see the `create_dataset.py` script, even though you may not want to run it directly, it can tell you how to run all the other scripts.
#### 3.1 Segment Objects and Backgrounds
Use the segementation script (`segment_imagenet.py`) to segment each of the dataset images.
Watch out, as this script uses `datadings` for image loading, so you need to provide a `datadings` variant of your dataset.
You need to provide the root folder of the dataset.
Choose your segmentation model using the `-model` argument (LaMa or AttErase).
If you want to use the >general< prompting strategy, set the `--parent_in_promt` flag.
Use `--output`/`-o` to set the output directory.
Use `--processes` and `-id` for splitting the task up into multiple parallelizable processes.
#### 3.2 Filter the segmented images
In this step, you use the pretrained ensemble of models (from step 2) for filtering the segmented images.
As this step is based on the training and model code, it's in the `Model Training Code` directory.
After setting the relevant paths to the pretrained weights (see step 2), you may run the `experiments/filter_segmentation_versions.py` script using that directory as the PWD.
#### 3.3 Zip the dataset
In distributed storage settings it might be useful to read from one large (unclompressed) zip file instead of reading millions of small single files.
To do this, run
```commandline
zip -r -0 backgrounds_train.zip train/backgrounds > /dev/null 2>&1
```
for the train and val backgrounds and foregrounds
#### 3.4 Compute the foreground size ratios
For the resizing step during recombination, the relative size of each object in each image is needed.
To compute it, run the `foreground_size_ratio.py` script on your filtered dataset.
It expects the zipfiled in the folder you provide as `-ds`.

View File

@@ -0,0 +1,98 @@
from copy import copy
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DDIMScheduler, DiffusionPipeline
from PIL import Image, ImageFilter
from torchvision.transforms.functional import gaussian_blur, to_tensor
def _preprocess_image(image, device, dtype=torch.float16):
image = to_tensor(image)
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
if image.shape[1] != 3:
image = image.expand(-1, 3, -1, -1)
image = F.interpolate(image, (1024, 1024))
return image.to(dtype).to(device)
def _preprocess_mask(mask, device, dtype=torch.float16):
mask = to_tensor(mask.convert("L"))
mask = mask.unsqueeze_(0).float() # 0 or 1
mask = F.interpolate(mask, (1024, 1024))
mask = gaussian_blur(mask, kernel_size=(77, 77))
mask[mask < 0.1] = 0
mask[mask >= 0.1] = 1
return mask.to(dtype).to(device)
class AttentiveEraser:
"""Attentive Eraser Pipeline + Pre- and Post-Processing."""
prompt = ""
dtype = torch.float16
def __init__(self, num_steps=50, device=None):
"""Create the attentive eraser.
Args:
num_steps (int, optional): Number of steps in the diffusion process. Will start at 20%. Defaults to 100.
device (_type_, optional): Device to run on. Defaults to 'cuda' if available, else 'cpu'.
"""
self.num_steps = num_steps
self.device = device if device else torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
)
self.model_path = "stabilityai/stable-diffusion-xl-base-1.0"
self.pipeline = DiffusionPipeline.from_pretrained(
self.model_path,
custom_pipeline="./pipelines/pipeline_stable_diffusion_xl_attentive_eraser.py",
scheduler=scheduler,
variant="fp16",
use_safetensors=True,
torch_dtype=self.dtype,
).to(self.device)
self.pipeline.enable_attention_slicing()
self.pipeline.enable_model_cpu_offload()
def __call__(self, image, mask):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if isinstance(mask, np.ndarray):
mask = Image.fromarray(mask)
prep_img = _preprocess_image(image, self.device, self.dtype)
prep_mask = _preprocess_mask(mask, self.device, self.dtype)
orig_shape = image.size
diff_img = (
self.pipeline(
prompt=self.prompt,
image=prep_img,
mask_image=prep_mask,
height=1024,
width=1024,
AAS=True, # enable AAS
strength=0.8, # inpainting strength
rm_guidance_scale=9, # removal guidance scale
ss_steps=9, # similarity suppression steps
ss_scale=0.3, # similarity suppression scale
AAS_start_step=0, # AAS start step
AAS_start_layer=34, # AAS start layer
AAS_end_layer=70, # AAS end layer
num_inference_steps=self.num_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
guidance_scale=1,
)
.images[0]
.resize(orig_shape)
)
paste_mask = Image.fromarray(np.array(mask) > 0 * 255)
paste_mask = paste_mask.convert("RGB").filter(ImageFilter.GaussianBlur(radius=10)).convert("L")
out_img = copy(image)
out_img.paste(diff_img, (0, 0), paste_mask)
return out_img

View File

@@ -0,0 +1,195 @@
import argparse
import os
import re
import subprocess
parser = argparse.ArgumentParser(description="Create Segment & Recombine dataset")
parser.add_argument("-d", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
parser.add_argument(
"--out_root", type=str, required=True, help="Root directory where the output directory will be created."
)
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
parser.add_argument(
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
)
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
parser.add_argument("-infill_model", choices=["LaMa", "AttErase"], default="LaMa", help="Infilling model to use")
parser.add_argument("--continue", dest="continue_", action="store_true", help="Continue from previous run")
args = parser.parse_args()
out_root = args.out_root
base_folder = os.path.dirname(__file__)
training_code_folder = os.path.join(base_folder, os.pardir, "Model Training Code")
ds_name_re = re.compile(r"(Tiny)?INSegment_v(\d*)(_f\d*)?")
name_match = ds_name_re.match(args.output)
assert name_match, f"Output name {args.output} does not match the expected format: {ds_name_re.pattern}"
assert (
args.output_ims == "best" or name_match.group(3) is None
), "For output_ims == 'all', the filter subversions will be automatically created."
assert args.continue_ or not os.path.exists(out_root + args.output), f"Output directory {args.output} already exists."
settings_file = {
"dataset": args.dataset,
"threshold": args.threshold,
"output_ims": args.output_ims,
"mask_merge_threshold": args.mask_merge_threshold,
"parent_labels": args.parent_labels,
"parent_in_prompt": args.parent_in_prompt,
"infill_model": args.infill_model,
}
settings_file = [f"{k} = {str(settings_file[k])}" for k in sorted(list(settings_file.keys()))]
if os.path.exists(out_root + args.output) and os.path.exists(out_root + args.output + "/settings.txt"):
with open(out_root + args.output + "/settings.txt", "r") as f:
old_settings = f.read().split("\n")
old_settings = [line.strip() for line in old_settings if line if len(line.strip()) > 0]
assert old_settings == settings_file, (
f"Settings file {out_root + args.output}/settings.txt does not match current settings: old: {old_settings} vs"
f" new: {settings_file}"
)
else:
os.makedirs(out_root + args.output, exist_ok=True)
with open(out_root + args.output + "/settings.txt", "w") as f:
f.write("\n".join(settings_file) + "\n")
general_args = [
"sbatch",
"sbatch-segment-tinyimnet-wait" if args.dataset == "tinyimagenet" else "sbatch-segment-imagenet-wait",
"-r",
args.dataset_root,
"-o",
out_root + args.output,
"--parent_labels",
str(args.parent_labels),
"--output_ims",
args.output_ims,
"--mask_merge_threshold",
str(args.mask_merge_threshold),
"-t",
str(args.threshold),
"-model",
args.infill_model,
]
if args.parent_in_prompt:
general_args.append("--parent_in_prompt")
print(f"Starting segmentation: {' '.join(general_args)} for {args.dataset}-val and {args.dataset}")
p_train = subprocess.Popen(general_args + ["-d", args.dataset], cwd=base_folder)
if args.dataset == "imagenet":
general_args[1] = "sbatch-segment-imagenet-val-wait"
p_val = subprocess.Popen(general_args + ["-d", args.dataset + "-val"], cwd=base_folder)
# detect if exit in error
p_val.wait()
p_train.wait()
rcodes = (p_val.returncode, p_train.returncode)
if any(rcode != 0 for rcode in rcodes):
print(f"Error in segmentation (val, train): {rcodes}")
exit(1)
print("Segmentation done.")
if args.output_ims == "all":
print("copy to subversions for filtering")
p_1 = subprocess.Popen(
[
"cp",
"-rl",
os.path.join(out_root, args.output),
out_root + f"{args.output}_f1/",
]
)
p_2 = subprocess.Popen(
[
"cp",
"-rl",
os.path.join(out_root, args.output),
out_root + f"{args.output}_f2/",
]
)
p_1.wait()
p_2.wait()
print("Filtering subversions copied over.")
print("Starting filtering")
filtering_args = ["./experiments/general_srun.sh", "experiments/filter_segmentation_versions.py"]
p_val_f1 = subprocess.Popen(
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/val", "-d", args.dataset], cwd=training_code_folder
)
p_train_f1 = subprocess.Popen(
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/train", "-d", args.dataset], cwd=training_code_folder
)
p_val_f2 = subprocess.Popen(
filtering_args
+ ["-f", out_root + f"{args.output}_f2" + "/val", "-score_f_weights", "automatic", "-d", args.dataset],
cwd=training_code_folder,
)
p_train_f2 = subprocess.Popen(
filtering_args
+ ["-f", out_root + f"{args.output}_f2" + "/train", "-score_f_weights", "automatic", "-d", args.dataset],
cwd=training_code_folder,
)
p_val_f1.wait()
p_train_f1.wait()
p_val_f2.wait()
p_train_f2.wait()
ds_folders = [f"{args.output}_f1", f"{args.output}_f2"]
else:
if name_match.group(2) is None:
new_name = args.output + "_f1"
print(f"Renaming output folder: {args.output} -> {new_name}")
os.rename(out_root + args.output, out_root + new_name)
else:
new_name = args.output
ds_folders = [new_name]
for folder in ds_folders:
print(f"Zipping up {folder}")
p_val_fg = subprocess.Popen(
["zip", "-r", "-0", "foregrounds_val.zip", "val/foregrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
)
p_train_fg = subprocess.Popen(
["zip", "-r", "-0", "foregrounds_train.zip", "train/foregrounds", ">", "/dev/null", "2>&1"],
cwd=out_root + folder,
)
p_val_bg = subprocess.Popen(
["zip", "-r", "-0", "backgrounds_val.zip", "val/backgrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
)
p_train_bg = subprocess.Popen(
["zip", "-r", "-0", "backgrounds_train.zip", "train/backgrounds", ">", "/dev/null", "2>&1"],
cwd=out_root + folder,
)
p_val_fg.wait()
p_train_fg.wait()
p_val_bg.wait()
p_train_bg.wait()
print(f"Gathering foreground size ratios for {folder}")
p_val = subprocess.Popen(
[
"./srun-general.sh",
"python",
"foreground_size_ratio.py",
"--root",
args.dataset_root,
"-ds",
folder,
"-mode",
"val",
],
cwd=base_folder,
)
p_train = subprocess.Popen(
["./srun-general.sh", "python", "foreground_size_ratio.py", "--root", args.dataset_root, "-ds", folder],
cwd=base_folder,
)
p_val.wait()
p_train.wait()

Binary file not shown.

After

Width:  |  Height:  |  Size: 402 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 148 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

View File

@@ -0,0 +1,69 @@
import argparse
import json
import os
import zipfile
from io import BytesIO
import numpy as np
import PIL
from PIL import Image
from tqdm.auto import tqdm
parser = argparse.ArgumentParser(description="Foreground size ratio")
parser.add_argument("-mode", choices=["train", "val"], default="train", help="Train or val data?")
parser.add_argument("-ds", "--dataset", type=str, required=True, help="Dataset to use")
parser.add_argument("--root", type=str, required=True, help="Root folder of the dataset")
args = parser.parse_args()
# if not os.path.exists("foreground_size_ratios.json"):
train = args.mode == "train"
root = os.path.join(args.root, args.dataset) if not args.dataset.startswith("/") else args.dataset
cls_bg_ratios = {}
fg_bg_ratio_map = {}
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip, zipfile.ZipFile(
f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r"
) as fg_zip:
backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
print(f"Bgs: {backgrounds[:5]}, ...\nFgs: {foregrounds[:5]}, ...")
for bg_name in tqdm(backgrounds):
fg_name = bg_name.replace("backgrounds", "foregrounds").replace("JPEG", "WEBP")
if fg_name not in foregrounds:
tqdm.write(f"Skipping {bg_name} as it has no corresponding foreground")
fg_bg_ratio_map[bg_name] = 0.0
continue
with bg_zip.open(bg_name) as f:
bg_data = BytesIO(f.read())
try:
bg_img = Image.open(bg_data)
except PIL.UnidentifiedImageError as e:
print(f"Error with file={bg_name}")
raise e
bg_img_size = bg_img.size
with fg_zip.open(fg_name) as f:
fg_data = BytesIO(f.read())
try:
fg_img = Image.open(fg_data)
except PIL.UnidentifiedImageError as e:
print(f"Error with file={fg_name}")
raise e
fg_img_size = fg_img.size
fg_img_pixel_size = int(np.sum(fg_img.split()[-1]) / 255)
img_cls = bg_name.split("/")[-2]
if img_cls not in cls_bg_ratios:
cls_bg_ratios[img_cls] = []
cls_bg_ratios[img_cls].append(fg_img_pixel_size / (bg_img_size[0] * bg_img_size[1]))
fg_bg_ratio_map[bg_name] = fg_img_pixel_size / (bg_img_size[0] * bg_img_size[1])
with open(f"{root}/fg_bg_ratios_{args.mode}.json", "w") as f:
json.dump(fg_bg_ratio_map, f)
print(f"Saved fg_bg_ratios_{args.mode} to disk. Exiting.")

View File

@@ -0,0 +1,42 @@
import argparse
import json
from datadings.reader import MsgpackReader
from datadings.torch import Dataset
from tqdm.auto import tqdm
parser = argparse.ArgumentParser(description="Foreground size ratio")
parser.add_argument("-mode", choices=["train", "val"], default="train", help="Train or val data?")
parser.add_argument("-ds", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
parser.add_argument("-r", "--dataset_root", required=True, type=str, help="Root directory of the dataset")
parser.add_argument("-s", "--segment_root", required=True, type=str, help="Root directory of the segmentation dataset")
args = parser.parse_args()
if args.dataset == "imagenet":
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/{args.mode}.msgpack")
dataset = Dataset(reader)
elif args.dataset == "tinyimagenet":
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_{args.mode}.msgpack")
dataset = Dataset(reader)
else:
raise ValueError(f"Unknown dataset: {args.dataset}")
if args.dataset.startswith("imagenet"):
with open("wordnet_data/imagenet1k_synsets.json", "r") as f:
id_to_synset = json.load(f)
id_to_synset = {int(k): v for k, v in id_to_synset.items()}
elif args.dataset.startswith("tinyimagenet"):
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
synsets = f.readlines()
id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets]
id_to_synset = sorted(id_to_synset)
key_to_idx = {
f"n{id_to_synset[data['label']]:08d}/{data['key'].split('.')[0]}": i
for i, data in enumerate(tqdm(dataset, leave=True))
}
print(len(key_to_idx))
with open(f"{args.segment_root}/{args.mode}_indices.json", "w") as f:
json.dump(key_to_idx, f)

View File

@@ -0,0 +1,223 @@
"""Use grounding DINO + Segment Anything (SAM) to perform grounded segmentation on an image.
Based on: https://github.com/IDEA-Research/Grounded-Segment-Anything
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import requests
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
@dataclass
class BoundingBox:
"""Bounding box representation."""
xmin: int
ymin: int
xmax: int
ymax: int
@property
def xyxy(self) -> List[float]:
"""Return bounding box coordinates.
Returns:
List[float]: coodinates: [xmin, ymin, xmax, ymax]
"""
return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass
class DetectionResult:
"""Detection result from Grounding DINO + Mask from SAM."""
score: float
label: str
box: BoundingBox
mask: Optional[np.array] = None
@classmethod
def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
"""Create a DetectionResult from a dictionary.
Args:
detection_dict (Dict): Detection result dictionary.
Returns:
DetectionResult: Detection result object.
"""
return cls(
score=detection_dict["score"],
label=detection_dict["label"],
box=BoundingBox(
xmin=detection_dict["box"]["xmin"],
ymin=detection_dict["box"]["ymin"],
xmax=detection_dict["box"]["xmax"],
ymax=detection_dict["box"]["ymax"],
),
)
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
"""Use OpenCV to refine a mask by turning it into a polygon.
Args:
mask (np.ndarray): Segmentation mask.
Returns:
List[List[int]]: List of (x, y) coordinates representing the vertices of the polygon.
"""
# Find contours in the binary mask
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find the contour with the largest area
largest_contour = max(contours, key=cv2.contourArea)
# Extract the vertices of the contour
return largest_contour.reshape(-1, 2).tolist()
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
"""Convert a polygon to a segmentation mask.
Args:
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
image_shape (tuple): Shape of the image (height, width) for the mask.
Returns:
np.ndarray: Segmentation mask with the polygon filled.
"""
# Create an empty mask
mask = np.zeros(image_shape, dtype=np.uint8)
# Convert polygon to an array of points
pts = np.array(polygon, dtype=np.int32)
# Fill the polygon with white color (255)
cv2.fillPoly(mask, [pts], color=(255,))
return mask
def load_image(image_str: str) -> Image.Image:
"""Load an image from a URL or file path.
Args:
image_str (str): URL or file path to the image.
Returns:
PIL.Image: Image object.
"""
if image_str.startswith("http"):
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
else:
image = Image.open(image_str).convert("RGB")
return image
def _get_boxes(results: DetectionResult) -> List[List[List[float]]]:
boxes = []
for result in results:
xyxy = result.box.xyxy
boxes.append(xyxy)
return [boxes]
def _refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(axis=-1)
masks = (masks > 0).int()
masks = masks.numpy().astype(np.uint8)
masks = list(masks)
if polygon_refinement:
for idx, mask in enumerate(masks):
shape = mask.shape
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
masks[idx] = mask
return masks
device = "cuda" if torch.cuda.is_available() else "cpu"
detector_id = "IDEA-Research/grounding-dino-tiny"
print(f"load object detector pipeline: {detector_id}")
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
segmenter_id = "facebook/sam-vit-base"
print(f"load segmentator: {segmenter_id}")
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
print(f"load processor: {segmenter_id}")
processor = AutoProcessor.from_pretrained(segmenter_id)
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3) -> List[Dict[str, Any]]:
"""Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion."""
global object_detector, device
labels = [label if label.endswith(".") else label + "." for label in labels]
results = object_detector(image, candidate_labels=labels, threshold=threshold)
return [DetectionResult.from_dict(result) for result in results]
def segment(
image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False
) -> List[DetectionResult]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
global segmentator, processor, device
boxes = _get_boxes(detection_results)
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
outputs = segmentator(**inputs)
masks = processor.post_process_masks(
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes
)[0]
masks = _refine_masks(masks, polygon_refinement)
for detection_result, mask in zip(detection_results, masks):
detection_result.mask = mask
return detection_results
def grounded_segmentation(
image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False
) -> Tuple[torch.Tensor, List[DetectionResult]]:
"""Segment out the objects in an image given a set of labels.
Args:
image (Union[Image.Image, str]): Image to load/work on.
labels (List[str]): Object labels to segment.
threshold (float, optional): Segmentation threshold. Defaults to 0.3.
polygon_refinement (bool, optional): Use polygon refinement on the segmented mask? Defaults to False.
Returns:
Tuple[torch.Tensor, List[DetectionResult]]: Image tensor and list of detection results.
"""
if isinstance(image, str):
image = load_image(image)
detections = detect(image, labels, threshold)
if len(detections) == 0:
return ToTensor()(image), []
detections = segment(image, detections, polygon_refinement)
return ToTensor()(image), detections

View File

@@ -0,0 +1,271 @@
"""Use the LaMa (*LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions*) model to infill the background of an image.
Based on: https://github.com/Sanster/IOPaint/blob/main/iopaint/model/lama.py
"""
import abc
import hashlib
import os
import sys
from enum import Enum
from typing import Optional
from urllib.parse import urlparse
import numpy as np
import torch
from torch.hub import download_url_to_file
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
class LDMSampler(str, Enum):
ddim = "ddim"
plms = "plms"
class SDSampler(str, Enum):
ddim = "ddim"
pndm = "pndm"
k_lms = "k_lms"
k_euler = "k_euler"
k_euler_a = "k_euler_a"
dpm_plus_plus = "dpm++"
uni_pc = "uni_pc"
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
def pad_img_to_modulo(img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None):
"""
Args:
img: [H, W, C]
mod:
square: 是否为正方形
min_size:
Returns:
"""
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
height, width = img.shape[:2]
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
if min_size is not None:
assert min_size % mod == 0
out_width = max(min_size, out_width)
out_height = max(min_size, out_height)
if square:
max_size = max(out_height, out_width)
out_height = max_size
out_width = max_size
return np.pad(
img,
((0, out_height - height), (0, out_width - width), (0, 0)),
mode="symmetric",
)
def undo_pad_to_mod(img: np.ndarray, height: int, width: int):
return img[:height, :width, :]
class InpaintModel:
name = "base"
min_size: Optional[int] = None
pad_mod = 8
pad_to_square = False
def __init__(self, device, **kwargs):
"""
Args:
device:
"""
self.device = device
self.init_model(device, **kwargs)
@abc.abstractmethod
def init_model(self, device, **kwargs): ...
@staticmethod
@abc.abstractmethod
def is_downloaded() -> bool: ...
@abc.abstractmethod
def forward(self, image, mask):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W, 1] 255 为 masks 区域
return: BGR IMAGE
"""
...
def _pad_forward(self, image, mask):
origin_height, origin_width = image.shape[:2]
pad_image = pad_img_to_modulo(image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
pad_mask = pad_img_to_modulo(mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
result = self.forward(pad_image, pad_mask)
# result = result[0:origin_height, 0:origin_width, :]
# result, image, mask = self.forward_post_process(result, image, mask)
mask = mask[:, :, np.newaxis]
result = result[0].permute(1, 2, 0).cpu().numpy() # 400 x 504 x 3
result = undo_pad_to_mod(result, origin_height, origin_width)
assert result.shape[:2] == image.shape[:2], f"{result.shape[:2]} != {image.shape[:2]}"
# result = result * 255
mask = (mask > 0) * 255
result = result * mask + image * (1 - (mask / 255))
result = np.clip(result, 0, 255).astype("uint8")
return result
def forward_post_process(self, result, image, mask):
return result, image, mask
def resize_np_img(np_img, size, interpolation="bicubic"):
assert interpolation in [
"nearest",
"bilinear",
"bicubic",
], f"Unsupported interpolation: {interpolation}, use nearest, bilinear or bicubic."
torch_img = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0).float() / 255
interp_img = torch.nn.functional.interpolate(torch_img, size=size, mode=interpolation, align_corners=True)
return (interp_img.squeeze(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
def resize_max_size(np_img, size_limit: int, interpolation="bicubic") -> np.ndarray:
# Resize image's longer size to size_limit if longer size larger than size_limit
h, w = np_img.shape[:2]
if max(h, w) > size_limit:
ratio = size_limit / max(h, w)
new_w = int(w * ratio + 0.5)
new_h = int(h * ratio + 0.5)
return resize_np_img(np_img, size=(new_w, new_h), interpolation=interpolation)
else:
return np_img
def norm_img(np_img):
if len(np_img.shape) == 2:
np_img = np_img[:, :, np.newaxis]
np_img = np.transpose(np_img, (2, 0, 1))
np_img = np_img.astype("float32") / 255
return np_img
def get_cache_path_by_url(url):
parts = urlparse(url)
hub_dir = "~/.TORCH_HUB_CACHE"
model_dir = os.path.join(hub_dir, "checkpoints")
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
return cached_file
def md5sum(filename):
md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
md5.update(chunk)
return md5.hexdigest()
def download_model(url, model_md5: str = None):
cached_file = get_cache_path_by_url(url)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
download_url_to_file(url, cached_file, hash_prefix, progress=True)
if model_md5:
_md5 = md5sum(cached_file)
if model_md5 == _md5:
print(f"Download model success, md5: {_md5}")
else:
try:
os.remove(cached_file)
print(
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart"
" lama-cleaner.If you still have errors, please try download model manually first"
" https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
)
except:
print(
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart"
" lama-cleaner."
)
exit(-1)
return cached_file
def load_jit_model(url_or_path, device, model_md5: str):
if os.path.exists(url_or_path):
model_path = url_or_path
else:
model_path = download_model(url_or_path, model_md5)
print(f"Loading model from: {model_path}")
try:
model = torch.jit.load(model_path, map_location="cpu").to(device)
except Exception as e:
print(f"Error loading model: {e}")
raise
model.eval()
return model
class LaMa(InpaintModel):
name = "lama"
pad_mod = 8
@torch.no_grad()
def __call__(self, image, mask):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
# boxes = boxes_from_mask(mask)
inpaint_result = self._pad_forward(image, mask)
return inpaint_result
def init_model(self, device, **kwargs):
self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
@staticmethod
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
def forward(self, image, mask):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
image = norm_img(image)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
inpainted_image = self.model(image, mask)
return inpainted_image

View File

@@ -0,0 +1,65 @@
import argparse
import os
from random import shuffle
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
parser = argparse.ArgumentParser(description="Inspect image versions")
parser.add_argument("-f", "--base_folder", type=str, required=True, help="Base folder to inspect")
parser.add_argument("-opt_fg_ratio", type=float, default=0.3, help="Optimal foreground ratio")
args = parser.parse_args()
bg_folder = os.path.join(args.base_folder, "backgrounds")
fg_folder = os.path.join(args.base_folder, "foregrounds")
classes = os.listdir(fg_folder)
classes = sorted(classes, key=lambda x: int(x[1:]))
assert len(classes) in [200, 1_000], f"Expected 200 or 1_000 classes, got {len(classes)}"
total_images = set()
for in_cls in classes:
cls_images = {
os.path.join(in_cls, "_".join(img.split(".")[0].split("_")[:3]))
for img in os.listdir(os.path.join(fg_folder, in_cls))
}
total_images.update(cls_images)
total_images = list(total_images)
shuffle(total_images)
print(total_images[:5], "...")
# in_cls to print name/lemma
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
in_cls_to_name = {line.split(":")[0].strip(): line.split(":")[1].strip() for line in f.readlines() if len(line) > 2}
for image_name in total_images:
in_cls, img_name = image_name.split("/")
versions = set()
for img in os.listdir(os.path.join(fg_folder, in_cls)):
if img.startswith(img_name):
versions.add(img)
if len(versions) <= 1:
print(f"Image {image_name} has only one version")
continue
versions = sorted(list(versions))
bg_img = Image.open(os.path.join(bg_folder, in_cls, f"{versions[0].split('.')[0]}.JPEG"))
# plot all versions at once
fig, axs = plt.subplots(1, len(versions) + 1, figsize=(15, 5))
axs[0].imshow(bg_img)
axs[0].axis("off")
axs[0].set_title("Background")
for version, ax in zip(versions, axs[1:]):
img = Image.open(os.path.join(fg_folder, in_cls, version))
img_mask = np.array(img.convert("RGBA").split()[-1])
fg_dens_fact = np.sum(img_mask) / (255 * img_mask.size)
fg_ratio = np.sum(img_mask) / (255 * bg_img.size[0] * bg_img.size[1])
ax.imshow(img)
ax.axis("off")
ax.set_title(f"{version}\nfg_rat: {fg_ratio:.2f} dist to optimal: {abs(fg_ratio - args.opt_fg_ratio):.2f}")
fig.suptitle(in_cls_to_name[in_cls])
plt.show()
plt.close()

View File

@@ -0,0 +1,11 @@
tqdm
einops
omegaconf
diffusers
opencv-python
transformers
accelerate
torchvision
datadings
numpy
nltk

View File

@@ -0,0 +1,20 @@
#!/bin/bash
#SBATCH --array=0-29%30
#SBATCH --time=1-0
#SBATCH --mem=64G
#SBATCH --gpus=1
#SBATCH --partition=H200,H100,H100-PCI,A100-40GB,A100-80GB
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
#SBATCH --job-name="Segment ImageNet (val)"
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
#SBATCH --wait
srun -K \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py -p 30 -id $SLURM_ARRAY_TASK_ID -o /OUTPUT/FOLDER/PATH/ "$@"

View File

@@ -0,0 +1,20 @@
#!/bin/bash
#SBATCH --array=0-319%60
#SBATCH --time=1-0
#SBATCH --mem=64G
#SBATCH --gpus=1
#SBATCH --partition=H200,H100,H100-PCI,A100-PCI,A100-40GB,A100-80GB
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
#SBATCH --job-name="Segment ImageNet (train)"
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
#SBATCH --wait
srun -K \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py -p 320 -id $SLURM_ARRAY_TASK_ID -o /OUTPUT_PATH/INSegment/ "$@"

View File

@@ -0,0 +1,19 @@
#!/bin/bash
#SBATCH --array=0-4%10
#SBATCH --time=1-0
#SBATCH --mem=64G
#SBATCH --gpus=1
#SBATCH --partition=RTX3090,RTXA6000,A100-PCI,A100-40GB,H100-PCI
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
#SBATCH --job-name="Segment ImageNet"
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
srun -K \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py -p 5 -id $SLURM_ARRAY_TASK_ID -o /PATH/TO/OUT/FOLDER/INSegment/ "$@"

View File

@@ -0,0 +1,19 @@
#!/bin/bash
#SBATCH --array=0-19%20
#SBATCH --time=1-0
#SBATCH --mem=64G
#SBATCH --gpus=1
#SBATCH --partition=RTX3090,RTXA6000,A100-PCI,A100-40GB,H100-PCI
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
#SBATCH --job-name="Segment TinyImageNet"
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
srun -K \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py -p 20 -id $SLURM_ARRAY_TASK_ID "$@"

View File

@@ -0,0 +1,20 @@
#!/bin/bash
#SBATCH --array=0-29%30
#SBATCH --time=1-0
#SBATCH --mem=64G
#SBATCH --gpus=1
#SBATCH --partition=H200,H100,H100-PCI,A100-PCI,A100-40GB,A100-80GB
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
#SBATCH --job-name="Segment TinyImageNet"
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
#SBATCH --wait
srun -K \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py -p 30 -id $SLURM_ARRAY_TASK_ID "$@"

View File

@@ -0,0 +1,338 @@
import argparse
import json
import os
from datetime import datetime
from math import ceil
import numpy as np
import torch
from datadings.reader import MsgpackReader
from datadings.torch import Compose, CompressedToPIL, Dataset
from PIL import Image, ImageFilter
from torchvision.transforms import ToPILImage, ToTensor
from tqdm.auto import tqdm
from attentive_eraser import AttentiveEraser
from grounded_segmentation import grounded_segmentation
from infill_lama import LaMa
from utils import already_segmented, save_img
from wordnet_tree import WNTree
def _collate_single(data):
return data[0]
def _collate_multiple(datas):
return {
"images": [data["image"] for data in datas],
"keys": [data["key"] for data in datas],
"labels": [data["label"] for data in datas],
}
def smallest_crop(image: torch.Tensor, mask: torch.Tensor):
"""Crops the image to just so fit the mask given.
Mask and image have to be of the same size.
Args:
image (torch.Tensor): image to crop
mask (torch.Tensor): cropping to mask
Returns:
torch.Tensor: cropped image
"""
if len(mask.shape) == 3:
assert mask.shape[0] == 1
mask = mask[0]
assert (
len(image.shape) == 3 and len(mask.shape) == 2 and image.shape[1:] == mask.shape
), f"Invalid shapes: {image.shape}, {mask.shape}"
dim0_indices = mask.sum(dim=0).nonzero()
dim1_indices = mask.sum(dim=1).nonzero()
return image[
:,
dim1_indices.min().item() : dim1_indices.max().item() + 1,
dim0_indices.min().item() : dim0_indices.max().item() + 1,
]
def _synset_to_prompt(synset, tree, parent_in_prompt):
prompt = f"an {synset.print_name}." if synset.print_name[0] in "aeiou" else f"a {synset.print_name}."
if synset.parent_id is None or not parent_in_prompt:
return prompt
parent = tree[synset.parent_id]
return prompt[:-1] + f", a type of {parent.print_name}."
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Grounded Segmentation")
parser.add_argument(
"-d",
"--dataset",
choices=["imagenet", "imagenet-val", "imagenet21k", "tinyimagenet", "tinyimagenet-val"],
default="imagenet",
help="Dataset to use",
)
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
parser.add_argument("--debug", action="store_true", help="Debug mode")
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
parser.add_argument(
"-p", "--processes", type=int, default=1, help="Number of processes that are used to process the data"
)
parser.add_argument("-id", type=int, default=0, help="ID of this process)")
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files, instead of skipping them")
parser.add_argument(
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
)
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
parser.add_argument(
"-model",
choices=["LaMa", "AttErase"],
default="LaMa",
help="Model to use for erasing/infilling. Defaults to LaMa.",
)
args = parser.parse_args()
dataset = args.dataset.lower()
part = "train"
if dataset == "imagenet21k":
reader = MsgpackReader(f"{args.dataset_root}imagenet21k/train.msgpack")
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
elif dataset == "imagenet-val":
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/val.msgpack")
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
part = "val"
elif dataset == "imagenet":
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/train.msgpack")
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
elif dataset == "tinyimagenet":
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_train.msgpack")
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
elif dataset == "tinyimagenet-val":
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_val.msgpack")
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
part = "val"
else:
raise ValueError(f"Unknown dataset: {dataset}")
assert 0 <= args.id < args.processes, "ID must be in the range [0, processes)"
if args.processes > 1:
partlen = ceil(len(dataset) / args.processes)
start_id = args.id * partlen
end_id = min((args.id + 1) * partlen, len(dataset))
dataset = torch.utils.data.Subset(dataset, list(range(start_id, end_id)))
assert args.batch_size == 1, "Batch size must be 1 for grounded segmentation (for now)"
dataset = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=args.debug,
drop_last=False,
num_workers=10,
collate_fn=_collate_single if args.batch_size == 1 else _collate_multiple,
)
infill_model = (
LaMa(device="cuda" if torch.cuda.is_available() else "cpu") if args.model == "LaMa" else AttentiveEraser()
)
if args.dataset.startswith("imagenet"):
with open("wordnet_data/imagenet1k_synsets.json", "r") as f:
id_to_synset = json.load(f)
id_to_synset = {int(k): v for k, v in id_to_synset.items()}
elif args.dataset.startswith("tinyimagenet"):
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
synsets = f.readlines()
id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets]
id_to_synset = sorted(id_to_synset)
wordnet = WNTree.load("wordnet_data/imagenet21k+1k_masses_tree.json")
# create the necessary folders
os.makedirs(os.path.join(args.output, part, "foregrounds"), exist_ok=True)
os.makedirs(os.path.join(args.output, part, "backgrounds"), exist_ok=True)
os.makedirs(os.path.join(args.output, part, "no_detect"), exist_ok=True)
os.makedirs(os.path.join(args.output, part, "error"), exist_ok=True)
man_print_progress = int(os.environ.get("TQDM_DISABLE", "0")) == 1
for i, data in tqdm(enumerate(dataset), total=len(dataset)):
if args.debug and i > 10:
break
if man_print_progress and i % 200 == 0:
print(f"{datetime.now()} \t process {args.id}/{args.processes}: \t sample: {i}/{len(dataset)}", flush=True)
image = data["image"]
key = data["key"]
if args.dataset.endswith("-val"):
img_class = f"n{id_to_synset[data['label']]:08d}" # overwrite the synset id on the val set
else:
img_class = key.split("/")[-1].split("_")[0]
if (
already_segmented(key, os.path.join(args.output, part, "foregrounds"), img_class=img_class)
and not args.overwrite
):
continue
# get the next 3 labels up in the wordnet tree
label_id = data["label"]
offset = id_to_synset[label_id] # int(data["key"].split("_")[0][1:])
synset = wordnet[offset]
labels = [_synset_to_prompt(synset, wordnet, args.parent_in_prompt)]
while len(labels) < 1 + args.parent_labels and synset.parent_id is not None:
synset = wordnet[synset.parent_id]
_label = _synset_to_prompt(synset, wordnet, args.parent_in_prompt)
_label = _label.replace("_", " ")
labels.append(_label)
assert len(labels) > 0, f"No labels for {key}"
# 1. detect and segment
image_tensor, detections = grounded_segmentation(
image, labels, threshold=args.threshold, polygon_refinement=True
)
# merge all detections for the same label
masks = {}
for detection in detections:
if detection.label not in masks:
masks[detection.label] = detection.mask
else:
masks[detection.label] |= detection.mask
labels = list(masks.keys())
masks = [masks[lbl] for lbl in labels]
assert len(masks) == len(
labels
), f"Number of masks ({len(masks)}) != number of labels ({len(labels)}); detections: {detections}"
if args.debug:
for detected_mask, label in zip(masks, labels):
detected_mask = torch.from_numpy(detected_mask).unsqueeze(0) / 255
mask_foreground = torch.cat((image_tensor * detected_mask, detected_mask), dim=0)
ToPILImage()(mask_foreground).save(f"example_images/{key}_{label}_masked.png")
if len(masks) == 0:
tqdm.write(f"No detections for {key}; skipping")
save_img(image, key, os.path.join(args.output, part, "no_detect"), img_class=img_class, format="JPEG")
continue
elif len(masks) == 1:
# max_overlap_mask = masks[0]
masks = [masks[0]]
elif args.output_ims == "best":
# find the 2 masks with the largest overlap
max_overlap = 0
max_overlap_mask = None
using_labels = None
for i, mask1 in enumerate(masks):
for j, mask2 in enumerate(masks):
if i >= j:
continue
iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
if iou > max_overlap:
max_overlap = iou
max_overlap_mask = mask1 | mask2
using_labels = f"{labels[i]} & {labels[j]}"
if args.debug:
tqdm.write(f"{key}:\tMax overlap: {max_overlap}, using {using_labels}")
if max_overlap_mask is None:
# assert len(masks) > 0 and len(masks) == len(labels), f"No detections for {key}"
max_overlap_mask = masks[0]
masks = [max_overlap_mask]
else:
# merge masks that are too similar
has_changed = True
while has_changed:
has_changed = False
for i, mask1 in enumerate(masks):
for j, mask2 in enumerate(masks):
if i >= j:
continue
iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
if iou > args.mask_merge_threshold:
masks[i] |= masks[j]
masks.pop(j)
labels.pop(j)
has_changed = True
break
if has_changed:
break
for mask_idx, mask_array in enumerate(masks):
mask = torch.from_numpy(mask_array).unsqueeze(0) / 255
if args.debug:
mask_image = ToPILImage()(mask)
mask_image.save(f"example_images/{key}_mask.png")
# foreground = image_tensor
foreground = torch.cat((image_tensor * mask, mask), dim=0)
foreground = ToPILImage()(foreground)
mask_img = foreground.split()[-1]
foreground = smallest_crop(ToTensor()(foreground), ToTensor()(mask_img)) # TODO: fix smallest crop
fg_img = ToPILImage()(foreground)
background = (1 - mask) * image_tensor
bg_image = ToPILImage()(background)
# 2. infill background
mask_image = Image.fromarray(mask_array)
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=7))
try:
infilled_bg = infill_model(np.array(bg_image), np.array(mask_image))
infilled_bg = Image.fromarray(np.uint8(infilled_bg))
except RuntimeError as e:
tqdm.write(
f"Error infilling {key}: bg_image.shape={np.array(bg_image).shape},"
f" mask_image.shape={np.array(mask_image).shape}\n{e}"
)
save_img(
image,
key,
os.path.join(args.output, part, "error"),
img_class=img_class,
img_version=mask_idx if len(masks) > 1 else None,
format="JPEG",
)
infilled_bg = None
if args.debug:
data["image"].save(f"example_images/{key}_orig.png")
fg_img.save(f"example_images/{key}_fg.png")
infilled_bg.save(f"example_images/{key}_bg.png")
else:
# save files
class_name = key.split("_")[0]
save_img(
fg_img,
key,
os.path.join(args.output, part, "foregrounds"),
img_class=img_class,
img_version=mask_idx if len(masks) > 1 else None,
format="WEBP",
)
if infilled_bg is not None:
save_img(
infilled_bg,
key,
os.path.join(args.output, part, "backgrounds"),
img_class=img_class,
img_version=mask_idx if len(masks) > 1 else None,
format="JPEG",
)
if man_print_progress:
print(
f"{datetime.now()} \t process {args.id}/{args.processes}: \t done with all {len(dataset)} samples",
flush=True,
)

View File

@@ -0,0 +1,15 @@
#!/bin/bash
srun -K \
--partition=H200,H100,A100-40GB,A100-80GB \
--job-name="Segment ImageNet" \
--nodes=1 \
--gpus=1 \
--cpus-per-task=16 \
--mem=64G \
--time=1-0 \
--export=HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/ \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py "$@"

View File

@@ -0,0 +1,13 @@
#!/bin/bash
srun -K \
--partition=RTX3090,A100-40GB,H100 \
--job-name="Segment ImageNet" \
--nodes=1 \
--mem=64G \
--time=1-0 \
--export=HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,MAX_WORKERS=16 \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
"$@"

View File

@@ -0,0 +1,49 @@
import os
from PIL import Image
def save_img(img: Image, img_name: str, base_dir: str, img_class: str = None, format="PNG", img_version=None):
"""Save an image to a directory.
Args:
img (PIL.Image): Image to save.
img_name (str): Relative path to the image.
base_dir (str): Base directory to save images in.
img_class (str, optional): Image class, if not given try to extract it from the image name in ImageNet train format. Defaults to None.
format (str, optional): Format to save the image in. Defaults to "PNG".
img_version (int, optional): Version of the image. Will be appended to the path. Defaults to None.
"""
if not img_name.endswith(f".{format}"):
img_name = f"{img_name.split('.')[0]}.{format}"
if img_class is None:
img_class = img_name.split("_")[0]
if not os.path.exists(os.path.join(base_dir, img_class)):
os.makedirs(os.path.join(base_dir, img_class), exist_ok=True)
if img_version is not None:
img_name = f"{img_name.split('.')[0]}_v{img_version}.{format}"
img.save(os.path.join(base_dir, img_class, img_name), format.lower())
def already_segmented(img_name: str, base_dir: str, img_class: str = None):
"""Check if an image was already segmented.
Args:
img_name (str): Relative path to the image.
base_dir (str): Base directory to save images in.
img_class (str, optional): Image class, if not given try to extract it from the image name in ImageNet train format. Defaults to None.
Returns:
bool: Image was segmented already.
"""
img_base_name = ".".join(img_name.split(".")[:-1]) if "." in img_name else img_name
if img_class is None:
img_class = img_name.split("_")[0]
if not os.path.exists(os.path.join(base_dir, img_class)):
return False
return any(
file.startswith(img_base_name + "_v") or file.startswith(img_base_name + ".")
for file in os.listdir(os.path.join(base_dir, img_class))
)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,200 @@
n07695742: pretzel
n03902125: pay-phone, pay-station
n03980874: poncho
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
n02730930: apron
n02699494: altar
n03201208: dining table, board
n02056570: king penguin, Aptenodytes patagonica
n04099969: rocking chair, rocker
n04366367: suspension bridge
n04067472: reel
n02808440: bathtub, bathing tub, bath, tub
n04540053: volleyball
n02403003: ox
n03100240: convertible
n04562935: water tower
n02788148: bannister, banister, balustrade, balusters, handrail
n02988304: CD player
n02423022: gazelle
n03637318: lampshade, lamp shade
n01774384: black widow, Latrodectus mactans
n01768244: trilobite
n07614500: ice cream, icecream
n04254777: sock
n02085620: Chihuahua
n01443537: goldfish, Carassius auratus
n01629819: European fire salamander, Salamandra salamandra
n02099601: golden retriever
n02321529: sea cucumber, holothurian
n03837869: obelisk
n02002724: black stork, Ciconia nigra
n02841315: binoculars, field glasses, opera glasses
n04560804: water jug
n02364673: guinea pig, Cavia cobaya
n03706229: magnetic compass
n09256479: coral reef
n09332890: lakeside, lakeshore
n03544143: hourglass
n02124075: Egyptian cat
n02948072: candle, taper, wax light
n01950731: sea slug, nudibranch
n02791270: barbershop
n03179701: desk
n02190166: fly
n04275548: spider web, spider's web
n04417672: thatch, thatched roof
n03930313: picket fence, paling
n02236044: mantis, mantid
n03976657: pole
n01774750: tarantula
n04376876: syringe
n04133789: sandal
n02099712: Labrador retriever
n04532670: viaduct
n04487081: trolleybus, trolley coach, trackless trolley
n09428293: seashore, coast, seacoast, sea-coast
n03160309: dam, dike, dyke
n03250847: drumstick
n02843684: birdhouse
n07768694: pomegranate
n03670208: limousine, limo
n03085013: computer keyboard, keypad
n02892201: brass, memorial tablet, plaque
n02233338: cockroach, roach
n03649909: lawn mower, mower
n03388043: fountain
n02917067: bullet train, bullet
n02486410: baboon
n04596742: wok
n03255030: dumbbell
n03937543: pill bottle
n02113799: standard poodle
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
n02906734: broom
n07920052: espresso
n01698640: American alligator, Alligator mississipiensis
n02123394: Persian cat
n03424325: gasmask, respirator, gas helmet
n02129165: lion, king of beasts, Panthera leo
n04008634: projectile, missile
n03042490: cliff dwelling
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
n02815834: beaker
n02395406: hog, pig, grunter, squealer, Sus scrofa
n01784675: centipede
n03126707: crane
n04399382: teddy, teddy bear
n07875152: potpie
n03733131: maypole
n02802426: basketball
n03891332: parking meter
n01910747: jellyfish
n03838899: oboe, hautboy, hautbois
n03770439: miniskirt, mini
n02281406: sulphur butterfly, sulfur butterfly
n03970156: plunger, plumber's helper
n09246464: cliff, drop, drop-off
n02206856: bee
n02074367: dugong, Dugong dugon
n03584254: iPod
n04179913: sewing machine
n04328186: stopwatch, stop watch
n07583066: guacamole
n01917289: brain coral
n03447447: gondola
n02823428: beer bottle
n03854065: organ, pipe organ
n02793495: barn
n04285008: sports car, sport car
n02231487: walking stick, walkingstick, stick insect
n04465501: tractor
n02814860: beacon, lighthouse, beacon light, pharos
n02883205: bow tie, bow-tie, bowtie
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
n04149813: scoreboard
n04023962: punching bag, punch bag, punching ball, punchball
n02226429: grasshopper, hopper
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
n02669723: academic gown, academic robe, judge's robe
n04486054: triumphal arch
n04070727: refrigerator, icebox
n03444034: go-kart
n02666196: abacus
n01945685: slug
n04251144: snorkel
n03617480: kimono
n03599486: jinrikisha, ricksha, rickshaw
n02437312: Arabian camel, dromedary, Camelus dromedarius
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
n04118538: rugby ball
n01770393: scorpion
n04356056: sunglasses, dark glasses, shades
n03804744: nail
n02132136: brown bear, bruin, Ursus arctos
n03400231: frying pan, frypan, skillet
n03983396: pop bottle, soda bottle
n07734744: mushroom
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
n02410509: bison
n03404251: fur coat
n04456115: torch
n02123045: tabby, tabby cat
n03026506: Christmas stocking
n07715103: cauliflower
n04398044: teapot
n02927161: butcher shop, meat market
n07749582: lemon
n07615774: ice lolly, lolly, lollipop, popsicle
n02795169: barrel, cask
n04532106: vestment
n02837789: bikini, two-piece
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
n04265275: space heater
n02481823: chimpanzee, chimp, Pan troglodytes
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
n06596364: comic book
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
n02504458: African elephant, Loxodonta africana
n03014705: chest
n01944390: snail
n04146614: school bus
n01641577: bullfrog, Rana catesbeiana
n07720875: bell pepper
n02999410: chain
n01855672: goose
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
n07753592: banana
n07871810: meat loaf, meatloaf
n04501370: turnstile
n04311004: steel arch bridge
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
n04074963: remote control, remote
n03662601: lifeboat
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
n03089624: confectionery, confectionary, candy store
n04259630: sombrero
n03393912: freight car
n04597913: wooden spoon
n07711569: mashed potato
n03355925: flagpole, flagstaff
n02963159: cardigan
n07579787: plate
n02950826: cannon
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
n02094433: Yorkshire terrier
n02909870: bucket, pail
n02058221: albatross, mollymawk
n01742172: boa constrictor, Constrictor constrictor
n09193705: alp
n04371430: swimming trunks, bathing trunks
n07747607: orange
n03814639: neck brace
n04507155: umbrella
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
n03763968: military uniform
n07873807: pizza, pizza pie
n03992509: potter's wheel
n03796401: moving van
n12267677: acorn

View File

@@ -0,0 +1,400 @@
import argparse
import contextlib
import json
import os
from copy import copy
from nltk.corpus import wordnet as wn
class bcolors:
"""Colors for terminal output."""
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def _lemmas_str(synset):
return ", ".join([lemma.name() for lemma in synset.lemmas()])
class WNEntry:
"""One wordnet synset."""
def __init__(
self,
name: str,
id: int,
lemmas: str,
parent_id: int,
depth: int = None,
in_image_net: bool = False,
child_ids: list = None,
in_main_tree: bool = True,
_n_images: int = 0,
_description: str = None,
_name: str = None,
_pruned: bool = False,
):
self.name = name
self.id = id
self.lemmas = lemmas
self.parent_id = parent_id
self.depth = depth
self.in_image_net = in_image_net
self.child_ids = child_ids
self.in_main_tree = in_main_tree
self._n_images = _n_images
self._description = _description
self._name = _name
self._pruned = _pruned
def __str__(self, tree=None, accumulate=True, colors=True, max_depth=0, max_children=None):
green = f"{bcolors.OKGREEN}" if colors else ""
red = f"{bcolors.FAIL}" if colors else ""
end = f"{bcolors.ENDC}" if colors else ""
start_symb = f"{green}+{end}" if self.in_image_net else f"{red}-{end}"
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
if self.child_ids is None or tree is None or max_depth == 0:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
children = self.child_ids
if max_children is not None and len(children) > max_children:
children = children[:max_children]
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
[
"\n ".join(
tree.nodes[child_id]
.__str__(tree=tree, accumulate=accumulate, colors=colors, max_depth=max_depth - 1)
.split("\n")
)
for child_id in children
]
)
def tree_diff(self, tree_1, tree_2):
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
else:
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
n_ims = (
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
)
if self.child_ids is None:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
)
def prune(self, tree):
if self._pruned or self.parent_id is None:
return
if self.child_ids is not None:
for child_id in self.child_ids:
tree[child_id].prune(tree)
self._pruned = True
parent_node = tree.nodes[self.parent_id]
try:
parent_node.child_ids.remove(self.id)
except ValueError as e:
print(
f"Error removing {self.name} from"
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
)
while parent_node._pruned:
parent_node = tree.nodes[parent_node.parent_id]
parent_node._n_images += self._n_images
self._n_images = 0
@property
def description(self):
if not self._description:
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
return self._description
@property
def print_name(self):
return self.name.split(".")[0]
@property
def is_leaf(self):
return self.child_ids is None or len(self.child_ids) == 0
def get_branch(self, tree=None):
if self.parent_id is None or tree is None:
return self.print_name
parent = tree.nodes[self.parent_id]
return parent.get_branch(tree) + " > " + self.print_name
def get_branch_list(self, tree):
if self.parent_id is None:
return [self]
parent = tree.nodes[self.parent_id]
return parent.get_branch_list(tree) + [self]
def to_dict(self):
return {
"name": self.name,
"id": self.id,
"lemmas": self.lemmas,
"parent_id": self.parent_id,
"depth": self.depth,
"in_image_net": self.in_image_net,
"child_ids": self.child_ids,
"in_main_tree": self.in_main_tree,
"_n_images": self._n_images,
"_description": self._description,
"_name": self._name,
"_pruned": self._pruned,
}
@classmethod
def from_dict(cls, d):
return cls(**d)
def n_images(self, tree=None):
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
return self._n_images
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
def n_children(self, tree=None):
if self.child_ids is None:
return 0
if tree is None or len(self.child_ids) == 0:
return len(self.child_ids)
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
def get_examples(self, tree, n_examples=3):
if self.child_ids is None or len(self.child_ids) == 0:
return ""
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
max_images = max(child_images.values())
if max_images == 0:
# go on number of child nodes
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
# sorted childids by number of images
top_children = [
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
]
top_children = top_children[: min(n_examples, len(top_children))]
return ", ".join(
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
)
class WNTree:
def __init__(self, root=1740, nodes=None):
if isinstance(root, int):
root_id = root
root_synset = wn.synset_from_pos_and_offset("n", root)
root_node = WNEntry(
root_synset.name(),
root_id,
_lemmas_str(root_synset),
parent_id=None,
depth=0,
)
else:
assert isinstance(root, WNEntry)
root_id = root.id
root_node = root
self.root = root_node
self.nodes = {root_id: self.root} if nodes is None else nodes
self.parentless = []
self.label_index = None
self.pruned = set()
def to_dict(self):
return {
"root": self.root.to_dict(),
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
"parentless": self.parentless,
"pruned": list(self.pruned),
}
def prune(self, min_images):
pruned_nodes = set()
# prune all nodes that have fewer than min_images below them
for node_id, node in self.nodes.items():
if node.n_images(self) < min_images:
pruned_nodes.add(node_id)
node.prune(self)
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
node_stack = [self.root]
node_idx = 0
while node_idx < len(node_stack):
node = node_stack[node_idx]
if node.child_ids is not None:
for child_id in node.child_ids:
child = self.nodes[child_id]
node_stack.append(child)
node_idx += 1
# now prune the stack from the bottom up
for node in node_stack[::-1]:
# only look at images of that class, not of additional children
if node.n_images() < min_images:
pruned_nodes.add(node.id)
node.prune(self)
self.pruned = pruned_nodes
return pruned_nodes
@classmethod
def from_dict(cls, d):
tree = cls()
tree.root = WNEntry.from_dict(d["root"])
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
tree.parentless = d["parentless"]
if "pruned" in d:
tree.pruned = set(d["pruned"])
return tree
def add_node(self, node_id, in_in=True):
if node_id in self.nodes:
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
return
synset = wn.synset_from_pos_and_offset("n", node_id)
# print(f"adding node {synset.name()} with id {node_id}")
hypernyms = synset.hypernyms()
if len(hypernyms) == 0:
parent_id = None
self.parentless.append(node_id)
main_tree = False
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
else:
parent_id = synset.hypernyms()[0].offset()
if parent_id not in self.nodes:
self.add_node(parent_id, in_in=False)
parent = self.nodes[parent_id]
if parent.child_ids is None:
parent.child_ids = []
parent.child_ids.append(node_id)
main_tree = parent.in_main_tree
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
node = WNEntry(
synset.name(),
node_id,
_lemmas_str(synset),
parent_id=parent_id,
in_image_net=in_in,
depth=depth,
in_main_tree=main_tree,
)
self.nodes[node_id] = node
def __len__(self):
return len(self.nodes)
def image_net_len(self, only_main_tree=False):
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def max_depth(self, only_main_tree=False):
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def __str__(self, colors=True):
return (
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self, colors=colors)}\nParentless:\n"
+ "\n".join([self.nodes[node_id].__str__(tree=self, colors=colors) for node_id in self.parentless])
)
def save(self, path):
with open(path, "w") as f:
json.dump(self.to_dict(), f)
@classmethod
def load(cls, path):
with open(path, "r") as f:
tree_dict = json.load(f)
return cls.from_dict(tree_dict)
def subtree(self, node):
node_id = self[node].id
if node_id not in self.nodes:
return None
node_queue = [self.nodes[node_id]]
subtree_ids = set()
while len(node_queue) > 0:
node = node_queue.pop(0)
subtree_ids.add(node.id)
if node.child_ids is not None:
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
subtree_root = subtree_nodes[node_id]
subtree_root.parent_id = None
depth_diff = subtree_root.depth
for node in subtree_nodes.values():
node.depth -= depth_diff
return WNTree(root=subtree_root, nodes=subtree_nodes)
def _make_label_index(self):
self.label_index = sorted(
[node_id for node_id, node in self.nodes.items() if node.n_images(self) > 0 and not node._pruned]
)
def get_label(self, node_id):
if self.label_index is None:
self._make_label_index()
while self.nodes[node_id]._pruned:
node_id = self.nodes[node_id].parent_id
return self.label_index.index(node_id)
def n_labels(self):
if self.label_index is None:
self._make_label_index()
return len(self.label_index)
def __contains__(self, item):
if isinstance(item, str):
if item[0] == "n":
item = int(item[1:])
else:
return False
if isinstance(item, int):
return item in self.nodes
if isinstance(item, WNEntry):
return item.id in self.nodes
return False
def __getitem__(self, item):
if isinstance(item, str) and item[0].startswith("n"):
with contextlib.suppress(ValueError):
item = int(item[1:])
if isinstance(item, str) and ".n." in item:
for node in self.nodes.values():
if item == node.name:
return node
raise KeyError(f"Item {item} not found in tree")
if isinstance(item, int):
return self.nodes[item]
if isinstance(item, WNEntry):
return self.nodes[item.id]
raise KeyError(f"Item {item} not found in tree")
def __iter__(self):
return iter(self.nodes.keys())