Compare commits
1 Commits
5c08f9d31a
...
aaai26
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff34712155 |
BIN
AAAI Supplementary Material/ForAug_Appendix.pdf
Normal file
BIN
AAAI Supplementary Material/ForAug_Supplementary_Material.zip
Normal file
73
AAAI Supplementary Material/ForNet Creation Code/.ruff.toml
Normal file
@@ -0,0 +1,73 @@
|
||||
# Exclude a variety of commonly ignored directories.
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".ipynb_checkpoints",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pyenv",
|
||||
".pytest_cache",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
".vscode",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"site-packages",
|
||||
"venv",
|
||||
]
|
||||
|
||||
# Same as Black.
|
||||
line-length = 120
|
||||
indent-width = 4
|
||||
|
||||
[lint]
|
||||
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
|
||||
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
|
||||
# McCabe complexity (`C901`) by default.
|
||||
select = ["E", "F", "D", "B", "NPY", "PD", "TD005", "TD006", "TD007", "SIM", "RET", "Q", "ICN", "I"]
|
||||
ignore = ["D203", "D213", "E501", "D100", "NPY002", "D102", "B008", "PD011", "D105", "SIM118"]
|
||||
|
||||
# Allow fix for all enabled rules (when `--fix`) is provided.
|
||||
unfixable = ["B"]
|
||||
|
||||
# Allow unused variables when underscore-prefixed.
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
[format]
|
||||
# Like Black, use double quotes for strings.
|
||||
quote-style = "double"
|
||||
|
||||
# Like Black, indent with spaces, rather than tabs.
|
||||
indent-style = "space"
|
||||
|
||||
# Like Black, respect magic trailing commas.
|
||||
skip-magic-trailing-comma = true
|
||||
|
||||
# Like Black, automatically detect the appropriate line ending.
|
||||
line-ending = "auto"
|
||||
|
||||
# Enable auto-formatting of code examples in docstrings. Markdown,
|
||||
# reStructuredText code/literal blocks and doctests are all supported.
|
||||
#
|
||||
# This is currently disabled by default, but it is planned for this
|
||||
# to be opt-out in the future.
|
||||
docstring-code-format = true
|
||||
|
||||
# Set the line length limit used when formatting code snippets in
|
||||
# docstrings.
|
||||
#
|
||||
# This only has an effect when the `docstring-code-format` setting is
|
||||
# enabled.
|
||||
docstring-code-line-length = "dynamic"
|
||||
61
AAAI Supplementary Material/ForNet Creation Code/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Creating the ForNet Dataset
|
||||
|
||||
We can't just provide the ForNet dataset here, as it's too large to be part of the appendix and using a link will go against double-blind review rules.
|
||||
After acceptance, the dataset will be downloadable online.
|
||||
For now, we provide the scripts and steps to recreate the dataset.
|
||||
In general, if you are unsure what arguments each script allows, run it using the `--help` flag.
|
||||
|
||||
## 1. Setup paths
|
||||
|
||||
Fill in the paths in `experiments/general_srun` in the `Model Training Code` folder, as well as in `srun-general.sh`, `slurm-segment-imnet.sh` and all the `sbatch-segment-...` files.
|
||||
In particular the `--container-image`, `--container-mounts`, `--output` and `NLTK_DATA` and `HF_HOME` paths in `--export`.
|
||||
|
||||
## 2. Pretrain Filtering Models
|
||||
|
||||
Use the `Model Trainig Code` to pretrain an ensemble of models to use for filtering in a later step.
|
||||
Train those models on either `TinyImageNet` or `ImageNet`, depending on if you want to create `TinyForNet` or `ForNet`.
|
||||
The fill in the relevant paths to the pretrained weights in `experiments/filter_segmentation_versions.py` lines 96/98.
|
||||
|
||||
## 3. Create the dataset
|
||||
|
||||
### Automatically: using slurm
|
||||
|
||||
You may just run the `create_dataset.py` file (on a slurm head node). That file will automatically run all the necessary steps one after another.
|
||||
|
||||
### Manually and step-by-step
|
||||
|
||||
If you want to run each step of the pipeline manually, follow these steps.
|
||||
For default arguments and settings, see the `create_dataset.py` script, even though you may not want to run it directly, it can tell you how to run all the other scripts.
|
||||
|
||||
#### 3.1 Segment Objects and Backgrounds
|
||||
|
||||
Use the segementation script (`segment_imagenet.py`) to segment each of the dataset images.
|
||||
Watch out, as this script uses `datadings` for image loading, so you need to provide a `datadings` variant of your dataset.
|
||||
You need to provide the root folder of the dataset.
|
||||
Choose your segmentation model using the `-model` argument (LaMa or AttErase).
|
||||
If you want to use the >general< prompting strategy, set the `--parent_in_promt` flag.
|
||||
Use `--output`/`-o` to set the output directory.
|
||||
Use `--processes` and `-id` for splitting the task up into multiple parallelizable processes.
|
||||
|
||||
#### 3.2 Filter the segmented images
|
||||
|
||||
In this step, you use the pretrained ensemble of models (from step 2) for filtering the segmented images.
|
||||
As this step is based on the training and model code, it's in the `Model Training Code` directory.
|
||||
After setting the relevant paths to the pretrained weights (see step 2), you may run the `experiments/filter_segmentation_versions.py` script using that directory as the PWD.
|
||||
|
||||
#### 3.3 Zip the dataset
|
||||
|
||||
In distributed storage settings it might be useful to read from one large (unclompressed) zip file instead of reading millions of small single files.
|
||||
To do this, run
|
||||
|
||||
```commandline
|
||||
zip -r -0 backgrounds_train.zip train/backgrounds > /dev/null 2>&1
|
||||
```
|
||||
|
||||
for the train and val backgrounds and foregrounds
|
||||
|
||||
#### 3.4 Compute the foreground size ratios
|
||||
|
||||
For the resizing step during recombination, the relative size of each object in each image is needed.
|
||||
To compute it, run the `foreground_size_ratio.py` script on your filtered dataset.
|
||||
It expects the zipfiled in the folder you provide as `-ds`.
|
||||
@@ -0,0 +1,98 @@
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision.transforms.functional import gaussian_blur, to_tensor
|
||||
|
||||
|
||||
def _preprocess_image(image, device, dtype=torch.float16):
|
||||
image = to_tensor(image)
|
||||
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
|
||||
if image.shape[1] != 3:
|
||||
image = image.expand(-1, 3, -1, -1)
|
||||
image = F.interpolate(image, (1024, 1024))
|
||||
return image.to(dtype).to(device)
|
||||
|
||||
|
||||
def _preprocess_mask(mask, device, dtype=torch.float16):
|
||||
mask = to_tensor(mask.convert("L"))
|
||||
mask = mask.unsqueeze_(0).float() # 0 or 1
|
||||
mask = F.interpolate(mask, (1024, 1024))
|
||||
mask = gaussian_blur(mask, kernel_size=(77, 77))
|
||||
mask[mask < 0.1] = 0
|
||||
mask[mask >= 0.1] = 1
|
||||
return mask.to(dtype).to(device)
|
||||
|
||||
|
||||
class AttentiveEraser:
|
||||
"""Attentive Eraser Pipeline + Pre- and Post-Processing."""
|
||||
|
||||
prompt = ""
|
||||
dtype = torch.float16
|
||||
|
||||
def __init__(self, num_steps=50, device=None):
|
||||
"""Create the attentive eraser.
|
||||
|
||||
Args:
|
||||
num_steps (int, optional): Number of steps in the diffusion process. Will start at 20%. Defaults to 100.
|
||||
device (_type_, optional): Device to run on. Defaults to 'cuda' if available, else 'cpu'.
|
||||
|
||||
"""
|
||||
self.num_steps = num_steps
|
||||
|
||||
self.device = device if device else torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
|
||||
)
|
||||
self.model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
self.pipeline = DiffusionPipeline.from_pretrained(
|
||||
self.model_path,
|
||||
custom_pipeline="./pipelines/pipeline_stable_diffusion_xl_attentive_eraser.py",
|
||||
scheduler=scheduler,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
torch_dtype=self.dtype,
|
||||
).to(self.device)
|
||||
self.pipeline.enable_attention_slicing()
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
|
||||
def __call__(self, image, mask):
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if isinstance(mask, np.ndarray):
|
||||
mask = Image.fromarray(mask)
|
||||
|
||||
prep_img = _preprocess_image(image, self.device, self.dtype)
|
||||
prep_mask = _preprocess_mask(mask, self.device, self.dtype)
|
||||
orig_shape = image.size
|
||||
|
||||
diff_img = (
|
||||
self.pipeline(
|
||||
prompt=self.prompt,
|
||||
image=prep_img,
|
||||
mask_image=prep_mask,
|
||||
height=1024,
|
||||
width=1024,
|
||||
AAS=True, # enable AAS
|
||||
strength=0.8, # inpainting strength
|
||||
rm_guidance_scale=9, # removal guidance scale
|
||||
ss_steps=9, # similarity suppression steps
|
||||
ss_scale=0.3, # similarity suppression scale
|
||||
AAS_start_step=0, # AAS start step
|
||||
AAS_start_layer=34, # AAS start layer
|
||||
AAS_end_layer=70, # AAS end layer
|
||||
num_inference_steps=self.num_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
|
||||
guidance_scale=1,
|
||||
)
|
||||
.images[0]
|
||||
.resize(orig_shape)
|
||||
)
|
||||
|
||||
paste_mask = Image.fromarray(np.array(mask) > 0 * 255)
|
||||
paste_mask = paste_mask.convert("RGB").filter(ImageFilter.GaussianBlur(radius=10)).convert("L")
|
||||
out_img = copy(image)
|
||||
out_img.paste(diff_img, (0, 0), paste_mask)
|
||||
return out_img
|
||||
@@ -0,0 +1,195 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
parser = argparse.ArgumentParser(description="Create Segment & Recombine dataset")
|
||||
parser.add_argument("-d", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
|
||||
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
|
||||
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
|
||||
parser.add_argument(
|
||||
"--out_root", type=str, required=True, help="Root directory where the output directory will be created."
|
||||
)
|
||||
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
|
||||
parser.add_argument(
|
||||
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
|
||||
)
|
||||
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
|
||||
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
|
||||
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
|
||||
parser.add_argument("-infill_model", choices=["LaMa", "AttErase"], default="LaMa", help="Infilling model to use")
|
||||
parser.add_argument("--continue", dest="continue_", action="store_true", help="Continue from previous run")
|
||||
|
||||
args = parser.parse_args()
|
||||
out_root = args.out_root
|
||||
|
||||
base_folder = os.path.dirname(__file__)
|
||||
training_code_folder = os.path.join(base_folder, os.pardir, "Model Training Code")
|
||||
|
||||
ds_name_re = re.compile(r"(Tiny)?INSegment_v(\d*)(_f\d*)?")
|
||||
name_match = ds_name_re.match(args.output)
|
||||
assert name_match, f"Output name {args.output} does not match the expected format: {ds_name_re.pattern}"
|
||||
assert (
|
||||
args.output_ims == "best" or name_match.group(3) is None
|
||||
), "For output_ims == 'all', the filter subversions will be automatically created."
|
||||
assert args.continue_ or not os.path.exists(out_root + args.output), f"Output directory {args.output} already exists."
|
||||
|
||||
settings_file = {
|
||||
"dataset": args.dataset,
|
||||
"threshold": args.threshold,
|
||||
"output_ims": args.output_ims,
|
||||
"mask_merge_threshold": args.mask_merge_threshold,
|
||||
"parent_labels": args.parent_labels,
|
||||
"parent_in_prompt": args.parent_in_prompt,
|
||||
"infill_model": args.infill_model,
|
||||
}
|
||||
settings_file = [f"{k} = {str(settings_file[k])}" for k in sorted(list(settings_file.keys()))]
|
||||
if os.path.exists(out_root + args.output) and os.path.exists(out_root + args.output + "/settings.txt"):
|
||||
with open(out_root + args.output + "/settings.txt", "r") as f:
|
||||
old_settings = f.read().split("\n")
|
||||
old_settings = [line.strip() for line in old_settings if line if len(line.strip()) > 0]
|
||||
assert old_settings == settings_file, (
|
||||
f"Settings file {out_root + args.output}/settings.txt does not match current settings: old: {old_settings} vs"
|
||||
f" new: {settings_file}"
|
||||
)
|
||||
else:
|
||||
os.makedirs(out_root + args.output, exist_ok=True)
|
||||
with open(out_root + args.output + "/settings.txt", "w") as f:
|
||||
f.write("\n".join(settings_file) + "\n")
|
||||
|
||||
general_args = [
|
||||
"sbatch",
|
||||
"sbatch-segment-tinyimnet-wait" if args.dataset == "tinyimagenet" else "sbatch-segment-imagenet-wait",
|
||||
"-r",
|
||||
args.dataset_root,
|
||||
"-o",
|
||||
out_root + args.output,
|
||||
"--parent_labels",
|
||||
str(args.parent_labels),
|
||||
"--output_ims",
|
||||
args.output_ims,
|
||||
"--mask_merge_threshold",
|
||||
str(args.mask_merge_threshold),
|
||||
"-t",
|
||||
str(args.threshold),
|
||||
"-model",
|
||||
args.infill_model,
|
||||
]
|
||||
|
||||
if args.parent_in_prompt:
|
||||
general_args.append("--parent_in_prompt")
|
||||
|
||||
print(f"Starting segmentation: {' '.join(general_args)} for {args.dataset}-val and {args.dataset}")
|
||||
p_train = subprocess.Popen(general_args + ["-d", args.dataset], cwd=base_folder)
|
||||
if args.dataset == "imagenet":
|
||||
general_args[1] = "sbatch-segment-imagenet-val-wait"
|
||||
p_val = subprocess.Popen(general_args + ["-d", args.dataset + "-val"], cwd=base_folder)
|
||||
|
||||
# detect if exit in error
|
||||
p_val.wait()
|
||||
p_train.wait()
|
||||
rcodes = (p_val.returncode, p_train.returncode)
|
||||
if any(rcode != 0 for rcode in rcodes):
|
||||
print(f"Error in segmentation (val, train): {rcodes}")
|
||||
exit(1)
|
||||
print("Segmentation done.")
|
||||
|
||||
if args.output_ims == "all":
|
||||
print("copy to subversions for filtering")
|
||||
p_1 = subprocess.Popen(
|
||||
[
|
||||
"cp",
|
||||
"-rl",
|
||||
os.path.join(out_root, args.output),
|
||||
out_root + f"{args.output}_f1/",
|
||||
]
|
||||
)
|
||||
p_2 = subprocess.Popen(
|
||||
[
|
||||
"cp",
|
||||
"-rl",
|
||||
os.path.join(out_root, args.output),
|
||||
out_root + f"{args.output}_f2/",
|
||||
]
|
||||
)
|
||||
p_1.wait()
|
||||
p_2.wait()
|
||||
print("Filtering subversions copied over.")
|
||||
|
||||
print("Starting filtering")
|
||||
filtering_args = ["./experiments/general_srun.sh", "experiments/filter_segmentation_versions.py"]
|
||||
p_val_f1 = subprocess.Popen(
|
||||
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/val", "-d", args.dataset], cwd=training_code_folder
|
||||
)
|
||||
p_train_f1 = subprocess.Popen(
|
||||
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/train", "-d", args.dataset], cwd=training_code_folder
|
||||
)
|
||||
p_val_f2 = subprocess.Popen(
|
||||
filtering_args
|
||||
+ ["-f", out_root + f"{args.output}_f2" + "/val", "-score_f_weights", "automatic", "-d", args.dataset],
|
||||
cwd=training_code_folder,
|
||||
)
|
||||
p_train_f2 = subprocess.Popen(
|
||||
filtering_args
|
||||
+ ["-f", out_root + f"{args.output}_f2" + "/train", "-score_f_weights", "automatic", "-d", args.dataset],
|
||||
cwd=training_code_folder,
|
||||
)
|
||||
|
||||
p_val_f1.wait()
|
||||
p_train_f1.wait()
|
||||
p_val_f2.wait()
|
||||
p_train_f2.wait()
|
||||
|
||||
ds_folders = [f"{args.output}_f1", f"{args.output}_f2"]
|
||||
else:
|
||||
if name_match.group(2) is None:
|
||||
new_name = args.output + "_f1"
|
||||
print(f"Renaming output folder: {args.output} -> {new_name}")
|
||||
os.rename(out_root + args.output, out_root + new_name)
|
||||
else:
|
||||
new_name = args.output
|
||||
ds_folders = [new_name]
|
||||
|
||||
for folder in ds_folders:
|
||||
print(f"Zipping up {folder}")
|
||||
p_val_fg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "foregrounds_val.zip", "val/foregrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
|
||||
)
|
||||
p_train_fg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "foregrounds_train.zip", "train/foregrounds", ">", "/dev/null", "2>&1"],
|
||||
cwd=out_root + folder,
|
||||
)
|
||||
p_val_bg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "backgrounds_val.zip", "val/backgrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
|
||||
)
|
||||
p_train_bg = subprocess.Popen(
|
||||
["zip", "-r", "-0", "backgrounds_train.zip", "train/backgrounds", ">", "/dev/null", "2>&1"],
|
||||
cwd=out_root + folder,
|
||||
)
|
||||
|
||||
p_val_fg.wait()
|
||||
p_train_fg.wait()
|
||||
p_val_bg.wait()
|
||||
p_train_bg.wait()
|
||||
|
||||
print(f"Gathering foreground size ratios for {folder}")
|
||||
p_val = subprocess.Popen(
|
||||
[
|
||||
"./srun-general.sh",
|
||||
"python",
|
||||
"foreground_size_ratio.py",
|
||||
"--root",
|
||||
args.dataset_root,
|
||||
"-ds",
|
||||
folder,
|
||||
"-mode",
|
||||
"val",
|
||||
],
|
||||
cwd=base_folder,
|
||||
)
|
||||
p_train = subprocess.Popen(
|
||||
["./srun-general.sh", "python", "foreground_size_ratio.py", "--root", args.dataset_root, "-ds", folder],
|
||||
cwd=base_folder,
|
||||
)
|
||||
p_val.wait()
|
||||
p_train.wait()
|
||||
|
After Width: | Height: | Size: 402 KiB |
|
After Width: | Height: | Size: 46 KiB |
|
After Width: | Height: | Size: 36 KiB |
|
After Width: | Height: | Size: 49 KiB |
|
After Width: | Height: | Size: 148 KiB |
|
After Width: | Height: | Size: 136 KiB |
|
After Width: | Height: | Size: 46 KiB |
@@ -0,0 +1,69 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser(description="Foreground size ratio")
|
||||
parser.add_argument("-mode", choices=["train", "val"], default="train", help="Train or val data?")
|
||||
parser.add_argument("-ds", "--dataset", type=str, required=True, help="Dataset to use")
|
||||
parser.add_argument("--root", type=str, required=True, help="Root folder of the dataset")
|
||||
args = parser.parse_args()
|
||||
|
||||
# if not os.path.exists("foreground_size_ratios.json"):
|
||||
train = args.mode == "train"
|
||||
root = os.path.join(args.root, args.dataset) if not args.dataset.startswith("/") else args.dataset
|
||||
|
||||
cls_bg_ratios = {}
|
||||
fg_bg_ratio_map = {}
|
||||
|
||||
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip, zipfile.ZipFile(
|
||||
f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r"
|
||||
) as fg_zip:
|
||||
backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
|
||||
foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
|
||||
|
||||
print(f"Bgs: {backgrounds[:5]}, ...\nFgs: {foregrounds[:5]}, ...")
|
||||
|
||||
for bg_name in tqdm(backgrounds):
|
||||
|
||||
fg_name = bg_name.replace("backgrounds", "foregrounds").replace("JPEG", "WEBP")
|
||||
if fg_name not in foregrounds:
|
||||
tqdm.write(f"Skipping {bg_name} as it has no corresponding foreground")
|
||||
fg_bg_ratio_map[bg_name] = 0.0
|
||||
continue
|
||||
|
||||
with bg_zip.open(bg_name) as f:
|
||||
bg_data = BytesIO(f.read())
|
||||
try:
|
||||
bg_img = Image.open(bg_data)
|
||||
except PIL.UnidentifiedImageError as e:
|
||||
print(f"Error with file={bg_name}")
|
||||
raise e
|
||||
bg_img_size = bg_img.size
|
||||
|
||||
with fg_zip.open(fg_name) as f:
|
||||
fg_data = BytesIO(f.read())
|
||||
try:
|
||||
fg_img = Image.open(fg_data)
|
||||
except PIL.UnidentifiedImageError as e:
|
||||
print(f"Error with file={fg_name}")
|
||||
raise e
|
||||
fg_img_size = fg_img.size
|
||||
fg_img_pixel_size = int(np.sum(fg_img.split()[-1]) / 255)
|
||||
|
||||
img_cls = bg_name.split("/")[-2]
|
||||
if img_cls not in cls_bg_ratios:
|
||||
cls_bg_ratios[img_cls] = []
|
||||
|
||||
cls_bg_ratios[img_cls].append(fg_img_pixel_size / (bg_img_size[0] * bg_img_size[1]))
|
||||
fg_bg_ratio_map[bg_name] = fg_img_pixel_size / (bg_img_size[0] * bg_img_size[1])
|
||||
|
||||
with open(f"{root}/fg_bg_ratios_{args.mode}.json", "w") as f:
|
||||
json.dump(fg_bg_ratio_map, f)
|
||||
print(f"Saved fg_bg_ratios_{args.mode} to disk. Exiting.")
|
||||
@@ -0,0 +1,42 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from datadings.reader import MsgpackReader
|
||||
from datadings.torch import Dataset
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser(description="Foreground size ratio")
|
||||
parser.add_argument("-mode", choices=["train", "val"], default="train", help="Train or val data?")
|
||||
parser.add_argument("-ds", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
|
||||
parser.add_argument("-r", "--dataset_root", required=True, type=str, help="Root directory of the dataset")
|
||||
parser.add_argument("-s", "--segment_root", required=True, type=str, help="Root directory of the segmentation dataset")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.dataset == "imagenet":
|
||||
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/{args.mode}.msgpack")
|
||||
dataset = Dataset(reader)
|
||||
elif args.dataset == "tinyimagenet":
|
||||
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_{args.mode}.msgpack")
|
||||
dataset = Dataset(reader)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset}")
|
||||
|
||||
if args.dataset.startswith("imagenet"):
|
||||
with open("wordnet_data/imagenet1k_synsets.json", "r") as f:
|
||||
id_to_synset = json.load(f)
|
||||
id_to_synset = {int(k): v for k, v in id_to_synset.items()}
|
||||
elif args.dataset.startswith("tinyimagenet"):
|
||||
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
|
||||
synsets = f.readlines()
|
||||
id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets]
|
||||
id_to_synset = sorted(id_to_synset)
|
||||
|
||||
key_to_idx = {
|
||||
f"n{id_to_synset[data['label']]:08d}/{data['key'].split('.')[0]}": i
|
||||
for i, data in enumerate(tqdm(dataset, leave=True))
|
||||
}
|
||||
|
||||
print(len(key_to_idx))
|
||||
with open(f"{args.segment_root}/{args.mode}_indices.json", "w") as f:
|
||||
json.dump(key_to_idx, f)
|
||||
@@ -0,0 +1,223 @@
|
||||
"""Use grounding DINO + Segment Anything (SAM) to perform grounded segmentation on an image.
|
||||
|
||||
Based on: https://github.com/IDEA-Research/Grounded-Segment-Anything
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToTensor
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
"""Bounding box representation."""
|
||||
|
||||
xmin: int
|
||||
ymin: int
|
||||
xmax: int
|
||||
ymax: int
|
||||
|
||||
@property
|
||||
def xyxy(self) -> List[float]:
|
||||
"""Return bounding box coordinates.
|
||||
|
||||
Returns:
|
||||
List[float]: coodinates: [xmin, ymin, xmax, ymax]
|
||||
|
||||
"""
|
||||
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Detection result from Grounding DINO + Mask from SAM."""
|
||||
|
||||
score: float
|
||||
label: str
|
||||
box: BoundingBox
|
||||
mask: Optional[np.array] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
|
||||
"""Create a DetectionResult from a dictionary.
|
||||
|
||||
Args:
|
||||
detection_dict (Dict): Detection result dictionary.
|
||||
|
||||
Returns:
|
||||
DetectionResult: Detection result object.
|
||||
|
||||
"""
|
||||
return cls(
|
||||
score=detection_dict["score"],
|
||||
label=detection_dict["label"],
|
||||
box=BoundingBox(
|
||||
xmin=detection_dict["box"]["xmin"],
|
||||
ymin=detection_dict["box"]["ymin"],
|
||||
xmax=detection_dict["box"]["xmax"],
|
||||
ymax=detection_dict["box"]["ymax"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
|
||||
"""Use OpenCV to refine a mask by turning it into a polygon.
|
||||
|
||||
Args:
|
||||
mask (np.ndarray): Segmentation mask.
|
||||
|
||||
Returns:
|
||||
List[List[int]]: List of (x, y) coordinates representing the vertices of the polygon.
|
||||
|
||||
"""
|
||||
# Find contours in the binary mask
|
||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Find the contour with the largest area
|
||||
largest_contour = max(contours, key=cv2.contourArea)
|
||||
|
||||
# Extract the vertices of the contour
|
||||
return largest_contour.reshape(-1, 2).tolist()
|
||||
|
||||
|
||||
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
|
||||
"""Convert a polygon to a segmentation mask.
|
||||
|
||||
Args:
|
||||
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
|
||||
image_shape (tuple): Shape of the image (height, width) for the mask.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Segmentation mask with the polygon filled.
|
||||
|
||||
"""
|
||||
# Create an empty mask
|
||||
mask = np.zeros(image_shape, dtype=np.uint8)
|
||||
|
||||
# Convert polygon to an array of points
|
||||
pts = np.array(polygon, dtype=np.int32)
|
||||
|
||||
# Fill the polygon with white color (255)
|
||||
cv2.fillPoly(mask, [pts], color=(255,))
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def load_image(image_str: str) -> Image.Image:
|
||||
"""Load an image from a URL or file path.
|
||||
|
||||
Args:
|
||||
image_str (str): URL or file path to the image.
|
||||
|
||||
Returns:
|
||||
PIL.Image: Image object.
|
||||
|
||||
"""
|
||||
if image_str.startswith("http"):
|
||||
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
|
||||
else:
|
||||
image = Image.open(image_str).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _get_boxes(results: DetectionResult) -> List[List[List[float]]]:
|
||||
boxes = []
|
||||
for result in results:
|
||||
xyxy = result.box.xyxy
|
||||
boxes.append(xyxy)
|
||||
|
||||
return [boxes]
|
||||
|
||||
|
||||
def _refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
|
||||
masks = masks.cpu().float()
|
||||
masks = masks.permute(0, 2, 3, 1)
|
||||
masks = masks.mean(axis=-1)
|
||||
masks = (masks > 0).int()
|
||||
masks = masks.numpy().astype(np.uint8)
|
||||
masks = list(masks)
|
||||
|
||||
if polygon_refinement:
|
||||
for idx, mask in enumerate(masks):
|
||||
shape = mask.shape
|
||||
polygon = mask_to_polygon(mask)
|
||||
mask = polygon_to_mask(polygon, shape)
|
||||
masks[idx] = mask
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
detector_id = "IDEA-Research/grounding-dino-tiny"
|
||||
print(f"load object detector pipeline: {detector_id}")
|
||||
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
|
||||
|
||||
segmenter_id = "facebook/sam-vit-base"
|
||||
print(f"load segmentator: {segmenter_id}")
|
||||
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
|
||||
print(f"load processor: {segmenter_id}")
|
||||
processor = AutoProcessor.from_pretrained(segmenter_id)
|
||||
|
||||
|
||||
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3) -> List[Dict[str, Any]]:
|
||||
"""Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion."""
|
||||
global object_detector, device
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
results = object_detector(image, candidate_labels=labels, threshold=threshold)
|
||||
return [DetectionResult.from_dict(result) for result in results]
|
||||
|
||||
|
||||
def segment(
|
||||
image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False
|
||||
) -> List[DetectionResult]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
global segmentator, processor, device
|
||||
boxes = _get_boxes(detection_results)
|
||||
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
|
||||
|
||||
outputs = segmentator(**inputs)
|
||||
masks = processor.post_process_masks(
|
||||
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes
|
||||
)[0]
|
||||
|
||||
masks = _refine_masks(masks, polygon_refinement)
|
||||
|
||||
for detection_result, mask in zip(detection_results, masks):
|
||||
detection_result.mask = mask
|
||||
|
||||
return detection_results
|
||||
|
||||
|
||||
def grounded_segmentation(
|
||||
image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False
|
||||
) -> Tuple[torch.Tensor, List[DetectionResult]]:
|
||||
"""Segment out the objects in an image given a set of labels.
|
||||
|
||||
Args:
|
||||
image (Union[Image.Image, str]): Image to load/work on.
|
||||
labels (List[str]): Object labels to segment.
|
||||
threshold (float, optional): Segmentation threshold. Defaults to 0.3.
|
||||
polygon_refinement (bool, optional): Use polygon refinement on the segmented mask? Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, List[DetectionResult]]: Image tensor and list of detection results.
|
||||
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
image = load_image(image)
|
||||
|
||||
detections = detect(image, labels, threshold)
|
||||
if len(detections) == 0:
|
||||
return ToTensor()(image), []
|
||||
detections = segment(image, detections, polygon_refinement)
|
||||
|
||||
return ToTensor()(image), detections
|
||||
271
AAAI Supplementary Material/ForNet Creation Code/infill_lama.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Use the LaMa (*LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions*) model to infill the background of an image.
|
||||
|
||||
Based on: https://github.com/Sanster/IOPaint/blob/main/iopaint/model/lama.py
|
||||
"""
|
||||
|
||||
import abc
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.hub import download_url_to_file
|
||||
|
||||
LAMA_MODEL_URL = os.environ.get(
|
||||
"LAMA_MODEL_URL",
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
)
|
||||
LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
|
||||
|
||||
|
||||
class LDMSampler(str, Enum):
|
||||
ddim = "ddim"
|
||||
plms = "plms"
|
||||
|
||||
|
||||
class SDSampler(str, Enum):
|
||||
ddim = "ddim"
|
||||
pndm = "pndm"
|
||||
k_lms = "k_lms"
|
||||
k_euler = "k_euler"
|
||||
k_euler_a = "k_euler_a"
|
||||
dpm_plus_plus = "dpm++"
|
||||
uni_pc = "uni_pc"
|
||||
|
||||
|
||||
def ceil_modulo(x, mod):
|
||||
if x % mod == 0:
|
||||
return x
|
||||
return (x // mod + 1) * mod
|
||||
|
||||
|
||||
def pad_img_to_modulo(img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
img: [H, W, C]
|
||||
mod:
|
||||
square: 是否为正方形
|
||||
min_size:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if len(img.shape) == 2:
|
||||
img = img[:, :, np.newaxis]
|
||||
height, width = img.shape[:2]
|
||||
out_height = ceil_modulo(height, mod)
|
||||
out_width = ceil_modulo(width, mod)
|
||||
|
||||
if min_size is not None:
|
||||
assert min_size % mod == 0
|
||||
out_width = max(min_size, out_width)
|
||||
out_height = max(min_size, out_height)
|
||||
|
||||
if square:
|
||||
max_size = max(out_height, out_width)
|
||||
out_height = max_size
|
||||
out_width = max_size
|
||||
|
||||
return np.pad(
|
||||
img,
|
||||
((0, out_height - height), (0, out_width - width), (0, 0)),
|
||||
mode="symmetric",
|
||||
)
|
||||
|
||||
|
||||
def undo_pad_to_mod(img: np.ndarray, height: int, width: int):
|
||||
return img[:height, :width, :]
|
||||
|
||||
|
||||
class InpaintModel:
|
||||
name = "base"
|
||||
min_size: Optional[int] = None
|
||||
pad_mod = 8
|
||||
pad_to_square = False
|
||||
|
||||
def __init__(self, device, **kwargs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
self.device = device
|
||||
self.init_model(device, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_model(self, device, **kwargs): ...
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def is_downloaded() -> bool: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, image, mask):
|
||||
"""Input images and output images have same size
|
||||
images: [H, W, C] RGB
|
||||
masks: [H, W, 1] 255 为 masks 区域
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
...
|
||||
|
||||
def _pad_forward(self, image, mask):
|
||||
origin_height, origin_width = image.shape[:2]
|
||||
pad_image = pad_img_to_modulo(image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
|
||||
pad_mask = pad_img_to_modulo(mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
|
||||
|
||||
result = self.forward(pad_image, pad_mask)
|
||||
# result = result[0:origin_height, 0:origin_width, :]
|
||||
|
||||
# result, image, mask = self.forward_post_process(result, image, mask)
|
||||
|
||||
mask = mask[:, :, np.newaxis]
|
||||
result = result[0].permute(1, 2, 0).cpu().numpy() # 400 x 504 x 3
|
||||
result = undo_pad_to_mod(result, origin_height, origin_width)
|
||||
assert result.shape[:2] == image.shape[:2], f"{result.shape[:2]} != {image.shape[:2]}"
|
||||
# result = result * 255
|
||||
mask = (mask > 0) * 255
|
||||
result = result * mask + image * (1 - (mask / 255))
|
||||
result = np.clip(result, 0, 255).astype("uint8")
|
||||
return result
|
||||
|
||||
def forward_post_process(self, result, image, mask):
|
||||
return result, image, mask
|
||||
|
||||
|
||||
def resize_np_img(np_img, size, interpolation="bicubic"):
|
||||
assert interpolation in [
|
||||
"nearest",
|
||||
"bilinear",
|
||||
"bicubic",
|
||||
], f"Unsupported interpolation: {interpolation}, use nearest, bilinear or bicubic."
|
||||
torch_img = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0).float() / 255
|
||||
interp_img = torch.nn.functional.interpolate(torch_img, size=size, mode=interpolation, align_corners=True)
|
||||
return (interp_img.squeeze(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
||||
|
||||
|
||||
def resize_max_size(np_img, size_limit: int, interpolation="bicubic") -> np.ndarray:
|
||||
# Resize image's longer size to size_limit if longer size larger than size_limit
|
||||
h, w = np_img.shape[:2]
|
||||
if max(h, w) > size_limit:
|
||||
ratio = size_limit / max(h, w)
|
||||
new_w = int(w * ratio + 0.5)
|
||||
new_h = int(h * ratio + 0.5)
|
||||
return resize_np_img(np_img, size=(new_w, new_h), interpolation=interpolation)
|
||||
else:
|
||||
return np_img
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
if len(np_img.shape) == 2:
|
||||
np_img = np_img[:, :, np.newaxis]
|
||||
np_img = np.transpose(np_img, (2, 0, 1))
|
||||
np_img = np_img.astype("float32") / 255
|
||||
return np_img
|
||||
|
||||
|
||||
def get_cache_path_by_url(url):
|
||||
parts = urlparse(url)
|
||||
hub_dir = "~/.TORCH_HUB_CACHE"
|
||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||
if not os.path.isdir(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
return cached_file
|
||||
|
||||
|
||||
def md5sum(filename):
|
||||
md5 = hashlib.md5()
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def download_model(url, model_md5: str = None):
|
||||
cached_file = get_cache_path_by_url(url)
|
||||
if not os.path.exists(cached_file):
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
||||
if model_md5:
|
||||
_md5 = md5sum(cached_file)
|
||||
if model_md5 == _md5:
|
||||
print(f"Download model success, md5: {_md5}")
|
||||
else:
|
||||
try:
|
||||
os.remove(cached_file)
|
||||
print(
|
||||
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart"
|
||||
" lama-cleaner.If you still have errors, please try download model manually first"
|
||||
" https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
||||
)
|
||||
except:
|
||||
print(
|
||||
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart"
|
||||
" lama-cleaner."
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
return cached_file
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device, model_md5: str):
|
||||
if os.path.exists(url_or_path):
|
||||
model_path = url_or_path
|
||||
else:
|
||||
model_path = download_model(url_or_path, model_md5)
|
||||
|
||||
print(f"Loading model from: {model_path}")
|
||||
try:
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
raise
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LaMa(InpaintModel):
|
||||
name = "lama"
|
||||
pad_mod = 8
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask):
|
||||
"""
|
||||
images: [H, W, C] RGB, not normalized
|
||||
masks: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
# boxes = boxes_from_mask(mask)
|
||||
inpaint_result = self._pad_forward(image, mask)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
|
||||
|
||||
def forward(self, image, mask):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
image = norm_img(image)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
|
||||
inpainted_image = self.model(image, mask)
|
||||
return inpainted_image
|
||||
@@ -0,0 +1,65 @@
|
||||
import argparse
|
||||
import os
|
||||
from random import shuffle
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
|
||||
parser = argparse.ArgumentParser(description="Inspect image versions")
|
||||
parser.add_argument("-f", "--base_folder", type=str, required=True, help="Base folder to inspect")
|
||||
parser.add_argument("-opt_fg_ratio", type=float, default=0.3, help="Optimal foreground ratio")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
bg_folder = os.path.join(args.base_folder, "backgrounds")
|
||||
fg_folder = os.path.join(args.base_folder, "foregrounds")
|
||||
|
||||
classes = os.listdir(fg_folder)
|
||||
classes = sorted(classes, key=lambda x: int(x[1:]))
|
||||
assert len(classes) in [200, 1_000], f"Expected 200 or 1_000 classes, got {len(classes)}"
|
||||
|
||||
total_images = set()
|
||||
for in_cls in classes:
|
||||
cls_images = {
|
||||
os.path.join(in_cls, "_".join(img.split(".")[0].split("_")[:3]))
|
||||
for img in os.listdir(os.path.join(fg_folder, in_cls))
|
||||
}
|
||||
total_images.update(cls_images)
|
||||
total_images = list(total_images)
|
||||
shuffle(total_images)
|
||||
print(total_images[:5], "...")
|
||||
|
||||
# in_cls to print name/lemma
|
||||
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
|
||||
in_cls_to_name = {line.split(":")[0].strip(): line.split(":")[1].strip() for line in f.readlines() if len(line) > 2}
|
||||
|
||||
for image_name in total_images:
|
||||
in_cls, img_name = image_name.split("/")
|
||||
versions = set()
|
||||
for img in os.listdir(os.path.join(fg_folder, in_cls)):
|
||||
if img.startswith(img_name):
|
||||
versions.add(img)
|
||||
if len(versions) <= 1:
|
||||
print(f"Image {image_name} has only one version")
|
||||
continue
|
||||
versions = sorted(list(versions))
|
||||
|
||||
bg_img = Image.open(os.path.join(bg_folder, in_cls, f"{versions[0].split('.')[0]}.JPEG"))
|
||||
|
||||
# plot all versions at once
|
||||
fig, axs = plt.subplots(1, len(versions) + 1, figsize=(15, 5))
|
||||
axs[0].imshow(bg_img)
|
||||
axs[0].axis("off")
|
||||
axs[0].set_title("Background")
|
||||
for version, ax in zip(versions, axs[1:]):
|
||||
img = Image.open(os.path.join(fg_folder, in_cls, version))
|
||||
img_mask = np.array(img.convert("RGBA").split()[-1])
|
||||
fg_dens_fact = np.sum(img_mask) / (255 * img_mask.size)
|
||||
fg_ratio = np.sum(img_mask) / (255 * bg_img.size[0] * bg_img.size[1])
|
||||
ax.imshow(img)
|
||||
ax.axis("off")
|
||||
ax.set_title(f"{version}\nfg_rat: {fg_ratio:.2f} dist to optimal: {abs(fg_ratio - args.opt_fg_ratio):.2f}")
|
||||
fig.suptitle(in_cls_to_name[in_cls])
|
||||
plt.show()
|
||||
plt.close()
|
||||
@@ -0,0 +1,11 @@
|
||||
tqdm
|
||||
einops
|
||||
omegaconf
|
||||
diffusers
|
||||
opencv-python
|
||||
transformers
|
||||
accelerate
|
||||
torchvision
|
||||
datadings
|
||||
numpy
|
||||
nltk
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --array=0-29%30
|
||||
#SBATCH --time=1-0
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --gpus=1
|
||||
#SBATCH --partition=H200,H100,H100-PCI,A100-40GB,A100-80GB
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
|
||||
#SBATCH --job-name="Segment ImageNet (val)"
|
||||
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
|
||||
#SBATCH --wait
|
||||
|
||||
srun -K \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py -p 30 -id $SLURM_ARRAY_TASK_ID -o /OUTPUT/FOLDER/PATH/ "$@"
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --array=0-319%60
|
||||
#SBATCH --time=1-0
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --gpus=1
|
||||
#SBATCH --partition=H200,H100,H100-PCI,A100-PCI,A100-40GB,A100-80GB
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
|
||||
#SBATCH --job-name="Segment ImageNet (train)"
|
||||
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
|
||||
#SBATCH --wait
|
||||
|
||||
srun -K \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py -p 320 -id $SLURM_ARRAY_TASK_ID -o /OUTPUT_PATH/INSegment/ "$@"
|
||||
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --array=0-4%10
|
||||
#SBATCH --time=1-0
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --gpus=1
|
||||
#SBATCH --partition=RTX3090,RTXA6000,A100-PCI,A100-40GB,H100-PCI
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
|
||||
#SBATCH --job-name="Segment ImageNet"
|
||||
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
|
||||
|
||||
srun -K \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py -p 5 -id $SLURM_ARRAY_TASK_ID -o /PATH/TO/OUT/FOLDER/INSegment/ "$@"
|
||||
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --array=0-19%20
|
||||
#SBATCH --time=1-0
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --gpus=1
|
||||
#SBATCH --partition=RTX3090,RTXA6000,A100-PCI,A100-40GB,H100-PCI
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
|
||||
#SBATCH --job-name="Segment TinyImageNet"
|
||||
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
|
||||
|
||||
srun -K \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py -p 20 -id $SLURM_ARRAY_TASK_ID "$@"
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --array=0-29%30
|
||||
#SBATCH --time=1-0
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --gpus=1
|
||||
#SBATCH --partition=H200,H100,H100-PCI,A100-PCI,A100-40GB,A100-80GB
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
|
||||
#SBATCH --job-name="Segment TinyImageNet"
|
||||
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
|
||||
#SBATCH --wait
|
||||
|
||||
srun -K \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py -p 30 -id $SLURM_ARRAY_TASK_ID "$@"
|
||||
@@ -0,0 +1,338 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from math import ceil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datadings.reader import MsgpackReader
|
||||
from datadings.torch import Compose, CompressedToPIL, Dataset
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision.transforms import ToPILImage, ToTensor
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from attentive_eraser import AttentiveEraser
|
||||
from grounded_segmentation import grounded_segmentation
|
||||
from infill_lama import LaMa
|
||||
from utils import already_segmented, save_img
|
||||
from wordnet_tree import WNTree
|
||||
|
||||
|
||||
def _collate_single(data):
|
||||
return data[0]
|
||||
|
||||
|
||||
def _collate_multiple(datas):
|
||||
return {
|
||||
"images": [data["image"] for data in datas],
|
||||
"keys": [data["key"] for data in datas],
|
||||
"labels": [data["label"] for data in datas],
|
||||
}
|
||||
|
||||
|
||||
def smallest_crop(image: torch.Tensor, mask: torch.Tensor):
|
||||
"""Crops the image to just so fit the mask given.
|
||||
|
||||
Mask and image have to be of the same size.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): image to crop
|
||||
mask (torch.Tensor): cropping to mask
|
||||
|
||||
Returns:
|
||||
torch.Tensor: cropped image
|
||||
|
||||
"""
|
||||
if len(mask.shape) == 3:
|
||||
assert mask.shape[0] == 1
|
||||
mask = mask[0]
|
||||
assert (
|
||||
len(image.shape) == 3 and len(mask.shape) == 2 and image.shape[1:] == mask.shape
|
||||
), f"Invalid shapes: {image.shape}, {mask.shape}"
|
||||
dim0_indices = mask.sum(dim=0).nonzero()
|
||||
dim1_indices = mask.sum(dim=1).nonzero()
|
||||
|
||||
return image[
|
||||
:,
|
||||
dim1_indices.min().item() : dim1_indices.max().item() + 1,
|
||||
dim0_indices.min().item() : dim0_indices.max().item() + 1,
|
||||
]
|
||||
|
||||
|
||||
def _synset_to_prompt(synset, tree, parent_in_prompt):
|
||||
prompt = f"an {synset.print_name}." if synset.print_name[0] in "aeiou" else f"a {synset.print_name}."
|
||||
if synset.parent_id is None or not parent_in_prompt:
|
||||
return prompt
|
||||
parent = tree[synset.parent_id]
|
||||
return prompt[:-1] + f", a type of {parent.print_name}."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Grounded Segmentation")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
choices=["imagenet", "imagenet-val", "imagenet21k", "tinyimagenet", "tinyimagenet-val"],
|
||||
default="imagenet",
|
||||
help="Dataset to use",
|
||||
)
|
||||
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
|
||||
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
|
||||
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
||||
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
|
||||
parser.add_argument(
|
||||
"-p", "--processes", type=int, default=1, help="Number of processes that are used to process the data"
|
||||
)
|
||||
parser.add_argument("-id", type=int, default=0, help="ID of this process)")
|
||||
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files, instead of skipping them")
|
||||
parser.add_argument(
|
||||
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
|
||||
)
|
||||
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
|
||||
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
|
||||
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
|
||||
parser.add_argument(
|
||||
"-model",
|
||||
choices=["LaMa", "AttErase"],
|
||||
default="LaMa",
|
||||
help="Model to use for erasing/infilling. Defaults to LaMa.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = args.dataset.lower()
|
||||
part = "train"
|
||||
if dataset == "imagenet21k":
|
||||
reader = MsgpackReader(f"{args.dataset_root}imagenet21k/train.msgpack")
|
||||
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
||||
elif dataset == "imagenet-val":
|
||||
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/val.msgpack")
|
||||
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
||||
part = "val"
|
||||
elif dataset == "imagenet":
|
||||
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/train.msgpack")
|
||||
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
||||
elif dataset == "tinyimagenet":
|
||||
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_train.msgpack")
|
||||
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
||||
elif dataset == "tinyimagenet-val":
|
||||
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_val.msgpack")
|
||||
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
|
||||
part = "val"
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {dataset}")
|
||||
|
||||
assert 0 <= args.id < args.processes, "ID must be in the range [0, processes)"
|
||||
if args.processes > 1:
|
||||
partlen = ceil(len(dataset) / args.processes)
|
||||
start_id = args.id * partlen
|
||||
end_id = min((args.id + 1) * partlen, len(dataset))
|
||||
dataset = torch.utils.data.Subset(dataset, list(range(start_id, end_id)))
|
||||
|
||||
assert args.batch_size == 1, "Batch size must be 1 for grounded segmentation (for now)"
|
||||
dataset = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=args.debug,
|
||||
drop_last=False,
|
||||
num_workers=10,
|
||||
collate_fn=_collate_single if args.batch_size == 1 else _collate_multiple,
|
||||
)
|
||||
|
||||
infill_model = (
|
||||
LaMa(device="cuda" if torch.cuda.is_available() else "cpu") if args.model == "LaMa" else AttentiveEraser()
|
||||
)
|
||||
|
||||
if args.dataset.startswith("imagenet"):
|
||||
with open("wordnet_data/imagenet1k_synsets.json", "r") as f:
|
||||
id_to_synset = json.load(f)
|
||||
id_to_synset = {int(k): v for k, v in id_to_synset.items()}
|
||||
elif args.dataset.startswith("tinyimagenet"):
|
||||
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
|
||||
synsets = f.readlines()
|
||||
id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets]
|
||||
id_to_synset = sorted(id_to_synset)
|
||||
|
||||
wordnet = WNTree.load("wordnet_data/imagenet21k+1k_masses_tree.json")
|
||||
|
||||
# create the necessary folders
|
||||
os.makedirs(os.path.join(args.output, part, "foregrounds"), exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output, part, "backgrounds"), exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output, part, "no_detect"), exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output, part, "error"), exist_ok=True)
|
||||
|
||||
man_print_progress = int(os.environ.get("TQDM_DISABLE", "0")) == 1
|
||||
|
||||
for i, data in tqdm(enumerate(dataset), total=len(dataset)):
|
||||
if args.debug and i > 10:
|
||||
break
|
||||
|
||||
if man_print_progress and i % 200 == 0:
|
||||
print(f"{datetime.now()} \t process {args.id}/{args.processes}: \t sample: {i}/{len(dataset)}", flush=True)
|
||||
|
||||
image = data["image"]
|
||||
key = data["key"]
|
||||
|
||||
if args.dataset.endswith("-val"):
|
||||
img_class = f"n{id_to_synset[data['label']]:08d}" # overwrite the synset id on the val set
|
||||
else:
|
||||
img_class = key.split("/")[-1].split("_")[0]
|
||||
|
||||
if (
|
||||
already_segmented(key, os.path.join(args.output, part, "foregrounds"), img_class=img_class)
|
||||
and not args.overwrite
|
||||
):
|
||||
continue
|
||||
|
||||
# get the next 3 labels up in the wordnet tree
|
||||
label_id = data["label"]
|
||||
offset = id_to_synset[label_id] # int(data["key"].split("_")[0][1:])
|
||||
synset = wordnet[offset]
|
||||
labels = [_synset_to_prompt(synset, wordnet, args.parent_in_prompt)]
|
||||
while len(labels) < 1 + args.parent_labels and synset.parent_id is not None:
|
||||
synset = wordnet[synset.parent_id]
|
||||
_label = _synset_to_prompt(synset, wordnet, args.parent_in_prompt)
|
||||
_label = _label.replace("_", " ")
|
||||
labels.append(_label)
|
||||
|
||||
assert len(labels) > 0, f"No labels for {key}"
|
||||
|
||||
# 1. detect and segment
|
||||
image_tensor, detections = grounded_segmentation(
|
||||
image, labels, threshold=args.threshold, polygon_refinement=True
|
||||
)
|
||||
|
||||
# merge all detections for the same label
|
||||
masks = {}
|
||||
for detection in detections:
|
||||
if detection.label not in masks:
|
||||
masks[detection.label] = detection.mask
|
||||
else:
|
||||
masks[detection.label] |= detection.mask
|
||||
labels = list(masks.keys())
|
||||
masks = [masks[lbl] for lbl in labels]
|
||||
assert len(masks) == len(
|
||||
labels
|
||||
), f"Number of masks ({len(masks)}) != number of labels ({len(labels)}); detections: {detections}"
|
||||
|
||||
if args.debug:
|
||||
for detected_mask, label in zip(masks, labels):
|
||||
detected_mask = torch.from_numpy(detected_mask).unsqueeze(0) / 255
|
||||
mask_foreground = torch.cat((image_tensor * detected_mask, detected_mask), dim=0)
|
||||
ToPILImage()(mask_foreground).save(f"example_images/{key}_{label}_masked.png")
|
||||
|
||||
if len(masks) == 0:
|
||||
tqdm.write(f"No detections for {key}; skipping")
|
||||
save_img(image, key, os.path.join(args.output, part, "no_detect"), img_class=img_class, format="JPEG")
|
||||
continue
|
||||
elif len(masks) == 1:
|
||||
# max_overlap_mask = masks[0]
|
||||
masks = [masks[0]]
|
||||
elif args.output_ims == "best":
|
||||
# find the 2 masks with the largest overlap
|
||||
max_overlap = 0
|
||||
max_overlap_mask = None
|
||||
using_labels = None
|
||||
for i, mask1 in enumerate(masks):
|
||||
for j, mask2 in enumerate(masks):
|
||||
if i >= j:
|
||||
continue
|
||||
iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
|
||||
if iou > max_overlap:
|
||||
max_overlap = iou
|
||||
max_overlap_mask = mask1 | mask2
|
||||
using_labels = f"{labels[i]} & {labels[j]}"
|
||||
if args.debug:
|
||||
tqdm.write(f"{key}:\tMax overlap: {max_overlap}, using {using_labels}")
|
||||
if max_overlap_mask is None:
|
||||
# assert len(masks) > 0 and len(masks) == len(labels), f"No detections for {key}"
|
||||
max_overlap_mask = masks[0]
|
||||
masks = [max_overlap_mask]
|
||||
else:
|
||||
# merge masks that are too similar
|
||||
has_changed = True
|
||||
while has_changed:
|
||||
has_changed = False
|
||||
for i, mask1 in enumerate(masks):
|
||||
for j, mask2 in enumerate(masks):
|
||||
if i >= j:
|
||||
continue
|
||||
iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
|
||||
if iou > args.mask_merge_threshold:
|
||||
masks[i] |= masks[j]
|
||||
masks.pop(j)
|
||||
labels.pop(j)
|
||||
has_changed = True
|
||||
break
|
||||
if has_changed:
|
||||
break
|
||||
|
||||
for mask_idx, mask_array in enumerate(masks):
|
||||
mask = torch.from_numpy(mask_array).unsqueeze(0) / 255
|
||||
|
||||
if args.debug:
|
||||
mask_image = ToPILImage()(mask)
|
||||
mask_image.save(f"example_images/{key}_mask.png")
|
||||
|
||||
# foreground = image_tensor
|
||||
foreground = torch.cat((image_tensor * mask, mask), dim=0)
|
||||
foreground = ToPILImage()(foreground)
|
||||
mask_img = foreground.split()[-1]
|
||||
foreground = smallest_crop(ToTensor()(foreground), ToTensor()(mask_img)) # TODO: fix smallest crop
|
||||
fg_img = ToPILImage()(foreground)
|
||||
|
||||
background = (1 - mask) * image_tensor
|
||||
bg_image = ToPILImage()(background)
|
||||
|
||||
# 2. infill background
|
||||
mask_image = Image.fromarray(mask_array)
|
||||
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=7))
|
||||
try:
|
||||
infilled_bg = infill_model(np.array(bg_image), np.array(mask_image))
|
||||
infilled_bg = Image.fromarray(np.uint8(infilled_bg))
|
||||
except RuntimeError as e:
|
||||
tqdm.write(
|
||||
f"Error infilling {key}: bg_image.shape={np.array(bg_image).shape},"
|
||||
f" mask_image.shape={np.array(mask_image).shape}\n{e}"
|
||||
)
|
||||
save_img(
|
||||
image,
|
||||
key,
|
||||
os.path.join(args.output, part, "error"),
|
||||
img_class=img_class,
|
||||
img_version=mask_idx if len(masks) > 1 else None,
|
||||
format="JPEG",
|
||||
)
|
||||
infilled_bg = None
|
||||
|
||||
if args.debug:
|
||||
data["image"].save(f"example_images/{key}_orig.png")
|
||||
fg_img.save(f"example_images/{key}_fg.png")
|
||||
infilled_bg.save(f"example_images/{key}_bg.png")
|
||||
else:
|
||||
# save files
|
||||
class_name = key.split("_")[0]
|
||||
save_img(
|
||||
fg_img,
|
||||
key,
|
||||
os.path.join(args.output, part, "foregrounds"),
|
||||
img_class=img_class,
|
||||
img_version=mask_idx if len(masks) > 1 else None,
|
||||
format="WEBP",
|
||||
)
|
||||
if infilled_bg is not None:
|
||||
save_img(
|
||||
infilled_bg,
|
||||
key,
|
||||
os.path.join(args.output, part, "backgrounds"),
|
||||
img_class=img_class,
|
||||
img_version=mask_idx if len(masks) > 1 else None,
|
||||
format="JPEG",
|
||||
)
|
||||
if man_print_progress:
|
||||
print(
|
||||
f"{datetime.now()} \t process {args.id}/{args.processes}: \t done with all {len(dataset)} samples",
|
||||
flush=True,
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
srun -K \
|
||||
--partition=H200,H100,A100-40GB,A100-80GB \
|
||||
--job-name="Segment ImageNet" \
|
||||
--nodes=1 \
|
||||
--gpus=1 \
|
||||
--cpus-per-task=16 \
|
||||
--mem=64G \
|
||||
--time=1-0 \
|
||||
--export=HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/ \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
python3 segment_imagenet.py "$@"
|
||||
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
srun -K \
|
||||
--partition=RTX3090,A100-40GB,H100 \
|
||||
--job-name="Segment ImageNet" \
|
||||
--nodes=1 \
|
||||
--mem=64G \
|
||||
--time=1-0 \
|
||||
--export=HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,MAX_WORKERS=16 \
|
||||
--container-image=/PATH/TO.sqsh \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
"$@"
|
||||
49
AAAI Supplementary Material/ForNet Creation Code/utils.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def save_img(img: Image, img_name: str, base_dir: str, img_class: str = None, format="PNG", img_version=None):
|
||||
"""Save an image to a directory.
|
||||
|
||||
Args:
|
||||
img (PIL.Image): Image to save.
|
||||
img_name (str): Relative path to the image.
|
||||
base_dir (str): Base directory to save images in.
|
||||
img_class (str, optional): Image class, if not given try to extract it from the image name in ImageNet train format. Defaults to None.
|
||||
format (str, optional): Format to save the image in. Defaults to "PNG".
|
||||
img_version (int, optional): Version of the image. Will be appended to the path. Defaults to None.
|
||||
|
||||
"""
|
||||
if not img_name.endswith(f".{format}"):
|
||||
img_name = f"{img_name.split('.')[0]}.{format}"
|
||||
if img_class is None:
|
||||
img_class = img_name.split("_")[0]
|
||||
if not os.path.exists(os.path.join(base_dir, img_class)):
|
||||
os.makedirs(os.path.join(base_dir, img_class), exist_ok=True)
|
||||
if img_version is not None:
|
||||
img_name = f"{img_name.split('.')[0]}_v{img_version}.{format}"
|
||||
img.save(os.path.join(base_dir, img_class, img_name), format.lower())
|
||||
|
||||
|
||||
def already_segmented(img_name: str, base_dir: str, img_class: str = None):
|
||||
"""Check if an image was already segmented.
|
||||
|
||||
Args:
|
||||
img_name (str): Relative path to the image.
|
||||
base_dir (str): Base directory to save images in.
|
||||
img_class (str, optional): Image class, if not given try to extract it from the image name in ImageNet train format. Defaults to None.
|
||||
|
||||
Returns:
|
||||
bool: Image was segmented already.
|
||||
|
||||
"""
|
||||
img_base_name = ".".join(img_name.split(".")[:-1]) if "." in img_name else img_name
|
||||
if img_class is None:
|
||||
img_class = img_name.split("_")[0]
|
||||
if not os.path.exists(os.path.join(base_dir, img_class)):
|
||||
return False
|
||||
return any(
|
||||
file.startswith(img_base_name + "_v") or file.startswith(img_base_name + ".")
|
||||
for file in os.listdir(os.path.join(base_dir, img_class))
|
||||
)
|
||||
@@ -0,0 +1,200 @@
|
||||
n07695742: pretzel
|
||||
n03902125: pay-phone, pay-station
|
||||
n03980874: poncho
|
||||
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
|
||||
n02730930: apron
|
||||
n02699494: altar
|
||||
n03201208: dining table, board
|
||||
n02056570: king penguin, Aptenodytes patagonica
|
||||
n04099969: rocking chair, rocker
|
||||
n04366367: suspension bridge
|
||||
n04067472: reel
|
||||
n02808440: bathtub, bathing tub, bath, tub
|
||||
n04540053: volleyball
|
||||
n02403003: ox
|
||||
n03100240: convertible
|
||||
n04562935: water tower
|
||||
n02788148: bannister, banister, balustrade, balusters, handrail
|
||||
n02988304: CD player
|
||||
n02423022: gazelle
|
||||
n03637318: lampshade, lamp shade
|
||||
n01774384: black widow, Latrodectus mactans
|
||||
n01768244: trilobite
|
||||
n07614500: ice cream, icecream
|
||||
n04254777: sock
|
||||
n02085620: Chihuahua
|
||||
n01443537: goldfish, Carassius auratus
|
||||
n01629819: European fire salamander, Salamandra salamandra
|
||||
n02099601: golden retriever
|
||||
n02321529: sea cucumber, holothurian
|
||||
n03837869: obelisk
|
||||
n02002724: black stork, Ciconia nigra
|
||||
n02841315: binoculars, field glasses, opera glasses
|
||||
n04560804: water jug
|
||||
n02364673: guinea pig, Cavia cobaya
|
||||
n03706229: magnetic compass
|
||||
n09256479: coral reef
|
||||
n09332890: lakeside, lakeshore
|
||||
n03544143: hourglass
|
||||
n02124075: Egyptian cat
|
||||
n02948072: candle, taper, wax light
|
||||
n01950731: sea slug, nudibranch
|
||||
n02791270: barbershop
|
||||
n03179701: desk
|
||||
n02190166: fly
|
||||
n04275548: spider web, spider's web
|
||||
n04417672: thatch, thatched roof
|
||||
n03930313: picket fence, paling
|
||||
n02236044: mantis, mantid
|
||||
n03976657: pole
|
||||
n01774750: tarantula
|
||||
n04376876: syringe
|
||||
n04133789: sandal
|
||||
n02099712: Labrador retriever
|
||||
n04532670: viaduct
|
||||
n04487081: trolleybus, trolley coach, trackless trolley
|
||||
n09428293: seashore, coast, seacoast, sea-coast
|
||||
n03160309: dam, dike, dyke
|
||||
n03250847: drumstick
|
||||
n02843684: birdhouse
|
||||
n07768694: pomegranate
|
||||
n03670208: limousine, limo
|
||||
n03085013: computer keyboard, keypad
|
||||
n02892201: brass, memorial tablet, plaque
|
||||
n02233338: cockroach, roach
|
||||
n03649909: lawn mower, mower
|
||||
n03388043: fountain
|
||||
n02917067: bullet train, bullet
|
||||
n02486410: baboon
|
||||
n04596742: wok
|
||||
n03255030: dumbbell
|
||||
n03937543: pill bottle
|
||||
n02113799: standard poodle
|
||||
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
|
||||
n02906734: broom
|
||||
n07920052: espresso
|
||||
n01698640: American alligator, Alligator mississipiensis
|
||||
n02123394: Persian cat
|
||||
n03424325: gasmask, respirator, gas helmet
|
||||
n02129165: lion, king of beasts, Panthera leo
|
||||
n04008634: projectile, missile
|
||||
n03042490: cliff dwelling
|
||||
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
|
||||
n02815834: beaker
|
||||
n02395406: hog, pig, grunter, squealer, Sus scrofa
|
||||
n01784675: centipede
|
||||
n03126707: crane
|
||||
n04399382: teddy, teddy bear
|
||||
n07875152: potpie
|
||||
n03733131: maypole
|
||||
n02802426: basketball
|
||||
n03891332: parking meter
|
||||
n01910747: jellyfish
|
||||
n03838899: oboe, hautboy, hautbois
|
||||
n03770439: miniskirt, mini
|
||||
n02281406: sulphur butterfly, sulfur butterfly
|
||||
n03970156: plunger, plumber's helper
|
||||
n09246464: cliff, drop, drop-off
|
||||
n02206856: bee
|
||||
n02074367: dugong, Dugong dugon
|
||||
n03584254: iPod
|
||||
n04179913: sewing machine
|
||||
n04328186: stopwatch, stop watch
|
||||
n07583066: guacamole
|
||||
n01917289: brain coral
|
||||
n03447447: gondola
|
||||
n02823428: beer bottle
|
||||
n03854065: organ, pipe organ
|
||||
n02793495: barn
|
||||
n04285008: sports car, sport car
|
||||
n02231487: walking stick, walkingstick, stick insect
|
||||
n04465501: tractor
|
||||
n02814860: beacon, lighthouse, beacon light, pharos
|
||||
n02883205: bow tie, bow-tie, bowtie
|
||||
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
|
||||
n04149813: scoreboard
|
||||
n04023962: punching bag, punch bag, punching ball, punchball
|
||||
n02226429: grasshopper, hopper
|
||||
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
|
||||
n02669723: academic gown, academic robe, judge's robe
|
||||
n04486054: triumphal arch
|
||||
n04070727: refrigerator, icebox
|
||||
n03444034: go-kart
|
||||
n02666196: abacus
|
||||
n01945685: slug
|
||||
n04251144: snorkel
|
||||
n03617480: kimono
|
||||
n03599486: jinrikisha, ricksha, rickshaw
|
||||
n02437312: Arabian camel, dromedary, Camelus dromedarius
|
||||
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
|
||||
n04118538: rugby ball
|
||||
n01770393: scorpion
|
||||
n04356056: sunglasses, dark glasses, shades
|
||||
n03804744: nail
|
||||
n02132136: brown bear, bruin, Ursus arctos
|
||||
n03400231: frying pan, frypan, skillet
|
||||
n03983396: pop bottle, soda bottle
|
||||
n07734744: mushroom
|
||||
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
|
||||
n02410509: bison
|
||||
n03404251: fur coat
|
||||
n04456115: torch
|
||||
n02123045: tabby, tabby cat
|
||||
n03026506: Christmas stocking
|
||||
n07715103: cauliflower
|
||||
n04398044: teapot
|
||||
n02927161: butcher shop, meat market
|
||||
n07749582: lemon
|
||||
n07615774: ice lolly, lolly, lollipop, popsicle
|
||||
n02795169: barrel, cask
|
||||
n04532106: vestment
|
||||
n02837789: bikini, two-piece
|
||||
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
|
||||
n04265275: space heater
|
||||
n02481823: chimpanzee, chimp, Pan troglodytes
|
||||
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
|
||||
n06596364: comic book
|
||||
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
|
||||
n02504458: African elephant, Loxodonta africana
|
||||
n03014705: chest
|
||||
n01944390: snail
|
||||
n04146614: school bus
|
||||
n01641577: bullfrog, Rana catesbeiana
|
||||
n07720875: bell pepper
|
||||
n02999410: chain
|
||||
n01855672: goose
|
||||
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
|
||||
n07753592: banana
|
||||
n07871810: meat loaf, meatloaf
|
||||
n04501370: turnstile
|
||||
n04311004: steel arch bridge
|
||||
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
|
||||
n04074963: remote control, remote
|
||||
n03662601: lifeboat
|
||||
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
|
||||
n03089624: confectionery, confectionary, candy store
|
||||
n04259630: sombrero
|
||||
n03393912: freight car
|
||||
n04597913: wooden spoon
|
||||
n07711569: mashed potato
|
||||
n03355925: flagpole, flagstaff
|
||||
n02963159: cardigan
|
||||
n07579787: plate
|
||||
n02950826: cannon
|
||||
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
|
||||
n02094433: Yorkshire terrier
|
||||
n02909870: bucket, pail
|
||||
n02058221: albatross, mollymawk
|
||||
n01742172: boa constrictor, Constrictor constrictor
|
||||
n09193705: alp
|
||||
n04371430: swimming trunks, bathing trunks
|
||||
n07747607: orange
|
||||
n03814639: neck brace
|
||||
n04507155: umbrella
|
||||
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
|
||||
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
|
||||
n03763968: military uniform
|
||||
n07873807: pizza, pizza pie
|
||||
n03992509: potter's wheel
|
||||
n03796401: moving van
|
||||
n12267677: acorn
|
||||
400
AAAI Supplementary Material/ForNet Creation Code/wordnet_tree.py
Normal file
@@ -0,0 +1,400 @@
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
from copy import copy
|
||||
|
||||
from nltk.corpus import wordnet as wn
|
||||
|
||||
|
||||
class bcolors:
|
||||
"""Colors for terminal output."""
|
||||
|
||||
HEADER = "\033[95m"
|
||||
OKBLUE = "\033[94m"
|
||||
OKCYAN = "\033[96m"
|
||||
OKGREEN = "\033[92m"
|
||||
WARNING = "\033[93m"
|
||||
FAIL = "\033[91m"
|
||||
ENDC = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
UNDERLINE = "\033[4m"
|
||||
|
||||
|
||||
def _lemmas_str(synset):
|
||||
return ", ".join([lemma.name() for lemma in synset.lemmas()])
|
||||
|
||||
|
||||
class WNEntry:
|
||||
"""One wordnet synset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
id: int,
|
||||
lemmas: str,
|
||||
parent_id: int,
|
||||
depth: int = None,
|
||||
in_image_net: bool = False,
|
||||
child_ids: list = None,
|
||||
in_main_tree: bool = True,
|
||||
_n_images: int = 0,
|
||||
_description: str = None,
|
||||
_name: str = None,
|
||||
_pruned: bool = False,
|
||||
):
|
||||
self.name = name
|
||||
self.id = id
|
||||
self.lemmas = lemmas
|
||||
self.parent_id = parent_id
|
||||
self.depth = depth
|
||||
self.in_image_net = in_image_net
|
||||
self.child_ids = child_ids
|
||||
self.in_main_tree = in_main_tree
|
||||
self._n_images = _n_images
|
||||
self._description = _description
|
||||
self._name = _name
|
||||
self._pruned = _pruned
|
||||
|
||||
def __str__(self, tree=None, accumulate=True, colors=True, max_depth=0, max_children=None):
|
||||
green = f"{bcolors.OKGREEN}" if colors else ""
|
||||
red = f"{bcolors.FAIL}" if colors else ""
|
||||
end = f"{bcolors.ENDC}" if colors else ""
|
||||
start_symb = f"{green}+{end}" if self.in_image_net else f"{red}-{end}"
|
||||
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
|
||||
if self.child_ids is None or tree is None or max_depth == 0:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
||||
|
||||
children = self.child_ids
|
||||
if max_children is not None and len(children) > max_children:
|
||||
children = children[:max_children]
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
||||
[
|
||||
"\n ".join(
|
||||
tree.nodes[child_id]
|
||||
.__str__(tree=tree, accumulate=accumulate, colors=colors, max_depth=max_depth - 1)
|
||||
.split("\n")
|
||||
)
|
||||
for child_id in children
|
||||
]
|
||||
)
|
||||
|
||||
def tree_diff(self, tree_1, tree_2):
|
||||
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
|
||||
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
|
||||
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
|
||||
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
|
||||
else:
|
||||
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
|
||||
n_ims = (
|
||||
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
|
||||
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
|
||||
)
|
||||
|
||||
if self.child_ids is None:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
||||
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
||||
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
|
||||
)
|
||||
|
||||
def prune(self, tree):
|
||||
if self._pruned or self.parent_id is None:
|
||||
return
|
||||
|
||||
if self.child_ids is not None:
|
||||
for child_id in self.child_ids:
|
||||
tree[child_id].prune(tree)
|
||||
|
||||
self._pruned = True
|
||||
parent_node = tree.nodes[self.parent_id]
|
||||
try:
|
||||
parent_node.child_ids.remove(self.id)
|
||||
except ValueError as e:
|
||||
print(
|
||||
f"Error removing {self.name} from"
|
||||
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
|
||||
)
|
||||
while parent_node._pruned:
|
||||
parent_node = tree.nodes[parent_node.parent_id]
|
||||
parent_node._n_images += self._n_images
|
||||
self._n_images = 0
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
if not self._description:
|
||||
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def print_name(self):
|
||||
return self.name.split(".")[0]
|
||||
|
||||
@property
|
||||
def is_leaf(self):
|
||||
return self.child_ids is None or len(self.child_ids) == 0
|
||||
|
||||
def get_branch(self, tree=None):
|
||||
if self.parent_id is None or tree is None:
|
||||
return self.print_name
|
||||
|
||||
parent = tree.nodes[self.parent_id]
|
||||
return parent.get_branch(tree) + " > " + self.print_name
|
||||
|
||||
def get_branch_list(self, tree):
|
||||
if self.parent_id is None:
|
||||
return [self]
|
||||
parent = tree.nodes[self.parent_id]
|
||||
return parent.get_branch_list(tree) + [self]
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"id": self.id,
|
||||
"lemmas": self.lemmas,
|
||||
"parent_id": self.parent_id,
|
||||
"depth": self.depth,
|
||||
"in_image_net": self.in_image_net,
|
||||
"child_ids": self.child_ids,
|
||||
"in_main_tree": self.in_main_tree,
|
||||
"_n_images": self._n_images,
|
||||
"_description": self._description,
|
||||
"_name": self._name,
|
||||
"_pruned": self._pruned,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(**d)
|
||||
|
||||
def n_images(self, tree=None):
|
||||
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
|
||||
return self._n_images
|
||||
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
|
||||
|
||||
def n_children(self, tree=None):
|
||||
if self.child_ids is None:
|
||||
return 0
|
||||
if tree is None or len(self.child_ids) == 0:
|
||||
return len(self.child_ids)
|
||||
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
|
||||
|
||||
def get_examples(self, tree, n_examples=3):
|
||||
if self.child_ids is None or len(self.child_ids) == 0:
|
||||
return ""
|
||||
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
|
||||
max_images = max(child_images.values())
|
||||
if max_images == 0:
|
||||
# go on number of child nodes
|
||||
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
|
||||
# sorted childids by number of images
|
||||
top_children = [
|
||||
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
top_children = top_children[: min(n_examples, len(top_children))]
|
||||
return ", ".join(
|
||||
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
|
||||
)
|
||||
|
||||
|
||||
class WNTree:
|
||||
def __init__(self, root=1740, nodes=None):
|
||||
if isinstance(root, int):
|
||||
root_id = root
|
||||
root_synset = wn.synset_from_pos_and_offset("n", root)
|
||||
root_node = WNEntry(
|
||||
root_synset.name(),
|
||||
root_id,
|
||||
_lemmas_str(root_synset),
|
||||
parent_id=None,
|
||||
depth=0,
|
||||
)
|
||||
else:
|
||||
assert isinstance(root, WNEntry)
|
||||
root_id = root.id
|
||||
root_node = root
|
||||
|
||||
self.root = root_node
|
||||
self.nodes = {root_id: self.root} if nodes is None else nodes
|
||||
self.parentless = []
|
||||
self.label_index = None
|
||||
self.pruned = set()
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"root": self.root.to_dict(),
|
||||
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
||||
"parentless": self.parentless,
|
||||
"pruned": list(self.pruned),
|
||||
}
|
||||
|
||||
def prune(self, min_images):
|
||||
pruned_nodes = set()
|
||||
|
||||
# prune all nodes that have fewer than min_images below them
|
||||
for node_id, node in self.nodes.items():
|
||||
if node.n_images(self) < min_images:
|
||||
pruned_nodes.add(node_id)
|
||||
node.prune(self)
|
||||
|
||||
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
|
||||
node_stack = [self.root]
|
||||
node_idx = 0
|
||||
while node_idx < len(node_stack):
|
||||
node = node_stack[node_idx]
|
||||
if node.child_ids is not None:
|
||||
for child_id in node.child_ids:
|
||||
child = self.nodes[child_id]
|
||||
node_stack.append(child)
|
||||
node_idx += 1
|
||||
|
||||
# now prune the stack from the bottom up
|
||||
for node in node_stack[::-1]:
|
||||
# only look at images of that class, not of additional children
|
||||
if node.n_images() < min_images:
|
||||
pruned_nodes.add(node.id)
|
||||
node.prune(self)
|
||||
|
||||
self.pruned = pruned_nodes
|
||||
return pruned_nodes
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
tree = cls()
|
||||
tree.root = WNEntry.from_dict(d["root"])
|
||||
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
|
||||
tree.parentless = d["parentless"]
|
||||
if "pruned" in d:
|
||||
tree.pruned = set(d["pruned"])
|
||||
return tree
|
||||
|
||||
def add_node(self, node_id, in_in=True):
|
||||
if node_id in self.nodes:
|
||||
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
|
||||
return
|
||||
|
||||
synset = wn.synset_from_pos_and_offset("n", node_id)
|
||||
|
||||
# print(f"adding node {synset.name()} with id {node_id}")
|
||||
|
||||
hypernyms = synset.hypernyms()
|
||||
if len(hypernyms) == 0:
|
||||
parent_id = None
|
||||
self.parentless.append(node_id)
|
||||
main_tree = False
|
||||
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
|
||||
else:
|
||||
parent_id = synset.hypernyms()[0].offset()
|
||||
if parent_id not in self.nodes:
|
||||
self.add_node(parent_id, in_in=False)
|
||||
parent = self.nodes[parent_id]
|
||||
|
||||
if parent.child_ids is None:
|
||||
parent.child_ids = []
|
||||
parent.child_ids.append(node_id)
|
||||
main_tree = parent.in_main_tree
|
||||
|
||||
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
|
||||
node = WNEntry(
|
||||
synset.name(),
|
||||
node_id,
|
||||
_lemmas_str(synset),
|
||||
parent_id=parent_id,
|
||||
in_image_net=in_in,
|
||||
depth=depth,
|
||||
in_main_tree=main_tree,
|
||||
)
|
||||
|
||||
self.nodes[node_id] = node
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes)
|
||||
|
||||
def image_net_len(self, only_main_tree=False):
|
||||
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
||||
|
||||
def max_depth(self, only_main_tree=False):
|
||||
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
||||
|
||||
def __str__(self, colors=True):
|
||||
return (
|
||||
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
|
||||
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self, colors=colors)}\nParentless:\n"
|
||||
+ "\n".join([self.nodes[node_id].__str__(tree=self, colors=colors) for node_id in self.parentless])
|
||||
)
|
||||
|
||||
def save(self, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(self.to_dict(), f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
with open(path, "r") as f:
|
||||
tree_dict = json.load(f)
|
||||
return cls.from_dict(tree_dict)
|
||||
|
||||
def subtree(self, node):
|
||||
node_id = self[node].id
|
||||
if node_id not in self.nodes:
|
||||
return None
|
||||
node_queue = [self.nodes[node_id]]
|
||||
subtree_ids = set()
|
||||
while len(node_queue) > 0:
|
||||
node = node_queue.pop(0)
|
||||
subtree_ids.add(node.id)
|
||||
if node.child_ids is not None:
|
||||
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
|
||||
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
|
||||
subtree_root = subtree_nodes[node_id]
|
||||
subtree_root.parent_id = None
|
||||
depth_diff = subtree_root.depth
|
||||
for node in subtree_nodes.values():
|
||||
node.depth -= depth_diff
|
||||
return WNTree(root=subtree_root, nodes=subtree_nodes)
|
||||
|
||||
def _make_label_index(self):
|
||||
self.label_index = sorted(
|
||||
[node_id for node_id, node in self.nodes.items() if node.n_images(self) > 0 and not node._pruned]
|
||||
)
|
||||
|
||||
def get_label(self, node_id):
|
||||
if self.label_index is None:
|
||||
self._make_label_index()
|
||||
while self.nodes[node_id]._pruned:
|
||||
node_id = self.nodes[node_id].parent_id
|
||||
return self.label_index.index(node_id)
|
||||
|
||||
def n_labels(self):
|
||||
if self.label_index is None:
|
||||
self._make_label_index()
|
||||
return len(self.label_index)
|
||||
|
||||
def __contains__(self, item):
|
||||
if isinstance(item, str):
|
||||
if item[0] == "n":
|
||||
item = int(item[1:])
|
||||
else:
|
||||
return False
|
||||
if isinstance(item, int):
|
||||
return item in self.nodes
|
||||
if isinstance(item, WNEntry):
|
||||
return item.id in self.nodes
|
||||
return False
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, str) and item[0].startswith("n"):
|
||||
with contextlib.suppress(ValueError):
|
||||
item = int(item[1:])
|
||||
if isinstance(item, str) and ".n." in item:
|
||||
for node in self.nodes.values():
|
||||
if item == node.name:
|
||||
return node
|
||||
raise KeyError(f"Item {item} not found in tree")
|
||||
if isinstance(item, int):
|
||||
return self.nodes[item]
|
||||
if isinstance(item, WNEntry):
|
||||
return self.nodes[item.id]
|
||||
raise KeyError(f"Item {item} not found in tree")
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.nodes.keys())
|
||||
76
AAAI Supplementary Material/Model Training Code/.ruff.toml
Normal file
@@ -0,0 +1,76 @@
|
||||
# Exclude a variety of commonly ignored directories.
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".ipynb_checkpoints",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pyenv",
|
||||
".pytest_cache",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
".vscode",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"site-packages",
|
||||
"venv",
|
||||
]
|
||||
|
||||
# Same as Black.
|
||||
line-length = 120
|
||||
indent-width = 4
|
||||
|
||||
[lint]
|
||||
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
|
||||
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
|
||||
# McCabe complexity (`C901`) by default.
|
||||
select = ["E", "F", "D", "B", "NPY", "PD", "TD005", "TD006", "TD007", "SIM", "RET", "Q", "ICN", "I"]
|
||||
ignore = ["D203", "D213", "E501", "D100", "NPY002", "D102", "B008", "PD011", "D105", "SIM118", "D417"]
|
||||
|
||||
# Allow fix for all enabled rules (when `--fix`) is provided.
|
||||
unfixable = ["B"]
|
||||
|
||||
# Allow unused variables when underscore-prefixed.
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
[lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[format]
|
||||
# Like Black, use double quotes for strings.
|
||||
quote-style = "double"
|
||||
|
||||
# Like Black, indent with spaces, rather than tabs.
|
||||
indent-style = "space"
|
||||
|
||||
# Like Black, respect magic trailing commas.
|
||||
skip-magic-trailing-comma = true
|
||||
|
||||
# Like Black, automatically detect the appropriate line ending.
|
||||
line-ending = "auto"
|
||||
|
||||
# Enable auto-formatting of code examples in docstrings. Markdown,
|
||||
# reStructuredText code/literal blocks and doctests are all supported.
|
||||
#
|
||||
# This is currently disabled by default, but it is planned for this
|
||||
# to be opt-out in the future.
|
||||
docstring-code-format = true
|
||||
|
||||
# Set the line length limit used when formatting code snippets in
|
||||
# docstrings.
|
||||
#
|
||||
# This only has an effect when the `docstring-code-format` setting is
|
||||
# enabled.
|
||||
docstring-code-line-length = "dynamic"
|
||||
107
AAAI Supplementary Material/Model Training Code/README.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# ForNet
|
||||
|
||||
This is the training code for the ForNet paper.
|
||||
All our experiments and evaluations were run using this codebase.
|
||||
|
||||
## Requirements
|
||||
|
||||
This project heavily builds on [timm](https://github.com/huggingface/pytorch-image-models) and open source implementations of the models that are tested.
|
||||
All requirements are listed in [requirements.txt](./requirements.txt).
|
||||
To install those, run
|
||||
|
||||
```commandline
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
After **cloning this repository**, you can train and test a lot of different models.
|
||||
By default, a `srun` command is executed to run the code on a slurm cluster.
|
||||
To run on the local machine, append the `-local` flag to the command.
|
||||
|
||||
### General Preparation
|
||||
|
||||
After cloning the repository on a slurm cluster, make sure main.py is executable (by using "chmod a+x main.py").
|
||||
|
||||
To run the project on a slurm cluster, you need to create a docker image from the requirements file.
|
||||
You will also want to adapt the default slurm parameters in `config.py`.
|
||||
|
||||
Next, adjust the paths in paths_config.py for your system, specifically results_folder, slurm_output_folder and dataset folders.
|
||||
|
||||
Finally, if you want to use Weights and Biases for Tracking, create the file ".wandb.apikey" in this folder and paste your API Key into it.
|
||||
|
||||
### Training
|
||||
|
||||
#### Pretraining
|
||||
|
||||
To pretrain a `ViT-S` on a given dataset, run
|
||||
|
||||
```commandline
|
||||
./main.py --task pre-train --model ViT-S/16 --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
|
||||
```
|
||||
|
||||
This will save a checkpoint (`.pt` file) every `<save_epochs>` epochs (the default is 10), which contains all the model weights, along with the optimizer and scheduler state, and the current training stats.
|
||||
|
||||
#### Finetuning
|
||||
|
||||
A model (checkpoint) can be finetuned on another dataset using the following command:
|
||||
|
||||
```commandline
|
||||
./main.py --task fine-tune --model <model_checkpoint.pt> --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
|
||||
```
|
||||
|
||||
This will also save new checkpoints during training.
|
||||
|
||||
### Evaluation
|
||||
|
||||
It is also possible to evaluate the models.
|
||||
To evaluate the model's accuracy on a specific dataset, run
|
||||
|
||||
```commandline
|
||||
./main.py -t eval -ds <dataset name> -m <model_checkpoint.pt> --ntasks 1 -bs 512 --num-workers 10 --cpus-per-task 10 --time 10:00 (--local)
|
||||
```
|
||||
|
||||
You can run our center-bias, size-bias, and foreground-focus evaluations using the `eval-attr`, `eval-center-bias`, and `eval-size-bias` tasks (`-t` or `--task` argument).
|
||||
|
||||
### Further Arguments
|
||||
|
||||
There can be multiple further arguments and flags given to the scripts.
|
||||
The most important ones are
|
||||
|
||||
| Arg | Description |
|
||||
| :------------------------------ | :----------------------------------------------------- |
|
||||
| `--model <model>` | Model name or checkpoint. |
|
||||
| `--run_name <name for the run>` | Name or description of this training run. |
|
||||
| `--dataset <dataset>` | Specifies a dataset to use. |
|
||||
| `--task <task>` | Specifies a task. The default is `pre-train`. |
|
||||
| `--local` | Run on the local machine, not on a slurm cluster. |
|
||||
| `--epochs <epochs>` | Epochs to train. |
|
||||
| `--lr <lr>` | Learning rate. Default is 3e-3. |
|
||||
| `--batch_size <bs>` | Batch size. Default is 2048. |
|
||||
| `--weight_decay <wd>` | Weight decay. Default is 0.02. |
|
||||
| `--imsize <image resolution>` | Resulution of the image to train with. Default is 224. |
|
||||
|
||||
For a list of all arguments, run
|
||||
|
||||
```commandline
|
||||
./main.py --help
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
These are the models we support. Links are to original code sources. If no link is provided, we implemented the architecture from scratch, following the specific paper.
|
||||
|
||||
| Architecture | Versions |
|
||||
| :----------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [DeiT](https://github.com/facebookresearch/deit) | `deit_tiny_patch16_LS`, `deit_small_patch16_LS`, `deit_medium_patch16_LS`, `deit_base_patch16_LS`, `deit_large_patch16_LS`, `deit_huge_patch14_LS`, `deit_huge_patch14_52_LS`, `deit_huge_patch14_26x2_LS`, `deit_Giant_48_patch14_LS`, `deit_giant_40_patch14_LS`, `deit_small_patch16_36_LS`, `deit_small_patch16_36`, `deit_small_patch16_18x2_LS`, `deit_small_patch16_18x2`, `deit_base_patch16_18x2_LS`, `deit_base_patch16_18x2`, `deit_base_patch16_36x1_LS`, `deit_base_patch16_36x1` |
|
||||
| [ResNet](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnet.py) | `resnet18`, `resnet34`, `resnet26`, `resnet50`, `resnet101`, `wide_resnet50_2` |
|
||||
| [Swin](https://github.com/microsoft/Swin-Transformer) | `swin_tiny_patch4_window7`, `swin_small_patch4_window7`, `swin_base_patch4_window7`, `swin_large_patch4_window7` |
|
||||
| [ViT](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py) | `ViT-{Ti,S,B,L}/<patch_size>` |
|
||||
|
||||
## License
|
||||
|
||||
We release this code under the [MIT license](./LICENSE).
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this codebase in your project, please cite:
|
||||
@@ -0,0 +1,3 @@
|
||||
albumentations==2.0.5
|
||||
datasets==3.5.0
|
||||
nvidia-dali-cuda120==1.47.0
|
||||
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
|
||||
from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
from architectures.vit import TimmViT
|
||||
|
||||
__all__ = [
|
||||
"deit_tiny_patch16_224",
|
||||
"deit_small_patch16_224",
|
||||
"deit_base_patch16_224",
|
||||
"deit_tiny_distilled_patch16_224",
|
||||
"deit_small_distilled_patch16_224",
|
||||
"deit_base_distilled_patch16_224",
|
||||
"deit_base_patch16_384",
|
||||
"deit_base_distilled_patch16_384",
|
||||
]
|
||||
|
||||
|
||||
class DistilledVisionTransformer(TimmViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.dist_token, std=0.02)
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
self.head_dist.apply(self._init_weights)
|
||||
|
||||
def forward_features(self, x):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add the dist_token
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
return x[:, 0], x[:, 1]
|
||||
|
||||
def forward(self, x):
|
||||
x, x_dist = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
x_dist = self.head_dist(x_dist)
|
||||
if self.training:
|
||||
return x, x_dist
|
||||
else:
|
||||
# during inference, return the average of both classifier predictions
|
||||
return (x + x_dist) / 2
|
||||
|
||||
|
||||
def _clean_kwargs(kwargs):
|
||||
allowed_keys = {key for key in kwargs.keys() if not key.startswith("pretrain")}
|
||||
allowed_keys = {key for key in allowed_keys if not key.startswith("cache")}
|
||||
return {key: kwargs[key] for key in allowed_keys}
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = TimmViT(
|
||||
patch_size=16,
|
||||
embed_dim=192,
|
||||
depth=12,
|
||||
num_heads=3,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = TimmViT(
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = TimmViT(
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_distilled_patch16(
|
||||
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
|
||||
):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = DistilledVisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=192,
|
||||
depth=12,
|
||||
num_heads=3,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_distilled_patch16(
|
||||
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
|
||||
):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = DistilledVisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_distilled_patch16(
|
||||
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
|
||||
):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = DistilledVisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
@@ -0,0 +1,850 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# Taken from https://github.com/facebookresearch/deit with slight modifications
|
||||
|
||||
from loguru import logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
from timm.models.vision_transformer import Mlp, PatchEmbed, _cfg
|
||||
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
from timm.models.registry import register_model
|
||||
|
||||
from resizing_interface import ResizingInterface
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
Attention_block=Attention,
|
||||
Mlp_block=Mlp,
|
||||
init_values=1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Layer_scale_init_Block(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
Attention_block=Attention,
|
||||
Mlp_block=Mlp,
|
||||
init_values=1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Layer_scale_init_Block_paralx2(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
Attention_block=Attention,
|
||||
Mlp_block=Mlp,
|
||||
init_values=1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.norm11 = norm_layer(dim)
|
||||
self.attn = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.attn1 = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.norm21 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.mlp1 = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = (
|
||||
x
|
||||
+ self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
+ self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
|
||||
)
|
||||
x = (
|
||||
x
|
||||
+ self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
+ self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class Block_paralx2(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
Attention_block=Attention,
|
||||
Mlp_block=Mlp,
|
||||
init_values=1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.norm11 = norm_layer(dim)
|
||||
self.attn = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.attn1 = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.norm21 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.mlp1 = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.attn1(self.norm11(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.mlp1(self.norm21(x)))
|
||||
return x
|
||||
|
||||
|
||||
class hMLP_stem(nn.Module):
|
||||
"""hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf
|
||||
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
with slight modifications
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=nn.SyncBatchNorm,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
self.proj = torch.nn.Sequential(
|
||||
*[
|
||||
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=4, stride=4),
|
||||
norm_layer(embed_dim // 4),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(embed_dim // 4, embed_dim // 4, kernel_size=2, stride=2),
|
||||
norm_layer(embed_dim // 4),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),
|
||||
norm_layer(embed_dim),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class vit_models(nn.Module, ResizingInterface):
|
||||
"""Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support
|
||||
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
with slight modifications
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
global_pool=None,
|
||||
block_layers=Block,
|
||||
Patch_layer=PatchEmbed,
|
||||
act_layer=nn.GELU,
|
||||
Attention_block=Attention,
|
||||
Mlp_block=Mlp,
|
||||
dpr_constant=True,
|
||||
init_scale=1e-4,
|
||||
mlp_ratio_clstk=4.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dropout_rate = drop_rate
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_embed = Patch_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.patch_size = patch_size
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.embed_layer = Patch_layer
|
||||
self.pre_norm = False
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
self.no_embed_class = True
|
||||
|
||||
dpr = [drop_path_rate for i in range(depth)]
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
block_layers(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=0.0,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
Attention_block=Attention_block,
|
||||
Mlp_block=Mlp_block,
|
||||
init_values=init_scale,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")]
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
trunc_normal_(self.cls_token, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def set_num_classes(self, n_classes):
|
||||
super().set_num_classes(n_classes)
|
||||
self._init_weights(self.head)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {"pos_embed", "cls_token"}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def get_num_layers(self):
|
||||
return len(self.blocks)
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=""):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x, test=False):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
if test and x.isnan().any().item():
|
||||
logger.error("patch embedded input has nan value")
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
|
||||
x = x + self.pos_embed
|
||||
|
||||
if test and x.isnan().any().item():
|
||||
logger.error("position embedded input has a nan value")
|
||||
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
if test and x.isnan().any().item():
|
||||
logger.error("input with [CLS] has a nan value")
|
||||
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if test and x.isnan().any().item():
|
||||
logger.error(f"output of block {i} has a nan value")
|
||||
|
||||
x = self.norm(x)
|
||||
return x[:, 0]
|
||||
|
||||
def forward(self, x, test=False):
|
||||
|
||||
x = self.forward_features(x, test=test)
|
||||
|
||||
if self.dropout_rate:
|
||||
x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=192,
|
||||
depth=12,
|
||||
num_heads=3,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
name = "https://dl.fbaipublicfiles.com/deit/deit_3_small_" + str(img_size) + "_"
|
||||
if pretrained_21k:
|
||||
name += "21k.pth"
|
||||
else:
|
||||
name += "1k.pth"
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
patch_size=16,
|
||||
embed_dim=512,
|
||||
depth=12,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
name = "https://dl.fbaipublicfiles.com/deit/deit_3_medium_" + str(img_size) + "_"
|
||||
if pretrained_21k:
|
||||
name += "21k.pth"
|
||||
else:
|
||||
name += "1k.pth"
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
if pretrained:
|
||||
name = "https://dl.fbaipublicfiles.com/deit/deit_3_base_" + str(img_size) + "_"
|
||||
if pretrained_21k:
|
||||
name += "21k.pth"
|
||||
else:
|
||||
name += "1k.pth"
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
if pretrained:
|
||||
name = "https://dl.fbaipublicfiles.com/deit/deit_3_large_" + str(img_size) + "_"
|
||||
if pretrained_21k:
|
||||
name += "21k.pth"
|
||||
else:
|
||||
name += "1k.pth"
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
embed_dim=1280,
|
||||
depth=32,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
if pretrained:
|
||||
name = "https://dl.fbaipublicfiles.com/deit/deit_3_huge_" + str(img_size) + "_"
|
||||
if pretrained_21k:
|
||||
name += "21k_v1.pth"
|
||||
else:
|
||||
name += "1k_v1.pth"
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
embed_dim=1280,
|
||||
depth=52,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
embed_dim=1280,
|
||||
depth=26,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block_paralx2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# @register_model
|
||||
# def deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
# model = vit_models(
|
||||
# img_size=img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4,
|
||||
# norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
|
||||
#
|
||||
# return model
|
||||
|
||||
|
||||
# @register_model
|
||||
# def deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
# model = vit_models(
|
||||
# img_size=img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4,
|
||||
# norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
|
||||
# return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
embed_dim=1664,
|
||||
depth=48,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
embed_dim=1408,
|
||||
depth=40,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
# model.default_cfg = _cfg()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# Models from Three things everyone should know about Vision Transformers (https://arxiv.org/pdf/2203.09795.pdf)
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=36,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=36,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=18,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block_paralx2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=18,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Block_paralx2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=18,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block_paralx2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=18,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Block_paralx2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=36,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=36,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,111 @@
|
||||
from timm.models import register_model
|
||||
from timm.models.resnet import BasicBlock, Bottleneck
|
||||
from timm.models.resnet import ResNet as ResNetTimm
|
||||
from torch import nn
|
||||
|
||||
from resizing_interface import ResizingInterface
|
||||
|
||||
|
||||
class ResNet(ResNetTimm, ResizingInterface):
|
||||
"""The popular ResNet model with a ResizingInterface."""
|
||||
|
||||
def __init__(self, *args, global_pool="avg", **kwargs):
|
||||
"""Create ResNet model with resizing capabilities.
|
||||
|
||||
Args:
|
||||
*args: Arguments for the ResNet model.
|
||||
global_pool (str, optional): _description_. Defaults to "avg".
|
||||
**kwargs: Keyword arguments for the ResNet model (from Timm).
|
||||
|
||||
Keyword Args:
|
||||
block, layers, num_classes, in_chans, output_stride, cardinality, base_width, stem_width, stem_type, replace_stem_pool, block_reduce_first, down_kernel_size, avg_down, act_layer, norm_layer, aa_layer, drop_rate, drop_path_rate, drop_block_rate, zero_init_last, block_args
|
||||
|
||||
"""
|
||||
admissible_kwargs = [
|
||||
"block",
|
||||
"layers",
|
||||
"num_classes",
|
||||
"in_chans",
|
||||
"output_stride",
|
||||
"cardinality",
|
||||
"base_width",
|
||||
"stem_width",
|
||||
"stem_type",
|
||||
"replace_stem_pool",
|
||||
"block_reduce_first",
|
||||
"down_kernel_size",
|
||||
"avg_down",
|
||||
"act_layer",
|
||||
"norm_layer",
|
||||
"aa_layer",
|
||||
"drop_rate",
|
||||
"drop_path_rate",
|
||||
"drop_block_rate",
|
||||
"zero_init_last",
|
||||
"block_args",
|
||||
]
|
||||
for key in list(kwargs.keys()):
|
||||
if key not in admissible_kwargs:
|
||||
kwargs.pop(key)
|
||||
super().__init__(*args, global_pool=global_pool, **kwargs)
|
||||
self.global_pool_str = global_pool
|
||||
|
||||
def set_image_res(self, res):
|
||||
# resizing not needed for CNNs with pooling
|
||||
return
|
||||
|
||||
def set_num_classes(self, n_classes):
|
||||
if self.num_classes == n_classes:
|
||||
return
|
||||
if n_classes > 0:
|
||||
self.reset_classifier(num_classes=n_classes, global_pool=self.global_pool_str)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet18(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-18 model."""
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet34(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-34 model."""
|
||||
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet26(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-26 model."""
|
||||
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-50 model."""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet101(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-101 model."""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def wide_resnet50_2(pretrained=False, **kwargs):
|
||||
"""Construct a Wide ResNet-50-2 model.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
|
||||
return ResNet(**model_args, **kwargs)
|
||||
@@ -0,0 +1,893 @@
|
||||
# Taken from https://github.com/microsoft/Swin-Transformer with slight modifications
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Swin Transformer
|
||||
# Copyright (c) 2021 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Ze Liu
|
||||
# --------------------------------------------------------
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from loguru import logger
|
||||
from timm.models import register_model
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
from resizing_interface import ResizingInterface
|
||||
|
||||
try:
|
||||
import os
|
||||
import sys
|
||||
|
||||
kernel_path = os.path.abspath(os.path.join(".."))
|
||||
sys.path.append(kernel_path)
|
||||
from kernels.window_process.window_process import (
|
||||
WindowProcess,
|
||||
WindowProcessReverse,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
WindowProcess = None
|
||||
WindowProcessReverse = None
|
||||
logger.warning("Fused window process have not been installed. Please refer to get_started.md for installation.")
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = (
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
|
||||
|
||||
def flops(self, N):
|
||||
# calculate flops for 1 window with token length of N
|
||||
flops = 0
|
||||
# qkv = self.qkv(x)
|
||||
flops += N * self.dim * 3 * self.dim
|
||||
# attn = (q @ k.transpose(-2, -1))
|
||||
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
||||
# x = (attn @ v)
|
||||
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
||||
# x = self.proj(x)
|
||||
flops += N * self.dim * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
r"""Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resulotion.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
fused_window_process=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None),
|
||||
)
|
||||
w_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None),
|
||||
)
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
self.fused_window_process = fused_window_process
|
||||
|
||||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
if not self.fused_window_process:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
else:
|
||||
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
|
||||
else:
|
||||
shifted_x = x
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
if not self.fused_window_process:
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
|
||||
else:
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||
x = shifted_x
|
||||
x = x.view(B, H * W, C)
|
||||
x = shortcut + self.drop_path(x)
|
||||
|
||||
# FFN
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
||||
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
||||
)
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.input_resolution
|
||||
# norm1
|
||||
flops += self.dim * H * W
|
||||
# W-MSA/SW-MSA
|
||||
nW = H * W / self.window_size / self.window_size
|
||||
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
||||
# mlp
|
||||
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
||||
# norm2
|
||||
flops += self.dim * H * W
|
||||
return flops
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r"""Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = H * W * self.dim
|
||||
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
fused_window_process=False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=(drop_path[i] if isinstance(drop_path, list) else drop_path),
|
||||
norm_layer=norm_layer,
|
||||
fused_window_process=fused_window_process,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
for blk in self.blocks:
|
||||
flops += blk.flops()
|
||||
if self.downsample is not None:
|
||||
flops += self.downsample.flops()
|
||||
return flops
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
r"""Image to Patch Embedding
|
||||
|
||||
Args:
|
||||
img_size (int): Image size. Default: 224.
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [
|
||||
img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1],
|
||||
]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.patches_resolution = patches_resolution
|
||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
Ho, Wo = self.patches_resolution
|
||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||
if self.norm is not None:
|
||||
flops += Ho * Wo * self.embed_dim
|
||||
return flops
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module, ResizingInterface):
|
||||
r"""Swin Transformer
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
|
||||
Args:
|
||||
img_size (int | tuple(int)): Input image size. Default 224
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
num_classes (int): Number of classes for classification head. Default: 1000
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
depths (tuple(int)): Depth of each Swin Transformer layer.
|
||||
num_heads (tuple(int)): Number of attention heads in different layers.
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
||||
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm,
|
||||
ape=False,
|
||||
patch_norm=True,
|
||||
use_checkpoint=False,
|
||||
fused_window_process=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.patch_norm = patch_norm
|
||||
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.img_size = img_size
|
||||
self.fused_window_process = fused_window_process
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
self.embed_layer = PatchEmbed
|
||||
self.patch_size = patch_size
|
||||
self.in_chans = in_chans
|
||||
self.norm_layer = norm_layer
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
dim=int(embed_dim * 2**i_layer),
|
||||
input_resolution=(
|
||||
patches_resolution[0] // (2**i_layer),
|
||||
patches_resolution[1] // (2**i_layer),
|
||||
),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
fused_window_process=fused_window_process,
|
||||
)
|
||||
self.layers.append(layer)
|
||||
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def set_num_classes(self, n_classes):
|
||||
"""Reset the classification head with a new number of classes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_classes : int
|
||||
new number of classes
|
||||
"""
|
||||
if n_classes == self.num_classes:
|
||||
return
|
||||
self.head = nn.Linear(self.num_features, n_classes) if n_classes > 0 else nn.Identity()
|
||||
self.num_classes = n_classes
|
||||
|
||||
nn.init.trunc_normal_(self.head.weight, std=0.02)
|
||||
nn.init.constant_(self.head.bias, 0)
|
||||
|
||||
def set_image_res(self, res):
|
||||
if res == self.img_size:
|
||||
return
|
||||
|
||||
old_patch_embed_state = copy(self.patch_embed.state_dict())
|
||||
self.patch_embed = self.embed_layer(
|
||||
img_size=res,
|
||||
patch_size=self.patch_size,
|
||||
in_chans=self.in_chans,
|
||||
embed_dim=self.embed_dim,
|
||||
norm_layer=self.norm_layer if self.patch_norm else None,
|
||||
)
|
||||
self.patch_embed.load_state_dict(old_patch_embed_state)
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
for i_layer, layer in enumerate(self.layers):
|
||||
input_resolution = (
|
||||
patches_resolution[0] // (2**i_layer),
|
||||
patches_resolution[1] // (2**i_layer),
|
||||
)
|
||||
layer.input_resolution = input_resolution
|
||||
downsample = PatchMerging if (i_layer < self.num_layers - 1) else None
|
||||
if downsample is not None:
|
||||
layer.downsample = downsample(input_resolution, dim=layer.dim, norm_layer=self.norm_layer)
|
||||
|
||||
for block in layer.blocks:
|
||||
block.input_resolution = input_resolution
|
||||
|
||||
if min(input_resolution) <= block.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
block.shift_size = 0
|
||||
block.window_size = min(block.input_resolution)
|
||||
assert 0 <= block.shift_size < block.window_size, "shift_size must in 0-window_size"
|
||||
|
||||
if block.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = block.input_resolution
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (
|
||||
slice(0, -block.window_size),
|
||||
slice(-block.window_size, -block.shift_size),
|
||||
slice(-block.shift_size, None),
|
||||
)
|
||||
w_slices = (
|
||||
slice(0, -block.window_size),
|
||||
slice(-block.window_size, -block.shift_size),
|
||||
slice(-block.shift_size, None),
|
||||
)
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, block.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, block.window_size * block.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0)
|
||||
)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
block.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
if self.ape:
|
||||
orig_size = int((self.absolute_pos_embed.shape[-2]) ** 0.5)
|
||||
new_size = int(self.patch_embed.num_patches**0.5)
|
||||
pos_tokens = self.absolute_pos_embed[:, :]
|
||||
# make it shape rest x embed_dim x orig_size x orig_size
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
|
||||
pos_tokens = nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
# make it shape rest x new_size^2 x embed_dim
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
self.absolute_pos_embed = nn.Parameter(pos_tokens.contiguous())
|
||||
|
||||
self.img_size = res
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {"absolute_pos_embed"}
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
return {"relative_position_bias_table"}
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.ape:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
x = self.norm(x) # B L C
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
flops += self.patch_embed.flops()
|
||||
for i, layer in enumerate(self.layers):
|
||||
flops += layer.flops()
|
||||
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2**self.num_layers)
|
||||
flops += self.num_features * self.num_classes
|
||||
return flops
|
||||
|
||||
|
||||
swin_sizes = {
|
||||
"Ti": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]),
|
||||
"S": dict(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24]),
|
||||
"B": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
|
||||
"L": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48]),
|
||||
}
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_tiny_patch4_window7(pretrained=False, img_size=224, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
size = swin_sizes["Ti"]
|
||||
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_small_patch4_window7(pretrained=False, img_size=224, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
size = swin_sizes["S"]
|
||||
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window7(pretrained=False, img_size=224, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
size = swin_sizes["B"]
|
||||
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window7(pretrained=False, img_size=224, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
size = swin_sizes["L"]
|
||||
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_ashim(pretrained=False, img_size=112, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
size = dict(embed_dim=384, depths=[12], num_heads=[12])
|
||||
if "num_heads" in kwargs:
|
||||
kwargs["num_heads"] = [kwargs["num_heads"]]
|
||||
return SwinTransformer(img_size=img_size, in_chans=3, patch_size=2, window_size=7, **{**size, **kwargs})
|
||||
@@ -0,0 +1,225 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from timm.models.vision_transformer import (
|
||||
Attention,
|
||||
Block,
|
||||
PatchEmbed,
|
||||
VisionTransformer,
|
||||
)
|
||||
|
||||
from resizing_interface import ResizingInterface
|
||||
|
||||
|
||||
class _MatrixSaveAttn(Attention):
|
||||
attn_mat = None
|
||||
|
||||
@classmethod
|
||||
def cast(cls, attn: Attention):
|
||||
assert isinstance(attn, Attention), "Can only save attention from Timms attention class"
|
||||
attn.__class__ = cls
|
||||
assert isinstance(attn, _MatrixSaveAttn)
|
||||
return attn
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
self.attn_mat = attn.detach()
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
def _index_picker(tensor, idx=-1):
|
||||
"""Pick a specific index from a tensor.
|
||||
|
||||
tensor: B x N x D -> B x D, by picking idx from N.
|
||||
|
||||
Args:
|
||||
tensor (toch.tensor): tensor to pick from.
|
||||
idx (int, optional): index to pick. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
torch.tensor: index from tensor
|
||||
|
||||
"""
|
||||
return tensor[..., idx, :] # B x N x D -> B x D
|
||||
|
||||
|
||||
class TimmViT(VisionTransformer, ResizingInterface):
|
||||
"""Wrapper for *VisionTransformer* from *timm* library (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool="token",
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
init_values=None,
|
||||
class_token=True,
|
||||
no_embed_class=True,
|
||||
pre_norm=False,
|
||||
fc_norm=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
weight_init="",
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
block_fn=Block,
|
||||
save_attention_maps=False,
|
||||
fused_attn=True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a Vision Transformer model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
img_size : int
|
||||
image dimensions -> img_size x img_size
|
||||
patch_size : int
|
||||
patch_size
|
||||
in_chans : int
|
||||
number of image channels
|
||||
num_classes : int
|
||||
number of classes for classification head
|
||||
global_pool : str,int
|
||||
type of global pooling for final sequence (default: 'token'), or index for token to be taken
|
||||
embed_dim : int
|
||||
embedding dimension
|
||||
depth : int
|
||||
number of transformer layers
|
||||
num_heads : int
|
||||
number of transformer heads
|
||||
mlp_ratio : float
|
||||
ratio of feed forward (mlp) hidden dimension to embedding dimension
|
||||
qkv_bias : bool
|
||||
enable bias for query, key, and value (qkv) embeddings
|
||||
qk_norm : bool
|
||||
normalize query and key embeddings
|
||||
init_values : float
|
||||
layer scale initial values
|
||||
class_token : bool
|
||||
use a class token [CLS]
|
||||
no_embed_class : bool
|
||||
no positional embedding for the class token
|
||||
pre_norm : bool
|
||||
use pre-norm architecture (norm before the blocks, not after)
|
||||
fc_norm : bool
|
||||
norm after pool (used when global_pool == 'avg')
|
||||
drop_rate : float
|
||||
dropout rate
|
||||
attn_drop_rate : float
|
||||
dropout rate in the attention module
|
||||
drop_path_rate : float
|
||||
drop path rate (stochastic depth)
|
||||
weight_init : str
|
||||
scheme for weight initialization
|
||||
embed_layer : nn.Module
|
||||
patch embedding layer
|
||||
norm_layer : nn.Module
|
||||
normalization layer
|
||||
act_layer : nn.Module
|
||||
activation function
|
||||
block_fn : nn.Module
|
||||
which block structure to use; for parallel attention layers, ...
|
||||
save_attention_maps : bool
|
||||
save attention maps for each block
|
||||
fused_attn : bool
|
||||
use fused attention
|
||||
kwargs : dict
|
||||
additional arguments (will be ignored)
|
||||
|
||||
"""
|
||||
init_kwargs = dict(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
num_classes=num_classes,
|
||||
global_pool=global_pool if isinstance(global_pool, str) else "avg",
|
||||
embed_dim=embed_dim,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
init_values=init_values,
|
||||
class_token=class_token,
|
||||
no_embed_class=no_embed_class,
|
||||
pre_norm=pre_norm,
|
||||
fc_norm=fc_norm,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
weight_init=weight_init,
|
||||
embed_layer=embed_layer,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
block_fn=block_fn,
|
||||
)
|
||||
|
||||
self.new_version = False # TODO: check based on the timm version
|
||||
self.global_pool = global_pool
|
||||
if isinstance(global_pool, int):
|
||||
self.attn_pool = partial(_index_picker, idx=global_pool)
|
||||
|
||||
if self.new_version:
|
||||
init_kwargs["qk_norm"] = qk_norm
|
||||
init_kwargs["proj_drop_rate"] = drop_rate
|
||||
super(TimmViT, self).__init__(**init_kwargs)
|
||||
self.embed_layer = embed_layer
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.in_chans = in_chans
|
||||
self.pre_norm = pre_norm
|
||||
self.class_token = class_token
|
||||
self.no_embed_class = no_embed_class
|
||||
self.num_classes = num_classes
|
||||
self.num_heads = num_heads
|
||||
self.depth = depth
|
||||
self.save_attention_maps = save_attention_maps
|
||||
if save_attention_maps:
|
||||
self.do_save_attention_maps()
|
||||
try:
|
||||
for block in self.blocks:
|
||||
block.attn.fused_attn = fused_attn and self.blocks[0].attn.fused_attn
|
||||
use_fused = self.blocks[0].attn.fused_attn
|
||||
logger.info(f"Use fused attention: {use_fused}")
|
||||
except: # I'm lazy for now # noqa: E722
|
||||
pass
|
||||
|
||||
def do_save_attention_maps(self):
|
||||
self.save_attention_maps = True
|
||||
for block in self.blocks:
|
||||
block.attn = _MatrixSaveAttn.cast(block.attn)
|
||||
|
||||
def attention_maps(self):
|
||||
assert self.save_attention_maps, "Have to save attention maps first"
|
||||
return [getattr(block.attn, "attn_mat", None) for block in self.blocks]
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
if self.new_version:
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
x = self.blocks(x)
|
||||
return self.norm(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
return self.forward_head(x)
|
||||
118
AAAI Supplementary Material/Model Training Code/config.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Default configuration parameters.
|
||||
|
||||
Attributes:
|
||||
default_kwargs (dict): Default hyperparameters for the training process.
|
||||
slurm_defaults (dict): Default values for SLURM batch job settings.
|
||||
|
||||
"""
|
||||
|
||||
from paths_config import user
|
||||
|
||||
default_kwargs = {
|
||||
"amp": True,
|
||||
"aug_color_jitter_factor": 0.3,
|
||||
"aug_crop": True,
|
||||
"aug_cutmix_alpha": 1.0,
|
||||
"aug_flip": True,
|
||||
"aug_gauss_blur": True,
|
||||
"aug_grayscale": True,
|
||||
"aug_mixup_alpha": 0.0,
|
||||
"aug_normalize": True,
|
||||
"aug_rand_rot": 0,
|
||||
"aug_random_erase_count": 1,
|
||||
"aug_random_erase_mode": "pixel",
|
||||
"aug_random_erase_prob": 0.0,
|
||||
"aug_repeated_augment_repeats": 1,
|
||||
"aug_resize": True,
|
||||
"aug_solarize": True,
|
||||
"augment_engine": "torchvision",
|
||||
"augment_strategy": "3-augment",
|
||||
"auto_augment_strategy": "rand-m9-mstd0.5-inc1",
|
||||
"batch_size": 2048,
|
||||
"compile_model": False,
|
||||
"cuda": True,
|
||||
"custom_dataset_path": None,
|
||||
"debug": False,
|
||||
"drop_path_rate": 0.05,
|
||||
"dropout": 0.0,
|
||||
"eval_amp": True,
|
||||
"experiment_name": "none",
|
||||
"fused_attn": True,
|
||||
"gather_stats_during_training": True,
|
||||
"imsize": 224,
|
||||
"input_dim": None,
|
||||
"keep_interm_states": 2,
|
||||
"label_smoothing": 0.1,
|
||||
"layer_scale": True,
|
||||
"layer_scale_init_values": 1e-4,
|
||||
"log_level": "info",
|
||||
"loss": "ce",
|
||||
"loss_weight": "none",
|
||||
"lr": 3e-3,
|
||||
"max_grad_norm": 1.0,
|
||||
"max_seq_len": None,
|
||||
"min_lr": 1e-5,
|
||||
"momentum": 0.0,
|
||||
"num_heads": None,
|
||||
"num_workers": 44,
|
||||
"opt": "fusedlamb",
|
||||
"opt_eps": 1e-7,
|
||||
"pin_memory": False,
|
||||
"pre_norm": False,
|
||||
"prefetch_factor": 2,
|
||||
"qkv_bias": True,
|
||||
"run_name": None,
|
||||
"save_epochs": 10,
|
||||
"sched": "cosine",
|
||||
"seed": None,
|
||||
"shuffle": True,
|
||||
"tqdm": True,
|
||||
"wandb": True,
|
||||
"warmup_epochs": 5,
|
||||
"warmup_lr": 1e-6,
|
||||
"warmup_sched": "linear",
|
||||
"weight_decay": 0.02,
|
||||
"weighted_sampler": False,
|
||||
}
|
||||
# , 'model_ema': True, 'model_ema_decay': 0.99996}
|
||||
|
||||
|
||||
deit_kwargs = {
|
||||
"aug_mixup_alpha": 0.8,
|
||||
"aug_repeated_augment_repeats": 3,
|
||||
"augment_strategy": "deit",
|
||||
"aug_random_erase_prob": 0.25,
|
||||
"batch_size": 1024,
|
||||
"lr": 1e-3,
|
||||
"max_grad_norm": 0.0,
|
||||
"num_workers": 10,
|
||||
"opt": "adamw",
|
||||
"opt_eps": 1e-8,
|
||||
"weight_decay": 0.05,
|
||||
}
|
||||
|
||||
|
||||
def get_default_kwargs(settings="deitiii"):
|
||||
if settings.lower() == "deitiii":
|
||||
return default_kwargs
|
||||
if settings.lower() == "deit":
|
||||
return {**default_kwargs, **deit_kwargs}
|
||||
raise NotImplementedError(f"No such defaults setting: {settings}")
|
||||
|
||||
|
||||
slurm_defaults = {
|
||||
"after_job": None,
|
||||
"container_image": f"PATH/TO/ENROOT/IMAGE",
|
||||
"container_mounts": f'MOUNT_ALL_IMPORTANT_STORAGE_SERVERS_HERE,"`pwd`":"`pwd`"',
|
||||
"container_workdir": '"`pwd`"',
|
||||
"cpus_per_task": 24,
|
||||
"exclude": None,
|
||||
"export": "ALL,TQDM_DISABLE=1",
|
||||
"job_name": None,
|
||||
"mem_per_gpu": 90,
|
||||
"nodes": 1,
|
||||
"ntasks": 4,
|
||||
"partition": ["A100", "H100", "H200"],
|
||||
"task_prolog": None,
|
||||
"time": "1-0",
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import numpy as np
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from datadings.torch import CompressedToPIL
|
||||
|
||||
|
||||
class AlbumTorchCompose(A.Compose):
|
||||
"""Compose albumentation augmentations in a way that works with PIL images and datadings."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Pass to A.Compose."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.to_pil = CompressedToPIL()
|
||||
|
||||
def __call__(self, image, mask=None, **kwargs):
|
||||
if isinstance(image, bytes):
|
||||
image = self.to_pil(image)
|
||||
if mask is not None and len(mask) == 0:
|
||||
mask = None
|
||||
if not isinstance(image, np.ndarray):
|
||||
image = np.array(image)
|
||||
if mask is not None and not isinstance(mask, np.ndarray):
|
||||
mask = np.array(mask)
|
||||
if mask is None:
|
||||
return super().__call__(image=image, **kwargs)["image"]
|
||||
return super().__call__(image=image, mask=mask, **kwargs)
|
||||
|
||||
|
||||
class PILToNP(A.DualTransform):
|
||||
"""Convert PIL image to numpy array."""
|
||||
|
||||
def apply(self, image, **params):
|
||||
return np.array(image)
|
||||
|
||||
def apply_to_mask(self, image, **params):
|
||||
return np.array(image)
|
||||
|
||||
def get_transform_init_args_names(self):
|
||||
return ()
|
||||
|
||||
|
||||
class AlbumCompressedToPIL(A.DualTransform):
|
||||
"""Convert compressed image to PIL image."""
|
||||
|
||||
def apply(self, img, **params):
|
||||
return self.to_pil(img)
|
||||
|
||||
def apply_to_mask(self, img, **params):
|
||||
return self.to_pil(img)
|
||||
|
||||
def get_transform_init_args_names(self):
|
||||
return ()
|
||||
|
||||
|
||||
def minimal_augment(args, test=False):
|
||||
"""Get minimal augmentations for training or testing.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): arguments
|
||||
test (bool, optional): if True, return test augmentations. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List: Augmentation list
|
||||
"""
|
||||
augs = []
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(A.CenterCrop(args.imsize, args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(A.RandomCrop(args.imsize, args.imsize))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(A.HorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
|
||||
|
||||
augs.append(ToTensorV2())
|
||||
return augs
|
||||
|
||||
|
||||
def three_augment(args, as_list=False, test=False):
|
||||
"""Create the data augmentation.
|
||||
|
||||
Args:
|
||||
args (Namespace): arguments
|
||||
as_list (bool): return list of transformations, not composed transformation
|
||||
test (bool): In eval mode? If False => train mode
|
||||
|
||||
Returns:
|
||||
torch.nn.Module | list[torch.nn.Module]: composed transformation or list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(A.CenterCrop(args.imsize, args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(A.RandomCrop(args.imsize, args.imsize, pad_if_needed=True, border_mode=cv2.BORDER_REFLECT))
|
||||
|
||||
if not test:
|
||||
if args.aug_flip:
|
||||
augs.append(A.HorizontalFlip(p=0.5))
|
||||
|
||||
augs_choice = []
|
||||
if args.aug_grayscale:
|
||||
augs_choice.append(A.ToGray(p=1, num_output_channels=3))
|
||||
|
||||
if args.aug_solarize:
|
||||
augs_choice.append(A.Solarize(p=1, threshold_range=(0.5, 0.5)))
|
||||
|
||||
if args.aug_gauss_blur:
|
||||
augs_choice.append(A.GaussianBlur(p=1, sigma_limit=(0.2, 2.0), blur_limit=(7, 7)))
|
||||
|
||||
if len(augs_choice) > 0:
|
||||
augs.append(A.OneOf(augs_choice))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
A.ColorJitter(
|
||||
brightness=args.aug_color_jitter_factor,
|
||||
contrast=args.aug_color_jitter_factor,
|
||||
saturation=args.aug_color_jitter_factor,
|
||||
hue=0.0,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
|
||||
|
||||
augs.append(ToTensorV2())
|
||||
|
||||
if as_list:
|
||||
return augs
|
||||
return AlbumTorchCompose(augs)
|
||||
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class CounterAnimal(Dataset):
|
||||
"""Dataset to load the CounterAnimal dataset with ImageNet labels."""
|
||||
|
||||
def __init__(self, base_path, mode, transform=None, target_transform=None, train=False):
|
||||
"""Create the dataset.
|
||||
|
||||
Args:
|
||||
base_path (str): path to the base folder (the one where the class folders are in)
|
||||
mode (str): mode/variant of the dataset (common/counter)
|
||||
transform: Image augmentation
|
||||
target_transform: label augmentation
|
||||
train: train or test set. Train set is not supported
|
||||
"""
|
||||
super().__init__()
|
||||
self.base = base_path
|
||||
assert mode in ["counter", "common"], f"Supported modes are counter and common, but got '{mode}'"
|
||||
assert not train, "CounterAnimal only consists of test data, not training data."
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
self.index = []
|
||||
for class_folder in os.listdir(self.base):
|
||||
if not os.path.isdir(os.path.join(self.base, class_folder)):
|
||||
continue
|
||||
# print(f"looking in folder {class_folder}")
|
||||
class_idx = int(class_folder.split(" ")[0])
|
||||
for variant_folder in os.listdir(os.path.join(self.base, class_folder)):
|
||||
# print(f"\tlooking in variant {variant_folder}")
|
||||
if not variant_folder.startswith(mode):
|
||||
# print("\t\tskip")
|
||||
continue
|
||||
|
||||
_folder = os.path.join(self.base, class_folder, variant_folder)
|
||||
# print(f"\t\tadding {len(os.listdir(_folder))} files to index")
|
||||
for file in os.listdir(_folder):
|
||||
if file.lower().split(".")[-1] in ["jpg", "jpeg", "png"]:
|
||||
self.index.append((os.path.join(_folder, file), class_idx))
|
||||
|
||||
# print(f"loaded {len(self.index)} images into the index: {self.index[:5]}...")
|
||||
assert len(self.index) > 0, "did not find any images :("
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path, label = self.index[idx]
|
||||
|
||||
img = Image.open(path).convert("RGB")
|
||||
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform:
|
||||
label = self.target_transform(label)
|
||||
|
||||
return img, label
|
||||
@@ -0,0 +1,145 @@
|
||||
from nvidia.dali import fn, pipeline_def, types
|
||||
|
||||
# see https://docs.nvidia.com/deeplearning/dali/user-guide/docs/plugins/pytorch_dali_proxy.html
|
||||
|
||||
|
||||
@pipeline_def
|
||||
def minimal_augment(args, test=False):
|
||||
"""Minimal Augmentation set for images.
|
||||
|
||||
Contains only resize, crop, flip, to tensor and normalize.
|
||||
|
||||
Args:
|
||||
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
|
||||
test (bool, optional): On the test set? Defaults to False.
|
||||
|
||||
Returns:
|
||||
images: augmented images.
|
||||
|
||||
"""
|
||||
images = fn.external_source(name="images", no_copy=True)
|
||||
|
||||
if args.aug_resize:
|
||||
images = fn.resize(images, size=args.imsize, mode="not_smaller")
|
||||
|
||||
if test and args.aug_crop:
|
||||
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
|
||||
elif args.aug_crop:
|
||||
images = fn.crop(
|
||||
images,
|
||||
crop=(args.imsize, args.imsize),
|
||||
crop_pos_x=fn.random.uniform(range=(0, 1)),
|
||||
crop_pos_y=fn.random.uniform(range=(0, 1)),
|
||||
)
|
||||
|
||||
# if not test and args.aug_flip:
|
||||
# images = fn.flip(images, horizontal=fn.random.coin_flip())
|
||||
|
||||
# if args.aug_normalize:
|
||||
# images = fn.normalize(
|
||||
# images,
|
||||
# mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
||||
# std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
|
||||
# dtype=types.FLOAT,
|
||||
# )
|
||||
return fn.crop_mirror_normalize(
|
||||
images,
|
||||
dtype=types.FLOAT,
|
||||
output_layout="CHW",
|
||||
crop=(args.imsize, args.imsize),
|
||||
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
|
||||
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
|
||||
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
|
||||
)
|
||||
|
||||
|
||||
def dali_solarize(images, threshold=128):
|
||||
"""Solarize implementation for nvidia DALI.
|
||||
|
||||
Args:
|
||||
images (DALI Tensor): Images to solarize.
|
||||
threshold (int, optional): Threshold for solarization. Defaults to 128.
|
||||
|
||||
Returns:
|
||||
images: solarized images.
|
||||
|
||||
"""
|
||||
inv_images = types.Constant(255).uint8() - images
|
||||
mask = (images >= threshold) * types.Constant(1).uint8()
|
||||
return mask * inv_images + (types.Constant(1).uint8() ^ mask) * images
|
||||
|
||||
|
||||
@pipeline_def(enable_conditionals=True)
|
||||
def three_augment(args, test=False):
|
||||
"""3-augment data augmentation pipeline for nvidia DALI.
|
||||
|
||||
Args:
|
||||
args (namespace): augmentation arguments.
|
||||
test (bool, optional): Test (or train) split. Defaults to False.
|
||||
|
||||
Returns:
|
||||
images: augmented images.
|
||||
|
||||
"""
|
||||
images = fn.external_source(name="images", no_copy=True)
|
||||
|
||||
if args.aug_resize:
|
||||
images = fn.resize(images, size=args.imsize, mode="not_smaller")
|
||||
|
||||
if test and args.aug_crop:
|
||||
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
|
||||
elif args.aug_crop:
|
||||
images = fn.crop(
|
||||
images,
|
||||
crop=(args.imsize, args.imsize),
|
||||
crop_pos_x=fn.random.uniform(range=(0, 1)),
|
||||
crop_pos_y=fn.random.uniform(range=(0, 1)),
|
||||
)
|
||||
|
||||
if not test:
|
||||
choices = []
|
||||
# choice = fn.random.choice(3)
|
||||
# print(images.layout())
|
||||
choice_ps = [1 * args.aug_grayscale, 1 * args.aug_solarize, 1 * args.aug_gauss_blur]
|
||||
choice_ps = [c / sum(choice_ps) for c in choice_ps]
|
||||
choice = fn.random.choice(
|
||||
[0, 1, 2],
|
||||
p=choice_ps,
|
||||
)
|
||||
|
||||
if choice == 0:
|
||||
images = fn.color_space_conversion(
|
||||
fn.color_space_conversion(images, image_type=types.RGB, output_type=types.GRAY),
|
||||
image_type=types.GRAY,
|
||||
output_type=types.RGB,
|
||||
)
|
||||
|
||||
elif choice == 1:
|
||||
images = dali_solarize(images, threshold=128)
|
||||
elif choice == 2:
|
||||
images = fn.gaussian_blur(images, window_size=7, sigma=fn.random.uniform(range=(0.2, 2.0)))
|
||||
|
||||
if len(choices) > 0:
|
||||
images = fn.random.choice(choices)
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
images = fn.color_twist(
|
||||
images,
|
||||
brightness=fn.random.uniform(
|
||||
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
|
||||
),
|
||||
contrast=fn.random.uniform(range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)),
|
||||
saturation=fn.random.uniform(
|
||||
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
|
||||
),
|
||||
)
|
||||
|
||||
return fn.crop_mirror_normalize(
|
||||
images,
|
||||
dtype=types.FLOAT,
|
||||
output_layout="CHW",
|
||||
crop=(args.imsize, args.imsize),
|
||||
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
|
||||
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
|
||||
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
|
||||
)
|
||||
@@ -0,0 +1,407 @@
|
||||
from random import uniform
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import torchvision
|
||||
from datadings.torch import CompressedToPIL
|
||||
from datadings.torch import Dataset as DDDataset
|
||||
from PIL import ImageFilter
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
ColorJitter,
|
||||
Compose,
|
||||
GaussianBlur,
|
||||
Grayscale,
|
||||
InterpolationMode,
|
||||
Normalize,
|
||||
RandomChoice,
|
||||
RandomCrop,
|
||||
RandomHorizontalFlip,
|
||||
RandomResizedCrop,
|
||||
RandomSolarize,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
_image_and_target_transforms = [
|
||||
torchvision.transforms.RandomCrop,
|
||||
torchvision.transforms.RandomHorizontalFlip,
|
||||
torchvision.transforms.CenterCrop,
|
||||
torchvision.transforms.RandomRotation,
|
||||
torchvision.transforms.RandomAffine,
|
||||
torchvision.transforms.RandomResizedCrop,
|
||||
torchvision.transforms.RandomRotation,
|
||||
]
|
||||
|
||||
|
||||
def apply_dense_transforms(x, y, transforms: torchvision.transforms.transforms.Compose):
|
||||
"""Apply some transfomations to both image and target.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): image
|
||||
y (torch.Tensor): target (image)
|
||||
transforms (torchvision.transforms.transforms.Compose): transformations to apply
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: (x, y) with applyed transformations
|
||||
|
||||
"""
|
||||
for trans in transforms.transforms:
|
||||
if isinstance(trans, torchvision.transforms.RandomResizedCrop):
|
||||
params = trans.get_params(x, trans.scale, trans.ratio)
|
||||
x = F.resized_crop(x, *params, trans.size, trans.interpolation, antialias=trans.antialias)
|
||||
y = F.resized_crop(y.unsqueeze(0), *params, trans.size, 0).squeeze(0) # nearest neighbor interpolation
|
||||
elif isinstance(trans, Resize):
|
||||
pre_shape = x.shape
|
||||
x = trans(x)
|
||||
if x.shape != pre_shape:
|
||||
y = F.resize(y.unsqueeze(0), trans.size, 0, trans.max_size, trans.antialias).squeeze(
|
||||
0
|
||||
) # nearest neighbor interpolation
|
||||
elif any(isinstance(trans, simul_transform) for simul_transform in _image_and_target_transforms):
|
||||
xy = torch.cat([x, y.unsqueeze(0).float()], dim=0)
|
||||
xy = trans(xy)
|
||||
x, y = xy[:-1], xy[-1].long()
|
||||
elif isinstance(trans, torchvision.transforms.ToTensor):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = trans(x)
|
||||
else:
|
||||
x = trans(x)
|
||||
|
||||
return x, y
|
||||
|
||||
|
||||
def get_hf_transform(transform_f, trgt_transform_f=None, image_key="image"):
|
||||
"""Convert the transform function to a huggingface compatible transform function.
|
||||
|
||||
Args:
|
||||
transform_f (callable): Image transform.
|
||||
trgt_transform (callable, optional): Target transform. Defaults to None.
|
||||
image_key (str, optional): Key for the image in the hf ds return dict. Defaults to "image".
|
||||
"""
|
||||
|
||||
def _transform(samples):
|
||||
try:
|
||||
samples[image_key] = [transform_f(im) for im in samples[image_key]]
|
||||
if trgt_transform_f is not None:
|
||||
samples["label"] = [trgt_transform_f(tgt) for tgt in samples["label"]]
|
||||
except TypeError as e:
|
||||
print(f"Type error when transforming samples: {samples}")
|
||||
raise e
|
||||
return samples
|
||||
|
||||
return _transform
|
||||
|
||||
|
||||
class DDDecodeDataset(DDDataset):
|
||||
"""Datadings dataset with image decoding before transform."""
|
||||
|
||||
def __init__(self, *args, transform=None, target_transform=None, transforms=None, **kwargs):
|
||||
"""Create datadings dataset.
|
||||
|
||||
Args:
|
||||
transform (callable, optional): Image transform. Overrides transforms['image']. Defaults to None.
|
||||
target_transform (callable, optional): Label transform. Overrides transforms['label']. Defaults to None.
|
||||
transforms (dict[str, callable], optional): Dict of transforms for each key. Defaults to None.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
if transforms is None:
|
||||
transforms = {}
|
||||
self._decode_transform = transform if transform is not None else transforms.get("image", None)
|
||||
self._decode_target_transform = (
|
||||
target_transform if target_transform is not None else transforms.get("label", None)
|
||||
)
|
||||
|
||||
self.ctp = CompressedToPIL()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = super().__getitem__(idx)
|
||||
img, lbl = sample["image"], sample["label"]
|
||||
if isinstance(img, bytes):
|
||||
img = self.ctp(img)
|
||||
if self._decode_transform is not None:
|
||||
img = self._decode_transform(img)
|
||||
if self._decode_target_transform is not None:
|
||||
lbl = self._decode_target_transform(lbl)
|
||||
return img, lbl
|
||||
|
||||
|
||||
def minimal_augment(args, test=False):
|
||||
"""Minimal Augmentation set for images.
|
||||
|
||||
Contains only resize, crop, flip, to tensor and normalize.
|
||||
|
||||
Args:
|
||||
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
|
||||
test (bool, optional): On the test set? Defaults to False.
|
||||
|
||||
Returns:
|
||||
List: Augmentation list
|
||||
|
||||
"""
|
||||
augs = []
|
||||
augs.append(ToTensor())
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
def three_augment(args, as_list=False, test=False):
|
||||
"""Create the data augmentation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Args:
|
||||
arguments
|
||||
as_list : bool
|
||||
return list of transformations, not composed transformation
|
||||
test : bool
|
||||
In eval mode? If False => train mode
|
||||
|
||||
Returns:
|
||||
-------
|
||||
torch.nn.Module | list[torch.nn.Module]
|
||||
composed transformation of list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
augs.append(ToTensor())
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
|
||||
|
||||
if not test:
|
||||
if args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
augs_choice = []
|
||||
if args.aug_grayscale:
|
||||
augs_choice.append(Grayscale(num_output_channels=3))
|
||||
if args.aug_solarize:
|
||||
augs_choice.append(RandomSolarize(threshold=0.5, p=1.0))
|
||||
if args.aug_gauss_blur:
|
||||
# TODO: check kernel size?
|
||||
augs_choice.append(GaussianBlur(kernel_size=7, sigma=(0.2, 2.0)))
|
||||
# augs_choice.append(QuickGaussBlur())
|
||||
|
||||
if len(augs_choice) > 0:
|
||||
augs.append(RandomChoice(augs_choice))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
ColorJitter(
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
if as_list:
|
||||
return augs
|
||||
return Compose(augs)
|
||||
|
||||
|
||||
def segment_augment(args, test=False):
|
||||
"""Create the data augmentation for segmentation.
|
||||
|
||||
No cropping in this part, as cropping has to be done for the image and labels simultaneously.
|
||||
|
||||
Args:
|
||||
args (DotDict): arguments
|
||||
test (bool, optional): In eval mode? If False => train mode. Defaults to False.
|
||||
|
||||
Returns:
|
||||
list[torch.nn.Module]: list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
|
||||
if test:
|
||||
augs.append(ResizeUp(args.imsize))
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
else:
|
||||
augs.append(RandomResizedCrop(args.imsize, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
ColorJitter(
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
class QuickGaussBlur:
|
||||
"""Gaussian blur transformation using PIL ImageFilter."""
|
||||
|
||||
def __init__(self, sigma=(0.2, 2.0)):
|
||||
"""Create Gaussian blur operator.
|
||||
|
||||
Args:
|
||||
-----
|
||||
sigma : tuple[float, float]
|
||||
range of sigma for blur
|
||||
|
||||
"""
|
||||
self.sigma_min, self.sigma_max = sigma
|
||||
|
||||
def __call__(self, img):
|
||||
return img.filter(ImageFilter.GaussianBlur(radius=uniform(self.sigma_min, self.sigma_max)))
|
||||
|
||||
|
||||
class RemoveTransform:
|
||||
"""Remove data from transformation.
|
||||
|
||||
To use with default collate function.
|
||||
"""
|
||||
|
||||
def __call__(self, x, y=None):
|
||||
if y is None:
|
||||
return [1]
|
||||
return [1], y
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
def collate_imnet(data, image_key="image"):
|
||||
"""Collate function for imagenet(1k / 21k) with datadings.
|
||||
|
||||
Args:
|
||||
----
|
||||
data : list[dict[str, Any]]
|
||||
images for a batch
|
||||
|
||||
Returns:
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
images, labels
|
||||
|
||||
"""
|
||||
if isinstance(data[0][image_key], torch.Tensor):
|
||||
ims = torch.stack([d[image_key] for d in data], dim=0)
|
||||
else:
|
||||
ims = [d[image_key] for d in data]
|
||||
labels = torch.tensor([d["label"] for d in data])
|
||||
# keys = [d['key'] for d in data]
|
||||
return ims, labels # , keys
|
||||
|
||||
|
||||
def collate_listops(data):
|
||||
"""Collate function for ListOps with datadings.
|
||||
|
||||
Args:
|
||||
----
|
||||
data : list[tuple[torch.Tensor, torch.Tensor]]
|
||||
list of samples
|
||||
|
||||
Returns:
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
images, labels
|
||||
|
||||
"""
|
||||
return data[0][0], data[0][1]
|
||||
|
||||
|
||||
def no_param_transf(self, sample):
|
||||
"""Call transformation without extra parameter.
|
||||
|
||||
To use with datadings QuasiShuffler.
|
||||
|
||||
Args:
|
||||
----
|
||||
self : object
|
||||
use this as a method ( <obj>.<method_name> = MethodType(no_param_transf, <obj>) )
|
||||
sample : Any
|
||||
sample to transform
|
||||
|
||||
Returns:
|
||||
-------
|
||||
Any
|
||||
transformed sample
|
||||
|
||||
"""
|
||||
if isinstance(sample, tuple):
|
||||
# sample of type (name (str), data (bytes encoded))
|
||||
sample = sample[1]
|
||||
if isinstance(sample, bytes):
|
||||
# decode msgpack bytes
|
||||
sample = msgpack.loads(sample)
|
||||
params = self._rng(sample)
|
||||
for k, f in self._transforms.items():
|
||||
sample[k] = f(sample[k], params)
|
||||
return sample
|
||||
|
||||
|
||||
class ToOneHotSequence:
|
||||
"""Convert a sequence of grayscale values (range 0 to 1) to a one-hot encoded sequence based on 8-bit rounded values."""
|
||||
|
||||
def __call__(self, x, y=None):
|
||||
# x is 1 x 32 x 32
|
||||
x = (255 * x).round().to(torch.int64).view(-1)
|
||||
assert x.max() < 256, f"Found max value {x.max()} in {x}."
|
||||
x = torch.nn.functional.one_hot(x, num_classes=256).float()
|
||||
if y is None:
|
||||
return x
|
||||
return x, y
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
class ResizeUp(Resize):
|
||||
"""Resize up if image is smaller than target size."""
|
||||
|
||||
def forward(self, img):
|
||||
w, h = img.shape[-2], img.shape[-1]
|
||||
if w < self.size or h < self.size:
|
||||
img = super().forward(img)
|
||||
return img
|
||||
484
AAAI Supplementary Material/Model Training Code/data/fornet.py
Normal file
@@ -0,0 +1,484 @@
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from math import floor
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from datadings.torch import Compose
|
||||
from loguru import logger
|
||||
from PIL import Image, ImageFilter
|
||||
from torch.utils.data import Dataset, get_worker_info
|
||||
from torchvision import transforms as T
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from data.data_utils import apply_dense_transforms
|
||||
|
||||
|
||||
class ForNet(Dataset):
|
||||
"""Recombine ImageNet forgrounds and backgrounds.
|
||||
|
||||
Note:
|
||||
This dataset has exactly the ImageNet classes.
|
||||
|
||||
"""
|
||||
|
||||
_back_combs = ["same", "all", "original"]
|
||||
_bg_transforms = {T.RandomCrop, T.CenterCrop, T.Resize, T.RandomResizedCrop}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
transform=None,
|
||||
train=True,
|
||||
target_transform=None,
|
||||
background_combination="all",
|
||||
fg_scale_jitter=0.3,
|
||||
fg_transform=None,
|
||||
pruning_ratio=0.8,
|
||||
return_fg_masks=False,
|
||||
fg_size_mode="range",
|
||||
fg_bates_n=1,
|
||||
paste_pre_transform=True,
|
||||
mask_smoothing_sigma=4.0,
|
||||
rel_jut_out=0.0,
|
||||
fg_in_nonant=None,
|
||||
size_fact=1.0,
|
||||
orig_img_prob=0.0,
|
||||
orig_ds=None,
|
||||
_orig_ds_file_type="JPEG",
|
||||
epochs=0,
|
||||
_album_compose=False,
|
||||
):
|
||||
"""Create RecombinationNet dataset.
|
||||
|
||||
Args:
|
||||
root (str): Root folder for the dataset.
|
||||
transform (T.Collate | list, optional): Transform to apply to the image. Defaults to None.
|
||||
train (bool, optional): On the train set (False -> val set). Defaults to True.
|
||||
target_transform (T.Collate | list, optional): Transform to apply to the target values. Defaults to None.
|
||||
background_combination (str, optional): Which backgrounds to combine with foregrounds. Defaults to "same".
|
||||
fg_scale_jitter (tuple, optional): How much should the size of the foreground be changed (random ratio). Defaults to (0.1, 0.8).
|
||||
fg_transform (_type_, optional): Transform to apply to the foreground before applying to the background. This is supposed to be a random rotation, mainly. Defaults to None.
|
||||
pruning_ratio (float, optional): For pruning backgrounds, with (foreground size/background size) >= <pruning_ratio>. Backgrounds from images that contain very large foreground objects are mostly computer generated and therefore relatively unnatural. Defaults to full dataset.
|
||||
return_fg_masks (bool, optional): Return the foreground masks. Defaults to False.
|
||||
fg_size_mode (str, optional): How to determine the size of the foreground, based on the foreground sizes of the foreground and background images. Defaults to "max".
|
||||
fg_bates_n (int, optional): Bates parameter for the distribution of the object position in the foreground. Defaults to 1 (uniform distribution). The higher the value, the more likely the object is in the center. For fg_bates_n = 0, the object is always in the center.
|
||||
paste_pre_transform (bool, optional): Paste the foreground onto the background before applying the transform. If false, the background will be cropped and resized before pasting the foreground. Defaults to False.
|
||||
mask_smoothing_sigma (float, optional): Sigma for the Gaussian blur of the mask edge. Defaults to 0.0. Try 2.0 or 4.0?
|
||||
rel_jut_out (float, optional): How much is the foreground allowed to stand/jut out of the background (and then cut off). Defaults to 0.0.
|
||||
fg_in_nonant (int, optional): If not None, the foreground will be placed in a specific nonant (0-8) of the image. Defaults to None.
|
||||
size_fact (float, optional): Factor to multiply the size of the foreground with. Defaults to 1.0.
|
||||
orig_img_prob (float | str, optional): Probability to use the original image, instead of the fg-bg recombinations. Defaults to 0.0.
|
||||
orig_ds (Dataset, optional): Original dataset (without transforms) to use for the original images. Defaults to None.
|
||||
_orig_ds_file_type (str, optional): File type of the original dataset. Defaults to "JPEG".
|
||||
epochs (int, optional): Number of epochs to train on. Used for linear increase of orig_img_prob.
|
||||
|
||||
Note:
|
||||
For more information on the bates distribution, see https://en.wikipedia.org/wiki/Bates_distribution.
|
||||
For fg_bats_n < 0, we take extend the bates dirstribution to focus more and more on the edges. This is done by sampling B ~ Bates(|fg_bates_n|) and then passing through f(x) = x + 0.5 - floor(x + 0.5).
|
||||
|
||||
For the list of transformations that will be applied to the background only (if paste_pre_transform=False), see RecombinationNet._bg_transforms.
|
||||
|
||||
A nonant in this case refers to a square in a 3x3 grid dividing the image.
|
||||
|
||||
"""
|
||||
assert (
|
||||
background_combination in self._back_combs
|
||||
), f"background_combination={background_combination} is not supported. Use one of {self._back_combs}"
|
||||
|
||||
if (not os.path.exists(f"{root}/backgrounds_{'train' if train else 'val'}.zip")) and os.path.exists(
|
||||
os.path.join(root, "train" if train else "val", "backgrounds")
|
||||
):
|
||||
self._mode = "folder"
|
||||
else:
|
||||
self._mode = "zip"
|
||||
|
||||
if self._mode == "zip":
|
||||
try:
|
||||
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip:
|
||||
self.backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
|
||||
with zipfile.ZipFile(f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r") as fg_zip:
|
||||
self.foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
|
||||
except FileNotFoundError as e:
|
||||
logger.error(
|
||||
f"RecombinationNet: {e}. Make sure to have the background and foreground zips in the root"
|
||||
f" directory: found {os.listdir(root)}"
|
||||
)
|
||||
raise e
|
||||
classes = set([f.split("/")[-2] for f in self.foregrounds])
|
||||
else:
|
||||
logger.info("ForNet folder mode: loading classes")
|
||||
classes = set(os.listdir(os.path.join(root, "train" if train else "val", "foregrounds")))
|
||||
foregrounds = []
|
||||
backgrounds = []
|
||||
for cls in tqdm(classes, desc="Loading files"):
|
||||
foregrounds.extend(
|
||||
[
|
||||
f"{cls}/{f}"
|
||||
for f in os.listdir(os.path.join(root, "train" if train else "val", "foregrounds", cls))
|
||||
]
|
||||
)
|
||||
backgrounds.extend(
|
||||
[
|
||||
f"{cls}/{f}"
|
||||
for f in os.listdir(os.path.join(root, "train" if train else "val", "backgrounds", cls))
|
||||
]
|
||||
)
|
||||
self.foregrounds = foregrounds
|
||||
self.backgrounds = backgrounds
|
||||
|
||||
self.classes = sorted(list(classes), key=lambda x: int(x[1:]))
|
||||
|
||||
assert os.path.exists(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json"), (
|
||||
f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json not found, provide the information or set"
|
||||
" pruning_ratio=1.0"
|
||||
)
|
||||
with open(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json", "r") as f:
|
||||
self.fg_bg_ratios = json.load(f)
|
||||
if self._mode == "folder":
|
||||
self.fg_bg_ratios = {"/".join(key.split("/")[-2:]): val for key, val in self.fg_bg_ratios.items()}
|
||||
logger.info(f"Renamed fg_bg_ratios keys to {list(self.fg_bg_ratios.keys())[:3]}...")
|
||||
|
||||
if pruning_ratio <= 1.0:
|
||||
backup_backgrounds = {}
|
||||
for bg_file in self.backgrounds:
|
||||
bg_cls = bg_file.split("/")[-2]
|
||||
backup_backgrounds[bg_cls] = bg_file
|
||||
self.backgrounds = [
|
||||
bg for bg in self.backgrounds if bg in self.fg_bg_ratios and self.fg_bg_ratios[bg] < pruning_ratio
|
||||
]
|
||||
# logger.info(
|
||||
# f"RecombinationNet: keep {len(self.backgrounds)} of {len(self.fg_bg_ratios)} backgrounds with pr {pruning_ratio}"
|
||||
# )
|
||||
|
||||
self.root = root
|
||||
self.train = train
|
||||
self.background_combination = background_combination
|
||||
self.fg_scale_jitter = fg_scale_jitter
|
||||
self.fg_transform = fg_transform
|
||||
self.return_fg_masks = return_fg_masks
|
||||
self.paste_pre_transform = paste_pre_transform
|
||||
self.mask_smoothing_sigma = mask_smoothing_sigma
|
||||
self.rel_jut_out = rel_jut_out
|
||||
self.size_fact = size_fact
|
||||
self.fg_in_nonant = fg_in_nonant
|
||||
assert fg_in_nonant is None or -1 <= fg_in_nonant < 9, f"fg_in_nonant={fg_in_nonant} not in [0, 8] or None"
|
||||
|
||||
self.orig_img_prob = orig_img_prob
|
||||
if orig_img_prob != 0.0:
|
||||
assert (isinstance(orig_img_prob, float) and orig_img_prob > 0.0) or orig_img_prob in [
|
||||
"linear",
|
||||
"cos",
|
||||
"revlinear",
|
||||
]
|
||||
assert orig_ds is not None, "orig_ds must be provided if orig_img_prob > 0.0"
|
||||
assert not return_fg_masks, "can't provide fg masks for original images (yet)"
|
||||
assert os.path.exists(os.path.join(root, f"{'train' if train else 'val'}_indices.json")) or isinstance(
|
||||
orig_ds, str
|
||||
), f"{root}/{'train' if train else 'val'}_indices.json must be provided if orig_ds is a dataset"
|
||||
if not isinstance(orig_ds, str):
|
||||
with open(os.path.join(root, f"{'train' if train else 'val'}_indices.json"), "r") as f:
|
||||
self.key_to_orig_idx = json.load(f)
|
||||
else:
|
||||
if not (
|
||||
orig_ds.endswith("train" if train else "val") or orig_ds.endswith("train/" if train else "val/")
|
||||
):
|
||||
orig_ds = f"{orig_ds}/{'train' if train else 'val'}"
|
||||
self.key_to_orig_idx = None
|
||||
self._orig_ds_file_type = _orig_ds_file_type
|
||||
self.orig_ds = orig_ds
|
||||
self.epochs = epochs
|
||||
self._epoch = 0
|
||||
|
||||
assert fg_size_mode in [
|
||||
"max",
|
||||
"min",
|
||||
"mean",
|
||||
"range",
|
||||
], f"fg_size_mode={fg_size_mode} not supported; use one of ['max', 'min', 'mean', 'range']"
|
||||
self.fg_size_mode = fg_size_mode
|
||||
self.fg_bates_n = fg_bates_n
|
||||
|
||||
if not paste_pre_transform:
|
||||
if isinstance(transform, (T.Compose, Compose)):
|
||||
transform = transform.transforms
|
||||
|
||||
# do cropping and resizing mainly on background; paste foreground on top later
|
||||
self.bg_transform = []
|
||||
self.join_transform = []
|
||||
for tf in transform:
|
||||
if isinstance(tf, tuple(self._bg_transforms)):
|
||||
self.bg_transform.append(tf)
|
||||
else:
|
||||
self.join_transform.append(tf)
|
||||
|
||||
if _album_compose:
|
||||
from data.album_transf import AlbumTorchCompose
|
||||
|
||||
self.bg_transform = AlbumTorchCompose(self.bg_transform)
|
||||
self.join_transform = AlbumTorchCompose(self.join_transform)
|
||||
else:
|
||||
self.bg_transform = T.Compose(self.bg_transform)
|
||||
self.join_transform = T.Compose(self.join_transform)
|
||||
|
||||
else:
|
||||
if isinstance(transform, list):
|
||||
if _album_compose:
|
||||
from data.album_transf import AlbumTorchCompose
|
||||
|
||||
self.join_transform = AlbumTorchCompose(transform)
|
||||
else:
|
||||
self.join_transform = T.Compose(transform)
|
||||
else:
|
||||
self.join_transform = transform
|
||||
self.bg_transform = None
|
||||
|
||||
self.trgt_map = {cls: i for i, cls in enumerate(self.classes)}
|
||||
|
||||
self.target_transform = target_transform
|
||||
|
||||
self.cls_to_allowed_bg = {}
|
||||
for bg_file in self.backgrounds:
|
||||
if background_combination == "same":
|
||||
bg_cls = bg_file.split("/")[-2]
|
||||
if bg_cls not in self.cls_to_allowed_bg:
|
||||
self.cls_to_allowed_bg[bg_cls] = []
|
||||
self.cls_to_allowed_bg[bg_cls].append(bg_file)
|
||||
|
||||
if background_combination == "same":
|
||||
for cls_code in classes:
|
||||
if cls_code not in self.cls_to_allowed_bg or len(self.cls_to_allowed_bg[cls_code]) == 0:
|
||||
self.cls_to_allowed_bg[cls_code] = [backup_backgrounds[cls_code]]
|
||||
logger.warning(f"No background for class {cls_code}, using {backup_backgrounds[cls_code]}")
|
||||
|
||||
self._zf = {}
|
||||
|
||||
@property
|
||||
def epoch(self):
|
||||
return self._epoch
|
||||
|
||||
@epoch.setter
|
||||
def epoch(self, value):
|
||||
self._epoch = value
|
||||
|
||||
def __len__(self):
|
||||
"""Size of the dataset.
|
||||
|
||||
Returns:
|
||||
int: number of foregrounds
|
||||
|
||||
"""
|
||||
return len(self.foregrounds)
|
||||
|
||||
def num_classes(self):
|
||||
return len(self.classes)
|
||||
|
||||
def _get_fg(self, idx):
|
||||
worker_id = self._wrkr_info()
|
||||
|
||||
fg_file = self.foregrounds[idx]
|
||||
with self._zf[worker_id]["fg"].open(fg_file) as f:
|
||||
fg_data = BytesIO(f.read())
|
||||
return Image.open(fg_data)
|
||||
|
||||
def _wrkr_info(self):
|
||||
worker_id = get_worker_info().id if get_worker_info() else 0
|
||||
|
||||
if worker_id not in self._zf and self._mode == "zip":
|
||||
self._zf[worker_id] = {
|
||||
"bg": zipfile.ZipFile(f"{self.root}/backgrounds_{'train' if self.train else 'val'}.zip", "r"),
|
||||
"fg": zipfile.ZipFile(f"{self.root}/foregrounds_{'train' if self.train else 'val'}.zip", "r"),
|
||||
}
|
||||
return worker_id
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get the foreground at index idx and combine it with a (random) background.
|
||||
|
||||
Args:
|
||||
idx (int): foreground index
|
||||
|
||||
Returns:
|
||||
torch.Tensor, torch.Tensor: image, target
|
||||
|
||||
"""
|
||||
worker_id = self._wrkr_info()
|
||||
fg_file = self.foregrounds[idx]
|
||||
trgt_cls = fg_file.split("/")[-2]
|
||||
|
||||
if (
|
||||
(self.orig_img_prob == "linear" and np.random.rand() < self._epoch / self.epochs)
|
||||
or (self.orig_img_prob == "revlinear" and np.random.rand() < (self._epoch - self.epochs) / self.epochs)
|
||||
or (self.orig_img_prob == "cos" and np.random.rand() > np.cos(np.pi * self._epoch / (2 * self.epochs)))
|
||||
or (
|
||||
isinstance(self.orig_img_prob, float)
|
||||
and self.orig_img_prob > 0.0
|
||||
and np.random.rand() < self.orig_img_prob
|
||||
)
|
||||
):
|
||||
data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
|
||||
if isinstance(self.orig_ds, str):
|
||||
image_file = os.path.join(self.orig_ds, f"{data_key}.{self._orig_ds_file_type}")
|
||||
orig_img = Image.open(image_file).convert("RGB")
|
||||
else:
|
||||
orig_data = self.orig_ds[self.key_to_orig_idx[data_key]]
|
||||
orig_img = orig_data["image"] if isinstance(orig_data, dict) else orig_data[0]
|
||||
|
||||
if self.bg_transform:
|
||||
orig_img = self.bg_transform(orig_img)
|
||||
if self.join_transform:
|
||||
orig_img = self.join_transform(orig_img)
|
||||
trgt = self.trgt_map[trgt_cls]
|
||||
if self.target_transform:
|
||||
trgt = self.target_transform(trgt)
|
||||
return orig_img, trgt
|
||||
|
||||
if self._mode == "zip":
|
||||
with self._zf[worker_id]["fg"].open(fg_file) as f:
|
||||
fg_data = BytesIO(f.read())
|
||||
try:
|
||||
fg_img = Image.open(fg_data).convert("RGBA")
|
||||
except PIL.UnidentifiedImageError as e:
|
||||
logger.error(f"Error with idx={idx}, file={self.foregrounds[idx]}")
|
||||
raise e
|
||||
else:
|
||||
# data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
|
||||
fg_img = Image.open(
|
||||
os.path.join(self.root, "train" if self.train else "val", "foregrounds", fg_file)
|
||||
).convert("RGBA")
|
||||
|
||||
if self.fg_transform:
|
||||
fg_img = self.fg_transform(fg_img)
|
||||
fg_size_factor = T.ToTensor()(fg_img.split()[-1]).mean().item()
|
||||
|
||||
if self.background_combination == "all":
|
||||
bg_idx = np.random.randint(len(self.backgrounds))
|
||||
bg_file = self.backgrounds[bg_idx]
|
||||
elif self.background_combination == "original":
|
||||
bg_file = fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")
|
||||
else:
|
||||
bg_idx = np.random.randint(len(self.cls_to_allowed_bg[trgt_cls]))
|
||||
bg_file = self.cls_to_allowed_bg[trgt_cls][bg_idx]
|
||||
|
||||
if self._mode == "zip":
|
||||
with self._zf[worker_id]["bg"].open(bg_file) as f:
|
||||
bg_data = BytesIO(f.read())
|
||||
bg_img = Image.open(bg_data).convert("RGB")
|
||||
else:
|
||||
bg_img = Image.open(
|
||||
os.path.join(self.root, "train" if self.train else "val", "backgrounds", bg_file)
|
||||
).convert("RGB")
|
||||
|
||||
if not self.paste_pre_transform:
|
||||
bg_img = self.bg_transform(bg_img)
|
||||
|
||||
bg_size = bg_img.size
|
||||
|
||||
# choose scale factor, such that relative area is in fg_scale
|
||||
bg_area = bg_size[0] * bg_size[1]
|
||||
if self.fg_in_nonant is not None:
|
||||
bg_area = bg_area / 9
|
||||
|
||||
# logger.info(f"background: size={bg_size} area={bg_area}")
|
||||
# logger.info(f"fg_file={fg_file}, fg_bg_ratio_keys={list(self.fg_bg_ratios.keys())[:3]}...")
|
||||
orig_fg_ratio = self.fg_bg_ratios[fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")]
|
||||
bg_fg_ratio = self.fg_bg_ratios[bg_file]
|
||||
|
||||
if self.fg_size_mode == "max":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
|
||||
elif self.fg_size_mode == "min":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = min(orig_fg_ratio, bg_fg_ratio)
|
||||
elif self.fg_size_mode == "mean":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = (orig_fg_ratio + bg_fg_ratio) / 2
|
||||
else:
|
||||
# range
|
||||
goal_fg_ratio_lower = min(orig_fg_ratio, bg_fg_ratio)
|
||||
goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
|
||||
|
||||
# logger.info(f"fg_bg_ratio={orig_fg_ratio}")
|
||||
fg_scale = (
|
||||
np.random.uniform(
|
||||
goal_fg_ratio_lower * (1 - self.fg_scale_jitter), goal_fg_ratio_upper * (1 + self.fg_scale_jitter)
|
||||
)
|
||||
/ fg_size_factor
|
||||
* self.size_fact
|
||||
)
|
||||
|
||||
goal_shape_y = round(np.sqrt(bg_area * fg_scale * fg_img.size[1] / fg_img.size[0]))
|
||||
goal_shape_x = round(np.sqrt(bg_area * fg_scale * fg_img.size[0] / fg_img.size[1]))
|
||||
|
||||
fg_img = fg_img.resize((goal_shape_x, goal_shape_y))
|
||||
|
||||
if fg_img.size[0] > bg_size[0] or fg_img.size[1] > bg_size[1]:
|
||||
# random crop to fit
|
||||
goal_w, goal_h = (min(fg_img.size[0], bg_size[0]), min(fg_img.size[1], bg_size[1]))
|
||||
fg_img = T.RandomCrop((goal_h, goal_w))(fg_img) if self.train else T.CenterCrop((goal_h, goal_w))(fg_img)
|
||||
|
||||
# paste fg on bg
|
||||
z1, z2 = (
|
||||
(
|
||||
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(), # bates distribution n=1 => uniform
|
||||
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(),
|
||||
)
|
||||
if self.fg_bates_n != 0
|
||||
else (0.5, 0.5)
|
||||
)
|
||||
if self.fg_bates_n < 0:
|
||||
z1 = z1 + 0.5 - floor(z1 + 0.5)
|
||||
z2 = z2 + 0.5 - floor(z2 + 0.5)
|
||||
|
||||
x_min = -self.rel_jut_out * fg_img.size[0]
|
||||
x_max = bg_size[0] - fg_img.size[0] * (1 - self.rel_jut_out)
|
||||
y_min = -self.rel_jut_out * fg_img.size[1]
|
||||
y_max = bg_size[1] - fg_img.size[1] * (1 - self.rel_jut_out)
|
||||
|
||||
if self.fg_in_nonant is not None and self.fg_in_nonant >= 0:
|
||||
x_min = (self.fg_in_nonant % 3) * bg_size[0] / 3
|
||||
x_max = ((self.fg_in_nonant % 3) + 1) * bg_size[0] / 3 - fg_img.size[0]
|
||||
y_min = (self.fg_in_nonant // 3) * bg_size[1] / 3
|
||||
y_max = ((self.fg_in_nonant // 3) + 1) * bg_size[1] / 3 - fg_img.size[1]
|
||||
|
||||
if x_min > x_max:
|
||||
x_min = x_max = (x_min + x_max) / 2
|
||||
if y_min > y_max:
|
||||
y_min = y_max = (y_min + y_max) / 2
|
||||
|
||||
offs_x = round(z1 * (x_max - x_min) + x_min)
|
||||
offs_y = round(z2 * (y_max - y_min) + y_min)
|
||||
|
||||
paste_mask = fg_img.split()[-1]
|
||||
if self.mask_smoothing_sigma > 0.0:
|
||||
sigma = (np.random.rand() * 0.9 + 0.1) * self.mask_smoothing_sigma
|
||||
paste_mask = paste_mask.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
paste_mask = paste_mask.point(lambda p: 2 * p - 255 if p > 128 else 0)
|
||||
|
||||
bg_img.paste(fg_img.convert("RGB"), (offs_x, offs_y), paste_mask)
|
||||
bg_img = bg_img.convert("RGB")
|
||||
|
||||
if self.return_fg_masks:
|
||||
fg_mask = Image.new("L", bg_size, 0)
|
||||
fg_mask.paste(paste_mask, (offs_x, offs_y))
|
||||
|
||||
fg_mask = T.ToTensor()(fg_mask)[0]
|
||||
|
||||
bg_img = T.ToTensor()(bg_img)
|
||||
|
||||
if self.join_transform:
|
||||
# img_mask_stack = torch.cat([bg_img, fg_mask.unsqueeze(0)], dim=0)
|
||||
# img_mask_stack = self.join_transform(img_mask_stack)
|
||||
# bg_img, fg_mask = img_mask_stack[:-1], img_mask_stack[-1]
|
||||
bg_img, fg_mask = apply_dense_transforms(bg_img, fg_mask, self.join_transform)
|
||||
else:
|
||||
bg_img = self.join_transform(bg_img)
|
||||
|
||||
if trgt_cls not in self.trgt_map:
|
||||
raise ValueError(f"trgt_cls={trgt_cls} not in trgt_map: {self.trgt_map}")
|
||||
trgt = self.trgt_map[trgt_cls]
|
||||
if self.target_transform:
|
||||
trgt = self.target_transform(trgt)
|
||||
|
||||
if self.return_fg_masks:
|
||||
return bg_img, trgt, fg_mask
|
||||
|
||||
return bg_img, trgt
|
||||
@@ -0,0 +1,151 @@
|
||||
import argparse
|
||||
import shutil
|
||||
import zipfile
|
||||
from os import listdir, makedirs, path
|
||||
from random import choice
|
||||
|
||||
from datadings.reader import MsgpackReader
|
||||
from datadings.writer import FileWriter
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-tiny_imagenet_zip", type=str, required=True, help="Path to the Tiny ImageNet zip file")
|
||||
parser.add_argument("-output_dir", type=str, required=True, help="Directory to extract the image names to")
|
||||
parser.add_argument("-in_segment_dir", type=str, required=True, help="Directory that holds the segmented ImageNet")
|
||||
parser.add_argument(
|
||||
"-imagenet_path", type=str, nargs="?", required=True, help="Path to the original ImageNet dataset (datadings)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
images = {"train": set(), "val": set()}
|
||||
|
||||
with zipfile.ZipFile(args.tiny_imagenet_zip, "r") as zip_ref:
|
||||
for info in tqdm(zip_ref.infolist(), desc="Gathering Images"):
|
||||
if info.filename.endswith(".JPEG"):
|
||||
if "/val/" in info.filename:
|
||||
images["val"].add(info.filename.split("/")[-1])
|
||||
elif "/train/" in info.filename:
|
||||
images["train"].add(info.filename.split("/")[-1])
|
||||
|
||||
with open(path.join(args.output_dir, "tiny_imagenet_train_images.txt"), "w+") as f:
|
||||
f.write("\n".join(images["train"]))
|
||||
with open(path.join(args.output_dir, "tiny_imagenet_val_images.txt"), "w+") as f:
|
||||
f.write("\n".join(images["val"]))
|
||||
|
||||
print(f"Found {len(images['train'])} training images and {len(images['val'])} validation images")
|
||||
classes = {img_name.split("_")[0] for img_name in images["train"]}
|
||||
|
||||
classes = sorted(list(classes), key=lambda x: int(x[1:]))
|
||||
assert len(classes) == 200, f"Expected 200 classes, found {len(classes)}"
|
||||
assert len(images["train"]) == len(classes) * 500, f"Expected 100000 training images, found {len(images['train'])}"
|
||||
assert len(images["val"]) == len(classes) * 50, f"Expected 10000 validation images, found {len(images['val'])}"
|
||||
with open(path.join(args.output_dir, "tiny_imagenet_classes.txt"), "w+") as f:
|
||||
f.write("\n".join(classes))
|
||||
|
||||
# copy over the relevant images
|
||||
for split in ["train", "val"]:
|
||||
ipc = 500 if split == "train" else 50
|
||||
part = "foregrounds_WEBP"
|
||||
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
|
||||
for synset in classes:
|
||||
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
|
||||
tqdm.write(
|
||||
f"skip {split} > {part} > {synset} with"
|
||||
f" {len(listdir(path.join(args.output_dir, split, part, synset)))} ims"
|
||||
)
|
||||
pbar.update(ipc)
|
||||
continue
|
||||
for img in listdir(path.join(args.in_segment_dir, split, part, synset)):
|
||||
orig_name = (
|
||||
img.split(".")[0] + ".JPEG"
|
||||
if split == "train"
|
||||
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
|
||||
)
|
||||
if orig_name in images[split]:
|
||||
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
shutil.copy(
|
||||
path.join(args.in_segment_dir, split, part, synset, img),
|
||||
path.join(args.output_dir, split, part, synset, img),
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
while len(listdir(path.join(args.output_dir, split, part, synset))) < min(
|
||||
ipc, len(listdir(path.join(args.in_segment_dir, split, part, synset)))
|
||||
):
|
||||
# copy over more random images
|
||||
image_names = [
|
||||
(
|
||||
img,
|
||||
(
|
||||
img.split(".")[0] + ".JPEG"
|
||||
if split == "train"
|
||||
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
|
||||
),
|
||||
)
|
||||
for img in listdir(path.join(args.in_segment_dir, split, part, synset))
|
||||
]
|
||||
image_names = [
|
||||
img for img in image_names if img[1] not in listdir(path.join(args.output_dir, split, part, synset))
|
||||
]
|
||||
img = choice(image_names)[0]
|
||||
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
shutil.copy(
|
||||
path.join(args.in_segment_dir, split, part, synset, img),
|
||||
path.join(args.output_dir, split, part, synset, img),
|
||||
)
|
||||
pbar.update(1)
|
||||
tqdm.write(f"Extra image: {orig_name} to {split}/{part}/{synset}")
|
||||
|
||||
# copy over the background images corresponding to those foregrounds
|
||||
part = "backgrounds_JPEG"
|
||||
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
|
||||
for synset in classes:
|
||||
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
|
||||
tqdm.write(f"skip {split} > {part} > {synset}")
|
||||
pbar.update(ipc)
|
||||
continue
|
||||
for img in listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset)):
|
||||
bg_name = img.replace(".WEBP", ".JPEG")
|
||||
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
shutil.copy(
|
||||
path.join(args.in_segment_dir, split, part, synset, bg_name),
|
||||
path.join(args.output_dir, split, part, synset, bg_name),
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
assert len(listdir(path.join(args.output_dir, split, part, synset))) == len(
|
||||
listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset))
|
||||
), (
|
||||
f"Expected {len(listdir(path.join(args.output_dir, split, 'foregrounds_WEBP', synset)))} background"
|
||||
f" images, found {len(listdir(path.join(args.output_dir, split, part, synset)))}"
|
||||
)
|
||||
|
||||
# write the original dataset to datadings
|
||||
for part in ["train", "val"]:
|
||||
reader = MsgpackReader(path.join(args.imagenet_path, f"{part}.msgpack"))
|
||||
with FileWriter(path.join(args.output_dir, f"TinyIN_{part}.msgpack")) as writer:
|
||||
for data in tqdm(reader, desc=f"Writing {part} to datadings"):
|
||||
key = data["key"].split("/")[-1]
|
||||
allowed_synsets = [key.split("_")[0]] if part == "train" else classes
|
||||
|
||||
if part == "train" and allowed_synsets[0] not in classes:
|
||||
continue
|
||||
|
||||
keep_image = False
|
||||
label_synset = None
|
||||
for synset in allowed_synsets:
|
||||
for img in listdir(path.join(args.output_dir, part, "foregrounds_WEBP", synset)):
|
||||
if img.split(".")[0] == key.split(".")[0]:
|
||||
keep_image = True
|
||||
label_synset = synset
|
||||
break
|
||||
|
||||
if not keep_image:
|
||||
continue
|
||||
|
||||
data["label"] = classes.index(label_synset)
|
||||
|
||||
writer.write(data)
|
||||
@@ -0,0 +1,48 @@
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
parser = argparse.ArgumentParser("Script to convert ImageNet trained models to ImageNet-9")
|
||||
parser.add_argument("-m", "--model", type=str, required=True, help="Model weights (.pt file).")
|
||||
parser.add_argument(
|
||||
"--in_to_in9", type=str, default="/ds-sds/images/ImageNet-9/in_to_in9.json", help="Path to in_to_in9.json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
checkpoint = torch.load(args.model, map_location="cpu")
|
||||
|
||||
model_state = checkpoint["model_state"]
|
||||
head_keys = [k for k in model_state.keys() if ".head." in k or ".fc." in k]
|
||||
print("weights that will be modified:", head_keys)
|
||||
assert len(head_keys) > 0, "no head keys found :("
|
||||
|
||||
with open(args.in_to_in9, "r") as f:
|
||||
in_to_in9_classes = json.load(f)
|
||||
print(f"{len([k for k, v in in_to_in9_classes.items() if v == -1])} classes get mapped to -1")
|
||||
|
||||
print("map", len(in_to_in9_classes), " classes to", set(in_to_in9_classes.values()))
|
||||
|
||||
print("Building conversion matrix")
|
||||
conversion_matrix = torch.zeros((9, 1000))
|
||||
for in_idx, in9_idx in in_to_in9_classes.items():
|
||||
if in9_idx == -1:
|
||||
continue
|
||||
in_idx = int(in_idx)
|
||||
conversion_matrix[in9_idx, in_idx] = 1
|
||||
print(f"Conversion matrix ({conversion_matrix.shape}) has {int(conversion_matrix.sum().item())} non-zero values")
|
||||
|
||||
for head_key in head_keys:
|
||||
print(f"converting {head_key} of shape {model_state[head_key].shape}", end=" ")
|
||||
model_state[head_key] = conversion_matrix @ model_state[head_key]
|
||||
print(f"\tto {model_state[head_key].shape}")
|
||||
|
||||
checkpoint["model_state"] = model_state
|
||||
checkpoint["args"]["n_classes"] = 9
|
||||
save_folder = os.path.dirname(args.model)
|
||||
orig_model_name = args.model.split(os.sep)[-1]
|
||||
new_model_name = ".".join(orig_model_name.split(".")[:-1]) + "_to_in9." + orig_model_name.split(".")[-1]
|
||||
print(f"saving model as {new_model_name} in {save_folder}")
|
||||
torch.save(checkpoint, os.path.join(save_folder, new_model_name))
|
||||
@@ -0,0 +1,200 @@
|
||||
n07695742: pretzel
|
||||
n03902125: pay-phone, pay-station
|
||||
n03980874: poncho
|
||||
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
|
||||
n02730930: apron
|
||||
n02699494: altar
|
||||
n03201208: dining table, board
|
||||
n02056570: king penguin, Aptenodytes patagonica
|
||||
n04099969: rocking chair, rocker
|
||||
n04366367: suspension bridge
|
||||
n04067472: reel
|
||||
n02808440: bathtub, bathing tub, bath, tub
|
||||
n04540053: volleyball
|
||||
n02403003: ox
|
||||
n03100240: convertible
|
||||
n04562935: water tower
|
||||
n02788148: bannister, banister, balustrade, balusters, handrail
|
||||
n02988304: CD player
|
||||
n02423022: gazelle
|
||||
n03637318: lampshade, lamp shade
|
||||
n01774384: black widow, Latrodectus mactans
|
||||
n01768244: trilobite
|
||||
n07614500: ice cream, icecream
|
||||
n04254777: sock
|
||||
n02085620: Chihuahua
|
||||
n01443537: goldfish, Carassius auratus
|
||||
n01629819: European fire salamander, Salamandra salamandra
|
||||
n02099601: golden retriever
|
||||
n02321529: sea cucumber, holothurian
|
||||
n03837869: obelisk
|
||||
n02002724: black stork, Ciconia nigra
|
||||
n02841315: binoculars, field glasses, opera glasses
|
||||
n04560804: water jug
|
||||
n02364673: guinea pig, Cavia cobaya
|
||||
n03706229: magnetic compass
|
||||
n09256479: coral reef
|
||||
n09332890: lakeside, lakeshore
|
||||
n03544143: hourglass
|
||||
n02124075: Egyptian cat
|
||||
n02948072: candle, taper, wax light
|
||||
n01950731: sea slug, nudibranch
|
||||
n02791270: barbershop
|
||||
n03179701: desk
|
||||
n02190166: fly
|
||||
n04275548: spider web, spider's web
|
||||
n04417672: thatch, thatched roof
|
||||
n03930313: picket fence, paling
|
||||
n02236044: mantis, mantid
|
||||
n03976657: pole
|
||||
n01774750: tarantula
|
||||
n04376876: syringe
|
||||
n04133789: sandal
|
||||
n02099712: Labrador retriever
|
||||
n04532670: viaduct
|
||||
n04487081: trolleybus, trolley coach, trackless trolley
|
||||
n09428293: seashore, coast, seacoast, sea-coast
|
||||
n03160309: dam, dike, dyke
|
||||
n03250847: drumstick
|
||||
n02843684: birdhouse
|
||||
n07768694: pomegranate
|
||||
n03670208: limousine, limo
|
||||
n03085013: computer keyboard, keypad
|
||||
n02892201: brass, memorial tablet, plaque
|
||||
n02233338: cockroach, roach
|
||||
n03649909: lawn mower, mower
|
||||
n03388043: fountain
|
||||
n02917067: bullet train, bullet
|
||||
n02486410: baboon
|
||||
n04596742: wok
|
||||
n03255030: dumbbell
|
||||
n03937543: pill bottle
|
||||
n02113799: standard poodle
|
||||
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
|
||||
n02906734: broom
|
||||
n07920052: espresso
|
||||
n01698640: American alligator, Alligator mississipiensis
|
||||
n02123394: Persian cat
|
||||
n03424325: gasmask, respirator, gas helmet
|
||||
n02129165: lion, king of beasts, Panthera leo
|
||||
n04008634: projectile, missile
|
||||
n03042490: cliff dwelling
|
||||
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
|
||||
n02815834: beaker
|
||||
n02395406: hog, pig, grunter, squealer, Sus scrofa
|
||||
n01784675: centipede
|
||||
n03126707: crane
|
||||
n04399382: teddy, teddy bear
|
||||
n07875152: potpie
|
||||
n03733131: maypole
|
||||
n02802426: basketball
|
||||
n03891332: parking meter
|
||||
n01910747: jellyfish
|
||||
n03838899: oboe, hautboy, hautbois
|
||||
n03770439: miniskirt, mini
|
||||
n02281406: sulphur butterfly, sulfur butterfly
|
||||
n03970156: plunger, plumber's helper
|
||||
n09246464: cliff, drop, drop-off
|
||||
n02206856: bee
|
||||
n02074367: dugong, Dugong dugon
|
||||
n03584254: iPod
|
||||
n04179913: sewing machine
|
||||
n04328186: stopwatch, stop watch
|
||||
n07583066: guacamole
|
||||
n01917289: brain coral
|
||||
n03447447: gondola
|
||||
n02823428: beer bottle
|
||||
n03854065: organ, pipe organ
|
||||
n02793495: barn
|
||||
n04285008: sports car, sport car
|
||||
n02231487: walking stick, walkingstick, stick insect
|
||||
n04465501: tractor
|
||||
n02814860: beacon, lighthouse, beacon light, pharos
|
||||
n02883205: bow tie, bow-tie, bowtie
|
||||
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
|
||||
n04149813: scoreboard
|
||||
n04023962: punching bag, punch bag, punching ball, punchball
|
||||
n02226429: grasshopper, hopper
|
||||
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
|
||||
n02669723: academic gown, academic robe, judge's robe
|
||||
n04486054: triumphal arch
|
||||
n04070727: refrigerator, icebox
|
||||
n03444034: go-kart
|
||||
n02666196: abacus
|
||||
n01945685: slug
|
||||
n04251144: snorkel
|
||||
n03617480: kimono
|
||||
n03599486: jinrikisha, ricksha, rickshaw
|
||||
n02437312: Arabian camel, dromedary, Camelus dromedarius
|
||||
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
|
||||
n04118538: rugby ball
|
||||
n01770393: scorpion
|
||||
n04356056: sunglasses, dark glasses, shades
|
||||
n03804744: nail
|
||||
n02132136: brown bear, bruin, Ursus arctos
|
||||
n03400231: frying pan, frypan, skillet
|
||||
n03983396: pop bottle, soda bottle
|
||||
n07734744: mushroom
|
||||
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
|
||||
n02410509: bison
|
||||
n03404251: fur coat
|
||||
n04456115: torch
|
||||
n02123045: tabby, tabby cat
|
||||
n03026506: Christmas stocking
|
||||
n07715103: cauliflower
|
||||
n04398044: teapot
|
||||
n02927161: butcher shop, meat market
|
||||
n07749582: lemon
|
||||
n07615774: ice lolly, lolly, lollipop, popsicle
|
||||
n02795169: barrel, cask
|
||||
n04532106: vestment
|
||||
n02837789: bikini, two-piece
|
||||
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
|
||||
n04265275: space heater
|
||||
n02481823: chimpanzee, chimp, Pan troglodytes
|
||||
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
|
||||
n06596364: comic book
|
||||
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
|
||||
n02504458: African elephant, Loxodonta africana
|
||||
n03014705: chest
|
||||
n01944390: snail
|
||||
n04146614: school bus
|
||||
n01641577: bullfrog, Rana catesbeiana
|
||||
n07720875: bell pepper
|
||||
n02999410: chain
|
||||
n01855672: goose
|
||||
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
|
||||
n07753592: banana
|
||||
n07871810: meat loaf, meatloaf
|
||||
n04501370: turnstile
|
||||
n04311004: steel arch bridge
|
||||
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
|
||||
n04074963: remote control, remote
|
||||
n03662601: lifeboat
|
||||
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
|
||||
n03089624: confectionery, confectionary, candy store
|
||||
n04259630: sombrero
|
||||
n03393912: freight car
|
||||
n04597913: wooden spoon
|
||||
n07711569: mashed potato
|
||||
n03355925: flagpole, flagstaff
|
||||
n02963159: cardigan
|
||||
n07579787: plate
|
||||
n02950826: cannon
|
||||
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
|
||||
n02094433: Yorkshire terrier
|
||||
n02909870: bucket, pail
|
||||
n02058221: albatross, mollymawk
|
||||
n01742172: boa constrictor, Constrictor constrictor
|
||||
n09193705: alp
|
||||
n04371430: swimming trunks, bathing trunks
|
||||
n07747607: orange
|
||||
n03814639: neck brace
|
||||
n04507155: umbrella
|
||||
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
|
||||
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
|
||||
n03763968: military uniform
|
||||
n07873807: pizza, pizza pie
|
||||
n03992509: potter's wheel
|
||||
n03796401: moving van
|
||||
n12267677: acorn
|
||||
@@ -0,0 +1,73 @@
|
||||
# Repeat Augment sampler taken from DeiT: https://github.com/facebookresearch/deit/blob/main/samplers.py
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class RASampler(torch.utils.data.Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation.
|
||||
|
||||
It ensures that different each augmented version of a sample will be visible to a
|
||||
different process (GPU)
|
||||
Heavily based on torch.utils.data.DistributedSampler
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
if num_repeats < 1:
|
||||
raise ValueError("num_repeats should be greater than 0")
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.num_repeats = num_repeats
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
|
||||
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(self.dataset), generator=g)
|
||||
else:
|
||||
indices = torch.arange(start=0, end=len(self.dataset))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
|
||||
padding_size: int = self.total_size - len(indices)
|
||||
if padding_size > 0:
|
||||
indices += indices[:padding_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices[: self.num_selected_samples])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_selected_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"{type(self).__name__}(num_replicas: {self.num_replicas}, rank: {self.rank}, num_repeats:"
|
||||
f" {self.num_repeats}, epoch: {self.epoch}, num_samples: {self.num_samples}, total_size: {self.total_size},"
|
||||
f" num_selected_samples: {self.num_selected_samples}, shuffle: {self.shuffle})"
|
||||
)
|
||||
@@ -0,0 +1,381 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from copy import copy
|
||||
|
||||
from loguru import logger
|
||||
from nltk.corpus import wordnet as wn
|
||||
|
||||
|
||||
class bcolors:
|
||||
HEADER = "\033[95m"
|
||||
OKBLUE = "\033[94m"
|
||||
OKCYAN = "\033[96m"
|
||||
OKGREEN = "\033[92m"
|
||||
WARNING = "\033[93m"
|
||||
FAIL = "\033[91m"
|
||||
ENDC = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
UNDERLINE = "\033[4m"
|
||||
|
||||
|
||||
def _lemmas_str(synset):
|
||||
return ", ".join([lemma.name() for lemma in synset.lemmas()])
|
||||
|
||||
|
||||
class WNEntry:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
id: int,
|
||||
lemmas: str,
|
||||
parent_id: int,
|
||||
depth: int = None,
|
||||
in_image_net: bool = False,
|
||||
child_ids: list = None,
|
||||
in_main_tree: bool = True,
|
||||
_n_images: int = 0,
|
||||
_description: str = None,
|
||||
_name: str = None,
|
||||
_pruned: bool = False,
|
||||
):
|
||||
self.name = name
|
||||
self.id = id
|
||||
self.lemmas = lemmas
|
||||
self.parent_id = parent_id
|
||||
self.depth = depth
|
||||
self.in_image_net = in_image_net
|
||||
self.child_ids = child_ids
|
||||
self.in_main_tree = in_main_tree
|
||||
self._n_images = _n_images
|
||||
self._description = _description
|
||||
self._name = _name
|
||||
self._pruned = _pruned
|
||||
|
||||
def __str__(self, tree=None, accumulate=True):
|
||||
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}" if self.in_image_net else f"{bcolors.FAIL}-{bcolors.ENDC}"
|
||||
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
|
||||
if self.child_ids is None or tree is None:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
||||
else:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
||||
["\n ".join(tree.nodes[child_id].__str__(tree).split("\n")) for child_id in self.child_ids]
|
||||
)
|
||||
|
||||
def tree_diff(self, tree_1, tree_2):
|
||||
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
|
||||
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
|
||||
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
|
||||
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
|
||||
else:
|
||||
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
|
||||
n_ims = (
|
||||
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
|
||||
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
|
||||
)
|
||||
|
||||
if self.child_ids is None:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
||||
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
||||
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
|
||||
)
|
||||
|
||||
def prune(self, tree):
|
||||
if self._pruned:
|
||||
return
|
||||
|
||||
if self.child_ids is not None:
|
||||
for child_id in self.child_ids:
|
||||
tree[child_id].prune(tree)
|
||||
|
||||
self._pruned = True
|
||||
parent_node = tree.nodes[self.parent_id]
|
||||
try:
|
||||
parent_node.child_ids.remove(self.id)
|
||||
except ValueError as e:
|
||||
print(
|
||||
f"Error removing {self.name} from"
|
||||
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
|
||||
)
|
||||
while parent_node._pruned:
|
||||
parent_node = tree.nodes[parent_node.parent_id]
|
||||
parent_node._n_images += self._n_images
|
||||
self._n_images = 0
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
if not self._description:
|
||||
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def print_name(self):
|
||||
return self.name.split(".")[0]
|
||||
|
||||
def get_branch(self, tree=None):
|
||||
if self.parent_id is None or tree is None:
|
||||
return self.print_name
|
||||
|
||||
parent = tree.nodes[self.parent_id]
|
||||
return parent.get_branch(tree) + " > " + self.print_name
|
||||
|
||||
def get_branch_list(self, tree):
|
||||
if self.parent_id is None:
|
||||
return [self]
|
||||
parent = tree.nodes[self.parent_id]
|
||||
return parent.get_branch_list(tree) + [self]
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"id": self.id,
|
||||
"lemmas": self.lemmas,
|
||||
"parent_id": self.parent_id,
|
||||
"depth": self.depth,
|
||||
"in_image_net": self.in_image_net,
|
||||
"child_ids": self.child_ids,
|
||||
"in_main_tree": self.in_main_tree,
|
||||
"_n_images": self._n_images,
|
||||
"_description": self._description,
|
||||
"_name": self._name,
|
||||
"_pruned": self._pruned,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(**d)
|
||||
|
||||
def n_images(self, tree=None):
|
||||
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
|
||||
return self._n_images
|
||||
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
|
||||
|
||||
def n_children(self, tree=None):
|
||||
if self.child_ids is None:
|
||||
return 0
|
||||
if tree is None or len(self.child_ids) == 0:
|
||||
return len(self.child_ids)
|
||||
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
|
||||
|
||||
def get_examples(self, tree, n_examples=3):
|
||||
if self.child_ids is None or len(self.child_ids) == 0:
|
||||
return ""
|
||||
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
|
||||
max_images = max(child_images.values())
|
||||
if max_images == 0:
|
||||
# go on number of child nodes
|
||||
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
|
||||
# sorted childids by number of images
|
||||
top_children = [
|
||||
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
top_children = top_children[: min(n_examples, len(top_children))]
|
||||
return ", ".join(
|
||||
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
|
||||
)
|
||||
|
||||
|
||||
class WNTree:
|
||||
def __init__(self, root=1740, nodes=None):
|
||||
if isinstance(root, int):
|
||||
root_id = root
|
||||
root_synset = wn.synset_from_pos_and_offset("n", root)
|
||||
root_node = WNEntry(
|
||||
root_synset.name(),
|
||||
root_id,
|
||||
_lemmas_str(root_synset),
|
||||
parent_id=None,
|
||||
depth=0,
|
||||
)
|
||||
else:
|
||||
assert isinstance(root, WNEntry)
|
||||
root_id = root.id
|
||||
root_node = root
|
||||
|
||||
self.root = root_node
|
||||
self.nodes = {root_id: self.root} if nodes is None else nodes
|
||||
self.parentless = []
|
||||
self.label_index = None
|
||||
self.pruned = set()
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"root": self.root.to_dict(),
|
||||
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
||||
"parentless": self.parentless,
|
||||
"pruned": list(self.pruned),
|
||||
}
|
||||
|
||||
def prune(self, min_images):
|
||||
pruned_nodes = set()
|
||||
|
||||
# prune all nodes that have fewer than min_images below them
|
||||
for node_id, node in self.nodes.items():
|
||||
if node.n_images(self) < min_images:
|
||||
pruned_nodes.add(node_id)
|
||||
node.prune(self)
|
||||
|
||||
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
|
||||
node_stack = [self.root]
|
||||
node_idx = 0
|
||||
while node_idx < len(node_stack):
|
||||
node = node_stack[node_idx]
|
||||
if node.child_ids is not None:
|
||||
for child_id in node.child_ids:
|
||||
child = self.nodes[child_id]
|
||||
node_stack.append(child)
|
||||
node_idx += 1
|
||||
|
||||
# now prune the stack from the bottom up
|
||||
for node in node_stack[::-1]:
|
||||
# only look at images of that class, not of additional children
|
||||
if node.n_images() < min_images:
|
||||
pruned_nodes.add(node.id)
|
||||
node.prune(self)
|
||||
|
||||
self.pruned = pruned_nodes
|
||||
return pruned_nodes
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
tree = cls()
|
||||
tree.root = WNEntry.from_dict(d["root"])
|
||||
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
|
||||
tree.parentless = d["parentless"]
|
||||
if "pruned" in d:
|
||||
tree.pruned = set(d["pruned"])
|
||||
return tree
|
||||
|
||||
def add_node(self, node_id, in_in=True):
|
||||
if node_id in self.nodes:
|
||||
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
|
||||
return
|
||||
|
||||
synset = wn.synset_from_pos_and_offset("n", node_id)
|
||||
|
||||
# print(f"adding node {synset.name()} with id {node_id}")
|
||||
|
||||
hypernyms = synset.hypernyms()
|
||||
if len(hypernyms) == 0:
|
||||
parent_id = None
|
||||
self.parentless.append(node_id)
|
||||
main_tree = False
|
||||
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
|
||||
else:
|
||||
parent_id = synset.hypernyms()[0].offset()
|
||||
if parent_id not in self.nodes:
|
||||
self.add_node(parent_id, in_in=False)
|
||||
parent = self.nodes[parent_id]
|
||||
|
||||
if parent.child_ids is None:
|
||||
parent.child_ids = []
|
||||
parent.child_ids.append(node_id)
|
||||
main_tree = parent.in_main_tree
|
||||
|
||||
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
|
||||
node = WNEntry(
|
||||
synset.name(),
|
||||
node_id,
|
||||
_lemmas_str(synset),
|
||||
parent_id=parent_id,
|
||||
in_image_net=in_in,
|
||||
depth=depth,
|
||||
in_main_tree=main_tree,
|
||||
)
|
||||
|
||||
self.nodes[node_id] = node
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes)
|
||||
|
||||
def image_net_len(self, only_main_tree=False):
|
||||
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
||||
|
||||
def max_depth(self, only_main_tree=False):
|
||||
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
|
||||
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self)}\nParentless:\n"
|
||||
+ "\n".join([self.nodes[node_id].__str__(tree=self) for node_id in self.parentless])
|
||||
)
|
||||
|
||||
def save(self, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(self.to_dict(), f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
with open(path, "r") as f:
|
||||
tree_dict = json.load(f)
|
||||
return cls.from_dict(tree_dict)
|
||||
|
||||
def subtree(self, node_id):
|
||||
if node_id not in self.nodes:
|
||||
return None
|
||||
node_queue = [self.nodes[node_id]]
|
||||
subtree_ids = set()
|
||||
while len(node_queue) > 0:
|
||||
node = node_queue.pop(0)
|
||||
subtree_ids.add(node.id)
|
||||
if node.child_ids is not None:
|
||||
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
|
||||
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
|
||||
subtree_root = subtree_nodes[node_id]
|
||||
subtree_root.parent_id = None
|
||||
depth_diff = subtree_root.depth
|
||||
for node in subtree_nodes.values():
|
||||
node.depth -= depth_diff
|
||||
return WNTree(root=subtree_root, nodes=subtree_nodes)
|
||||
|
||||
def _make_label_index(self, include_merged=False):
|
||||
self.label_index = sorted(
|
||||
[
|
||||
node_id
|
||||
for node_id, node in self.nodes.items()
|
||||
if node.n_images(self if include_merged else None) > 0 and not node._pruned
|
||||
]
|
||||
)
|
||||
|
||||
def get_label(self, node_id):
|
||||
if self.label_index is None:
|
||||
self._make_label_index()
|
||||
while self.nodes[node_id]._pruned:
|
||||
node_id = self.nodes[node_id].parent_id
|
||||
return self.label_index.index(node_id)
|
||||
|
||||
def n_labels(self):
|
||||
if self.label_index is None:
|
||||
self._make_label_index()
|
||||
return len(self.label_index)
|
||||
|
||||
def __contains__(self, item):
|
||||
if isinstance(item, str):
|
||||
if item[0] == "n":
|
||||
item = int(item[1:])
|
||||
else:
|
||||
return False
|
||||
if isinstance(item, int):
|
||||
return item in self.nodes
|
||||
if isinstance(item, WNEntry):
|
||||
return item.id in self.nodes
|
||||
return False
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, str) and item[0].startswith("n"):
|
||||
try:
|
||||
item = int(item[1:])
|
||||
except ValueError:
|
||||
pass
|
||||
if isinstance(item, str) and ".n." in item:
|
||||
for node in self.nodes.values():
|
||||
if item == node.name:
|
||||
return node
|
||||
raise KeyError(f"Item {item} not found in tree")
|
||||
if isinstance(item, int):
|
||||
return self.nodes[item]
|
||||
if isinstance(item, WNEntry):
|
||||
return self.nodes[item.id]
|
||||
raise KeyError(f"Item {item} not found in tree")
|
||||
834
AAAI Supplementary Material/Model Training Code/engine.py
Normal file
@@ -0,0 +1,834 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from math import isfinite
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from loguru import logger
|
||||
from timm.data import Mixup
|
||||
from timm.optim import create_optimizer
|
||||
from timm.scheduler import create_scheduler
|
||||
from torch import distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from metrics import calculate_metrics, per_class_counts
|
||||
from utils import (
|
||||
NoScaler,
|
||||
ScalerGradNormReturn,
|
||||
SchedulerArgs,
|
||||
log_formatter,
|
||||
save_model_state,
|
||||
)
|
||||
|
||||
try:
|
||||
from apex.optimizers import FusedLAMB # noqa: F401
|
||||
|
||||
apex_available = True
|
||||
except ImportError:
|
||||
logger.error("Nvidia apex not available")
|
||||
apex_available = False
|
||||
try:
|
||||
from lion_pytorch import Lion
|
||||
|
||||
lion_available = True
|
||||
except ImportError:
|
||||
logger.error("Lion not available")
|
||||
lion_available = False
|
||||
|
||||
|
||||
WANDB_AVAILABLE = False
|
||||
try:
|
||||
import wandb
|
||||
|
||||
WANDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.error("wandb not available")
|
||||
|
||||
|
||||
def wandb_available(turn_off=False):
|
||||
"""If wandb is available.
|
||||
|
||||
Args:
|
||||
turn_off (bool, optional): set wandb to be unavailble manually.
|
||||
|
||||
Returns:
|
||||
bool: wandb is available
|
||||
"""
|
||||
global WANDB_AVAILABLE
|
||||
if turn_off:
|
||||
WANDB_AVAILABLE = False
|
||||
return WANDB_AVAILABLE
|
||||
|
||||
|
||||
tqdm = partial(tqdm, leave=True, position=0) # noqa: F405
|
||||
|
||||
|
||||
def setup_tracking_and_logging(args, rank, append_model_path=None, log_wandb=True):
|
||||
"""Set up logging and tracking for an experiment.
|
||||
|
||||
Args:
|
||||
args (DotDict): Parsed command-line arguments
|
||||
rank (int): The rank of the current process
|
||||
append_model_path (str, optional): Path of an existing model, by default None
|
||||
log_wandb (bool, optional): Whether to log to wandb, by default True
|
||||
|
||||
Returns:
|
||||
str: folder, where all the run data is saved.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `dataset` or `model` is `None`.
|
||||
|
||||
Notes:
|
||||
This function sets up logger to stdout and file, as well as MLflow tracking for an experiment.
|
||||
For wandb logger, provide .wandb.apikey in the current directory.
|
||||
"""
|
||||
dataset, model, epochs = args.dataset.replace(os.sep, "_").lower(), args.model.replace(os.sep, "_"), args.epochs
|
||||
_base_folder = (
|
||||
os.path.join(args.results_folder, args.experiment_name, args.task.replace("-", ""), dataset)
|
||||
if args.out_dir is None
|
||||
else args.out_dir
|
||||
)
|
||||
run_folder = os.path.join(
|
||||
_base_folder,
|
||||
f"{args.run_name.replace(os.sep, '_')}_{model}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}",
|
||||
)
|
||||
assert dataset is not None and model is not None
|
||||
|
||||
if os.name == "nt":
|
||||
run_folder = run_folder.replace("@", "_").replace(" ", "_").replace(":", ".")
|
||||
|
||||
if append_model_path is not None:
|
||||
run_folder = os.path.dirname(append_model_path)
|
||||
if "run_name" not in args or args.run_name is None:
|
||||
args.run_name = run_folder.split(os.sep)[-1].split("_")[0]
|
||||
elif args.distributed:
|
||||
obj_list = [None]
|
||||
if rank == 0:
|
||||
obj_list[0] = run_folder
|
||||
dist.broadcast_object_list(obj_list, src=0)
|
||||
run_folder = obj_list[0]
|
||||
if rank == 0:
|
||||
os.makedirs(run_folder, exist_ok=True)
|
||||
dist.barrier()
|
||||
elif rank == 0:
|
||||
os.makedirs(run_folder, exist_ok=True)
|
||||
|
||||
assert "%" not in args.run_name, f"found '%' in run_name '{args.run_name}'. This messes with string formatting..."
|
||||
|
||||
if args.debug:
|
||||
args.log_level = "debug"
|
||||
|
||||
# logger to stdout & file
|
||||
log_name = args.task.replace("-", "")
|
||||
if args.task not in ["pre-train", "fine-tune", "fine-tune-head"]:
|
||||
log_name += f"_{dataset}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}"
|
||||
log_file = os.path.join(run_folder, f"{log_name}.log")
|
||||
logger.remove()
|
||||
logger.configure(extra=dict(run_name=args.run_name, rank=rank, world_size=args.world_size))
|
||||
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.info(f"Run folder '{run_folder}'")
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"{args.task.replace('-', '').capitalize()} {model} on {dataset} for {epochs} epochs")
|
||||
|
||||
global WANDB_AVAILABLE
|
||||
WANDB_AVAILABLE = WANDB_AVAILABLE and log_wandb and os.path.isfile(".wandb.apikey") and args.wandb
|
||||
if WANDB_AVAILABLE:
|
||||
with open(".wandb.apikey", "r") as f:
|
||||
__wandb_api_key = f.read().strip()
|
||||
wandb.login(key=__wandb_api_key)
|
||||
if args.wandb_run_id is not None:
|
||||
wandb_args = dict(project=args.experiment_name, resume="must", id=args.wandb_run_id)
|
||||
else:
|
||||
wandb_args = dict(
|
||||
project=args.experiment_name,
|
||||
name=args.run_name.replace("_", "-").replace(" ", "-"),
|
||||
config={"logfile": log_file, **dict(args)},
|
||||
job_type=args.task,
|
||||
tags=[model, dataset],
|
||||
resume="allow",
|
||||
id=args.wandb_run_id,
|
||||
)
|
||||
wandb.init(**wandb_args)
|
||||
args["wandb_run_id"] = wandb.run.id
|
||||
if rank == 0:
|
||||
logger.info(f"wandb run initialized with id {args['wandb_run_id']}.")
|
||||
else:
|
||||
logger.info(
|
||||
f"Not using wandb. (args.wandb={args.wandb}, .wandb.apikey exists={os.path.isfile('.wandb.apikey')},"
|
||||
f" function declaration log_wandb={log_wandb})"
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
|
||||
if args.debug:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
logger.warning("torch.autograd anomaly detection enabled. Will slow down model.")
|
||||
|
||||
return run_folder
|
||||
|
||||
|
||||
def setup_model_optim_sched_scaler(model, device, epochs, args, head_only=False):
|
||||
"""Set up model, optimizer, and scheduler with automatic mixed precision (amp) and distributed data parallel (DDP).
|
||||
|
||||
Args:
|
||||
model (nn.Module): the loaded model
|
||||
device (torch.device): the current device
|
||||
epochs (int): total number of epochs to learn for (for scheduler)
|
||||
args: further arguments
|
||||
head_only (bool, optional): train only the linear head. Default: False
|
||||
|
||||
Returns:
|
||||
tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, ScalerGradNormReturn]: model, optimizer, scheduler, scaler
|
||||
|
||||
"""
|
||||
model = model.to(device)
|
||||
|
||||
if head_only:
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
for param in model.head.parameters():
|
||||
param.requires_grad = True
|
||||
for name, param in model.named_parameters():
|
||||
if "head" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
params = model.head.parameters()
|
||||
else:
|
||||
params = model # model.named_parameters() use model itself for now and let timm do the work...
|
||||
|
||||
if args.opt == "lion" and not lion_available:
|
||||
args.opt = "fusedlamb"
|
||||
logger.warning("Falling back from lion to fusedlamb")
|
||||
if args.opt == "fusedlamb" and not apex_available:
|
||||
args.opt = "adamw"
|
||||
logger.warning("Falling back from fusedlamb to adamw")
|
||||
if args.opt == "lion":
|
||||
optimizer = Lion(params, lr=args["lr"], weight_decay=args["weight_decay"])
|
||||
else:
|
||||
optimizer = create_optimizer(args, params)
|
||||
|
||||
scaler = ScalerGradNormReturn() if args.amp else NoScaler()
|
||||
|
||||
# if args.model_ema:
|
||||
# # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
||||
# ema_model = ModelEma(model, decay=args.model_ema_decay, resume='')
|
||||
|
||||
if args.distributed:
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = DDP(model, device_ids=[device])
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# scheduler = optim.lr_scheduler.LambdaLR(optimizer,
|
||||
# lr_lambda=scheduler_function_factory(**args))
|
||||
sched_args = SchedulerArgs(args.sched, args.epochs, args.min_lr, args.warmup_lr, args.warmup_epochs)
|
||||
scheduler, _ = create_scheduler(sched_args, optimizer)
|
||||
|
||||
return model, optimizer, scheduler, scaler
|
||||
|
||||
|
||||
def setup_criteria_mixup(args, dataset=None, **criterion_kwargs):
|
||||
"""Set up further objects that are needed for training.
|
||||
|
||||
Args:
|
||||
args: arguments
|
||||
dataset (torch.data.Dataset, optional): dataset that implements images_per_class, for class weights (Default value = None)
|
||||
criterion_kwargs: further arguments for the criterion
|
||||
**criterion_kwargs:
|
||||
|
||||
Returns:
|
||||
tuple[nn.Module, nn.Module, Mixup]: criterion, val_criterion, mixup
|
||||
|
||||
"""
|
||||
weight = None
|
||||
if args.loss_weight != "none":
|
||||
if dataset is not None and hasattr(dataset, "images_per_class"):
|
||||
ipc = dataset.images_per_class
|
||||
total_ims = sum(ipc)
|
||||
|
||||
if args.loss_weight == "linear":
|
||||
weight = torch.tensor([total_ims / (ims * args.n_classes) for ims in ipc])
|
||||
elif args.loss_weight == "log":
|
||||
p_c = torch.tensor([ims / total_ims for ims in ipc])
|
||||
log_p_c = torch.where(p_c > 0, p_c.log(), torch.zeros_like(p_c))
|
||||
entr = -(p_c * log_p_c).sum()
|
||||
weight = -log_p_c / entr
|
||||
elif args.loss_weight == "sqrt":
|
||||
p_c = torch.tensor([ims / total_ims for ims in ipc])
|
||||
weight = 1 / (p_c.sqrt() * p_c.sqrt().sum())
|
||||
|
||||
else:
|
||||
logger.warning("Could not find images_per_class in dataset. Using uniform weights.")
|
||||
|
||||
if args.aug_cutmix or args.multi_label:
|
||||
# criterion = SoftTargetCrossEntropy()
|
||||
if args.ignore_index >= 0:
|
||||
if weight is None:
|
||||
weight = torch.ones(args.n_classes)
|
||||
weight[args.ignore_index] = 0
|
||||
if args.multi_label:
|
||||
if args.loss == "ce":
|
||||
criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
|
||||
val_criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Only BCEWithLogitsLoss (ce) is implemented for multi-label classification, not {args.loss}."
|
||||
)
|
||||
else:
|
||||
if args.loss == "ce":
|
||||
loss_cls = nn.CrossEntropyLoss
|
||||
elif args.loss == "baikal":
|
||||
loss_cls = BaikalLoss
|
||||
else:
|
||||
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
|
||||
criterion = loss_cls(weight=weight, **criterion_kwargs)
|
||||
val_criterion = loss_cls(weight=weight, **criterion_kwargs)
|
||||
else:
|
||||
if args.loss == "ce":
|
||||
loss_cls = nn.CrossEntropyLoss
|
||||
elif args.loss == "baikal":
|
||||
loss_cls = BaikalLoss
|
||||
else:
|
||||
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
|
||||
criterion = loss_cls(
|
||||
label_smoothing=args.label_smoothing,
|
||||
ignore_index=args.ignore_index if weight is None else -100,
|
||||
weight=weight,
|
||||
**criterion_kwargs,
|
||||
) # LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
# criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
val_criterion = loss_cls(
|
||||
label_smoothing=args.label_smoothing,
|
||||
ignore_index=args.ignore_index if weight is None else -100,
|
||||
weight=weight,
|
||||
**criterion_kwargs,
|
||||
) # LabelSmoothingCrossEntropy(smoothing=0.)
|
||||
|
||||
mixup_kwargs = dict(
|
||||
mixup_alpha=args.aug_mixup_alpha,
|
||||
cutmix_alpha=args.aug_cutmix_alpha,
|
||||
label_smoothing=args.label_smoothing,
|
||||
num_classes=args.n_classes,
|
||||
)
|
||||
mixup = Mixup(**mixup_kwargs) if abs(args.aug_cutmix_alpha) + abs(args.aug_mixup_alpha) > 0.0 else None
|
||||
|
||||
return criterion, val_criterion, mixup
|
||||
|
||||
|
||||
def _train(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epochs,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
world_size,
|
||||
scheduler,
|
||||
args,
|
||||
val_loader,
|
||||
val_criterion,
|
||||
model_folder,
|
||||
scaler,
|
||||
do_metrics_calculation=True,
|
||||
start_epoch=0,
|
||||
show_tqdm=True,
|
||||
topk=(1, 5),
|
||||
acc_dict_key=None,
|
||||
train_dali_server=None,
|
||||
val_dali_server=None,
|
||||
):
|
||||
"""Train the model.
|
||||
|
||||
Args:
|
||||
model:
|
||||
train_loader:
|
||||
optimizer:
|
||||
rank:
|
||||
epochs:
|
||||
device:
|
||||
mixup:
|
||||
criterion:
|
||||
world_size:
|
||||
scheduler:
|
||||
args:
|
||||
val_loader:
|
||||
val_criterion:
|
||||
model_folder:
|
||||
scaler:
|
||||
do_metrics_calculation: (Default value = True)
|
||||
start_epoch: (Default value = 0)
|
||||
show_tqdm: (Default value = True)
|
||||
topk: (Default value = (1)
|
||||
5):
|
||||
acc_dict_key: (Default value = None)
|
||||
|
||||
Returns:
|
||||
dict: evaluation metrics at the end of training
|
||||
|
||||
"""
|
||||
if acc_dict_key is None:
|
||||
acc_dict_key = "acc{}"
|
||||
training_start = time()
|
||||
topk = tuple(k for k in topk if k <= args.n_classes)
|
||||
time_spend_training = time_spend_validating = 0
|
||||
current_best_acc = 0.0
|
||||
if rank == 0:
|
||||
logger.info(f"Dataloader has {len(train_loader)} batches")
|
||||
|
||||
logger.debug("Starting training with the following settings:")
|
||||
logger.debug(f"criterion: {criterion}")
|
||||
logger.debug(f"train_loader: {train_loader}, sampler: {train_loader.sampler}")
|
||||
logger.debug(f"dataset: {train_loader.dataset}")
|
||||
logger.debug(f"optimizer: {optimizer}")
|
||||
logger.debug(f"device: {device}")
|
||||
logger.debug(f"start epoch: {start_epoch}, epochs: {epochs}")
|
||||
logger.debug(f"scaler: {scaler}")
|
||||
logger.debug(f"max_grad_norm: {args.max_grad_norm}")
|
||||
# logger.debug(f"model_ema:\n{model_ema}\n{model_ema.decay}\n{model_ema.device}")
|
||||
if mixup:
|
||||
logger.debug(
|
||||
f"mixup: {mixup}; mixup_alpha: {mixup.mixup_alpha}, cutmix_alpha: {mixup.cutmix_alpha},"
|
||||
f" cutmix_minmax: {mixup.cutmix_minmax}, prob: {mixup.mix_prob}, switch_prob: {mixup.switch_prob},"
|
||||
f" label_smoothing: {mixup.label_smoothing}, num_classes: {mixup.num_classes}, correct_lam:"
|
||||
f" {mixup.correct_lam}, mixup_enabled: {mixup.mixup_enabled}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"mixup: {mixup}")
|
||||
|
||||
for epoch in range(start_epoch, epochs):
|
||||
with logger.contextualize(epoch=str(epoch + 1)):
|
||||
if args.distributed:
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
|
||||
set_ep_func = getattr(train_loader.dataset, "set_epoch", None)
|
||||
if callable(set_ep_func):
|
||||
train_loader.dataset.set_epoch(epoch)
|
||||
val_loader.dataset.set_epoch(epoch)
|
||||
|
||||
if train_dali_server:
|
||||
train_dali_server.start_thread()
|
||||
logger.info("started train dali server")
|
||||
|
||||
epoch_time, epoch_stats = _train_one_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epoch,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
scheduler,
|
||||
scaler,
|
||||
args,
|
||||
topk,
|
||||
"train/" + acc_dict_key,
|
||||
show_tqdm,
|
||||
)
|
||||
time_spend_training += epoch_time
|
||||
|
||||
if train_dali_server:
|
||||
train_dali_server.stop_thread()
|
||||
|
||||
val_time, val_stats = _evaluate(
|
||||
model,
|
||||
val_loader,
|
||||
epoch,
|
||||
rank,
|
||||
device,
|
||||
val_criterion,
|
||||
args,
|
||||
topk,
|
||||
"val/" + acc_dict_key,
|
||||
dali_server=val_dali_server,
|
||||
)
|
||||
time_spend_validating += val_time
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"total_time={time() - training_start}s")
|
||||
|
||||
if rank == 0:
|
||||
top1_val_acc = val_stats["val/" + acc_dict_key.format(1)]
|
||||
# print metadata for grafana
|
||||
metadata = {
|
||||
"epoch": epoch + 1,
|
||||
"progress": (epoch + 1) / args.epochs,
|
||||
**val_stats,
|
||||
**epoch_stats,
|
||||
}
|
||||
# filter out Nan and infinity values
|
||||
metadata = {k: v for k, v in metadata.items() if isfinite(v)}
|
||||
print(json.dumps(metadata), flush=True)
|
||||
logger.debug(f"printed metadata: {json.dumps(metadata)}")
|
||||
if WANDB_AVAILABLE:
|
||||
wandb.log(metadata, step=epoch + 1)
|
||||
|
||||
# saving current state
|
||||
if top1_val_acc > current_best_acc or (epoch + 1) % args.save_epochs == 0:
|
||||
reason = "top" if top1_val_acc > current_best_acc else "" # min(...) will be the top-1 accuracy
|
||||
if reason == "top":
|
||||
current_best_acc = top1_val_acc
|
||||
logger.info(f"found a new best model with acc: {current_best_acc}")
|
||||
kwargs = dict(
|
||||
model_state=model.state_dict(),
|
||||
stats=metadata,
|
||||
optimizer_state=optimizer.state_dict(),
|
||||
additional_reason=reason,
|
||||
regular_save=(epoch + 1) % args.save_epochs == 0,
|
||||
)
|
||||
if scheduler:
|
||||
kwargs["scheduler_state"] = scheduler.state_dict()
|
||||
save_model_state(
|
||||
model_folder, epoch + 1, args, **kwargs, max_interm_ep_states=args.keep_interm_states
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
end_time = time()
|
||||
logger.info(
|
||||
f"training done: total time={end_time - training_start}, "
|
||||
f"time spend training={time_spend_training}, "
|
||||
f"time spend validating={time_spend_validating}"
|
||||
)
|
||||
|
||||
results = {**val_stats, **epoch_stats, f"val/best_{acc_dict_key.format(1)}": current_best_acc}
|
||||
|
||||
if rank == 0:
|
||||
save_model_state(
|
||||
model_folder,
|
||||
epoch + 1,
|
||||
args,
|
||||
model_state=model.state_dict(),
|
||||
stats=results,
|
||||
additional_reason="final",
|
||||
regular_save=False,
|
||||
max_interm_ep_states=args.keep_interm_states,
|
||||
)
|
||||
|
||||
if do_metrics_calculation:
|
||||
# Calculate efficiency metrics
|
||||
inp = next(iter(train_loader))[0].to(device)
|
||||
metrics = calculate_metrics(
|
||||
args,
|
||||
model,
|
||||
rank=rank,
|
||||
input=inp,
|
||||
device=device,
|
||||
did_training=True,
|
||||
all_metrics=False,
|
||||
world_size=world_size,
|
||||
key_start="eval/",
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"Efficiency metrics: {json.dumps(metrics)}")
|
||||
return results
|
||||
|
||||
|
||||
def _mask_preds(preds, cls_masks, mask_val=-100):
|
||||
"""Mask the predictions by the mask.
|
||||
|
||||
Args:
|
||||
preds: model predictions
|
||||
cls_masks: class masks
|
||||
mask_val: (Default value = -100)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: masked predictions
|
||||
|
||||
"""
|
||||
if cls_masks is None:
|
||||
return preds
|
||||
return torch.where(cls_masks.bool(), mask_val, preds)
|
||||
|
||||
|
||||
def _evaluate(
|
||||
model,
|
||||
val_loader,
|
||||
epoch,
|
||||
rank,
|
||||
device,
|
||||
val_criterion,
|
||||
args,
|
||||
topk=(1, 5),
|
||||
acc_dict_key=None,
|
||||
dali_server=None,
|
||||
):
|
||||
"""Evaluate the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): the model to evaluate
|
||||
val_loader (DataLoader): loader for evaluation data
|
||||
epoch (int): the current epoch (for logger & tracking)
|
||||
rank (int): this processes rank (don't log n times)
|
||||
device (torch.device): device to evaluate on
|
||||
val_criterion (nn.Module): validation loss
|
||||
args (DotDict): further arguments
|
||||
topk (tuple[int], optional, optional): top-k accuracy, by default (1, 5)
|
||||
acc_dict_key (str, optional, optional): key for the accuracy dictionary, by default name of the performance metric. 'val_' will be prepended.
|
||||
|
||||
Returns:
|
||||
tuple[float, float, dict, dict]: validation time, validation loss, validation accuracies, additional information
|
||||
|
||||
"""
|
||||
if not acc_dict_key:
|
||||
acc_dict_key = "acc{}"
|
||||
|
||||
if dali_server:
|
||||
dali_server.start_thread()
|
||||
topk = tuple(k for k in topk if k <= args.n_classes)
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
val_accs = {acc_dict_key.format(k): 0.0 for k in topk}
|
||||
val_start = time()
|
||||
n_iters = 0
|
||||
iterator = (
|
||||
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch + 1}")
|
||||
if rank == 0 and args.tqdm
|
||||
else val_loader
|
||||
)
|
||||
class_counts = torch.zeros(1 if isinstance(topk, int) else len(topk), args.n_classes, 2)
|
||||
for batch_data in iterator:
|
||||
xs, ys = batch_data[:2]
|
||||
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
|
||||
|
||||
if args.debug:
|
||||
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
|
||||
|
||||
xs, ys = xs.to(device, non_blocking=True), ys.to(device, non_blocking=True)
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", enabled=args.eval_amp):
|
||||
preds = model(xs)
|
||||
preds = _mask_preds(preds, cls_masks)
|
||||
|
||||
if args.multi_label:
|
||||
# labels are float for BCELoss
|
||||
ys = ys.float()
|
||||
val_loss += val_criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys).item()
|
||||
class_counts += per_class_counts(preds, ys, args.n_classes, topk=topk, ignore_index=args.ignore_index)
|
||||
n_iters += 1
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
|
||||
if dali_server:
|
||||
dali_server.stop_thread()
|
||||
val_end = time()
|
||||
iterations = n_iters
|
||||
|
||||
if args.distributed:
|
||||
gather_tensor = torch.Tensor([val_loss]).to(device)
|
||||
dist.barrier()
|
||||
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
|
||||
gather_tensor = gather_tensor.tolist()
|
||||
val_loss = gather_tensor[0]
|
||||
class_counts = class_counts.to(device)
|
||||
dist.all_reduce(class_counts, op=dist.ReduceOp.SUM)
|
||||
class_counts = class_counts.cpu()
|
||||
|
||||
for i, k in enumerate(topk):
|
||||
key = acc_dict_key.format(k)
|
||||
mkey = key.replace("acc", "m-acc")
|
||||
val_accs[key] = class_counts[i].sum(dim=0)[0].item() / class_counts[i].sum(dim=0).sum(dim=-1).item()
|
||||
val_accs[mkey] = (class_counts[i, :, 0] / class_counts[i].sum(dim=-1)).mean().item()
|
||||
|
||||
val_accs["val/loss"] = val_loss
|
||||
|
||||
if rank == 0:
|
||||
log_s = f"val/time={val_end - val_start}s"
|
||||
for key, val in val_accs.items():
|
||||
log_s += f", {key}={val:.4f}"
|
||||
logger.info(log_s)
|
||||
|
||||
return val_end - val_start, val_accs
|
||||
|
||||
|
||||
def _train_one_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epoch,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
scheduler,
|
||||
scaler,
|
||||
args,
|
||||
topk=(1, 5),
|
||||
acc_dict_key=None,
|
||||
show_tqdm=True,
|
||||
):
|
||||
"""Train the model for one epoch.
|
||||
|
||||
Args:
|
||||
model:
|
||||
train_loader:
|
||||
optimizer:
|
||||
rank:
|
||||
epoch:
|
||||
device:
|
||||
mixup:
|
||||
criterion:
|
||||
scheduler:
|
||||
scaler:
|
||||
args:
|
||||
topk: (Default value = (1, 5)
|
||||
acc_dict_key: (Default value = None)
|
||||
show_tqdm: (Default value = True)
|
||||
|
||||
Returns:
|
||||
tuple[float, float, dict]: time spend in training, epoch loss, epoch accuracies
|
||||
|
||||
"""
|
||||
if not acc_dict_key:
|
||||
acc_dict_key = "acc{}"
|
||||
|
||||
model.train()
|
||||
iterator = (
|
||||
tqdm(train_loader, total=len(train_loader), desc=f"Training epoch {epoch + 1}")
|
||||
if rank == 0 and show_tqdm
|
||||
else train_loader
|
||||
)
|
||||
|
||||
if not args.amp:
|
||||
scaler = NoScaler()
|
||||
|
||||
epoch_loss = 0
|
||||
epoch_accs = {}
|
||||
epoch_start = time()
|
||||
grad_norms = []
|
||||
n_iters = 0
|
||||
if hasattr(train_loader.dataset, "epoch"):
|
||||
train_loader.dataset.epoch = epoch
|
||||
for i, batch_data in enumerate(iterator):
|
||||
xs, ys = batch_data[:2]
|
||||
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
|
||||
optimizer.zero_grad()
|
||||
n_iters += 1
|
||||
xs = xs.to(device, non_blocking=True)
|
||||
ys = ys.to(device, non_blocking=True)
|
||||
|
||||
if args.debug and i == 0:
|
||||
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
|
||||
|
||||
if mixup:
|
||||
if args.multi_label:
|
||||
xs, ys = mixup(xs, ys, cls_masks)
|
||||
else:
|
||||
xs, ys = mixup(xs, ys)
|
||||
|
||||
if args.debug and i == 0:
|
||||
logger.debug(f"input x: {type(xs)}; {xs.shape}, y: {type(ys)}; {ys.shape}")
|
||||
|
||||
with torch.amp.autocast("cuda", enabled=args.amp):
|
||||
preds = model(xs)
|
||||
preds = _mask_preds(preds, cls_masks)
|
||||
if args.multi_label:
|
||||
# labels are float for BCELoss
|
||||
ys = ys.float()
|
||||
loss = criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys) + (
|
||||
model.get_internal_loss() if hasattr(model, "get_internal_loss") else model.module.get_internal_loss()
|
||||
)
|
||||
|
||||
if not isfinite(loss.item()):
|
||||
logger.error(f"Got loss value {loss.item()}. Stopping training.")
|
||||
logger.info(f"input has nan: {xs.isnan().any().item()}")
|
||||
logger.info(f"target has nan: {ys.isnan().any().item()}")
|
||||
logger.info(f"output has nan: {preds.isnan().any().item()}")
|
||||
for name, param in model.named_parameters():
|
||||
if param.isnan().any().item():
|
||||
logger.error(f"parameter {name} has a nan value")
|
||||
if len(grad_norms) > 0:
|
||||
grad_norms = torch.Tensor(grad_norms)
|
||||
logger.info(
|
||||
f"Gradient norms until now: min={grad_norms.min().item()}, 20th"
|
||||
f" %tile={torch.quantile(grad_norms, .2).item()}, mean={torch.mean(grad_norms)}, 80th"
|
||||
f" %tile={torch.quantile(grad_norms, .8).item()}, max={grad_norms.max()}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
iter_grad_norm = scaler(
|
||||
loss,
|
||||
optimizer,
|
||||
parameters=model.parameters(),
|
||||
clip_grad=args.max_grad_norm if args.max_grad_norm > 0.0 else None,
|
||||
).cpu()
|
||||
|
||||
if args.gather_stats_during_training and isfinite(iter_grad_norm):
|
||||
grad_norms.append(iter_grad_norm)
|
||||
|
||||
# if args.aug_cutmix:
|
||||
# ys = ys.argmax(dim=-1) # for accuracy with CutMix, just use the argmax for both
|
||||
#
|
||||
epoch_loss += loss.item()
|
||||
# accuracies = accuracy(preds, ys, topk=topk, dict_key=acc_dict_key, ignore_index=args.ignore_index)
|
||||
# for key in accuracies:
|
||||
# epoch_accs[key] += accuracies[key]
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
epoch_end = time()
|
||||
|
||||
iterations = n_iters
|
||||
# epoch_accs = {key: val / iterations for key, val in epoch_accs.items()}
|
||||
epoch_loss = epoch_loss / iterations
|
||||
grad_norm_avrg = -1
|
||||
inf_grads = iterations - len(grad_norms)
|
||||
if len(grad_norms) > 0 and args.gather_stats_during_training:
|
||||
grad_norm_max = max(grad_norms)
|
||||
grad_norms = torch.Tensor(grad_norms)
|
||||
grad_norm_20 = torch.quantile(grad_norms, 0.2).item()
|
||||
grad_norm_80 = torch.quantile(grad_norms, 0.8).item()
|
||||
grad_norm_avrg = torch.mean(grad_norms)
|
||||
|
||||
if args.distributed:
|
||||
# grad norm is already synchronized
|
||||
# gather_tensor = torch.Tensor([epoch_loss, *[epoch_accs[acc_dict_key.format(k)] for k in topk]]).to(device)
|
||||
gather_tensor = torch.Tensor([epoch_loss]).to(device)
|
||||
dist.barrier()
|
||||
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
|
||||
# gather_tensor = (gather_tensor / world_size).tolist()
|
||||
epoch_loss = gather_tensor.item()
|
||||
# for i, k in enumerate(topk):
|
||||
# epoch_accs[acc_dict_key.format(k)] = gather_tensor[i + 1]
|
||||
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
epoch_accs["train/lr"] = lr
|
||||
epoch_accs["train/loss"] = epoch_loss
|
||||
|
||||
if rank == 0:
|
||||
if args.gather_stats_during_training:
|
||||
print_s = f"train/time={epoch_end - epoch_start}s"
|
||||
logger.info(print_s)
|
||||
if len(grad_norms) > 0:
|
||||
logger.info(
|
||||
f"grad norm avrg={grad_norm_avrg}, grad norm max={grad_norm_max}, "
|
||||
f"inf grad norm={inf_grads}, grad norm 20%={grad_norm_20}, grad norm 80%={grad_norm_80}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"inf grad norm={inf_grads}")
|
||||
logger.error("100% of update steps with infinite grad norms!")
|
||||
else:
|
||||
logger.info(f"train/time={epoch_end - epoch_start}s")
|
||||
|
||||
if scheduler:
|
||||
if isinstance(scheduler, optim.lr_scheduler.LambdaLR):
|
||||
scheduler.step()
|
||||
else:
|
||||
scheduler.step(epoch)
|
||||
|
||||
if args.gather_stats_during_training:
|
||||
return epoch_end - epoch_start, epoch_accs
|
||||
return epoch_end - epoch_start, {}
|
||||
710
AAAI Supplementary Material/Model Training Code/evaluate.py
Normal file
@@ -0,0 +1,710 @@
|
||||
"""Module to evaluate trained models."""
|
||||
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from math import sqrt
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from timm.loss import LabelSmoothingCrossEntropy
|
||||
from timm.models.resnet import ResNet as TimmResNet
|
||||
from torch import distributed as dist
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from engine import (
|
||||
_evaluate,
|
||||
setup_criteria_mixup,
|
||||
setup_model_optim_sched_scaler,
|
||||
setup_tracking_and_logging,
|
||||
wandb_available,
|
||||
)
|
||||
from load_dataset import prepare_dataset
|
||||
from metrics import calculate_metrics
|
||||
from models import load_pretrained
|
||||
from utils import (
|
||||
RepeatedDataset,
|
||||
ddp_cleanup,
|
||||
ddp_setup,
|
||||
denormalize,
|
||||
get_cpu_name,
|
||||
grad_cam_reshape_transform,
|
||||
prep_kwargs,
|
||||
set_filter_warnings,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_metrics(model, dataset, **kwargs):
|
||||
"""Evaluate efficiency metrics for a given model.
|
||||
|
||||
Args:
|
||||
model (str): path to model state .tar
|
||||
dataset (str): name of the dataset to evaluate on
|
||||
**kwargs: further arguments
|
||||
|
||||
"""
|
||||
set_filter_warnings()
|
||||
model_path = model
|
||||
args = prep_kwargs(kwargs)
|
||||
if args.cuda:
|
||||
args.distributed, device, world_size, rank, _ = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
args.distributed = False
|
||||
device = torch.device("cpu")
|
||||
rank = 0
|
||||
args.compile_model = False
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
args.dataset = dataset
|
||||
args.run_name = old_args.run_name
|
||||
args.experiment_name = old_args.experiment_name
|
||||
args.wandb_run_id = old_args.wandb_run_id
|
||||
setup_tracking_and_logging(args, rank=rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None)
|
||||
|
||||
train_loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
old_args["eval_imsize"] = args.imsize
|
||||
args.model = model_name = old_args.model
|
||||
args.dataset = dataset
|
||||
args.epochs = 5
|
||||
|
||||
model, optim, _, scaler = setup_model_optim_sched_scaler(model, device, epochs=10, args=args, head_only=False)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
f"Evaluate metrics for model {model_name} on {dataset}. "
|
||||
f"It was {old_args.task.replace('-','')}d on {old_args.dataset} for {save_state['epoch']} "
|
||||
"epochs."
|
||||
)
|
||||
# logger.info(f"full set of arguments: {args}")
|
||||
logger.info(f"full set of training arguments: {old_args}")
|
||||
logger.info(f"full set of eval-metrics arguments: {args}")
|
||||
|
||||
logger.info(
|
||||
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
||||
)
|
||||
metrics = calculate_metrics(
|
||||
args, model, rank=rank, device=device, optim=optim, scaler=scaler, train_loader=train_loader, key_start="eval/"
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"Metrics: {metrics}")
|
||||
if wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb.log(metrics)
|
||||
|
||||
|
||||
def evaluate(model, dataset=None, val_dataset=None, **kwargs):
|
||||
"""Evaluate model accuracy.
|
||||
|
||||
Args:
|
||||
model (str): path to model state .tar
|
||||
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
**kwargs: further arguments
|
||||
Note:
|
||||
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
||||
|
||||
"""
|
||||
set_filter_warnings()
|
||||
model_path = model
|
||||
args = prep_kwargs(kwargs)
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
args.dataset = dataset
|
||||
args.val_dataset = val_dataset
|
||||
if args.cuda:
|
||||
args.distributed, device, world_size, rank, _ = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
args.distributed = False
|
||||
device = torch.device("cpu")
|
||||
world_size = 1
|
||||
rank = 0
|
||||
args.compile_model = False
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
args.dataset = dataset
|
||||
args.run_name = old_args.run_name
|
||||
args.experiment_name = old_args.experiment_name
|
||||
args.wandb_run_id = old_args.wandb_run_id
|
||||
run_folder = setup_tracking_and_logging(
|
||||
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
|
||||
)
|
||||
|
||||
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
|
||||
val_dataset, args, train=False
|
||||
)
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
model = model.to(device)
|
||||
args.model = model_name = old_args.model
|
||||
args.dataset = dataset
|
||||
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
f"Evaluate model {model_name} on {val_dataset}. "
|
||||
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
model = DDP(model)
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# log all devices
|
||||
logger.info(
|
||||
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
logger.info(f"full set of arguments: {args}")
|
||||
logger.info(f"full set of old arguments: {old_args}")
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
if rank == 0:
|
||||
logger.info("start evaluation")
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
if rank == 0:
|
||||
val_time, val_stats = _evaluate(
|
||||
model.to(device),
|
||||
val_loader,
|
||||
epoch=save_state["epoch"] - 1,
|
||||
rank=rank,
|
||||
device=device,
|
||||
val_criterion=val_criterion,
|
||||
args=args,
|
||||
dali_server=dali_server,
|
||||
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
|
||||
)
|
||||
log_s = f"Evaluation done in {val_time}s"
|
||||
for key, val in val_stats.items():
|
||||
log_s += f", {key}={val:.4f}"
|
||||
logger.info(log_s)
|
||||
if wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb.log(val_stats)
|
||||
else:
|
||||
_evaluate(
|
||||
model.to(device),
|
||||
val_loader,
|
||||
epoch=save_state["epoch"] - 1,
|
||||
rank=rank,
|
||||
device=device,
|
||||
val_criterion=val_criterion,
|
||||
args=args,
|
||||
dali_server=dali_server,
|
||||
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
|
||||
)
|
||||
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
|
||||
|
||||
def evaluate_center_bias(model, dataset=None, val_dataset=None, **kwargs):
|
||||
"""Evaluate model accuracy in different nonants.
|
||||
|
||||
Args:
|
||||
model (str): path to model state .tar
|
||||
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
**kwargs: further arguments
|
||||
Note:
|
||||
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
||||
|
||||
"""
|
||||
set_filter_warnings()
|
||||
model_path = model
|
||||
args = prep_kwargs(kwargs)
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
if dataset is None:
|
||||
dataset = val_dataset
|
||||
assert dataset is not None, "Specify validation dataset (-valds) or dataset (-ds)."
|
||||
args.dataset = dataset
|
||||
args.val_dataset = val_dataset
|
||||
if args.cuda:
|
||||
args.distributed, device, world_size, rank, _ = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
args.distributed = False
|
||||
device = torch.device("cpu")
|
||||
world_size = 1
|
||||
rank = 0
|
||||
args.compile_model = False
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
args.dataset = dataset
|
||||
args.run_name = old_args.run_name
|
||||
args.experiment_name = old_args.experiment_name
|
||||
args.wandb_run_id = old_args.wandb_run_id
|
||||
run_folder = setup_tracking_and_logging(
|
||||
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
|
||||
)
|
||||
|
||||
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
|
||||
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
model = model.to(device)
|
||||
args.model = model_name = old_args.model
|
||||
args.dataset = dataset
|
||||
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
f"Evaluate model {model_name} on {val_dataset}. "
|
||||
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
model = DDP(model)
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# log all devices
|
||||
logger.info(
|
||||
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
logger.info(f"full set of arguments: {args}")
|
||||
logger.info(f"full set of old arguments: {old_args}")
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
if rank == 0:
|
||||
logger.info("start evaluation")
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
if rank == 0:
|
||||
nonant_accs = []
|
||||
for nonant in range(-1, 9):
|
||||
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
|
||||
val_loader.dataset.fg_in_nonant = nonant
|
||||
logger.info(f"Evaluate nonant {nonant} for 5 rounds.")
|
||||
round_accs = []
|
||||
for _ in range(5):
|
||||
val_time, val_stats = _evaluate(
|
||||
model.to(device),
|
||||
val_loader,
|
||||
epoch=save_state["epoch"] - 1,
|
||||
rank=rank,
|
||||
device=device,
|
||||
val_criterion=val_criterion,
|
||||
args=args,
|
||||
dali_server=dali_server,
|
||||
)
|
||||
round_accs.append(val_stats["acc1"])
|
||||
nonant_accs.append(sum(round_accs) / len(round_accs))
|
||||
log_s = f"Evaluation done in {val_time}s: "
|
||||
for nonant, val in enumerate(nonant_accs[1:]):
|
||||
log_s += f", nonant {nonant}={val}% acc ({val / nonant_accs[0]} rel acc)"
|
||||
center_bias_val = 1 - (
|
||||
min([nonant_accs[1], nonant_accs[3], nonant_accs[7], nonant_accs[9]])
|
||||
+ min([nonant_accs[2], nonant_accs[4], nonant_accs[6], nonant_accs[8]])
|
||||
) / (2 * nonant_accs[5])
|
||||
log_s += f", center_bias={center_bias_val:.4f}"
|
||||
logger.info(log_s)
|
||||
if wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb.log({f"eval_{args.val_dataset}/center_bias": center_bias_val})
|
||||
else:
|
||||
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
|
||||
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
|
||||
|
||||
def evaluate_size_bias(model, dataset=None, val_dataset=None, **kwargs):
|
||||
"""Evaluate model accuracy for differently scaled foregrounds.
|
||||
|
||||
Args:
|
||||
model (str): path to model state .tar
|
||||
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
**kwargs: further arguments
|
||||
Note:
|
||||
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
||||
|
||||
"""
|
||||
set_filter_warnings()
|
||||
model_path = model
|
||||
args = prep_kwargs(kwargs)
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
if dataset is None:
|
||||
dataset = val_dataset
|
||||
assert val_dataset is not None and dataset is not None
|
||||
args.dataset = dataset
|
||||
args.val_dataset = val_dataset
|
||||
if args.cuda:
|
||||
args.distributed, device, world_size, rank, _ = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
args.distributed = False
|
||||
device = torch.device("cpu")
|
||||
world_size = 1
|
||||
rank = 0
|
||||
args.compile_model = False
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
args.dataset = dataset
|
||||
args.run_name = old_args.run_name
|
||||
args.experiment_name = old_args.experiment_name
|
||||
args.wandb_run_id = old_args.wandb_run_id
|
||||
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path, log_wandb=False)
|
||||
|
||||
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
|
||||
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
model = model.to(device)
|
||||
args.model = model_name = old_args.model
|
||||
args.dataset = dataset
|
||||
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
f"Evaluate model {model_name} on {val_dataset}. "
|
||||
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
model = DDP(model)
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# log all devices
|
||||
logger.info(
|
||||
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
logger.info(f"full set of arguments: {args}")
|
||||
logger.info(f"full set of old arguments: {old_args}")
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
if rank == 0:
|
||||
logger.info("start evaluation")
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
sizes = [0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 1.75, 2.0]
|
||||
if rank == 0:
|
||||
size_accs = []
|
||||
val_times = 0
|
||||
for size in sizes:
|
||||
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
|
||||
val_loader.dataset.size_fact = size
|
||||
val_loader.dataset.fg_scale_jitter = 0.0
|
||||
logger.info(f"Evaluate size factor {size} for 5 rounds.")
|
||||
round_accs = []
|
||||
for _ in range(5):
|
||||
val_time, val_stats = _evaluate(
|
||||
model.to(device),
|
||||
val_loader,
|
||||
epoch=save_state["epoch"] - 1,
|
||||
rank=rank,
|
||||
device=device,
|
||||
val_criterion=val_criterion,
|
||||
args=args,
|
||||
dali_server=dali_server,
|
||||
)
|
||||
round_accs.append(val_stats["acc1"])
|
||||
val_times += val_time
|
||||
size_accs.append(sum(round_accs) / len(round_accs))
|
||||
log_s = f"Evaluation done in {val_times}s: "
|
||||
for size, val in zip(sizes, size_accs):
|
||||
log_s += f", rel_size {size}={val}% acc ({val / size_accs[sizes.index(1.0)]} rel acc)"
|
||||
logger.info(log_s)
|
||||
else:
|
||||
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
|
||||
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
|
||||
|
||||
def evaluate_attributions(model, dataset=None, val_dataset=None, **kwargs):
|
||||
"""Evaluate model attributions using captum.
|
||||
|
||||
Args:
|
||||
model (str): path to model state .tar
|
||||
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
**kwargs: further arguments
|
||||
Note:
|
||||
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
||||
The `captum` package is required.
|
||||
|
||||
"""
|
||||
from captum.attr import IntegratedGradients
|
||||
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
|
||||
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
||||
|
||||
set_filter_warnings()
|
||||
model_path = model
|
||||
args = prep_kwargs(kwargs)
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
assert val_dataset is not None, "Please set dataset (-ds) or validation dataset (-valds)"
|
||||
args.dataset = val_dataset
|
||||
args.val_dataset = val_dataset
|
||||
if args.cuda:
|
||||
args.distributed, device, world_size, rank, _ = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
args.distributed = False
|
||||
device = torch.device("cpu")
|
||||
world_size = 1
|
||||
rank = 0
|
||||
args.compile_model = False
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
args.dataset = val_dataset
|
||||
args.run_name = old_args.run_name
|
||||
args.experiment_name = old_args.experiment_name
|
||||
args.wandb_run_id = old_args.wandb_run_id
|
||||
run_folder = setup_tracking_and_logging(
|
||||
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
|
||||
)
|
||||
|
||||
assert "fornet" in val_dataset.lower(), "Only ForNet supported for attribution evaluation."
|
||||
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
|
||||
val_dataset, args, train=False
|
||||
)
|
||||
val_loader.dataset.return_fg_masks = True
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
model = model.to(device)
|
||||
args.model = model_name = old_args.model
|
||||
args.dataset = dataset
|
||||
# assert (
|
||||
# args.imsize == old_args.imsize
|
||||
# ), f"Model was trained on {old_args.imsize}x{old_args.imsize} images. Not {args.imsize}x{args.imsize}."
|
||||
epoch = save_state["epoch"]
|
||||
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
f"Evaluate attributions of model {model_name} on {dataset}. "
|
||||
f"It was pretrained on {old_args.dataset} for {epoch} epochs."
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
model = DDP(model)
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# log all devices
|
||||
logger.info(
|
||||
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
if args.new_log:
|
||||
logger.info(f"full set of arguments: {args}")
|
||||
logger.info(f"full set of old arguments: {old_args}")
|
||||
else:
|
||||
logger.info(f"full set of attribution evaluation arguments: {old_args}")
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
iterator = (
|
||||
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch}")
|
||||
if rank == 0 and args.tqdm
|
||||
else val_loader
|
||||
)
|
||||
|
||||
if args.debug:
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
eval_attn_importance = False
|
||||
if isinstance(model, TimmResNet):
|
||||
reshape_transform = None
|
||||
target_layers = [model.layer4[-1]]
|
||||
elif model_name.lower().startswith("vit-"):
|
||||
reshape_transform = grad_cam_reshape_transform
|
||||
target_layers = [model.blocks[-1].norm1]
|
||||
eval_attn_importance = True
|
||||
from architectures.vit import _MatrixSaveAttn
|
||||
|
||||
model.blocks[-1].attn = _MatrixSaveAttn.cast(model.blocks[-1].attn)
|
||||
elif model_name.lower().startswith("swin_"):
|
||||
reshape_transform = grad_cam_reshape_transform
|
||||
target_layers = [model.layers[-1].blocks[-1].norm1]
|
||||
else:
|
||||
raise NotImplementedError(f"Model {model_name} not supported for attribution evaluation.")
|
||||
|
||||
model.eval()
|
||||
val_start = time()
|
||||
rel_ig_weights = 0.0
|
||||
rel_attn_weights = 0.0
|
||||
rel_cam_weights = {"GradCAM": 0.0, "GradCAM++": 0.0}
|
||||
if rank == 0:
|
||||
logger.info("Start attribution evaluation")
|
||||
if dali_server:
|
||||
dali_server.start_thread()
|
||||
for batch_data in iterator:
|
||||
xs, ys, fg_masks = batch_data
|
||||
|
||||
xs, ys, fg_masks = (
|
||||
xs.to(device, non_blocking=True),
|
||||
ys.to(device, non_blocking=True),
|
||||
fg_masks.float().to(device, non_blocking=True),
|
||||
)
|
||||
|
||||
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
|
||||
model.zero_grad()
|
||||
ig = IntegratedGradients(model)
|
||||
# we use attention temperature of 10 to make differences more apparent after exp
|
||||
attr_ig = (
|
||||
ig.attribute(xs, target=ys, baselines=0.0, internal_batch_size=args.batch_size * 4).sum(dim=1) * 10
|
||||
) # B x W x H
|
||||
attr_probs = attr_ig.view(xs.shape[0], -1).softmax(dim=-1).view(xs.shape[0], *xs.shape[2:])
|
||||
fg_masks = fg_masks.view(attr_probs.shape)
|
||||
|
||||
fg_attrs = (attr_probs * fg_masks).sum(dim=(-1, -2))
|
||||
rel_attr_weight = fg_attrs / fg_masks.mean(dim=(-1, -2))
|
||||
rel_attr_weight = torch.where(fg_masks.mean(dim=(-1, -2)) > 0, rel_attr_weight, 1.0)
|
||||
if rel_attr_weight.isnan().any():
|
||||
logger.error(f"NaNs in rel_attr_weight: {rel_attr_weight}, fg_mask_weights: {fg_masks.mean(dim=(-1, -2))}")
|
||||
break
|
||||
rel_ig_weights += rel_attr_weight.mean().item()
|
||||
|
||||
cam_targets = [ClassifierOutputTarget(int(trgt)) for trgt in ys.tolist()]
|
||||
for method, name in zip([GradCAM, GradCAMPlusPlus], ["GradCAM", "GradCAM++"]):
|
||||
with method(model=model, target_layers=target_layers, reshape_transform=reshape_transform) as cam, (
|
||||
torch.amp.autocast("cuda") if args.eval_amp else nullcontext()
|
||||
):
|
||||
cam_attr = cam(input_tensor=xs, targets=cam_targets)
|
||||
|
||||
cam_attr = torch.from_numpy(cam_attr).to(device)
|
||||
rel_cam_attr = (cam_attr * fg_masks).sum(dim=(-1, -2)) / cam_attr.sum(dim=(-1, -2))
|
||||
cam_attr_weight = rel_cam_attr / fg_masks.mean(dim=(-1, -2))
|
||||
cam_attr_weight = torch.where(
|
||||
(fg_masks.mean(dim=(-1, -2)) > 0) & (cam_attr.sum(dim=(-1, -2)) > 0), cam_attr_weight, 1.0
|
||||
)
|
||||
rel_cam_weights[name] += cam_attr_weight.mean().item()
|
||||
if cam_attr_weight.isnan().any():
|
||||
logger.error(
|
||||
f"NaNs in cam_attr_weight ({name}): {cam_attr_weight}, fg_mask_weights:"
|
||||
f" {fg_masks.mean(dim=(-1, -2))}"
|
||||
)
|
||||
break
|
||||
|
||||
if eval_attn_importance:
|
||||
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
|
||||
pred = model(xs) # noqa: F841
|
||||
last_attn_mat = model.blocks[-1].attn.attn_mat
|
||||
cls_tkn_attn = last_attn_mat[:, :, 0, 1:].mean(dim=1).squeeze(dim=1) # B x H x 1(CLS Token) X N -> B x N
|
||||
B, N = cls_tkn_attn.shape
|
||||
att_HW = int(sqrt(N))
|
||||
cls_tkn_attn = cls_tkn_attn.view(B, 1, att_HW, att_HW)
|
||||
attn_attr = F.interpolate(
|
||||
cls_tkn_attn, size=(xs.shape[-2], xs.shape[-1]), mode="bilinear", align_corners=False
|
||||
).view(B, xs.shape[-2], xs.shape[-1])
|
||||
rel_attn_attr = (attn_attr * fg_masks).sum(dim=(-1, -2)) / attn_attr.sum(dim=(-1, -2))
|
||||
attn_attr_weight = rel_attn_attr / fg_masks.mean(dim=(-1, -2))
|
||||
attn_attr_weight = torch.where(
|
||||
(fg_masks.mean(dim=(-1, -2)) > 0) & (attn_attr.sum(dim=(-1, -2)) > 0), attn_attr_weight, 1.0
|
||||
)
|
||||
rel_attn_weights += attn_attr_weight.mean().item()
|
||||
|
||||
if args.debug:
|
||||
logger.debug(f"Attribution scores: IG: {rel_attr_weight[:5]}, GradCAM(++): {cam_attr_weight[:5]}")
|
||||
num_subplots = 5 if eval_attn_importance else 4
|
||||
fig, axs = plt.subplots(num_subplots, 4)
|
||||
for plt_i in range(4):
|
||||
axs[0][plt_i].imshow(denormalize(xs[plt_i]).permute(1, 2, 0).cpu().numpy())
|
||||
axs[1][plt_i].imshow(fg_masks[plt_i].cpu().numpy())
|
||||
axs[2][plt_i].imshow(attr_probs[plt_i].cpu().numpy())
|
||||
axs[3][plt_i].imshow(cam_attr[plt_i].cpu().numpy())
|
||||
if eval_attn_importance:
|
||||
axs[4][plt_i].imshow(attn_attr[plt_i].cpu().numpy())
|
||||
plt.show()
|
||||
|
||||
iterator_desc = (
|
||||
f"IG weights: {rel_ig_weights / (iterator.n + 1):.4f}, GradCAM weights:"
|
||||
f" {rel_cam_weights['GradCAM'] / (iterator.n + 1):.4f}, GradCAM++ weights:"
|
||||
f" {rel_cam_weights['GradCAM++'] / (iterator.n + 1):.4f}"
|
||||
)
|
||||
if eval_attn_importance:
|
||||
iterator_desc += f", Attn weights: {rel_attn_weights / (iterator.n + 1):.4f}"
|
||||
|
||||
iterator.set_description(iterator_desc)
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
|
||||
val_end = time()
|
||||
rel_ig_weights /= len(iterator)
|
||||
rel_grad_cam = rel_cam_weights["GradCAM"] / len(iterator)
|
||||
rel_grad_cam_pp = rel_cam_weights["GradCAM++"] / len(iterator)
|
||||
rel_attn_weights /= len(iterator)
|
||||
|
||||
if dali_server:
|
||||
dali_server.stop_thread()
|
||||
|
||||
if args.distributed:
|
||||
gather_tensor = torch.Tensor([rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights]).to(device)
|
||||
dist.barrier()
|
||||
dist.all_reduce(gather_tensor)
|
||||
gather_tensor = (gather_tensor / world_size).tolist()
|
||||
rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights = gather_tensor
|
||||
|
||||
if rank == 0:
|
||||
output_text = (
|
||||
f"epoch {epoch}: eval_{args.val_dataset}/rel_ig_weights={rel_ig_weights},"
|
||||
f" eval_{args.val_dataset}/rel_grad_cam={rel_grad_cam},"
|
||||
f" eval_{args.val_dataset}/rel_grad_cam_pp={rel_grad_cam_pp}"
|
||||
)
|
||||
if eval_attn_importance:
|
||||
output_text += f", eval_{args.val_dataset}/rel_attn_weights={rel_attn_weights}"
|
||||
output_text += f", eval_{args.val_dataset}/attribution_eval_time={val_end - val_start}s"
|
||||
logger.info(output_text)
|
||||
if wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb_data = {
|
||||
f"eval_{args.val_dataset}/importance_ig": rel_ig_weights,
|
||||
f"eval_{args.val_dataset}/importance_grad_cam": rel_grad_cam,
|
||||
f"eval_{args.val_dataset}/importance_grad_cam_pp": rel_grad_cam_pp,
|
||||
}
|
||||
if eval_attn_importance:
|
||||
wandb_data[f"eval_{args.val_dataset}/importance_attn"] = rel_attn_weights
|
||||
wandb.log(wandb_data)
|
||||
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
@@ -0,0 +1,211 @@
|
||||
import argparse
|
||||
import os
|
||||
from functools import partial
|
||||
from math import log
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, RandomCrop, Resize, ToTensor
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
try:
|
||||
from models import load_pretrained
|
||||
except ModuleNotFoundError:
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from models import load_pretrained
|
||||
|
||||
from utils import prep_kwargs
|
||||
|
||||
|
||||
def score_5( # noqa: D103
|
||||
idx,
|
||||
bg_probs,
|
||||
mean_probs,
|
||||
fg_ratio,
|
||||
max_idx,
|
||||
fg_ratio_max=0.9,
|
||||
fg_ratio_min=0.002,
|
||||
fg_ratio_exp=0.4, # learned: 0.487
|
||||
idx_exp=0.01, # learned: 0.043
|
||||
bg_probs_exp=0.2, # learned: 0.24
|
||||
opt_fg_ratio=0.1, # learned: 0.1
|
||||
mean_probs_exp=0.2, # learned: 0.2
|
||||
fg_ratio_penalty=1, # learned: -0.446 ???
|
||||
):
|
||||
return (
|
||||
log(mean_probs) * mean_probs_exp
|
||||
+ log(1 - bg_probs) * bg_probs_exp
|
||||
+ log(1 - abs(fg_ratio - opt_fg_ratio)) * fg_ratio_exp
|
||||
+ log(1 - idx / (max_idx + 1)) * idx_exp
|
||||
+ (fg_ratio_min < fg_ratio < fg_ratio_max) * fg_ratio_penalty
|
||||
) # TODO: KEEP IT LIKE IT IS
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Inspect image versions")
|
||||
parser.add_argument("-f", "--base_folder", type=str, required=True, help="Base folder to inspect")
|
||||
parser.add_argument(
|
||||
"-batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for model inspection. Will be 1 background and batch_size - 1 foregrounds",
|
||||
)
|
||||
parser.add_argument("-imsize", type=int, default=224, help="Image size")
|
||||
parser.add_argument(
|
||||
"-score_f_weights", choices=["manual", "automatic"], default="manual", help="Score function hyperparameters"
|
||||
)
|
||||
parser.add_argument("-auto_fg_pen_val", type=float, default=-0.446, help="Automatic foreground penalty value")
|
||||
parser.add_argument("-d", "--dataset", choices=["tinyimagenet", "imagenet"], default="tinyimagenet", help="Dataset")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
score_f = (
|
||||
score_5
|
||||
if args.score_f_weights == "manual"
|
||||
else partial(
|
||||
score_5, **dict(fg_ratio_exp=0.487, idx_exp=0.043, bg_probs_exp=0.24, fg_ratio_penalty=args.auto_fg_pen_val)
|
||||
)
|
||||
)
|
||||
|
||||
bg_folder = os.path.join(args.base_folder, "backgrounds")
|
||||
fg_folder = os.path.join(args.base_folder, "foregrounds")
|
||||
|
||||
classes = os.listdir(fg_folder)
|
||||
classes = sorted(classes, key=lambda x: int(x[1:]))
|
||||
assert len(classes) in [200, 1_000], f"Expected 200 or 1_000 classes, got {len(classes)}"
|
||||
|
||||
total_images = set()
|
||||
for in_cls in classes:
|
||||
cls_images = {
|
||||
os.path.join(in_cls, "_".join(img.split(".")[0].split("_")[:-1]))
|
||||
for img in os.listdir(os.path.join(fg_folder, in_cls))
|
||||
if img.split(".")[0].split("_")[-1].startswith("v")
|
||||
}
|
||||
total_images.update(cls_images)
|
||||
total_images = list(total_images)
|
||||
|
||||
# base_folder = os.path.join(*(os.path.dirname(__file__).split("/")[:-1]))
|
||||
|
||||
# in_cls to print name/lemma
|
||||
with open(os.path.join("data", "misc_dataset_files", "tinyimagenet_synset_names.txt"), "r") as f:
|
||||
in_cls_to_name = {line.split(":")[0].strip(): line.split(":")[1].strip() for line in f.readlines() if len(line) > 2}
|
||||
|
||||
if args.dataset == "tinyimagenet":
|
||||
inspection_model_paths = [] # PATHS TO MODEL WEIGHTS (.pt) PRETRAINED ON TinyImageNet
|
||||
elif args.dataset == "imagenet":
|
||||
inspection_model_paths = [] # PATHS TO MODEL WEIGHTS (.pt) PRETRAINED ON ImageNet
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {args.dataset}")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
inspection_models = [
|
||||
load_pretrained(path, prep_kwargs({}), new_dataset_params=False)[0].to(device) for path in inspection_model_paths
|
||||
]
|
||||
img_transform = Compose([Resize((args.imsize, args.imsize)), RandomCrop(args.imsize), ToTensor()])
|
||||
|
||||
total_versions = []
|
||||
|
||||
for img_name in tqdm(total_images, desc="Image version computation"):
|
||||
in_cls, img_name = img_name.split("/")
|
||||
versions = set()
|
||||
for img in os.listdir(os.path.join(fg_folder, in_cls)):
|
||||
if "_".join(img.split("_")[: len(img_name.split("_"))]) == img_name:
|
||||
versions.add(img)
|
||||
if len(versions) == 1:
|
||||
version = list(versions)[0]
|
||||
if version.split(".")[0].split("_")[-1].startswith("v"):
|
||||
tqdm.write(f"renaming single version image {version} to {img_name}.WEBP")
|
||||
os.rename(os.path.join(fg_folder, in_cls, version), os.path.join(fg_folder, in_cls, f"{img_name}.WEBP"))
|
||||
os.rename(
|
||||
os.path.join(bg_folder, in_cls, version.replace(".WEBP", ".JPEG")),
|
||||
os.path.join(bg_folder, in_cls, f"{img_name}.JPEG"),
|
||||
)
|
||||
continue
|
||||
elif len(versions) == 0:
|
||||
tqdm.write(f"Image {img_name} has no versions")
|
||||
continue
|
||||
versions = sorted(list(versions))
|
||||
assert all(
|
||||
[version.split(".")[0].split("_")[-1].startswith("v") for version in versions]
|
||||
), f"Weird Versions: {versions} for image {img_name}"
|
||||
assert len(versions) <= 3, f"Too many versions for image {img_name}: {versions}"
|
||||
|
||||
version_scores = []
|
||||
for v_idx, version in enumerate(versions):
|
||||
img = Image.open(os.path.join(fg_folder, in_cls, version))
|
||||
bg_img = Image.open(os.path.join(bg_folder, in_cls, f"{version.split('.')[0]}.JPEG"))
|
||||
img_mask = np.array(img.convert("RGBA").split()[-1])
|
||||
|
||||
fg_ratio = np.sum(img_mask) / (255 * bg_img.size[0] * bg_img.size[1])
|
||||
|
||||
fg_size = img.size
|
||||
monochrome_backgrounds = [
|
||||
Image.new(
|
||||
"RGB",
|
||||
(max(args.imsize, fg_size[0]), max(args.imsize, fg_size[1])),
|
||||
(255 * i // (args.batch_size - 2), 255 * i // (args.batch_size - 2), 255 * i // (args.batch_size - 2)),
|
||||
)
|
||||
for i in range(args.batch_size - 1)
|
||||
]
|
||||
pasting_error = False
|
||||
for mc_bg in monochrome_backgrounds:
|
||||
try:
|
||||
mc_bg.paste(img, ((args.imsize - fg_size[0]) // 2, (args.imsize - fg_size[1]) // 2), img)
|
||||
except ValueError as e:
|
||||
tqdm.write(f"Image {img_name} could not be pasted into background: {e}")
|
||||
pasting_error = True
|
||||
break
|
||||
|
||||
inp_batch = torch.stack(
|
||||
[img_transform(bg_img)] + [img_transform(mc_bg) for mc_bg in monochrome_backgrounds], dim=0
|
||||
).to(device)
|
||||
|
||||
cls_idx = classes.index(in_cls)
|
||||
bg_probs = []
|
||||
mean_probs = []
|
||||
for model in inspection_models:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
out_probs = model(inp_batch).softmax(dim=-1)[:, cls_idx].cpu().numpy()
|
||||
bg_probs.append(out_probs[0])
|
||||
mean_probs.append(np.mean(out_probs[1:]))
|
||||
|
||||
# average the lists
|
||||
bg_probs = np.mean(bg_probs)
|
||||
mean_probs = np.mean(mean_probs)
|
||||
|
||||
version_score = (
|
||||
score_f(
|
||||
idx=v_idx,
|
||||
bg_probs=float(bg_probs),
|
||||
mean_probs=float(mean_probs),
|
||||
fg_ratio=float(fg_ratio),
|
||||
max_idx=len(versions) - 1,
|
||||
)
|
||||
if not pasting_error
|
||||
else -100
|
||||
)
|
||||
version_scores.append(version_score)
|
||||
|
||||
assert len(versions) == len(version_scores), f"Expected {len(versions)} scores, got {len(version_scores)}"
|
||||
|
||||
if max(version_scores) > min(version_scores):
|
||||
# find best version
|
||||
best_version_idx = int(np.argmax(version_scores))
|
||||
best_version = versions[best_version_idx]
|
||||
|
||||
# delete all other versions
|
||||
for version in versions:
|
||||
if version != best_version:
|
||||
os.remove(os.path.join(fg_folder, in_cls, version))
|
||||
os.remove(os.path.join(bg_folder, in_cls, f"{version.split('.')[0]}.JPEG"))
|
||||
# remove version tag in name
|
||||
new_version_name = "_".join(best_version.split("_")[:-1]) + "." + best_version.split(".")[-1]
|
||||
os.rename(os.path.join(fg_folder, in_cls, best_version), os.path.join(fg_folder, in_cls, new_version_name))
|
||||
os.rename(
|
||||
os.path.join(bg_folder, in_cls, f"{best_version.split('.')[0]}.JPEG"),
|
||||
os.path.join(bg_folder, in_cls, f"{new_version_name.split('.')[0]}.JPEG"),
|
||||
)
|
||||
else:
|
||||
tqdm.write(f"All versions have the same score for image {img_name}")
|
||||
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
srun -K \
|
||||
--container-image=PATH/TO/SLURM/IMAGE \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/ALL/IMPORTANT/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
--partition=RTXA6000,RTX3090,A100-40GB,A100-80GB,H100,H200 \
|
||||
--job-name="python" \
|
||||
--nodes=1 \
|
||||
--gpus=1 \
|
||||
--ntasks=1 \
|
||||
--cpus-per-task=24 \
|
||||
--mem=64G \
|
||||
--time=1-0 \
|
||||
--export="NLTK_DATA=/PATH/TO/NLTK_DATA" \
|
||||
python3 "$@"
|
||||
346
AAAI Supplementary Material/Model Training Code/load_dataset.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Module to load the datasets, using torch and datadings."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import torchvision.transforms as tv_transforms
|
||||
from datadings.reader import MsgpackReader
|
||||
from timm.data import create_transform
|
||||
from torch.utils.data import DataLoader, DistributedSampler, WeightedRandomSampler
|
||||
from torchvision.datasets import (
|
||||
CIFAR10,
|
||||
CIFAR100,
|
||||
FGVCAircraft,
|
||||
Flowers102,
|
||||
Food101,
|
||||
ImageFolder,
|
||||
OxfordIIITPet,
|
||||
StanfordCars,
|
||||
)
|
||||
|
||||
from data.counter_animal import CounterAnimal
|
||||
from data.data_utils import (
|
||||
DDDecodeDataset,
|
||||
ToOneHotSequence,
|
||||
collate_imnet,
|
||||
collate_listops,
|
||||
get_hf_transform,
|
||||
minimal_augment,
|
||||
segment_augment,
|
||||
three_augment,
|
||||
)
|
||||
from data.fornet import ForNet
|
||||
from data.samplers import RASampler
|
||||
from paths_config import ds_path
|
||||
|
||||
|
||||
def prepare_dataset(dataset_name, args, transform=None, train=True, rank=None):
|
||||
"""Load a dataset from disk, different formats are used for different datasets.
|
||||
|
||||
Supported datasets: CIFAR10, ImageNet, ImageNet21k
|
||||
|
||||
Args:
|
||||
dataset_name (str): name of the dataset
|
||||
args: further arguments
|
||||
transform (list[Module] | str, optional): transformations to use on the data; the list gets composed, or give args.augment_strategy (Default value = None)
|
||||
train (bool, optional): use the training split (or test/validation split) (Default value = True)
|
||||
rank (int, optional): global rank of this process in distributed training (Default value = None)
|
||||
|
||||
Returns:
|
||||
DataLoader: data loader for the dataset
|
||||
int: number of classes in the dataset
|
||||
int: ignore index for the dataset
|
||||
bool: whether the dataset is multi-label
|
||||
|
||||
"""
|
||||
compose = tv_transforms.Compose
|
||||
dali_server = None
|
||||
if transform is None:
|
||||
if args.augment_engine == "torchvision":
|
||||
if args.augment_strategy == "3-augment":
|
||||
transform = three_augment(args, as_list=False, test=not train)
|
||||
elif args.augment_strategy == "differentiable-transform":
|
||||
from data.distilled_dataset import differentiable_augment
|
||||
|
||||
transform = differentiable_augment(args, as_list=False, test=not train)
|
||||
elif args.augment_strategy == "none":
|
||||
transform = []
|
||||
elif args.augment_strategy == "lm_one_hot":
|
||||
transform = [
|
||||
tv_transforms.Grayscale(num_output_channels=1),
|
||||
tv_transforms.ToTensor(),
|
||||
ToOneHotSequence(),
|
||||
]
|
||||
elif args.augment_strategy == "segment-augment":
|
||||
transform = segment_augment(args, test=not train)
|
||||
elif args.augment_strategy == "minimal":
|
||||
transform = minimal_augment(args, test=not train)
|
||||
elif args.augment_strategy == "deit":
|
||||
if train:
|
||||
transform = create_transform(
|
||||
input_size=args.imsize,
|
||||
is_training=True,
|
||||
color_jitter=args.aug_color_jitter_factor,
|
||||
auto_augment=args.auto_augment_strategy,
|
||||
interpolation="bicubic",
|
||||
re_prob=args.aug_random_erase_prob,
|
||||
re_mode=args.aug_random_erase_mode,
|
||||
re_count=args.aug_random_erase_count,
|
||||
)
|
||||
else:
|
||||
transform = three_augment(args, test=True) # only do resize, centercrop, and normalize
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
||||
)
|
||||
elif args.augment_engine == "albumentations":
|
||||
from data import album_transf as ATf
|
||||
|
||||
compose = ATf.AlbumTorchCompose
|
||||
|
||||
if args.augment_strategy == "3-augment":
|
||||
transform = ATf.three_augment(args, as_list=False, test=not train)
|
||||
elif args.augment_strategy == "minimal":
|
||||
transform = ATf.minimal_augment(args, test=not train)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
||||
)
|
||||
elif args.augment_engine == "dali":
|
||||
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
|
||||
|
||||
from data import dali_transf as DTf
|
||||
|
||||
dev_id = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if args.augment_strategy == "3-augment":
|
||||
pipe = DTf.three_augment(
|
||||
args,
|
||||
test=not train,
|
||||
batch_size=args.batch_size,
|
||||
num_threads=args.num_workers,
|
||||
device_id=dev_id,
|
||||
)
|
||||
elif args.augment_strategy == "minimal":
|
||||
pipe = DTf.minimal_augment(
|
||||
args,
|
||||
test=not train,
|
||||
batch_size=args.batch_size,
|
||||
num_threads=args.num_workers,
|
||||
device_id=dev_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
||||
)
|
||||
|
||||
dali_server = dali_proxy.DALIServer(pipe)
|
||||
transform = dali_server.proxy
|
||||
|
||||
dataset_name_case_sensitive = dataset_name # keep the original name for AnimalNet folder
|
||||
dataset_name = dataset_name.lower()
|
||||
ignore_index = -100
|
||||
multi_label = False
|
||||
|
||||
if isinstance(transform, list):
|
||||
transform = compose(transform)
|
||||
|
||||
if dataset_name == "cifar10":
|
||||
dataset = CIFAR10(root=ds_path("cifar"), train=train, download=False, transform=transform)
|
||||
n_classes, collate = 10, None
|
||||
|
||||
elif dataset_name == "stanford-cars":
|
||||
dataset = StanfordCars(
|
||||
root=ds_path("stanford_cars"),
|
||||
split="train" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 196, None
|
||||
|
||||
elif dataset_name == "oxford-pet":
|
||||
dataset = OxfordIIITPet(
|
||||
root=ds_path("oxford_pet"),
|
||||
split="trainval" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 37, None
|
||||
|
||||
elif dataset_name == "flowers102":
|
||||
dataset = Flowers102(
|
||||
root=ds_path("flowers102"),
|
||||
split="train" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 102, None
|
||||
|
||||
elif dataset_name == "food-101":
|
||||
dataset = Food101(
|
||||
root=ds_path("food101"),
|
||||
split="train" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 101, None
|
||||
|
||||
elif dataset_name == "fgvc-aircraft":
|
||||
dataset = FGVCAircraft(
|
||||
root=ds_path("aircraft"),
|
||||
split="train" if train else "test",
|
||||
annotation_level="variant",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 100, None
|
||||
|
||||
elif dataset_name == "imagenet":
|
||||
dataset = ImageFolder(os.path.join(ds_path("imagenet1k"), "train" if train else "val"), transform=transform)
|
||||
n_classes, collate = 1000, None
|
||||
|
||||
elif dataset_name == "tinyimagenet":
|
||||
dataset = ImageFolder(os.path.join(ds_path("tinyimagenet"), "train" if train else "val"), transform=transform)
|
||||
n_classes, collate = 200, None
|
||||
|
||||
elif dataset_name.startswith("fornet"):
|
||||
ds_def = dataset_name.split("/")
|
||||
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
|
||||
pruning_ratio = 0.8 if len(ds_def) < 3 else float(ds_def[2])
|
||||
fg_size_mode = ("range" if train else "max") if len(ds_def) < 4 else ds_def[3]
|
||||
paste_pre_transform = True if len(ds_def) < 5 else ds_def[4] in ["y", "t"]
|
||||
orig_img_prob = (
|
||||
0.0 if len(ds_def) < 6 else (ds_def[5] if ds_def[5] in ["linear", "revlinear", "cos"] else float(ds_def[5]))
|
||||
)
|
||||
mask_smoothing_sigma = 0.0 if len(ds_def) < 7 else float(ds_def[6])
|
||||
assert len(ds_def) < 5 or ds_def[4] in [
|
||||
"y",
|
||||
"t",
|
||||
"n",
|
||||
"f",
|
||||
], f"Invalid dataset definition: {ds_def[4]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
|
||||
|
||||
orig_ds = ds_path("imagenet1k")
|
||||
|
||||
dataset = ForNet(
|
||||
ds_path("fornet"),
|
||||
train=train,
|
||||
background_combination=comb_scheme,
|
||||
pruning_ratio=pruning_ratio,
|
||||
transform=transform,
|
||||
fg_transform=(
|
||||
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
|
||||
),
|
||||
fg_size_mode=fg_size_mode,
|
||||
paste_pre_transform=paste_pre_transform,
|
||||
orig_img_prob=orig_img_prob,
|
||||
orig_ds=orig_ds,
|
||||
mask_smoothing_sigma=mask_smoothing_sigma,
|
||||
epochs=args.epochs,
|
||||
_album_compose=args.augment_engine == "albumentations",
|
||||
)
|
||||
n_classes, collate = 1000, None
|
||||
|
||||
elif dataset_name.startswith("tinyfornet"):
|
||||
ds_def = dataset_name.split("/")
|
||||
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
|
||||
pruning_ratio = 1.1 if len(ds_def) < 3 else float(ds_def[2])
|
||||
fg_size_mode = "range" if len(ds_def) < 4 else ds_def[3]
|
||||
fg_bates_n = 1 if len(ds_def) < 5 else int(ds_def[4])
|
||||
paste_pre_transform = False if len(ds_def) < 6 else ds_def[5] in ["y", "t"]
|
||||
orig_img_prob = (
|
||||
0.0 if len(ds_def) < 7 else (ds_def[6] if ds_def[6] in ["linear", "revlinear", "cos"] else float(ds_def[6]))
|
||||
)
|
||||
mask_smoothing_sigma = 0.0 if len(ds_def) < 8 else float(ds_def[7])
|
||||
assert len(ds_def) < 6 or ds_def[5] in [
|
||||
"y",
|
||||
"t",
|
||||
"n",
|
||||
"f",
|
||||
], f"Invalid dataset definition: {ds_def[5]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
|
||||
assert "-" not in ds_def[0] or len(ds_def[0].split("-")) == 3, f"Invalid dataset definition: {ds_def[0]}"
|
||||
version = "" if "-" not in ds_def[0] else f"_v{ds_def[0].split('-')[1]}_f{ds_def[0].split('-')[2]}"
|
||||
|
||||
orig_ds = ds_path("tinyimagenet")
|
||||
|
||||
dataset = ForNet(
|
||||
f"{ds_path('tinyimagenet')}{version}",
|
||||
train=train,
|
||||
background_combination=comb_scheme,
|
||||
pruning_ratio=pruning_ratio,
|
||||
transform=transform,
|
||||
fg_transform=(
|
||||
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
|
||||
),
|
||||
fg_size_mode=fg_size_mode,
|
||||
fg_bates_n=fg_bates_n,
|
||||
paste_pre_transform=paste_pre_transform,
|
||||
orig_img_prob=orig_img_prob,
|
||||
orig_ds=orig_ds,
|
||||
mask_smoothing_sigma=mask_smoothing_sigma,
|
||||
epochs=args.epochs,
|
||||
_album_compose=args.augment_engine == "albumentations",
|
||||
)
|
||||
n_classes, collate = 200, None
|
||||
|
||||
elif dataset_name.startswith("counteranimal/"):
|
||||
mode = dataset_name.split("/")[1]
|
||||
|
||||
dataset = CounterAnimal(ds_path("counteranimal"), mode=mode, transform=transform, train=train)
|
||||
n_classes, collate = 1000, None
|
||||
|
||||
elif dataset_name.startswith("imagenet9/"):
|
||||
variant = dataset_name.split("/")[1]
|
||||
assert variant in [
|
||||
"next",
|
||||
"same",
|
||||
"rand",
|
||||
], f"ImageNet-9 has possible variants next, same, and rand, but not '{variant}'."
|
||||
|
||||
dataset = ImageFolder(os.path.join(ds_path("imagenet9"), f"mixed_{variant}", "val"), transform=transform)
|
||||
n_classes, collate = 9, None
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Dataset {dataset_name} is not implemented (yet).")
|
||||
|
||||
if args.aug_repeated_augment_repeats > 1 and train:
|
||||
# use repeated augment sampler from DeiT
|
||||
sampler = RASampler(
|
||||
dataset,
|
||||
num_replicas=args.world_size,
|
||||
rank=rank,
|
||||
shuffle=args.shuffle,
|
||||
num_repeats=args.aug_repeated_augment_repeats,
|
||||
)
|
||||
elif args.weighted_sampler:
|
||||
assert hasattr(
|
||||
dataset, "per_sample_weights"
|
||||
), f"Dataset {type(dataset)} should implement per_sample_weights function, but does not."
|
||||
|
||||
sampler = WeightedRandomSampler(dataset.per_sample_weights(), num_samples=len(dataset) // args.world_size)
|
||||
elif args.distributed:
|
||||
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank, shuffle=train and args.shuffle)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
loader_batch_size = 1 if dataset_name.startswith("listops") else args.batch_size
|
||||
|
||||
loader_kwargs = dict(
|
||||
batch_size=loader_batch_size,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
drop_last=train,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
persistent_workers=False,
|
||||
collate_fn=collate,
|
||||
shuffle=None if sampler else train and args.shuffle,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
if args.augment_engine == "dali":
|
||||
data_loader = dali_proxy.DataLoader(dali_server, dataset, **loader_kwargs)
|
||||
else:
|
||||
data_loader = DataLoader(dataset, **loader_kwargs)
|
||||
|
||||
return data_loader, n_classes, ignore_index, multi_label, dali_server
|
||||
679
AAAI Supplementary Material/Model Training Code/main.py
Normal file
@@ -0,0 +1,679 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Parse args and call the correct script inside slurm container.
|
||||
|
||||
Outside the container, on the head-node, create and call the correct srun command.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
from config import default_kwargs, slurm_defaults
|
||||
from paths_config import results_folder, slurm_output_folder
|
||||
|
||||
_EXPNAMES = ["EfficientCVBench", "test", "recombine_imagenet"]
|
||||
|
||||
|
||||
def base_parser():
|
||||
"""Create the argument parser with all the choices for the training / evaluation scripts."""
|
||||
parser = argparse.ArgumentParser("Transformer training and evaluation.")
|
||||
|
||||
# Main
|
||||
group = parser.add_argument_group("Main")
|
||||
group.add_argument(
|
||||
"-t",
|
||||
"--task",
|
||||
nargs="?",
|
||||
choices=[
|
||||
"pre-train",
|
||||
"fine-tune",
|
||||
"fine-tune-head",
|
||||
"eval",
|
||||
"parser-test",
|
||||
"eval-metrics",
|
||||
"eval-attr",
|
||||
"continue",
|
||||
"eval-center-bias",
|
||||
"eval-size-bias",
|
||||
"load-images",
|
||||
"save-images",
|
||||
],
|
||||
required=True,
|
||||
help="Task to perform.",
|
||||
)
|
||||
group.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
nargs="?",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model to use. Either model name for a new model or weights and dicts to load for fine-tuning.",
|
||||
)
|
||||
group.add_argument("-ds", "--dataset", nargs="?", type=str, help="Dataset to train on.")
|
||||
group.add_argument(
|
||||
"-valds", "--val-dataset", nargs="?", type=str, help="Validation dataset. Defaults to same as training."
|
||||
)
|
||||
group.add_argument("-ep", "--epochs", nargs="?", type=int, help="Number of epochs to train.")
|
||||
group.add_argument(
|
||||
"-run",
|
||||
"--run-name",
|
||||
nargs="?",
|
||||
type=str,
|
||||
help="A name for the run. If not give, the model name is used instead.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--defaults", nargs="?", choices=["DeiT", "DeiTIII"], default="DeiTIII", help="Default settings to use."
|
||||
)
|
||||
|
||||
# Further model parameters
|
||||
group = parser.add_argument_group("Further model parameters")
|
||||
group.add_argument("--drop-path-rate", nargs="?", type=float, help="Drop path rate for ViT models.")
|
||||
group.add_argument("--layer-scale-init-values", nargs="?", type=float, help="LayerScale initial values.")
|
||||
group.add_argument("--layer-scale", action=argparse.BooleanOptionalAction, help="Use layer scale?")
|
||||
group.add_argument(
|
||||
"--qkv-bias",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use bias in linear transformation to queries, keys, and values?",
|
||||
)
|
||||
group.add_argument("--pre-norm", action=argparse.BooleanOptionalAction, help="Use norm first architecture?")
|
||||
group.add_argument("--dropout", nargs="?", type=float, help="Model dropout.")
|
||||
group.add_argument("-heads", "--num-heads", nargs="?", type=int, help="Number of parallel attention heads.")
|
||||
group.add_argument("--input-dim", nargs="?", type=int, help="Dimensionality of text encoding.")
|
||||
group.add_argument("--max-seq-len", nargs="?", type=int, help="Maximum sequence length for text data.")
|
||||
group.add_argument(
|
||||
"--fused-attn",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use fused attention (for ViT with Timm's attention only)?",
|
||||
)
|
||||
# group.add_argument(
|
||||
# "--perf-metric", nargs="?", choices=["acc", "mIoU"], help="Performance metric to use for evaluation."
|
||||
# )
|
||||
# group.add_argument("-no_model_ema", action="store_true",
|
||||
# help="Don't use an exponential moving average for model parameters")
|
||||
# group.add_argument("-model_ema_decay", nargs='?', type=float, default=default_kwargs["model_ema_decay"],
|
||||
# help="Decay rate for exponential moving average of model parameters")
|
||||
|
||||
# Experiment management
|
||||
group = parser.add_argument_group("Experiment management")
|
||||
group.add_argument("--seed", nargs="?", type=int, help="Manual RNG seed.")
|
||||
group.add_argument(
|
||||
"-exp",
|
||||
"--experiment-name",
|
||||
nargs="?",
|
||||
choices=_EXPNAMES,
|
||||
help="Name for the experiment. Is used for grouping of runs.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--save-epochs", nargs="?", type=int, help="Number of epochs after which to save the full training state."
|
||||
)
|
||||
group.add_argument(
|
||||
"--keep-interm-states",
|
||||
nargs="?",
|
||||
type=int,
|
||||
help="Number of intermediate states to keep. All others (earlier ones) will be deleted automatically.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--custom-dataset-path", nargs="?", type=str, help="Overwrite the path to any dataset to this path."
|
||||
)
|
||||
group.add_argument(
|
||||
"--results-folder",
|
||||
nargs="?",
|
||||
default=results_folder,
|
||||
type=str,
|
||||
help="Folder to put script results (mlflow data, models, etc.).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gather-stats-during-training",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Gather training statistics from all GPUs?",
|
||||
)
|
||||
group.add_argument("--tqdm", action=argparse.BooleanOptionalAction, help="Show tqdm for every epoch?")
|
||||
group.add_argument(
|
||||
"--debug", action=argparse.BooleanOptionalAction, help="Debug mode: lots of intermediate prints."
|
||||
)
|
||||
group.add_argument("--wandb", action=argparse.BooleanOptionalAction, help="Use external logging via Wandb?")
|
||||
group.add_argument("--log-level", choices=["info", "debug"], help="Log level", metavar="LEVEL")
|
||||
group.add_argument("-out", "--out-dir", type=str, help="Output directory for additional outputs.")
|
||||
|
||||
# Speedup
|
||||
group = parser.add_argument_group("Speedup")
|
||||
group.add_argument("--amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision?")
|
||||
group.add_argument(
|
||||
"--eval-amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision during evaluation?"
|
||||
)
|
||||
group.add_argument("--compile-model", action=argparse.BooleanOptionalAction, help="Use torch.compile?")
|
||||
group.add_argument("--cuda", action=argparse.BooleanOptionalAction, help="Use cuda?")
|
||||
|
||||
# Data loading
|
||||
group = parser.add_argument_group("Data loading")
|
||||
group.add_argument("-bs", "--batch-size", nargs="?", type=int, help="Batch size over all graphics cards (togeter).")
|
||||
group.add_argument("--num-workers", nargs="?", type=int, help="Number of dataloader worker threads. Should be >0.")
|
||||
group.add_argument(
|
||||
"--pin-memory", action=argparse.BooleanOptionalAction, help="Use pin_memory of torch Dataloader?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--prefetch-factor",
|
||||
nargs="?",
|
||||
type=int,
|
||||
help="Prefetch factor for dataloader workers (how many batches to fetch)",
|
||||
)
|
||||
group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, help="Shuffle the training data?")
|
||||
group.add_argument(
|
||||
"--weighted-sampler",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use a class-weighted sampler to sample evenly from all classes (train and val)?",
|
||||
)
|
||||
group.add_argument("--ipc", type=int, help="How many images per class to load and save.")
|
||||
|
||||
# Optimizer
|
||||
group = parser.add_argument_group("Optimizer")
|
||||
group.add_argument("--opt", nargs="?", type=str, help="Optimizer to use.")
|
||||
group.add_argument("--weight-decay", nargs="?", type=float, help="Weight decay factor for use in optimizer.")
|
||||
group.add_argument("-lr", "--lr", nargs="?", type=float, help="Initial learning rate.")
|
||||
group.add_argument(
|
||||
"--max-grad-norm", nargs="?", type=float, help="Maximum norm for the gradients (used for cutoff)."
|
||||
)
|
||||
group.add_argument("--warmup-epochs", nargs="?", type=int, help="Number of epochs of linear warmup.")
|
||||
group.add_argument("--label-smoothing", nargs="?", type=float, help="Label smoothing factor.")
|
||||
group.add_argument("--loss", nargs="?", choices=["ce", "baikal"], type=str, help="Loss function to use.")
|
||||
group.add_argument(
|
||||
"--loss-weight", nargs="?", type=str, choices=["none", "linear", "log", "sqrt"], help="Per class loss weight."
|
||||
)
|
||||
group.add_argument("--sched", nargs="?", choices=["cosine", "const"], help="Learning rate schedule.")
|
||||
group.add_argument("--min-lr", nargs="?", type=float, help="Minimum learning rate to be hit by scheduler.")
|
||||
group.add_argument("--warmup-lr", nargs="?", type=float, help="Warmup learning rate.")
|
||||
group.add_argument("--warmup-sched", nargs="?", choices=["linear", "const"], help="Schedule for warmup")
|
||||
group.add_argument(
|
||||
"--opt-eps", nargs="?", type=float, help="Epsilon value added in the optimizer to stabilize training."
|
||||
)
|
||||
group.add_argument("--momentum", nargs="?", type=float, help="Optimizer momentum.")
|
||||
|
||||
# Data augmentation
|
||||
group = parser.add_argument_group("Data augmentation")
|
||||
group.add_argument("--augment-strategy", nargs="?", type=str, help="Data augmentation strategy.")
|
||||
group.add_argument("--aug-rand-rot", nargs="?", type=int, help="Random rotation limit.")
|
||||
group.add_argument(
|
||||
"--aug-flip", action=argparse.BooleanOptionalAction, help="Use data augmentation: horizontal flip?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-crop",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use data augmentation: cropping. This may break the skript?",
|
||||
)
|
||||
group.add_argument("--aug-resize", action=argparse.BooleanOptionalAction, help="Use data augmentation: resize?")
|
||||
group.add_argument(
|
||||
"--aug-grayscale", action=argparse.BooleanOptionalAction, help="Use data augmentation: grayscale?"
|
||||
)
|
||||
group.add_argument("--aug-solarize", action=argparse.BooleanOptionalAction, help="Use data augmentation: solarize?")
|
||||
group.add_argument(
|
||||
"--aug-gauss-blur", action=argparse.BooleanOptionalAction, help="Use data augmentation: gaussian blur?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-cutmix-alpha",
|
||||
type=float,
|
||||
help="Alpha value for using CutMix. CutMix is active when aug_cutmix_alpha > 0.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-mixup-alpha", type=float, help="Alpha value for using Mixup. Mixup is active when aug_mixup_alpha > 0."
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-color-jitter-factor",
|
||||
nargs="?",
|
||||
type=float,
|
||||
help="Factor to use for the data augmentation: color jitter.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-normalize", action=argparse.BooleanOptionalAction, help="Use data augmentation: Normalization?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-repeated-augment-repeats",
|
||||
type=int,
|
||||
help="Number of image repeats with repeat-augment from DeiT. 1 is not using repeat-augment.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-random-erase-prob", type=float, help="For DeiT augment: Probabiliy of RandomErase augmentation."
|
||||
)
|
||||
group.add_argument("--auto-augment-strategy", type=str, help="For DeiT augment: AutoAugment Policy to use.")
|
||||
group.add_argument("--imsize", nargs="?", type=int, help="Image size given to the model -> imsize x imsize.")
|
||||
group.add_argument(
|
||||
"--augment-engine",
|
||||
nargs="?",
|
||||
choices=["torchvision", "albumentations", "dali"],
|
||||
help="Which data augmentation engine to use.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def partition_choices():
|
||||
"""Automatically create a list of all possible slurm partitions."""
|
||||
potential = list(set([l.split(" ")[0] for l in os.popen("sinfo")])) # noqa: E741
|
||||
if len(potential) <= 2:
|
||||
return slurm_defaults["partition"]
|
||||
return [p[:-1] if "*" in p else p for p in potential if p != "PARTITION"]
|
||||
|
||||
|
||||
def slurm_parser(parser=None):
|
||||
"""Add srun arguments to the given parser.
|
||||
|
||||
Args:
|
||||
parser (argparse.ArgumentParser, optional): base parser to extend; default is parser from *base_parser*
|
||||
|
||||
Returns:
|
||||
parser (argparse.ArgumentParser): extended parser
|
||||
|
||||
"""
|
||||
if parser is None:
|
||||
parser = base_parser()
|
||||
group = parser.add_argument_group("Slurm arguments")
|
||||
group.add_argument(
|
||||
"--partition",
|
||||
nargs="*",
|
||||
default=slurm_defaults["partition"],
|
||||
choices=partition_choices(),
|
||||
help="Slurm partition to use",
|
||||
)
|
||||
group.add_argument(
|
||||
"--container-image",
|
||||
nargs="?",
|
||||
default=slurm_defaults["container_image"],
|
||||
type=str,
|
||||
help="Path to slurm container image (.sqsh)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--container-workdir",
|
||||
nargs="?",
|
||||
default=slurm_defaults["container_workdir"],
|
||||
type=str,
|
||||
help="Working directory in container",
|
||||
)
|
||||
group.add_argument(
|
||||
"--container-mounts",
|
||||
nargs="?",
|
||||
default=slurm_defaults["container_mounts"],
|
||||
type=str,
|
||||
help="All slurm mounts separated by ','.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--job-name",
|
||||
nargs="?",
|
||||
default=slurm_defaults["job_name"],
|
||||
type=str,
|
||||
help="Slurm job name. Will default to '<model> <task> <dataset>'.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--nodes", nargs="?", default=slurm_defaults["nodes"], type=int, help="Number of cluster nodes to use."
|
||||
)
|
||||
group.add_argument(
|
||||
"--ntasks", nargs="?", default=slurm_defaults["ntasks"], type=int, help="Number of GPUs to use for the job."
|
||||
)
|
||||
group.add_argument("--gpus", action=argparse.BooleanOptionalAction, default=True, help="Use gpus for this job?")
|
||||
group.add_argument(
|
||||
"-cpus",
|
||||
"--cpus-per-task",
|
||||
"--cpus-per-gpu",
|
||||
nargs="?",
|
||||
default=slurm_defaults["cpus_per_task"],
|
||||
type=int,
|
||||
help="Number of CPUs per task/GPU.",
|
||||
)
|
||||
group.add_argument(
|
||||
"-mem",
|
||||
"--mem-per-gpu",
|
||||
"--mem-per-task",
|
||||
nargs="?",
|
||||
default=slurm_defaults["mem_per_gpu"],
|
||||
type=int,
|
||||
help="Ram per GPU (in Gb) to use. Will be given as total mem in srun command.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--task-prolog",
|
||||
nargs="?",
|
||||
default=slurm_defaults["task_prolog"],
|
||||
type=str,
|
||||
help="Shell script for task prolog (installing packages, etc.).",
|
||||
)
|
||||
group.add_argument("--time", nargs="?", default=slurm_defaults["time"], type=str, help="Slurm time limit.")
|
||||
group.add_argument(
|
||||
"--export",
|
||||
nargs="?",
|
||||
default=slurm_defaults["export"],
|
||||
type=str,
|
||||
help="Additional environment variables to export.",
|
||||
)
|
||||
group.add_argument("--exclude", nargs="?", default=slurm_defaults["exclude"], type=str, help="Nodes to exclude.")
|
||||
group.add_argument(
|
||||
"--after-job", nargs="?", default=slurm_defaults["after_job"], type=int, help="Job ID to wait for."
|
||||
)
|
||||
group.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Run using srun instead of sbatch. This will print the output into the terminal, not the slurm output file."
|
||||
" The logfile will still be created as usual."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Run locally")
|
||||
group.add_argument("--local", action="store_true", help="Run locally; not in slurm", default=False)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args(args=None, parser=None):
|
||||
"""Parse args from *base_parser* and insert defaults.
|
||||
|
||||
Args:
|
||||
args: (Default value = None)
|
||||
parser: (Default value = None)
|
||||
|
||||
Returns:
|
||||
dict: parsed arguments
|
||||
|
||||
"""
|
||||
if args is None:
|
||||
parser = base_parser()
|
||||
args = parser.parse_args()
|
||||
args = dict(vars(args))
|
||||
|
||||
check_arg_completeness(args, parser)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def check_arg_completeness(args, parser):
|
||||
"""Check completeness of arguments.
|
||||
|
||||
Args:
|
||||
args (dict): arguments to check
|
||||
parser (argparse.ArgumentParser): for raising the parser error
|
||||
Note:
|
||||
will raise a parser error if the arguments are not complete.
|
||||
"""
|
||||
if args["task"] in ["pre-train", "fine-tune", "fine-tune-head"]:
|
||||
if "run_name" not in args or args["run_name"] is None or len(args["run_name"]) == 0:
|
||||
parser.error(f"-run_name is required for task {args['task']}")
|
||||
|
||||
if "experiment_name" not in args or args["experiment_name"] is None or len(args["experiment_name"]) == 0:
|
||||
parser.error(f"-experiment_name is required for task {args['task']}. Choose from {_EXPNAMES}")
|
||||
|
||||
if "epochs" not in args or args["epochs"] is None:
|
||||
parser.error(f"-epochs is required for task {args['task']}")
|
||||
|
||||
if ("dataset" not in args or args["dataset"] is None) and args["task"] in [
|
||||
"pre-train",
|
||||
"fine-tune",
|
||||
"fine-tune-head",
|
||||
"eval-metrics",
|
||||
]:
|
||||
parser.error(f"-dataset is required for task {args['task']}")
|
||||
|
||||
if (
|
||||
("val_dataset" not in args or args["val_dataset"] is None)
|
||||
and ("dataset" not in args or args["dataset"] is None)
|
||||
and args["task"] in ["eval"]
|
||||
):
|
||||
parser.error(f"-dataset or -val_dataset is required for task {args['task']}")
|
||||
|
||||
if args["aug_repeated_augment_repeats"] is not None and args["aug_repeated_augment_repeats"] < 1:
|
||||
parser.error(
|
||||
"number of repeats for repeated augment has to be >= 1, but got -aug_repeated_augment_repeats ="
|
||||
f" {args['aug_repeated_augment_repeats']}"
|
||||
)
|
||||
|
||||
if args["task"] == "save-images" and ("out_dir" not in args or args["out_dir"] is None):
|
||||
parser.error("Need to set save directory (--out-dir) to save the images in.")
|
||||
|
||||
|
||||
def inside_slurm():
|
||||
"""Test for being inside a slurm container.
|
||||
|
||||
Works by testing for environment variable 'RANK'.
|
||||
"""
|
||||
return "RANK" in os.environ
|
||||
|
||||
|
||||
# TODO: fix ./runscript.tmp: 18: Syntax error: Unterminated quoted string
|
||||
def create_runscript(args, file_name=None):
|
||||
"""Create a run script for a distributed training job using SLURM.
|
||||
|
||||
Args:
|
||||
args (dict): A dictionary containing various arguments for the job, including parameters for SLURM and for training.
|
||||
file_name (str, optional, optional): The name of the file to create. Defaults to "runscript.tmp".
|
||||
|
||||
Returns:
|
||||
str: The name of the created file.
|
||||
str: Additional command line arguments for sbatch.
|
||||
|
||||
Example:
|
||||
>>> args = {"model": "vit_large_patch16_384", "task": "pre-train", "batch_size": 256, ...}
|
||||
|
||||
>>> file_name = "my_run_script.sh"
|
||||
|
||||
>>> create_runscript(args, file_name)
|
||||
|
||||
"""
|
||||
for key, val in slurm_defaults.items():
|
||||
if key not in args and val is not None:
|
||||
args[key] = val
|
||||
|
||||
if "run_name" not in args or args["run_name"] is None:
|
||||
model_str = args["model"]
|
||||
if model_str.endswith(".pt"):
|
||||
model_str = os.path.dirname(model_str)
|
||||
run_name = args["task"] + " " + model_str.split(os.sep)[-1].split("_")[0]
|
||||
else:
|
||||
run_name = args["run_name"]
|
||||
job_name = run_name.replace(" ", "_").replace("/", "_").replace(">", "_").replace("<", "_")
|
||||
if file_name is None:
|
||||
file_name = (
|
||||
f"experiments/sbatch/run_{args['task']}_{job_name}_at_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.sbatch"
|
||||
)
|
||||
|
||||
task_args = ""
|
||||
# slurm_command = "echo run distributed:\necho python3 main.py {0}\n\nsrun -K \\\n" # " --gpus-per-task=1 \\\n --gpu-bind=none \\\n"
|
||||
srun_command = "\nsrun -K \\\n"
|
||||
sbatch_commands = ( # outfile name is job name, date, job id, node name
|
||||
"#!/bin/bash\n\n#SBATCH"
|
||||
f" --output={slurm_output_folder}/%x-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-%j-%N.out\n"
|
||||
)
|
||||
sbatch_cmd_args = "" # additional command line arguments for sbatch
|
||||
python_command = " python3 main.py {0}\n"
|
||||
for key, val in args.items():
|
||||
if key == "local":
|
||||
continue
|
||||
if key == "interactive":
|
||||
continue
|
||||
if key == "gpus":
|
||||
continue
|
||||
if key in slurm_defaults:
|
||||
# it's a parameter for srun
|
||||
# slurm has - instead of _
|
||||
key = key.replace("_", "-")
|
||||
if key == "mem-per-gpu":
|
||||
# convert mem-per-gpu to mem
|
||||
# slurm_command += f" --mem={val * args['ntasks'] // args['nodes']}G \\\n" # that amount of memory is assigned on each node
|
||||
key = "mem"
|
||||
val = f"{val * args['ntasks'] // args['nodes']}G" # that amount of memory is assigned on each node
|
||||
# continue
|
||||
if key == "job-name" and val is None:
|
||||
# # default jobname is '<task> <model> <dataset>'
|
||||
# model_str = args["model"]
|
||||
# task = args["task"]
|
||||
# if task == "pre-train":
|
||||
# # it's just the model name...
|
||||
# model = model_str.split("_")[0]
|
||||
# else:
|
||||
# # it's a path to the tar file
|
||||
# if not model_str.startswith(res_folder):
|
||||
# model = "<vit model>"
|
||||
# else:
|
||||
# model = model_str[len(res_folder) :].split("_")[1].split(" ")[0]
|
||||
# if "dataset" in args and args["dataset"] is not None:
|
||||
# dataset = args["dataset"]
|
||||
# else:
|
||||
# dataset = ""
|
||||
val = run_name
|
||||
if key == "job-name" and not val.startswith('"'):
|
||||
val = f'"{val}"'
|
||||
if key in ["task-prolog", "nodes", "exclude", "after-job"] and val is None:
|
||||
continue
|
||||
if key == "task-prolog":
|
||||
srun_command += f' --{key}="{val}" \\\n'
|
||||
continue
|
||||
if key == "after-job":
|
||||
sbatch_cmd_args += f"--dependency=afterany:{val} "
|
||||
continue
|
||||
if key == "partition" and isinstance(val, list):
|
||||
val = ",".join(val)
|
||||
if key == "ntasks":
|
||||
if args["nodes"] == 1:
|
||||
gpus = val if args["gpus"] else 0
|
||||
# slurm_command += f" --gpus={val} \\\n"
|
||||
sbatch_commands += f"#SBATCH --gpus={gpus}\n"
|
||||
else:
|
||||
assert (
|
||||
val % args["nodes"] == 0
|
||||
), f"Number of tasks ({val}) must be a multiple of the number of nodes ({args['nodes']})."
|
||||
# slurm_command += f" --gpus-per-node={val // args['nodes']} \\\n"
|
||||
sbatch_commands += f"#SBATCH --gpus-per-node={val // args['nodes']}\n"
|
||||
sbatch_commands += "#SBATCH --ntasks-per-node=8\n"
|
||||
if "container" in key:
|
||||
srun_command += f" --{key}={val} \\\n"
|
||||
else:
|
||||
sbatch_commands += f"#SBATCH --{key}={val}\n"
|
||||
# slurm_command += f" --{key}={val} \\\n"
|
||||
else:
|
||||
# it's a parameter for the training
|
||||
if val is None:
|
||||
continue
|
||||
if key in ["results_folder"] and val == globals()[key]:
|
||||
continue
|
||||
key = key.replace("_", "-")
|
||||
if isinstance(val, bool):
|
||||
if val:
|
||||
task_args += f"--{key} "
|
||||
else:
|
||||
task_args += f"--no-{key} "
|
||||
continue
|
||||
if isinstance(val, str):
|
||||
task_args += f'--{key} "{val}" '
|
||||
else:
|
||||
task_args += f"--{key} {val} "
|
||||
|
||||
# slurm_command += "python3 main.py {0}\n"
|
||||
# os.umask(0) # make it possible to create an executable file
|
||||
# with open(file_name, "w+", opener=lambda pth, flgs: os.open(pth, flgs, 0o777)) as f:
|
||||
# f.write(slurm_command.format(task_args))
|
||||
with open(file_name, "w+") as f:
|
||||
f.write(sbatch_commands + srun_command + python_command.format(task_args))
|
||||
|
||||
# delete all runscripts older than a month
|
||||
n_old_files = int(os.popen("find experiments/sbatch/ -type f -mtime +30 | wc -l").read())
|
||||
if n_old_files > 0:
|
||||
print(f"Deleting {n_old_files} old runscripts.")
|
||||
os.system("find experiments/sbatch/ -type f -mtime +30 -delete")
|
||||
|
||||
return file_name, sbatch_cmd_args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not inside_slurm():
|
||||
# Make execution script and execute it
|
||||
parser = slurm_parser()
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
if not args["local"]:
|
||||
script_name, cmd_args = create_runscript(args)
|
||||
# os.system("./" + script_name) # run srun to execute this script in slurm cluster
|
||||
# -> the following lines will be executed there
|
||||
if args["interactive"]:
|
||||
os.system(f"python3 srun-sbatch.py {script_name}")
|
||||
else:
|
||||
os.system(f"sbatch {cmd_args} {script_name}") # sbatch to queue the job on the cluster
|
||||
exit(0)
|
||||
|
||||
# local execution is wanted
|
||||
for key in list(args.keys()):
|
||||
if key.replace("_", "-") in slurm_defaults:
|
||||
args.pop(key)
|
||||
args = parse_args(args, parser)
|
||||
|
||||
else:
|
||||
args = parse_args()
|
||||
|
||||
args["branch"] = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode("utf-8")
|
||||
args["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
|
||||
|
||||
if args["task"] == "pre-train":
|
||||
from train import pretrain
|
||||
|
||||
pretrain(**args)
|
||||
|
||||
elif args["task"] == "fine-tune":
|
||||
from train import finetune
|
||||
|
||||
finetune(**args)
|
||||
|
||||
elif args["task"] == "fine-tune-head":
|
||||
from train import finetune
|
||||
|
||||
finetune(**args, head_only=True)
|
||||
|
||||
elif args["task"] == "parser-test":
|
||||
from copy import copy
|
||||
|
||||
from utils import prep_kwargs, log_args
|
||||
|
||||
kwargs = prep_kwargs(copy(args))
|
||||
log_args(kwargs)
|
||||
# keys = sorted(list(args.keys()))
|
||||
# fill_len = max(len(k) for k in keys)
|
||||
# for key in keys:
|
||||
# print(f"{key + ' ' * (fill_len - len(key))} = {args[key]} -> {kwargs[key]}")
|
||||
|
||||
elif args["task"] == "eval-metrics":
|
||||
from evaluate import evaluate_metrics
|
||||
|
||||
evaluate_metrics(**args)
|
||||
|
||||
elif args["task"] == "eval":
|
||||
from evaluate import evaluate
|
||||
|
||||
evaluate(**args)
|
||||
|
||||
elif args["task"] == "eval-attr":
|
||||
from evaluate import evaluate_attributions
|
||||
|
||||
evaluate_attributions(**args)
|
||||
|
||||
elif args["task"] == "continue":
|
||||
from recover import continue_training
|
||||
|
||||
continue_training(**args)
|
||||
|
||||
elif args["task"] == "eval-center-bias":
|
||||
from evaluate import evaluate_center_bias
|
||||
|
||||
evaluate_center_bias(**args)
|
||||
|
||||
elif args["task"] == "eval-size-bias":
|
||||
from evaluate import evaluate_size_bias
|
||||
|
||||
evaluate_size_bias(**args)
|
||||
|
||||
elif args["task"] == "load-images":
|
||||
from test import load_images
|
||||
|
||||
load_images(**args)
|
||||
|
||||
elif args["task"] == "save-images":
|
||||
from test import save_images
|
||||
|
||||
save_images(**args)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Task {args['task']} is not implemented.")
|
||||
1032
AAAI Supplementary Material/Model Training Code/metrics.py
Normal file
154
AAAI Supplementary Material/Model Training Code/models.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Model loading and preparation."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
|
||||
import utils
|
||||
from architectures.vit import TimmViT
|
||||
from resizing_interface import vit_sizes
|
||||
|
||||
_ARCHITECTURES_IMPORTED = False
|
||||
|
||||
|
||||
def _import_architectures():
|
||||
global _ARCHITECTURES_IMPORTED
|
||||
if not _ARCHITECTURES_IMPORTED:
|
||||
model_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
for file in os.listdir(os.path.join(model_file_path, "architectures")):
|
||||
if not file.endswith(".py"):
|
||||
continue
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
importlib.import_module(f"architectures.{file[:-3]}")
|
||||
logger.debug(f"Imported architectures.{file[:-3]}")
|
||||
except Exception as e:
|
||||
logger.error(f"\033[93mCould not import \033[0m\033[91m{file}\033[0m")
|
||||
logger.error(e)
|
||||
_ARCHITECTURES_IMPORTED = True
|
||||
|
||||
|
||||
def prepare_model(model_str, args):
|
||||
"""Prepare a new model.
|
||||
|
||||
If the name is of the format ViT-<size>/<patch_size>, use a *TimmViT*, else fall back to timm model loading.
|
||||
|
||||
Args:
|
||||
model_str (str): model name
|
||||
args (utils.DotDict): further arguments, needs to have keys n_classes, drop_path_rate; key imsize or '_<imsize>' at the end of ViT specification
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: model
|
||||
|
||||
"""
|
||||
_import_architectures()
|
||||
|
||||
kwargs = dict(args)
|
||||
for key in list([key for key, val in kwargs.items() if val is None]):
|
||||
kwargs.pop(key)
|
||||
|
||||
if args.layer_scale_init_values:
|
||||
kwargs["init_values"] = kwargs["init_scale"] = args.layer_scale_init_values
|
||||
if args.dropout and args.dropout > 0.0:
|
||||
kwargs["drop"] = kwargs["drop_rate"] = args.dropout
|
||||
if args.drop_path_rate and args.drop_path_rate > 0.0:
|
||||
kwargs["drop_block_rate"] = args.drop_path_rate
|
||||
kwargs["num_classes"] = args.n_classes
|
||||
kwargs["img_size"] = args.imsize
|
||||
if model_str.startswith("ViT"):
|
||||
# Format: ViT-{Ti,S,B,L}/<patch_size>[_<image_res>]
|
||||
h1, h2 = model_str.split("/")
|
||||
_, model_size = h1.split("-")
|
||||
if "_" in h2:
|
||||
patch_size, image_res = h2.split("_")
|
||||
assert args.imsize is None or args.imsize == int(
|
||||
image_res
|
||||
), f"Got two different image sizes: {args.imsize} vs {image_res}"
|
||||
else:
|
||||
patch_size = h2
|
||||
|
||||
kwargs = {**vit_sizes[model_size], **kwargs}
|
||||
model = TimmViT(patch_size=int(patch_size), in_chans=3, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
|
||||
else:
|
||||
logger.debug(f"Loading model via timm api {model_str} with args {kwargs}")
|
||||
model = timm.create_model(model_str, pretrained=False, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def load_pretrained(model_path, args, new_dataset_params=False):
|
||||
"""Load a pretrained model from .tar file.
|
||||
|
||||
Args:
|
||||
new_dataset_params (bool, optional): change model parameters (imsize, n_classes) to the ones from args. (Default value = False)
|
||||
model_path (str): path to .tar file
|
||||
args: new model parameters
|
||||
|
||||
Returns:
|
||||
tuple: model, args, old_args, save_state
|
||||
|
||||
"""
|
||||
_import_architectures()
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = utils.prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
old_args.cuda = args.cuda
|
||||
|
||||
if old_args.model.startswith("flash_vit"):
|
||||
args.pop("layer_scale_init_values", None)
|
||||
old_args.pop("layer_scale_init_values", None)
|
||||
|
||||
# load the model (the old one first)
|
||||
model = prepare_model(old_args.model, old_args)
|
||||
logger.debug(f"loading model {old_args.model} from {model_path} with args {old_args}")
|
||||
file_save_state = utils.remove_prefix(save_state["model_state"], prefix="_orig_mod.")
|
||||
file_save_state = utils.remove_prefix(file_save_state)
|
||||
try:
|
||||
model.load_state_dict(file_save_state)
|
||||
except (UnboundLocalError, RuntimeError) as e:
|
||||
model_keys = set(model.state_dict().keys())
|
||||
file_keys = set(file_save_state.keys())
|
||||
logger.warning(f"Error loading state dict: {e}")
|
||||
model_minus_file = model_keys.difference(file_keys)
|
||||
file_minus_model = file_keys.difference(model_keys)
|
||||
logger.warning(f"model-file: {model_minus_file}\nfile-model: {file_minus_model}")
|
||||
if len(file_minus_model) == 0 and all([".ls" in key and key.endswith(".gamma") for key in model_minus_file]):
|
||||
logger.info("Old model was without LayerScale -> replicating")
|
||||
try:
|
||||
args.pop("layer_scale_init_values")
|
||||
old_args.pop("layer_scale_init_values")
|
||||
model = prepare_model(old_args.model, old_args)
|
||||
model.load_state_dict(file_save_state)
|
||||
except (UnboundLocalError, RuntimeError) as e:
|
||||
logger.error("Could not resolve conflict")
|
||||
logger.error(f"Still got error {e}")
|
||||
exit(-1)
|
||||
elif any("head.0." in key for key in file_minus_model):
|
||||
logger.info("Old model used nn.Seqeuntial for head. Trying to fix -> nn.Linear")
|
||||
file_save_state = {key.replace("head.0.", "head."): val for key, val in file_save_state.items()}
|
||||
try:
|
||||
model.load_state_dict(file_save_state)
|
||||
except (UnboundLocalError, RuntimeError) as e:
|
||||
logger.error("Could not resolve conflict")
|
||||
logger.error(f"Still got error {e}")
|
||||
exit(-1)
|
||||
else:
|
||||
exit(-1)
|
||||
|
||||
if new_dataset_params:
|
||||
# setup for finetuning parameters
|
||||
model.set_image_res(args.imsize)
|
||||
model.set_num_classes(args.n_classes)
|
||||
|
||||
if args.max_seq_len is not None:
|
||||
model.set_max_seq_len(args.max_seq_len)
|
||||
|
||||
return model, args, old_args, save_state
|
||||
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
|
||||
user = os.environ.get("USER")
|
||||
results_folder = os.path.join("/BASE/FOLDER/TO/STORE/WEIGHTS/AND/LOGS", "EfficientCVBench")
|
||||
# PATH: /netscratch/<user>/slurm
|
||||
slurm_output_folder = os.path.join("/FOLDER/FOR/SLURM/TO/WRITE/LOGS/TO", "slurm")
|
||||
|
||||
|
||||
_ds_paths = {
|
||||
"cifar": "/PATH/TO/CIFAT",
|
||||
"tinyimagenet": "/PATH/TO/TINYIMAGENET",
|
||||
"stanford_cars": "/PATH/TO/CARS/SUPERFOLDER",
|
||||
"oxford_pet": "/PATH/TO/PET/SUPERFOLDER",
|
||||
"flowers102": "/PATH/TO/FLOWERS/SUPERFOLDER",
|
||||
"food101": "/PATH/TO/FOOD/SUPERFOLDER",
|
||||
"aircraft": "/PATH/TO/AIRCRAFT/SUPERFOLDER",
|
||||
"fornet": "/PATH/TO/FORNET",
|
||||
"counteranimal": "/PATH/TO/CounterAnimal/LAION-final",
|
||||
}
|
||||
|
||||
|
||||
def ds_path(dataset, args=None):
|
||||
"""Get the (base) path for any dataset.
|
||||
|
||||
Args:
|
||||
-----
|
||||
dataset (str): The dataset I'm looking for.
|
||||
args (DotDict, optional): Run args. If args.custom_dataset_path is set, this one is always returned.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
str: Path to the dataset root folder.
|
||||
"""
|
||||
if args is not None and "custom_dataset_path" in args and args.custom_dataset_path is not None:
|
||||
return args.custom_dataset_path
|
||||
return _ds_paths[dataset]
|
||||
136
AAAI Supplementary Material/Model Training Code/recover.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Continue pretraining / finetuning after something went wrong."""
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from engine import _train, setup_criteria_mixup, setup_model_optim_sched_scaler, setup_tracking_and_logging
|
||||
from load_dataset import prepare_dataset
|
||||
from models import load_pretrained
|
||||
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs
|
||||
|
||||
|
||||
def continue_training(model, **kwargs):
|
||||
"""Continue training a model from a saved state.
|
||||
|
||||
Args:
|
||||
model (str): path to saved state.
|
||||
**kwargs: additional keyword arguments.
|
||||
|
||||
"""
|
||||
model_path = model
|
||||
save_state = torch.load(model, map_location="cpu")
|
||||
|
||||
# state is of the form
|
||||
#
|
||||
# state = {'epoch': epochs,
|
||||
# 'model_state': model.state_dict(),
|
||||
# 'optimizer_state': optimizer.state_dict(),
|
||||
# 'scheduler_state': scheduler.state_dict(),
|
||||
# 'args': dict(args),
|
||||
# 'run_name': run_name,
|
||||
# 'stats': metrics}
|
||||
|
||||
args = prep_kwargs(save_state["args"])
|
||||
|
||||
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
if "world_size" in args and args.world_size is not None:
|
||||
global_bs = args.batch_size * args.world_size
|
||||
else:
|
||||
# assume global bs is given in kwargs
|
||||
global_bs = kwargs["batch_size"]
|
||||
args.batch_size = int(global_bs / world_size)
|
||||
args.world_size = world_size
|
||||
|
||||
if "dataset" in args and args.dataset is not None:
|
||||
dataset = args.dataset
|
||||
else:
|
||||
# get default dataset for the task
|
||||
dataset = "ImageNet21k" if args.task == "pre-train" else "ImageNet"
|
||||
args.dataset = dataset
|
||||
|
||||
if "val_dataset" in args and args.val_dataset is not None:
|
||||
val_dataset = args.val_dataset
|
||||
else:
|
||||
val_dataset = dataset
|
||||
args.val_dataset = val_dataset
|
||||
|
||||
start_epoch = save_state["epoch"]
|
||||
if "epochs" in args and args.epochs is not None and args.epochs != start_epoch:
|
||||
epochs = args.epochs
|
||||
else:
|
||||
epochs = kwargs["epochs"]
|
||||
|
||||
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path)
|
||||
logger.info(f"Logging run information to '{run_folder}'")
|
||||
|
||||
# get the datasets & dataloaders
|
||||
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
|
||||
dataset, args, rank=rank
|
||||
)
|
||||
val_loader, _, __, ___, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
|
||||
|
||||
# model_name = args.model
|
||||
|
||||
model, args, _, __ = load_pretrained(model_path, args)
|
||||
|
||||
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
|
||||
|
||||
try:
|
||||
optimizer.load_state_dict(save_state["optimizer_state"])
|
||||
except ValueError as e:
|
||||
logger.error(f"Could not load optimizer state: {e}")
|
||||
logger.error(
|
||||
f"optimizer state: {optimizer.state_dict().keys()}, param groups: {optimizer.state_dict()['param_groups']}"
|
||||
)
|
||||
logger.error(
|
||||
f"saved state: {save_state['optimizer_state'].keys()}, param groups:"
|
||||
f" {save_state['optimizer_state']['param_groups']}"
|
||||
)
|
||||
raise e
|
||||
|
||||
scheduler.load_state_dict(save_state["scheduler_state"])
|
||||
|
||||
# log all devices
|
||||
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
log_args(args)
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
criterion, val_criterion, mixup = setup_criteria_mixup(args)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"start training at epoch {start_epoch}")
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
res = _train(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epochs,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
world_size,
|
||||
scheduler,
|
||||
args,
|
||||
val_loader,
|
||||
val_criterion,
|
||||
run_folder,
|
||||
scaler=scaler,
|
||||
do_metrics_calculation=True,
|
||||
start_epoch=start_epoch,
|
||||
show_tqdm=args.tqdm,
|
||||
train_dali_server=train_dali_server,
|
||||
val_dali_server=val_dali_server,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
|
||||
logger.info(f"Run '{args.run_name}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%")
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
@@ -0,0 +1,25 @@
|
||||
captum==0.8.0
|
||||
# datadings==3.4.6
|
||||
einops==0.8.1
|
||||
fvcore==0.1.5.post20221221
|
||||
grad-cam==1.5.4
|
||||
halonet-pytorch==0.0.4
|
||||
numpy==1.26.4
|
||||
opencv-python==4.11.0.86
|
||||
pytorch_wavelets==1.3.0
|
||||
pywavelets==1.8.0
|
||||
reformer-pytorch==1.4.4
|
||||
routing-transformer==1.6.1
|
||||
sinkhorn-transformer==0.11.4
|
||||
timm==1.0.15
|
||||
torch==2.6.0
|
||||
torcheval==0.0.7
|
||||
torchprofile==0.0.4
|
||||
torchvision==0.21.0
|
||||
tqdm==4.67.1
|
||||
nltk==3.9.1
|
||||
numpy==1.26.4
|
||||
Pillow==11.1.0
|
||||
psutils==3.3.9
|
||||
wandb==0.19.9
|
||||
psutil==7.0.0
|
||||
@@ -0,0 +1,108 @@
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch import nn
|
||||
|
||||
vit_sizes = {
|
||||
"na": dict(embed_dim=96, depth=2, num_heads=2),
|
||||
"mu": dict(embed_dim=144, depth=6, num_heads=3),
|
||||
"Ti": dict(embed_dim=192, depth=12, num_heads=3),
|
||||
"S": dict(embed_dim=384, depth=12, num_heads=6),
|
||||
"B": dict(embed_dim=768, depth=12, num_heads=12),
|
||||
"L": dict(embed_dim=1024, depth=24, num_heads=16),
|
||||
"LRA_CIFAR": dict(embed_dim=256, depth=1, num_heads=4, mlp_ratio=1.0),
|
||||
"LRA_IMDB": dict(embed_dim=256, depth=4, num_heads=4, mlp_ratio=4.0),
|
||||
"LRA_ListOps": dict(embed_dim=512, depth=4, num_heads=8, mlp_ratio=4.0),
|
||||
}
|
||||
|
||||
|
||||
class ResizingInterface:
|
||||
"""Interface for resizing parts of a Vision Transformer model."""
|
||||
|
||||
def get_internal_loss(self):
|
||||
"""Add a term to the loss."""
|
||||
return 0.0
|
||||
|
||||
def set_image_res(self, res):
|
||||
"""Set a new image resolution.
|
||||
|
||||
Resets the (learned) patch embedding.
|
||||
|
||||
Args:
|
||||
res (int): new image resolution
|
||||
|
||||
"""
|
||||
self._set_input_strand(res=res)
|
||||
|
||||
def _set_input_strand(self, res=None, patch_size=None):
|
||||
"""Set a new image resolution and patch size.
|
||||
|
||||
Args:
|
||||
res (int): (Default value = None)
|
||||
patch_size (int): (Default value = None)
|
||||
|
||||
"""
|
||||
if res is None:
|
||||
res = self.img_size
|
||||
|
||||
if patch_size is None:
|
||||
patch_size = self.patch_size
|
||||
else:
|
||||
# TODO: implement interpolation of patch_embed weights to new patch size/input shape
|
||||
raise NotImplementedError("Interpolation of patch_embed weights to new patch size not implemented yet.")
|
||||
|
||||
if res == self.img_size and patch_size == self.patch_size:
|
||||
return # nothing to do here
|
||||
|
||||
logger.info(f"Resizing input from {self.img_size} to {res} with patch size {self.patch_size} to {patch_size}.")
|
||||
|
||||
old_patch_embed_state = copy(self.patch_embed.state_dict())
|
||||
self.patch_embed = self.embed_layer(
|
||||
img_size=res,
|
||||
patch_size=patch_size,
|
||||
in_chans=self.in_chans,
|
||||
embed_dim=self.embed_dim,
|
||||
bias=not self.pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
)
|
||||
|
||||
self.patch_embed.load_state_dict(old_patch_embed_state)
|
||||
|
||||
num_extra_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
|
||||
orig_size = int((self.pos_embed.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
new_size = int(self.patch_embed.num_patches**0.5)
|
||||
extra_tokens = self.pos_embed[:, :num_extra_tokens]
|
||||
pos_tokens = self.pos_embed[:, num_extra_tokens:]
|
||||
# make it shape rest x embed_dim x orig_size x orig_size
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
|
||||
pos_tokens = nn.functional.interpolate(
|
||||
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
|
||||
)
|
||||
# make it shape rest x new_size^2 x embed_dim
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
if num_extra_tokens > 0:
|
||||
pos_tokens = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
self.pos_embed = nn.Parameter(pos_tokens.contiguous())
|
||||
|
||||
self.img_size = res
|
||||
self.patch_size = patch_size
|
||||
|
||||
def set_num_classes(self, n_classes):
|
||||
"""Reset the classification head with a new number of classes.
|
||||
|
||||
Args:
|
||||
n_classes (int): new number of classes
|
||||
|
||||
"""
|
||||
if n_classes == self.num_classes:
|
||||
return
|
||||
logger.info(f"Resizing classification head from {self.num_classes} to {n_classes}.")
|
||||
self.head = nn.Linear(self.embed_dim, n_classes) if n_classes > 0 else nn.Identity()
|
||||
self.num_classes = n_classes
|
||||
|
||||
# init weight + bias
|
||||
# nn.init.zeros_(self.head.weight)
|
||||
# nn.init.constant_(self.head.bias, -log(self.num_classes))
|
||||
|
||||
nn.init.trunc_normal_(self.head.weight, std=0.02)
|
||||
nn.init.constant_(self.head.bias, 0)
|
||||
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
assert len(sys.argv) >= 2, f"Add a sbatch script as the first argument"
|
||||
assert os.path.isfile(
|
||||
sys.argv[1]
|
||||
), f"First argument has to be an executable script (a file that exists), but got '{sys.argv[1]}'"
|
||||
|
||||
with open(sys.argv[1], "r") as f:
|
||||
script = f.readlines()
|
||||
|
||||
script = [l.strip() for l in script if len(l.strip()) > 0]
|
||||
|
||||
# join lines ending with \
|
||||
joined_script = []
|
||||
current_line = ""
|
||||
for line in script:
|
||||
current_line += line
|
||||
if current_line.endswith("\\"):
|
||||
current_line = current_line[:-1]
|
||||
else:
|
||||
joined_script.append(current_line)
|
||||
current_line = ""
|
||||
script = joined_script
|
||||
|
||||
additional_srun_params = []
|
||||
srun_lines = []
|
||||
for line in script:
|
||||
if line.upper().startswith("#SBATCH "):
|
||||
param = line[len("#SBATCH ") :]
|
||||
if param.startswith("--output="):
|
||||
continue
|
||||
additional_srun_params.append(param)
|
||||
elif line.startswith("srun "):
|
||||
srun_lines.append(line)
|
||||
|
||||
further_args = " ".join(sys.argv[2:])
|
||||
|
||||
sruns = [
|
||||
line.replace("srun ", "srun " + " ".join(additional_srun_params) + " ").replace('"$@"', further_args)
|
||||
for line in srun_lines
|
||||
]
|
||||
|
||||
for srun_line in sruns:
|
||||
print(f"I will run:\n{srun_line}", flush=True)
|
||||
os.system(srun_line)
|
||||
97
AAAI Supplementary Material/Model Training Code/test.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from load_dataset import prepare_dataset
|
||||
from utils import log_args, log_formatter, prep_kwargs
|
||||
|
||||
|
||||
def load_images(dataset, **kwargs):
|
||||
args = prep_kwargs(kwargs)
|
||||
args.dataset = dataset
|
||||
|
||||
args.aug_normalize = False
|
||||
|
||||
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
||||
images = next(iter(loader))[0]
|
||||
|
||||
images = images.permute(0, 2, 3, 1).numpy()
|
||||
images = [images[i] for i in range(images.shape[0])]
|
||||
|
||||
rows = math.ceil(math.sqrt(len(images) / 2))
|
||||
ims_per_row = len(images) // rows
|
||||
|
||||
fig, axs = plt.subplots(rows, ims_per_row)
|
||||
axs = [ax for row in axs for ax in row]
|
||||
for img, ax in zip(images, axs):
|
||||
ax.imshow(img)
|
||||
fig.suptitle(f"Examples from {dataset}")
|
||||
fig.tight_layout(pad=0)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def save_images(dataset, out_dir, ipc=None, **kwargs):
|
||||
args = prep_kwargs(kwargs)
|
||||
args.dataset = dataset
|
||||
args.out_dir = out_dir
|
||||
args.ipc = ipc
|
||||
args.aug_normalize = False
|
||||
|
||||
log_file = os.path.join(out_dir, "save_images.log")
|
||||
logger.remove()
|
||||
logger.configure(extra=dict(run_name=f"Save images of {dataset}", rank=0, world_size=-1))
|
||||
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.info(f"Out dir '{out_dir}'")
|
||||
log_args(args)
|
||||
|
||||
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
||||
n_ims = [0 for i in range(args.n_classes)]
|
||||
|
||||
if args.n_classes == 1000:
|
||||
# assume its ImageNet classes
|
||||
logger.info("1000 classes => assuming ImageNet class names")
|
||||
with open("data/misc_dataset_files/imagenet_labels.txt", "r") as f:
|
||||
lines = f.readlines()
|
||||
labels = [l.strip().split(" ")[0] for l in lines]
|
||||
lbl_to_cls_name = sorted(labels)
|
||||
else:
|
||||
lbl_to_cls_name = [i for i in range(args.n_classes)]
|
||||
|
||||
for cls_name in lbl_to_cls_name:
|
||||
os.makedirs(os.path.join(args.out_dir, cls_name), exist_ok=True)
|
||||
|
||||
skipped_ims = 0
|
||||
tqdm_is_disabled = (not args.tqdm) or os.environ.get("TQDM_DISABLE", 0) != 0
|
||||
for i, (images, labels) in (
|
||||
pbar := tqdm(enumerate(loader), desc="Loading and saving images", disable=tqdm_is_disabled, total=len(loader))
|
||||
):
|
||||
images = (images.permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
|
||||
images = [images[i] for i in range(images.shape[0])]
|
||||
labels = labels.tolist()
|
||||
|
||||
for img, lbl in zip(images, labels):
|
||||
if ipc is not None and n_ims[lbl] >= ipc:
|
||||
skipped_ims += 1
|
||||
continue
|
||||
|
||||
img = Image.fromarray(img).save(
|
||||
os.path.join(args.out_dir, lbl_to_cls_name[lbl], f"{lbl_to_cls_name[lbl]}_{n_ims[lbl]}.JPEG")
|
||||
)
|
||||
n_ims[lbl] += 1
|
||||
|
||||
if ipc is not None and sum(n_ims) >= args.n_classes * ipc:
|
||||
break
|
||||
if tqdm_is_disabled:
|
||||
if i % 1000 == 0:
|
||||
logger.info(f"Batch [{i+1}/{len(loader)}]: Saved {sum(n_ims)}, skipped {skipped_ims}")
|
||||
else:
|
||||
pbar.set_description(f"Loading and saving (saved {sum(n_ims)}, skipped {skipped_ims})")
|
||||
logger.success(f"Extracted {sum(n_ims)} images.")
|
||||
303
AAAI Supplementary Material/Model Training Code/train.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import timm
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from engine import (
|
||||
_train,
|
||||
setup_criteria_mixup,
|
||||
setup_model_optim_sched_scaler,
|
||||
setup_tracking_and_logging,
|
||||
wandb_available,
|
||||
)
|
||||
from load_dataset import prepare_dataset
|
||||
from models import load_pretrained, prepare_model
|
||||
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs, set_filter_warnings
|
||||
|
||||
|
||||
def finetune(model, dataset, epochs, val_dataset=None, head_only=False, **kwargs):
|
||||
"""Finetune a pretrained model on a given dataset for a specified number of epochs.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pretrained model state file (in .tar format).
|
||||
dataset (str): Name of the dataset to finetune on.
|
||||
val_dataset (str, optional): Name of the validation dataset. (Default value = None)
|
||||
epochs (int): Number of epochs to train for.
|
||||
head_only (bool, optional): Whether to train only the head of the model. Default: False.
|
||||
**kwargs (dict): Further arguments for model setup, training, evaluation,...
|
||||
|
||||
Notes:
|
||||
This function assumes that the model was pretrained on a different dataset using a different set of hyperparameters.
|
||||
It fine-tunes the model on a new dataset by loading the pretrained weights and training for the specified number of
|
||||
epochs. The function supports distributed training using the PyTorch DistributedDataParallel module.
|
||||
"""
|
||||
set_filter_warnings()
|
||||
|
||||
# Add defaults & make keys properties
|
||||
args = prep_kwargs(kwargs)
|
||||
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
|
||||
args.val_dataset = val_dataset
|
||||
args.dataset = dataset
|
||||
args.epochs = epochs
|
||||
|
||||
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
|
||||
args.world_size = world_size
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
except RuntimeError as e:
|
||||
logger.error(
|
||||
f"Could not set device {device} as current device; "
|
||||
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
||||
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
||||
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
||||
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
||||
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
||||
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
||||
)
|
||||
raise e
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
if args.seed is not None:
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed + rank
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# get the datasets & dataloaders
|
||||
# transform only contains resize & crop here; everything else is handled on the GPU / in the training loop
|
||||
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
|
||||
dataset, args, rank=rank
|
||||
)
|
||||
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
|
||||
assert (
|
||||
args.n_classes == _val_classes
|
||||
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
|
||||
|
||||
save_state = torch.load(model, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
parent_folder = os.path.dirname(model)
|
||||
args.model = old_args.model
|
||||
run_folder = setup_tracking_and_logging(args, rank)
|
||||
if rank == 0:
|
||||
if not os.path.exists(os.path.join(run_folder, "parent_run")):
|
||||
os.symlink(parent_folder, os.path.join(run_folder, "parent_run"), target_is_directory=True)
|
||||
logger.info(
|
||||
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
||||
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
||||
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
||||
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
||||
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
||||
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
||||
)
|
||||
|
||||
if args.seed:
|
||||
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
# model_name = old_args.model
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"The model was pretrained on {old_args.dataset} for {save_state['epoch']} epochs.")
|
||||
|
||||
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(
|
||||
model, device, epochs, args, head_only=head_only
|
||||
)
|
||||
|
||||
# log all devices
|
||||
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
logger.info(f"timm version {timm.__version__}")
|
||||
logger.info(f"full set of old arguments: {old_args}")
|
||||
log_args(args)
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
criterion, val_criterion, mixup = setup_criteria_mixup(args)
|
||||
if rank == 0:
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
res = _train(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epochs,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
world_size,
|
||||
scheduler,
|
||||
args,
|
||||
val_loader,
|
||||
val_criterion,
|
||||
run_folder,
|
||||
scaler,
|
||||
do_metrics_calculation=True,
|
||||
show_tqdm=args.tqdm,
|
||||
train_dali_server=train_dali_server,
|
||||
val_dali_server=val_dali_server,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
best_acc_key = sorted([key for key in res.keys() if key.startswith("val/best_")])[0]
|
||||
logger.info(
|
||||
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
|
||||
)
|
||||
|
||||
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)
|
||||
|
||||
|
||||
def pretrain(model, dataset, epochs, val_dataset=None, **kwargs):
|
||||
"""Train or pretrain a model.
|
||||
|
||||
Args:
|
||||
model (str): Name of the model to train.
|
||||
dataset (str): Name of the dataset to train the model on.
|
||||
epochs (int): Number of training epochs.
|
||||
val_dataset (str, optional, optional): Name of the validation dataset, by default None
|
||||
**kwargs (dict): Additional keyword arguments.
|
||||
|
||||
Notes:
|
||||
This function sets up logger, prepares the model, and trains the model on the given dataset.
|
||||
"""
|
||||
set_filter_warnings()
|
||||
|
||||
# Add defaults & make args properties
|
||||
args = prep_kwargs(kwargs)
|
||||
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
|
||||
args.val_dataset = val_dataset
|
||||
args.dataset = dataset
|
||||
args.model = model
|
||||
args.epochs = epochs
|
||||
|
||||
args.distributed, device, world_size, rank, gpu_id = ddp_setup(args.cuda)
|
||||
args.world_size = world_size
|
||||
|
||||
# sleep(rank * 5)
|
||||
# logger.debug(f'running environment commands for rank {rank}')
|
||||
# os.system('env')
|
||||
# os.system('nvidia-smi')
|
||||
# sleep((world_size - rank) * 5)
|
||||
|
||||
logger.debug(
|
||||
f"rank params: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
||||
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; gpu params: "
|
||||
f"SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
||||
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
||||
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}"
|
||||
)
|
||||
|
||||
if args.cuda:
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Could not set device {device} as current device: {e}")
|
||||
logger.error(
|
||||
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
||||
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
||||
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
||||
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
||||
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
||||
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
||||
)
|
||||
raise e
|
||||
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
run_folder = setup_tracking_and_logging(args, rank)
|
||||
if rank % world_size == 0:
|
||||
logger.info(
|
||||
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
||||
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
||||
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
||||
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
||||
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
||||
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
||||
)
|
||||
|
||||
if args.seed is not None:
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed + rank
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
|
||||
|
||||
# get the datasets & dataloaders
|
||||
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
|
||||
dataset, args, rank=rank
|
||||
)
|
||||
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
|
||||
assert (
|
||||
args.n_classes == _val_classes
|
||||
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
|
||||
|
||||
# setup model with amp & DDP
|
||||
if isinstance(model, str):
|
||||
if model.startswith("ViT") and "_" not in model:
|
||||
model += f"_{args.imsize}"
|
||||
model_name = model
|
||||
model = prepare_model(model, args)
|
||||
if not model_name:
|
||||
model_name = type(model).__name__
|
||||
|
||||
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
|
||||
|
||||
# log all devices
|
||||
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if device != 'cpu' else ''}")
|
||||
if rank == 0:
|
||||
logger.info(f"python version {sys.version}")
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
logger.info(f"timm version {timm.__version__}")
|
||||
log_args(args)
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
criterion, val_criterion, mixup = setup_criteria_mixup(args)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
res = _train(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epochs,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
world_size,
|
||||
scheduler,
|
||||
args,
|
||||
val_loader,
|
||||
val_criterion,
|
||||
run_folder,
|
||||
scaler,
|
||||
do_metrics_calculation=True,
|
||||
show_tqdm=args.tqdm,
|
||||
train_dali_server=train_dali_server,
|
||||
val_dali_server=val_dali_server,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
|
||||
logger.info(
|
||||
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
|
||||
)
|
||||
|
||||
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)
|
||||
594
AAAI Supplementary Material/Model Training Code/utils.py
Normal file
@@ -0,0 +1,594 @@
|
||||
"""Utils and small helper functions."""
|
||||
|
||||
import collections.abc
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from itertools import repeat
|
||||
from math import cos, pi, sqrt
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from loguru import logger
|
||||
from timm.data import Mixup
|
||||
from timm.utils import NativeScaler, dispatch_clip_grad
|
||||
from torch.nn.modules.loss import _WeightedLoss
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
import paths_config
|
||||
from config import default_kwargs, get_default_kwargs # noqa: F401 # is used in prep_kwargs
|
||||
|
||||
|
||||
class RepeatedDataset(Dataset):
|
||||
"""Dataset that repeats the given dataset a number of times."""
|
||||
|
||||
def __init__(self, dataset, num_repeats):
|
||||
"""Create repeated dataset.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): dataset to repeat.
|
||||
num_repeats (int): number of repeats.
|
||||
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.num_repeats = num_repeats
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset[idx // self.num_repeats]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset) * self.num_repeats
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerArgs:
|
||||
"""Class for scheduler arguments."""
|
||||
|
||||
sched: str
|
||||
epochs: int
|
||||
min_lr: float
|
||||
warmup_lr: float
|
||||
warmup_epochs: int
|
||||
cooldown_epochs: int = 0
|
||||
|
||||
|
||||
def scheduler_function_factory(
|
||||
epochs, sched, warmup_epochs=0, lr=None, min_lr=0.0, warmup_sched=None, warmup_lr=None, offset=0, **kwargs
|
||||
):
|
||||
"""Create a scheduler factor function.
|
||||
|
||||
Args:
|
||||
sched (str): the learning rate schedule type
|
||||
epochs (int): length of the full schedule
|
||||
warmup_epochs (int, optional): number of epochs reserved for warmup (Default value = 0)
|
||||
lr (float, optional): learning rate (has to be given, when warmup or min_lr are set) (Default value = None)
|
||||
min_lr (float, optional): minimum learning rate (Default value = 0.0)
|
||||
warmup_sched (str, optional): the type of schedule during warmup (Default value = None)
|
||||
warmup_lr (float, optional): (starting) learning rate during warmup (Default value = None)
|
||||
offset (int, optional): offset for the schedule (to be the same as the timm scheduler) (Default value = 0)
|
||||
**kwargs: unused
|
||||
|
||||
Returns:
|
||||
function: scheduler function
|
||||
|
||||
"""
|
||||
sched = sched.lower()
|
||||
|
||||
def warmup_f(ep):
|
||||
return 1.0
|
||||
|
||||
if warmup_epochs > 0:
|
||||
assert warmup_lr is not None, "Need warmup_lr, but got None"
|
||||
warmup_lr_factor = warmup_lr / lr
|
||||
if warmup_sched == "linear":
|
||||
|
||||
def warmup_f(ep):
|
||||
return warmup_lr_factor + (1 - warmup_lr_factor) * max(ep, 0.0) / warmup_epochs
|
||||
|
||||
elif warmup_sched == "const":
|
||||
|
||||
def warmup_f(ep):
|
||||
return warmup_lr_factor
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Warmup schedule {warmup_sched} not implemented")
|
||||
|
||||
epochs = epochs - warmup_epochs + offset
|
||||
if sched == "cosine":
|
||||
# cos from 0 to pi
|
||||
def main_f(ep):
|
||||
return cos(pi * ep / epochs) / 2 + 0.5
|
||||
|
||||
elif sched == "const":
|
||||
|
||||
def main_f(ep):
|
||||
return 1.0
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Schedule {sched} is not implemented.")
|
||||
|
||||
# rescale and add min_lr
|
||||
min_lr_fact = min_lr / lr
|
||||
|
||||
def main_f_with_min_lr(ep):
|
||||
return (1 - min_lr_fact) * main_f(ep) + min_lr_fact
|
||||
|
||||
return lambda ep: (
|
||||
warmup_f(ep + offset) if ep + offset < warmup_epochs else main_f_with_min_lr(ep + offset - warmup_epochs)
|
||||
)
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
"""Extension of a Python dictionary to access its keys using dot notation."""
|
||||
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
def __getattr__(self, item, default=None):
|
||||
"""Get item from.
|
||||
|
||||
Args:
|
||||
item: key
|
||||
default (optional): default value. Defaults to None.
|
||||
|
||||
Returns:
|
||||
value
|
||||
|
||||
"""
|
||||
if item not in self:
|
||||
return default
|
||||
return self.get(item)
|
||||
|
||||
|
||||
def prep_kwargs(kwargs):
|
||||
"""Prepare the arguments and add defaults.
|
||||
|
||||
Args:
|
||||
kwargs (dict[str, Any]): dict of kwargs
|
||||
|
||||
Returns:
|
||||
DotDict: prepared kwargs
|
||||
|
||||
"""
|
||||
if "defaults" not in kwargs:
|
||||
kwargs["defaults"] = "DeiTIII"
|
||||
defaults = get_default_kwargs(kwargs["defaults"])
|
||||
for k, v in defaults.items():
|
||||
if k not in kwargs or kwargs[k] is None:
|
||||
kwargs[k] = v
|
||||
|
||||
if "results_folder" not in kwargs:
|
||||
kwargs[var_name] = paths_config.results_folder # globals()[var_name]
|
||||
|
||||
if kwargs["results_folder"].endswith("/"):
|
||||
kwargs["results_folder"] = kwargs["results_folder"][:-1]
|
||||
|
||||
if "val_dataset" not in kwargs and "dataset" in kwargs:
|
||||
kwargs["val_dataset"] = kwargs["dataset"]
|
||||
|
||||
return DotDict(kwargs)
|
||||
|
||||
|
||||
def denormalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
||||
"""Invert the normlize operation.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): images to de-normalize
|
||||
mean (tuple, optional): normalization mean. Defaults to (0.485, 0.456, 0.406).
|
||||
std (tuple, optional): normalization std. Defaults to (0.229, 0.224, 0.225).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: de-normalized images
|
||||
|
||||
"""
|
||||
operation = transforms.Normalize(
|
||||
mean=[-mu / sigma for mu, sigma in zip(mean, std)], std=[1 / sigma for sigma in std]
|
||||
)
|
||||
return operation(x)
|
||||
|
||||
|
||||
def log_formatter(record):
|
||||
if "run_name" not in record["extra"]:
|
||||
return (
|
||||
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <y>name TBD</y> > <y>?</y>/<y>?</y> <c>|</c> <level>{level:"
|
||||
" <8}</level> <c>|</c> {message}\n"
|
||||
)
|
||||
|
||||
epoch_str = "@ epoch <y>{extra[epoch]: >3}</y> " if "epoch" in record["extra"] else ""
|
||||
code_loc_str = "<r>{name}</r>.<r>{function}</r>:<r>{line}</r> - " if record["level"].no >= 30 else ""
|
||||
|
||||
return (
|
||||
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <m>{extra[run_name]}</m> >"
|
||||
" <m>{extra[rank]}</m>/<m>{extra[world_size]}</m> "
|
||||
+ epoch_str
|
||||
+ "<c>|</c> <level>{level: <8}</level> <c>|</c> "
|
||||
+ code_loc_str
|
||||
+ "{message}\n"
|
||||
)
|
||||
|
||||
|
||||
def ddp_setup(use_cuda=True):
|
||||
"""Set up the distributed environment.
|
||||
|
||||
Args:
|
||||
use_cuda: (Default value = True)
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the following elements:
|
||||
* bool: Whether the training is distributed.
|
||||
* torch.device: The device to use for distributed training.
|
||||
* int: The total number of processes in the distributed setup.
|
||||
* int: The global rank of the current process in the distributed setup.
|
||||
* int: The local rank of the current process on its node.
|
||||
|
||||
Notes:
|
||||
The 'nccl' backend is used.
|
||||
|
||||
"""
|
||||
logger.remove()
|
||||
rank = int(os.getenv("RANK", 0))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||||
num_gpus = int(os.getenv("WORLD_SIZE", 1))
|
||||
distributed = "RANK" in os.environ and num_gpus > 1
|
||||
logger.add(sys.stderr, format=log_formatter, colorize=True, enqueue=True)
|
||||
if distributed:
|
||||
assert use_cuda, "Only use distributed mode with cuda."
|
||||
try:
|
||||
dist.init_process_group("nccl")
|
||||
except ValueError as e:
|
||||
logger.critical(f"Value error while setting up nccl process group: {e}")
|
||||
logger.info(
|
||||
f" CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
|
||||
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
|
||||
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
|
||||
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}. Shutting down now."
|
||||
)
|
||||
raise e
|
||||
|
||||
assert torch.cuda.is_available() or not use_cuda, "CUDA is not available"
|
||||
assert (
|
||||
len(str(os.environ.get("SLURM_STEP_GPUS")).split(","))
|
||||
== len(str(os.environ.get("CUDA_VISIBLE_DEVICES")).split(","))
|
||||
== len(str(os.environ.get("GPU_DEVICE_ORDINAL")).split(","))
|
||||
== num_gpus
|
||||
) or not use_cuda, (
|
||||
f"SLURM GPU setup is incorrect: CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
|
||||
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
|
||||
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
|
||||
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}"
|
||||
)
|
||||
return distributed, torch.device(f"cuda:{local_rank}") if use_cuda else "cpu", num_gpus, rank, local_rank
|
||||
|
||||
|
||||
def ddp_cleanup(args, sync_old_wandb=False, rank=0):
|
||||
"""Clean the distributed setup after use.
|
||||
|
||||
Args:
|
||||
args (DotDict): arguments
|
||||
sync_old_wandb (bool, optional): Whether to sync and remove wandb runs older than 100 hours (>3 days). Defaults to False.
|
||||
rank (int, optional): The rank of the current process, so only one process syncs wandb. Defaults to 0.
|
||||
|
||||
"""
|
||||
if sync_old_wandb and rank == 0:
|
||||
os.system("wandb sync --clean --clean-old-hours 100 --clean-force")
|
||||
|
||||
if args.distributed:
|
||||
logger.info("waiting for all processes to finish")
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
logger.info("exiting now")
|
||||
exit(0)
|
||||
|
||||
|
||||
def set_filter_warnings():
|
||||
"""Filter out some warnings to reduce spam."""
|
||||
# filter DataLoader number of workers warning
|
||||
warnings.filterwarnings(
|
||||
"ignore", ".*worker processes in total. Our suggested max number of worker in current system is.*"
|
||||
)
|
||||
|
||||
# Filter datadings only varargs warning
|
||||
warnings.filterwarnings("ignore", ".*only accepts varargs so.*")
|
||||
|
||||
# Filter warnings from calculation of MACs & FLOPs
|
||||
# warnings.filterwarnings("ignore", ".*No handlers found:.*")
|
||||
|
||||
# Filter warnings from gather
|
||||
warnings.filterwarnings("ignore", ".*is_namedtuple is deprecated, please use the python checks instead.*")
|
||||
|
||||
# Filter warnings from meshgrid
|
||||
warnings.filterwarnings("ignore", ".*in an upcoming release, it will be required to pass the indexing.*")
|
||||
|
||||
# Filter warnings from timm when overwriting models
|
||||
warnings.filterwarnings("ignore", ".*UserWarning: Overwriting .*")
|
||||
|
||||
|
||||
def remove_prefix(state_dict, prefix="module."):
|
||||
"""Remove a prefix from the keys in a state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict[str, Any]): The state dictionary to remove the prefix from.
|
||||
prefix (str, optional): The prefix to remove from the keys. Default is 'module.'.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A new dictionary with the prefix removed from the keys.
|
||||
|
||||
Examples:
|
||||
>>> state_dict = {'module.layer1.weight': 1, 'module.layer1.bias': 2}
|
||||
|
||||
>>> remove_prefix(state_dict)
|
||||
|
||||
{'layer1.weight': 1, 'layer1.bias': 2}
|
||||
|
||||
"""
|
||||
return {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
|
||||
|
||||
|
||||
def prime_factors(n):
|
||||
"""Calculate the prime factors of a given integer.
|
||||
|
||||
Args:
|
||||
n (int): The integer to find the prime factors of.
|
||||
|
||||
Returns:
|
||||
list[int]: The prime factors of n.
|
||||
|
||||
"""
|
||||
i = 2
|
||||
factors = []
|
||||
while i * i <= n:
|
||||
if n % i:
|
||||
i += 1
|
||||
else:
|
||||
n //= i
|
||||
factors.append(i)
|
||||
if n > 1:
|
||||
factors.append(n)
|
||||
return factors
|
||||
|
||||
|
||||
def linear_regession(points):
|
||||
"""Calculate a linear interpolation of the points.
|
||||
|
||||
Args:
|
||||
points (dict[float, float]): points to interpolate in the format points[x] = y
|
||||
|
||||
Returns:
|
||||
function: A function that interpolates the points.
|
||||
|
||||
"""
|
||||
N = len(points)
|
||||
x = []
|
||||
y = []
|
||||
for x_i, y_i in points.items():
|
||||
x.append(x_i)
|
||||
y.append(y_i)
|
||||
x = np.array(x)
|
||||
y = np.array(y)
|
||||
|
||||
a = (N * (x * y).sum() - x.sum() * y.sum()) / (N * (x * x).sum() - x.sum() ** 2)
|
||||
b = (y.sum() - a * x.sum()) / N
|
||||
return lambda z: a * z + b
|
||||
|
||||
|
||||
def save_model_state(
|
||||
model_folder,
|
||||
epoch,
|
||||
args,
|
||||
model_state,
|
||||
regular_save=True,
|
||||
stats=None,
|
||||
val_accs=None,
|
||||
epoch_accs=None,
|
||||
additional_reason="",
|
||||
max_interm_ep_states=2,
|
||||
**kwargs,
|
||||
):
|
||||
"""Save the model state.
|
||||
|
||||
Args:
|
||||
model_folder: Folder to save model in
|
||||
epoch: current epoch
|
||||
args: arguments to guide and save
|
||||
model_state: state of the model
|
||||
regular_save: Is this a regular or a special save? (Default value = True)
|
||||
stats: model stats to save (Default value = None)
|
||||
val_accs: model accuracy to save (Default value = None)
|
||||
epoch_accs: training accuracy to save (Default value = None)
|
||||
additional_reason: save reason; in case it's not just a regular save interval. Would be "top" or "final", for example. (Default value = "")
|
||||
max_interm_ep_states: Number of regular epoch states to keep (Default value = 2)
|
||||
**kwargs: Further arguments to save
|
||||
|
||||
"""
|
||||
# make args dict, not DotDict to be able to save it
|
||||
state = {"epoch": epoch, "model_state": model_state, "run_name": args.run_name, "args": dict(args)}
|
||||
if stats is None:
|
||||
stats = {}
|
||||
if val_accs is not None:
|
||||
stats = {**stats, **val_accs}
|
||||
if epoch_accs is not None:
|
||||
stats = {**stats, **epoch_accs}
|
||||
state["stats"] = stats
|
||||
state = {**state, **kwargs}
|
||||
logger.info(f"saving model state at epoch {epoch} ({additional_reason})")
|
||||
regular_file_name = f"ep_{epoch}.pt"
|
||||
save_name = additional_reason + ".pt" if len(additional_reason) > 0 else regular_file_name
|
||||
outfile = os.path.join(model_folder, save_name)
|
||||
torch.save(state, outfile)
|
||||
if len(additional_reason) > 0 and regular_save:
|
||||
shutil.copyfile(outfile, os.path.join(model_folder, regular_file_name))
|
||||
|
||||
# remove intermediate epoch states (all but the last max_interm_ep_states)
|
||||
if max_interm_ep_states > 0:
|
||||
epoch_states = [f for f in os.listdir(model_folder) if f.startswith("ep_") and f.endswith(".pt")]
|
||||
epoch_states = sorted(epoch_states, key=lambda x: int(x.split("_")[1].split(".")[0]))
|
||||
if len(epoch_states) > max_interm_ep_states:
|
||||
for f in epoch_states[:-max_interm_ep_states]:
|
||||
os.remove(os.path.join(model_folder, f))
|
||||
logger.debug(f"removed intermediate epoch state {f}")
|
||||
|
||||
|
||||
def log_args(args, rank=0):
|
||||
if rank == 0:
|
||||
logger.info("full set of arguments: " + json.dumps(dict(args), sort_keys=True))
|
||||
# keys = sorted(list(args.keys()))
|
||||
# for key in keys:
|
||||
# logger.info(f"arg: {key} = {args[key]}")
|
||||
|
||||
|
||||
class ScalerGradNormReturn(NativeScaler):
|
||||
"""A wrapper around PyTorch's NativeScaler that returns the gradient norm."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{type(self).__name__}(_scaler: {self._scaler})"
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode="norm", parameters=None, create_graph=False):
|
||||
"""Scale and backpropagate through the loss tensor, and return the gradient norm of the selected parameters.
|
||||
|
||||
Does an optimizer step.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss tensor to scale and backpropagate through.
|
||||
optimizer (torch.optim.Optimizer): The optimizer to use for the optimization step.
|
||||
clip_grad (float, optional): The maximum allowed norm of the gradients. If None, no clipping is performed.
|
||||
clip_mode (str, optional): The mode used for clipping the gradients. Only used if `clip_grad` is not None. Possible values are 'norm'
|
||||
(clipping the norm of the gradients) and 'value' (clipping the value of the gradients). (default='norm')
|
||||
parameters (iterable[torch.nn.Parameter], optional): The parameters to compute the gradient norm for. If None, the gradient norm is not computed.
|
||||
create_graph (bool, optional): Whether to create a computation graph for computing second-order gradients. (default=False)
|
||||
|
||||
Returns:
|
||||
float: The gradient norm of the selected parameters.
|
||||
|
||||
"""
|
||||
self._scaler.scale(loss).backward(create_graph=create_graph)
|
||||
|
||||
# always unscale the gradients, since it's being done anyway
|
||||
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
if parameters is not None:
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
device = grads[0].device
|
||||
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
||||
else:
|
||||
grad_norm = -1
|
||||
if clip_grad is not None:
|
||||
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
|
||||
self._scaler.step(optimizer)
|
||||
self._scaler.update()
|
||||
return grad_norm
|
||||
|
||||
|
||||
class NoScaler:
|
||||
"""Dummy gradient scaler that doesn't scale gradients.
|
||||
|
||||
This scaler performs a simple backward pass with the given loss, and then updates the model's parameters
|
||||
with the given optimizer. The resulting gradient norm is computed and returned.
|
||||
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{type(self).__name__}()"
|
||||
|
||||
def __call__(self, loss, optimizer, parameters=None, **kwargs):
|
||||
"""Perform backward pass with the given loss, updates the model's parameters with the given optimizer, and computes the resulting gradient norm.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss tensor that the gradients will be computed from.
|
||||
optimizer (torch.optim.Optimizer): The optimizer that will be used to update the model's parameters.
|
||||
parameters (iterable[torch.Tensor], optional): An iterable of model parameters to compute gradients. If None, returns -1.
|
||||
**kwargs: Additional keyword arguments; nothing will be done with these.
|
||||
|
||||
Returns:
|
||||
float: The gradient norm computed after the optimizer step, if parameters is not None.
|
||||
|
||||
"""
|
||||
loss.backward()
|
||||
if parameters is not None:
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
device = grads[0].device
|
||||
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
||||
else:
|
||||
grad_norm = -1
|
||||
optimizer.step()
|
||||
return grad_norm
|
||||
|
||||
|
||||
def get_cpu_name():
|
||||
"""Get the name of the CPU."""
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("model name"):
|
||||
return line.split(":")[1].strip()
|
||||
return "unknown"
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
"""Make a function to create n-tuples.
|
||||
|
||||
Args:
|
||||
n (int): tuple length
|
||||
|
||||
Returns:
|
||||
function: function to create n-tuples
|
||||
|
||||
"""
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
|
||||
"""Calculate the smallest number >= v that is divisible by divisor.
|
||||
|
||||
This function is primarily used to ensure that the output of a layer
|
||||
is divisible by a certain number, typically to align with hardware
|
||||
optimizations or memory layouts.
|
||||
|
||||
Args:
|
||||
v (int): The input value.
|
||||
divisor (int, optional): The divisor. Defaults to 8.
|
||||
min_value (int, optional): The minimum value to return. If None, defaults to the divisor.
|
||||
round_limit (float, optional): A threshold for rounding down. If the result of rounding down is less than round_limit * v, the next multiple of the divisor is returned instead. Defaults to 0.9.
|
||||
|
||||
Returns:
|
||||
int: The smallest number >= v that is divisible by divisor.
|
||||
|
||||
"""
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < round_limit * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def grad_cam_reshape_transform(tensor):
|
||||
"""Transform the tensor for Grad-CAM calculation.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): input tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: reshaped tensor without [CLS] token.
|
||||
|
||||
"""
|
||||
n_squ = tensor.shape[1]
|
||||
result = tensor[:, 1:] if int(sqrt(n_squ)) ** 2 != n_squ else tensor
|
||||
bs, n, dim = result.shape
|
||||
result = result.reshape(bs, int(sqrt(n)), int(sqrt(n)), dim)
|
||||
|
||||
# Bring the channels to the first dimension,
|
||||
# like in CNNs.
|
||||
return result.transpose(2, 3).transpose(1, 2)
|
||||
BIN
ForAug-supplementary.pdf
Normal file
BIN
ForAug.pdf
Normal file
111
aaai2026.bib
Normal file
@@ -0,0 +1,111 @@
|
||||
@book{em:86,
|
||||
editor = "Engelmore, Robert and Morgan, Anthony",
|
||||
title = "Blackboard Systems",
|
||||
year = 1986,
|
||||
address = "Reading, Mass.",
|
||||
publisher = "Addison-Wesley",
|
||||
}
|
||||
|
||||
@inproceedings{c:83,
|
||||
author = "Clancey, William J.",
|
||||
year = 1983,
|
||||
title = "{Communication, Simulation, and Intelligent
|
||||
Agents: Implications of Personal Intelligent Machines
|
||||
for Medical Education}",
|
||||
booktitle="Proceedings of the Eighth International Joint Conference on Artificial Intelligence {(IJCAI-83)}",
|
||||
pages = "556-560",
|
||||
address = "Menlo Park, Calif",
|
||||
publisher = "{IJCAI Organization}",
|
||||
}
|
||||
@inproceedings{c:84,
|
||||
author = "Clancey, William J.",
|
||||
year = 1984,
|
||||
title = "{Classification Problem Solving}",
|
||||
booktitle = "Proceedings of the Fourth National
|
||||
Conference on Artificial Intelligence",
|
||||
pages = "45-54",
|
||||
address = "Menlo Park, Calif.",
|
||||
publisher="AAAI Press",
|
||||
}
|
||||
@article{r:80,
|
||||
author = {Robinson, Arthur L.},
|
||||
title = {New Ways to Make Microcircuits Smaller},
|
||||
volume = {208},
|
||||
number = {4447},
|
||||
pages = {1019--1022},
|
||||
year = {1980},
|
||||
doi = {10.1126/science.208.4447.1019},
|
||||
publisher = {American Association for the Advancement of Science},
|
||||
issn = {0036-8075},
|
||||
URL = {https://science.sciencemag.org/content/208/4447/1019},
|
||||
eprint = {https://science.sciencemag.org/content/208/4447/1019.full.pdf},
|
||||
journal = {Science},
|
||||
}
|
||||
@article{r:80x,
|
||||
author = "Robinson, Arthur L.",
|
||||
year = 1980,
|
||||
title = "{New Ways to Make Microcircuits Smaller---Duplicate Entry}",
|
||||
journal = "Science",
|
||||
volume = 208,
|
||||
pages = "1019-1026",
|
||||
}
|
||||
@article{hcr:83,
|
||||
title = {Strategic explanations for a diagnostic consultation system},
|
||||
journal = {International Journal of Man-Machine Studies},
|
||||
volume = {20},
|
||||
number = {1},
|
||||
pages = {3-19},
|
||||
year = {1984},
|
||||
issn = {0020-7373},
|
||||
doi = {https://doi.org/10.1016/S0020-7373(84)80003-6},
|
||||
url = {https://www.sciencedirect.com/science/article/pii/S0020737384800036},
|
||||
author = {Diane Warner Hasling and William J. Clancey and Glenn Rennels},
|
||||
abstract = {This article examines the problem of automatte explanation of reasoning, especially as it relates to expert systems. By explanation we mean the ability of a program to discuss what it is doing in some understandable way. We first present a general framework in which to view explanation and review some of the research done in this area. We then focus on the explanation system for NEOMYCIN, a medical consultation program. A consultation program interactively helps a user to solve a problem. Our goal is to have NEOMYCIN explain its problem-solving strategies. An explanation of strategy describes the plan the program is using to reach a solution. Such an explanation is usually concrete, referring to aspects of the current problem situation. Abstract explanations articulate a general principle, which can be applied in different situations; such explanations are useful in teaching and in explaining by analogy. We describe the aspects of NEOMYCIN that make abstract strategic explanations possible—the representation of strategic knowledge explicitly and separately from domain knowledge— and demonstrate how this representation can be used to generate explanations.}
|
||||
}
|
||||
@article{hcrt:83,
|
||||
author = "Hasling, Diane Warner and Clancey, William J. and Rennels, Glenn R. and Test, Thomas",
|
||||
year = 1983,
|
||||
title = "{Strategic Explanations in Consultation---Duplicate}",
|
||||
journal = "The International Journal of Man-Machine Studies",
|
||||
volume = 20,
|
||||
number = 1,
|
||||
pages = "3-19",
|
||||
}
|
||||
@techreport{r:86,
|
||||
author = "Rice, James",
|
||||
year = 1986,
|
||||
title = "{Poligon: A System for Parallel Problem Solving}",
|
||||
type = "Technical Report",
|
||||
number = "KSL-86-19",
|
||||
institution = "Dept.\ of Computer Science, Stanford Univ.",
|
||||
}
|
||||
@phdthesis{c:79,
|
||||
author = "Clancey, William J.",
|
||||
year = 1979,
|
||||
title = "{Transfer of Rule-Based Expertise
|
||||
through a Tutorial Dialogue}",
|
||||
type = "{Ph.D.} diss.",
|
||||
school = "Dept.\ of Computer Science, Stanford Univ.",
|
||||
address = "Stanford, Calif.",
|
||||
}
|
||||
@unpublished{c:21,
|
||||
author = "Clancey, William J.",
|
||||
title = "{The Engineering of Qualitative Models}",
|
||||
year = 2021,
|
||||
note = "Forthcoming",
|
||||
}
|
||||
@misc{c:22,
|
||||
title={Attention Is All You Need},
|
||||
author={Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
|
||||
year={2017},
|
||||
eprint={1706.03762},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
@misc{c:23,
|
||||
title = "Pluto: The 'Other' Red Planet",
|
||||
author = "{NASA}",
|
||||
howpublished = "\url{https://www.nasa.gov/nh/pluto-the-other-red-planet}",
|
||||
year = 2015,
|
||||
note = "Accessed: 2018-12-06"
|
||||
}
|
||||
315
aaai2026.sty
Normal file
@@ -0,0 +1,315 @@
|
||||
\NeedsTeXFormat{LaTeX2e}%
|
||||
\ProvidesPackage{aaai2026}[2026/04/29 AAAI 2026 Submission format]%
|
||||
\def\year{2026}%
|
||||
\typeout{Conference Style for AAAI for LaTeX 2e -- version for submission}%
|
||||
%
|
||||
\def\copyright@on{T}
|
||||
\def\showauthors@on{T}
|
||||
\def\nocopyright{\gdef\copyright@on{}} % Copyright notice is required for camera-ready only.
|
||||
\DeclareOption{submission}{%
|
||||
\gdef\copyright@on{}%
|
||||
\gdef\showauthors@on{}%
|
||||
\long\gdef\pdfinfo #1{\relax}%
|
||||
}%
|
||||
\DeclareOption{draft}{%
|
||||
\gdef\copyright@on{}%
|
||||
}%
|
||||
\ProcessOptions\relax%
|
||||
% WARNING: IF YOU ARE USING THIS STYLE SHEET FOR AN AAAI PUBLICATION, YOU
|
||||
% MAY NOT MODIFY IT FOR ANY REASON. MODIFICATIONS (IN YOUR SOURCE
|
||||
% OR IN THIS STYLE SHEET WILL RESULT IN REJECTION OF YOUR PAPER).
|
||||
%
|
||||
% WARNING: This style is NOT guaranteed to work. It is provided in the
|
||||
% hope that it might make the preparation of papers easier, but this style
|
||||
% file is provided "as is" without warranty of any kind, either express or
|
||||
% implied, including but not limited to the implied warranties of
|
||||
% merchantability, fitness for a particular purpose, or noninfringement.
|
||||
% You use this style file at your own risk. Standard disclaimers apply.
|
||||
% There are undoubtably bugs in this style. If you would like to submit
|
||||
% bug fixes, improvements, etc. please let us know. Please use the contact form
|
||||
% at www.aaai.org.
|
||||
%
|
||||
% Do not use this file unless you are an experienced LaTeX user.
|
||||
%
|
||||
% PHYSICAL PAGE LAYOUT
|
||||
\setlength\topmargin{-0.25in} \setlength\oddsidemargin{-0.25in}
|
||||
\setlength\textheight{9.0in} \setlength\textwidth{7.0in}
|
||||
\setlength\columnsep{0.375in} \newlength\titlebox \setlength\titlebox{2.25in}
|
||||
\setlength\headheight{0pt} \setlength\headsep{0pt}
|
||||
%\setlength\footheight{0pt} \setlength\footskip{0pt}
|
||||
\thispagestyle{empty} \pagestyle{empty}
|
||||
\flushbottom \twocolumn \sloppy
|
||||
% We're never going to need a table of contents, so just flush it to
|
||||
% save space --- suggested by drstrip@sandia-2
|
||||
\def\addcontentsline#1#2#3{}
|
||||
% gf: PRINT COPYRIGHT NOTICE
|
||||
\def\copyright@year{\number\year}
|
||||
\def\copyright@text{Copyright \copyright\space \copyright@year,
|
||||
Association for the Advancement of Artificial Intelligence (www.aaai.org).
|
||||
All rights reserved.}
|
||||
\def\copyrighttext#1{\gdef\copyright@on{T}\gdef\copyright@text{#1}}
|
||||
\def\copyrightyear#1{\gdef\copyright@on{T}\gdef\copyright@year{#1}}
|
||||
% gf: End changes for copyright notice (used in \maketitle, below)
|
||||
% Title stuff, taken from deproc.
|
||||
%
|
||||
\def\maketitle{%
|
||||
\par%
|
||||
\begingroup % to make the footnote style local to the title
|
||||
\def\thefootnote{\fnsymbol{footnote}}
|
||||
\twocolumn[\@maketitle] \@thanks%
|
||||
\endgroup%
|
||||
% Insert copyright slug unless turned off
|
||||
\if T\copyright@on\insert\footins{\noindent\footnotesize\copyright@text}\fi%
|
||||
%
|
||||
\setcounter{footnote}{0}%
|
||||
\let\maketitle\relax%
|
||||
\let\@maketitle\relax%
|
||||
\gdef\@thanks{}%
|
||||
\gdef\@author{}%
|
||||
\gdef\@title{}%
|
||||
\let\thanks\relax%
|
||||
}%
|
||||
\long\gdef\affiliations #1{ \def \affiliations_{\if T\showauthors@on#1\fi}}%
|
||||
%
|
||||
\def\@maketitle{%
|
||||
\def\theauthors{\if T\showauthors@on\@author\else Anonymous submission\fi}
|
||||
\newcounter{eqfn}\setcounter{eqfn}{0}%
|
||||
\newsavebox{\titlearea}
|
||||
\sbox{\titlearea}{
|
||||
\let\footnote\relax\let\thanks\relax%
|
||||
\setcounter{footnote}{0}%
|
||||
\def\equalcontrib{%
|
||||
\ifnum\value{eqfn}=0%
|
||||
\footnote{These authors contributed equally.}%
|
||||
\setcounter{eqfn}{\value{footnote}}%
|
||||
\else%
|
||||
\footnotemark[\value{eqfn}]%
|
||||
\fi%
|
||||
}%
|
||||
\vbox{%
|
||||
\hsize\textwidth%
|
||||
\linewidth\hsize%
|
||||
\vskip 0.625in minus 0.125in%
|
||||
\centering%
|
||||
{\LARGE\bf \@title \par}%
|
||||
\vskip 0.1in plus 0.5fil minus 0.05in%
|
||||
{\Large{\textbf{\theauthors\ifhmode\\\fi}}}%
|
||||
\vskip .2em plus 0.25fil%
|
||||
{\normalsize \affiliations_\ifhmode\\\fi}%
|
||||
\vskip 1em plus 2fil%
|
||||
}%
|
||||
}%
|
||||
%
|
||||
\newlength\actualheight%
|
||||
\settoheight{\actualheight}{\usebox{\titlearea}}%
|
||||
\ifdim\actualheight>\titlebox%
|
||||
\setlength{\titlebox}{\actualheight}%
|
||||
\fi%
|
||||
%
|
||||
\vbox to \titlebox {%
|
||||
\let\footnote\thanks\relax%
|
||||
\setcounter{footnote}{0}%
|
||||
\def\equalcontrib{%
|
||||
\ifnum\value{eqfn}=0%
|
||||
\footnote{These authors contributed equally.}%
|
||||
\setcounter{eqfn}{\value{footnote}}%
|
||||
\else%
|
||||
\footnotemark[\value{eqfn}]%
|
||||
\fi%
|
||||
}%
|
||||
\hsize\textwidth%
|
||||
\linewidth\hsize%
|
||||
\vskip 0.625in minus 0.125in%
|
||||
\centering%
|
||||
{\LARGE\bf \@title \par}%
|
||||
\vskip 0.1in plus 0.5fil minus 0.05in%
|
||||
{\Large{\textbf{\theauthors\ifhmode\\\fi}}}%
|
||||
\vskip .2em plus 0.25fil%
|
||||
{\normalsize \affiliations_\ifhmode\\\fi}%
|
||||
\vskip 1em plus 2fil%
|
||||
}%
|
||||
}%
|
||||
%
|
||||
\renewenvironment{abstract}{%
|
||||
\centerline{\bf Abstract}%
|
||||
\vspace{0.5ex}%
|
||||
\setlength{\leftmargini}{10pt}%
|
||||
\begin{quote}%
|
||||
\small%
|
||||
}{%
|
||||
\par%
|
||||
\end{quote}%
|
||||
\vskip 1ex%
|
||||
}%
|
||||
\newenvironment{links}{%
|
||||
\newcommand{\link}[2]{\par\textbf{##1} --- \url{##2}}%
|
||||
\setlength{\hangindent}{10pt}%
|
||||
\setlength{\parskip}{2pt}%
|
||||
\begin{flushleft}%
|
||||
}{%
|
||||
\end{flushleft}%
|
||||
\vskip 1ex%
|
||||
}%
|
||||
% jsp added:
|
||||
\def\pubnote#1{
|
||||
\thispagestyle{myheadings}%
|
||||
\pagestyle{myheadings}%
|
||||
\markboth{#1}{#1}%
|
||||
\setlength\headheight{10pt}%
|
||||
\setlength\headsep{10pt}%
|
||||
}%
|
||||
%
|
||||
% SECTIONS with less space
|
||||
\def\section{\@startsection {section}{1}{\z@}{-2.0ex plus
|
||||
-0.5ex minus -.2ex}{3pt plus 2pt minus 1pt}{\Large\bf\centering}}
|
||||
\def\subsection{\@startsection{subsection}{2}{\z@}{-2.0ex plus
|
||||
-0.5ex minus -.2ex}{3pt plus 2pt minus 1pt}{\large\bf\raggedright}}
|
||||
\def\subsubsection{\@startsection{subparagraph}{3}{\z@}{-6pt plus
|
||||
%%% DIEGO changed: 29/11/2009
|
||||
%% 2pt minus 1pt}{-1em}{\normalsize\bf}}
|
||||
-2pt minus -1pt}{-1em}{\normalsize\bf}}
|
||||
%%% END changed
|
||||
\renewcommand\paragraph{\@startsection{paragraph}{4}{\z@}{-6pt plus -2pt minus -1pt}{-1em}{\normalsize\bf}}%
|
||||
\setcounter{secnumdepth}{0}
|
||||
% add period to section (but not subsection) numbers, reduce space after
|
||||
%\renewcommand{\thesection}
|
||||
% {\arabic{section}.\hskip-0.6em}
|
||||
%\renewcommand{\thesubsection}
|
||||
% {\arabic{section}.\arabic{subsection}\hskip-0.6em}
|
||||
% FOOTNOTES
|
||||
\footnotesep 6.65pt %
|
||||
\skip\footins 9pt plus 4pt minus 2pt
|
||||
\def\footnoterule{\kern-3pt \hrule width 5pc \kern 2.6pt }
|
||||
\setcounter{footnote}{0}
|
||||
% LISTS AND PARAGRAPHS
|
||||
\parindent 10pt
|
||||
\topsep 4pt plus 1pt minus 2pt
|
||||
\partopsep 1pt plus 0.5pt minus 0.5pt
|
||||
\itemsep 0.5pt plus 1pt minus 0.5pt
|
||||
\parsep 2pt plus 1pt minus 0.5pt
|
||||
\leftmargin 10pt \leftmargini 13pt \leftmarginii 10pt \leftmarginiii 5pt \leftmarginiv 5pt \leftmarginv 5pt \leftmarginvi 5pt
|
||||
\labelwidth\leftmargini\advance\labelwidth-\labelsep \labelsep 5pt
|
||||
\def\@listi{\leftmargin\leftmargini}
|
||||
\def\@listii{\leftmargin\leftmarginii
|
||||
\labelwidth\leftmarginii\advance\labelwidth-\labelsep
|
||||
\topsep 2pt plus 1pt minus 0.5pt
|
||||
\parsep 1pt plus 0.5pt minus 0.5pt
|
||||
\itemsep \parsep}
|
||||
\def\@listiii{\leftmargin\leftmarginiii
|
||||
\labelwidth\leftmarginiii\advance\labelwidth-\labelsep
|
||||
\topsep 1pt plus 0.5pt minus 0.5pt
|
||||
\parsep \z@
|
||||
\partopsep 0.5pt plus 0pt minus 0.5pt
|
||||
\itemsep \topsep}
|
||||
\def\@listiv{\leftmargin\leftmarginiv
|
||||
\labelwidth\leftmarginiv\advance\labelwidth-\labelsep}
|
||||
\def\@listv{\leftmargin\leftmarginv
|
||||
\labelwidth\leftmarginv\advance\labelwidth-\labelsep}
|
||||
\def\@listvi{\leftmargin\leftmarginvi
|
||||
\labelwidth\leftmarginvi\advance\labelwidth-\labelsep}
|
||||
\abovedisplayskip 7pt plus2pt minus5pt%
|
||||
\belowdisplayskip \abovedisplayskip
|
||||
\abovedisplayshortskip 0pt plus3pt%
|
||||
\belowdisplayshortskip 4pt plus3pt minus3pt%
|
||||
% Less leading in most fonts (due to the narrow columns)
|
||||
% The choices were between 1-pt and 1.5-pt leading
|
||||
\def\normalsize{\@setfontsize\normalsize\@xpt{11}} % 10 point on 11
|
||||
\def\small{\@setfontsize\small\@ixpt{10}} % 9 point on 10
|
||||
\def\footnotesize{\@setfontsize\footnotesize\@ixpt{10}} % 9 point on 10
|
||||
\def\scriptsize{\@setfontsize\scriptsize\@viipt{10}} % 7 point on 8
|
||||
\def\tiny{\@setfontsize\tiny\@vipt{7}} % 6 point on 7
|
||||
\def\large{\@setfontsize\large\@xipt{12}} % 11 point on 12
|
||||
\def\Large{\@setfontsize\Large\@xiipt{14}} % 12 point on 14
|
||||
\def\LARGE{\@setfontsize\LARGE\@xivpt{16}} % 14 point on 16
|
||||
\def\huge{\@setfontsize\huge\@xviipt{20}} % 17 point on 20
|
||||
\def\Huge{\@setfontsize\Huge\@xxpt{23}} % 20 point on 23
|
||||
|
||||
\AtBeginDocument{%
|
||||
\@ifpackageloaded{natbib}%
|
||||
{%
|
||||
% When natbib is in use, set the proper style and fix a few things
|
||||
\let\cite\citep
|
||||
\let\shortcite\citeyearpar
|
||||
\setcitestyle{aysep={}}
|
||||
\setlength\bibhang{0pt}
|
||||
\bibliographystyle{aaai2026}
|
||||
}{}%
|
||||
\@ifpackageloaded{hyperref}%
|
||||
{%
|
||||
\PackageError{aaai}{You must not use hyperref in AAAI papers.}{You (or one of the packages you imported) are importing the hyperref package, which is forbidden in AAAI papers. You must remove it from the paper to proceed.}
|
||||
}{}%
|
||||
\@ifpackageloaded{bbm}%
|
||||
{%
|
||||
\PackageError{aaai}{You must not use bbm package in AAAI papers because it introduces Type 3 fonts which are forbidden.}{See https://tex.stackexchange.com/questions/479160/a-replacement-to-mathbbm1-with-type-1-fonts for possible alternatives.}
|
||||
}{}%
|
||||
\@ifpackageloaded{authblk}%
|
||||
{%
|
||||
\PackageError{aaai}{Package authblk is forbbidden.}{Package authblk is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{balance}%
|
||||
{%
|
||||
\PackageError{aaai}{Package balance is forbbidden.}{Package balance is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{CJK}%
|
||||
{%
|
||||
\PackageError{aaai}{Package CJK is forbbidden.}{Package CJK is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{flushend}%
|
||||
{%
|
||||
\PackageError{aaai}{Package flushend is forbbidden.}{Package flushend is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{fontenc}%
|
||||
{%
|
||||
\PackageError{aaai}{Package fontenc is forbbidden.}{Package fontenc is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{fullpage}%
|
||||
{%
|
||||
\PackageError{aaai}{Package fullpage is forbbidden.}{Package fullpage is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{geometry}%
|
||||
{%
|
||||
\PackageError{aaai}{Package geometry is forbbidden.}{Package geometry is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{grffile}%
|
||||
{%
|
||||
\PackageError{aaai}{Package grffile is forbbidden.}{Package grffile is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{navigator}%
|
||||
{%
|
||||
\PackageError{aaai}{Package navigator is forbbidden.}{Package navigator is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{savetrees}%
|
||||
{%
|
||||
\PackageError{aaai}{Package savetrees is forbbidden.}{Package savetrees is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{setspace}%
|
||||
{%
|
||||
\PackageError{aaai}{Package setspace is forbbidden.}{Package setspace is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{stfloats}%
|
||||
{%
|
||||
\PackageError{aaai}{Package stfloats is forbbidden.}{Package stfloats is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{tabu}%
|
||||
{%
|
||||
\PackageError{aaai}{Package tabu is forbbidden.}{Package tabu is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{titlesec}%
|
||||
{%
|
||||
\PackageError{aaai}{Package titlesec is forbbidden.}{Package titlesec is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{tocbibind}%
|
||||
{%
|
||||
\PackageError{aaai}{Package tocbibind is forbbidden.}{Package tocbibind is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{ulem}%
|
||||
{%
|
||||
\PackageError{aaai}{Package ulem is forbbidden.}{Package ulem is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
\@ifpackageloaded{wrapfig}%
|
||||
{%
|
||||
\PackageError{aaai}{Package wrapfig is forbbidden.}{Package wrapfig is forbbiden. You must find an alternative.}
|
||||
}{}%
|
||||
}
|
||||
|
||||
\let\endthebibliography=\endlist
|
||||
817
anonymous-submission-latex-2026.tex
Normal file
@@ -0,0 +1,817 @@
|
||||
%File: anonymous-submission-latex-2026.tex
|
||||
\documentclass[letterpaper]{article} % DO NOT CHANGE THIS
|
||||
\usepackage[submission]{aaai2026} % DO NOT CHANGE THIS
|
||||
\usepackage{times} % DO NOT CHANGE THIS
|
||||
\usepackage{helvet} % DO NOT CHANGE THIS
|
||||
\usepackage{courier} % DO NOT CHANGE THIS
|
||||
\usepackage[hyphens]{url} % DO NOT CHANGE THIS
|
||||
\usepackage{graphicx} % DO NOT CHANGE THIS
|
||||
\urlstyle{rm} % DO NOT CHANGE THIS
|
||||
\def\UrlFont{\rm} % DO NOT CHANGE THIS
|
||||
\usepackage{natbib} % DO NOT CHANGE THIS AND DO NOT ADD ANY OPTIONS TO IT
|
||||
\usepackage{caption} % DO NOT CHANGE THIS AND DO NOT ADD ANY OPTIONS TO IT
|
||||
\frenchspacing % DO NOT CHANGE THIS
|
||||
\setlength{\pdfpagewidth}{8.5in} % DO NOT CHANGE THIS
|
||||
\setlength{\pdfpageheight}{11in} % DO NOT CHANGE THIS
|
||||
%
|
||||
% These are recommended to typeset algorithms but not required. See the subsubsection on algorithms. Remove them if you don't have algorithms in your paper.
|
||||
\usepackage{algorithm}
|
||||
\usepackage{algorithmic}
|
||||
|
||||
%
|
||||
% These are are recommended to typeset listings but not required. See the subsubsection on listing. Remove this block if you don't have listings in your paper.
|
||||
\usepackage{newfloat}
|
||||
\usepackage{listings}
|
||||
\DeclareCaptionStyle{ruled}{labelfont=normalfont,labelsep=colon,strut=off} % DO NOT CHANGE THIS
|
||||
\lstset{%
|
||||
basicstyle={\footnotesize\ttfamily},% footnotesize acceptable for monospace
|
||||
numbers=left,numberstyle=\footnotesize,xleftmargin=2em,% show line numbers, remove this entire line if you don't want the numbers.
|
||||
aboveskip=0pt,belowskip=0pt,%
|
||||
showstringspaces=false,tabsize=2,breaklines=true}
|
||||
\floatstyle{ruled}
|
||||
\newfloat{listing}{tb}{lst}{}
|
||||
\floatname{listing}{Listing}
|
||||
%
|
||||
% Keep the \pdfinfo as shown here. There's no need
|
||||
% for you to add the /Title and /Author tags.
|
||||
\pdfinfo{
|
||||
/TemplateVersion (2026.1)
|
||||
}
|
||||
|
||||
% DISALLOWED PACKAGES
|
||||
% \usepackage{authblk} -- This package is specifically forbidden
|
||||
% \usepackage{balance} -- This package is specifically forbidden
|
||||
% \usepackage{color (if used in text)
|
||||
% \usepackage{CJK} -- This package is specifically forbidden
|
||||
% \usepackage{float} -- This package is specifically forbidden
|
||||
% \usepackage{flushend} -- This package is specifically forbidden
|
||||
% \usepackage{fontenc} -- This package is specifically forbidden
|
||||
% \usepackage{fullpage} -- This package is specifically forbidden
|
||||
% \usepackage{geometry} -- This package is specifically forbidden
|
||||
% \usepackage{grffile} -- This package is specifically forbidden
|
||||
% \usepackage{hyperref} -- This package is specifically forbidden
|
||||
% \usepackage{navigator} -- This package is specifically forbidden
|
||||
% (or any other package that embeds links such as navigator or hyperref)
|
||||
% \indentfirst} -- This package is specifically forbidden
|
||||
% \layout} -- This package is specifically forbidden
|
||||
% \multicol} -- This package is specifically forbidden
|
||||
% \nameref} -- This package is specifically forbidden
|
||||
% \usepackage{savetrees} -- This package is specifically forbidden
|
||||
% \usepackage{setspace} -- This package is specifically forbidden
|
||||
% \usepackage{stfloats} -- This package is specifically forbidden
|
||||
% \usepackage{tabu} -- This package is specifically forbidden
|
||||
% \usepackage{titlesec} -- This package is specifically forbidden
|
||||
% \usepackage{tocbibind} -- This package is specifically forbidden
|
||||
% \usepackage{ulem} -- This package is specifically forbidden
|
||||
% \usepackage{wrapfig} -- This package is specifically forbidden
|
||||
% DISALLOWED COMMANDS
|
||||
% \nocopyright -- Your paper will not be published if you use this command
|
||||
% \addtolength -- This command may not be used
|
||||
% \balance -- This command may not be used
|
||||
% \baselinestretch -- Your paper will not be published if you use this command
|
||||
% \clearpage -- No page breaks of any kind may be used for the final version of your paper
|
||||
% \columnsep -- This command may not be used
|
||||
% \newpage -- No page breaks of any kind may be used for the final version of your paper
|
||||
% \pagebreak -- No page breaks of any kind may be used for the final version of your paperr
|
||||
% \pagestyle -- This command may not be used
|
||||
% \tiny -- This is not an acceptable font size.
|
||||
% \vspace{- -- No negative value may be used in proximity of a caption, figure, table, section, subsection, subsubsection, or reference
|
||||
% \vskip{- -- No negative value may be used to alter spacing above or below a caption, figure, table, section, subsection, subsubsection, or reference
|
||||
|
||||
\setcounter{secnumdepth}{0} %May be changed to 1 or 2 if section numbers are desired.
|
||||
|
||||
% The file aaai2026.sty is the style file for AAAI Press
|
||||
% proceedings, working notes, and technical reports.
|
||||
%
|
||||
|
||||
% Title
|
||||
|
||||
% Your title must be in mixed case, not sentence case.
|
||||
% That means all verbs (including short verbs like be, is, using,and go),
|
||||
% nouns, adverbs, adjectives should be capitalized, including both words in hyphenated terms, while
|
||||
% articles, conjunctions, and prepositions are lower case unless they
|
||||
% directly follow a colon or long dash
|
||||
\title{AAAI Press Anonymous Submission\\Instructions for Authors Using \LaTeX{}}
|
||||
\author{
|
||||
%Authors
|
||||
% All authors must be in the same font size and format.
|
||||
Written by AAAI Press Staff\textsuperscript{\rm 1}\thanks{With help from the AAAI Publications Committee.}\\
|
||||
AAAI Style Contributions by Pater Patel Schneider,
|
||||
Sunil Issar,\\
|
||||
J. Scott Penberthy,
|
||||
George Ferguson,
|
||||
Hans Guesgen,
|
||||
Francisco Cruz\equalcontrib,
|
||||
Marc Pujol-Gonzalez\equalcontrib
|
||||
}
|
||||
\affiliations{
|
||||
%Afiliations
|
||||
\textsuperscript{\rm 1}Association for the Advancement of Artificial Intelligence\\
|
||||
% If you have multiple authors and multiple affiliations
|
||||
% use superscripts in text and roman font to identify them.
|
||||
% For example,
|
||||
|
||||
% Sunil Issar\textsuperscript{\rm 2},
|
||||
% J. Scott Penberthy\textsuperscript{\rm 3},
|
||||
% George Ferguson\textsuperscript{\rm 4},
|
||||
% Hans Guesgen\textsuperscript{\rm 5}
|
||||
% Note that the comma should be placed after the superscript
|
||||
|
||||
1101 Pennsylvania Ave, NW Suite 300\\
|
||||
Washington, DC 20004 USA\\
|
||||
% email address must be in roman text type, not monospace or sans serif
|
||||
proceedings-questions@aaai.org
|
||||
%
|
||||
% See more examples next
|
||||
}
|
||||
|
||||
%Example, Single Author, ->> remove \iffalse,\fi and place them surrounding AAAI title to use it
|
||||
\iffalse
|
||||
\title{My Publication Title --- Single Author}
|
||||
\author {
|
||||
Author Name
|
||||
}
|
||||
\affiliations{
|
||||
Affiliation\\
|
||||
Affiliation Line 2\\
|
||||
name@example.com
|
||||
}
|
||||
\fi
|
||||
|
||||
\iffalse
|
||||
%Example, Multiple Authors, ->> remove \iffalse,\fi and place them surrounding AAAI title to use it
|
||||
\title{My Publication Title --- Multiple Authors}
|
||||
\author {
|
||||
% Authors
|
||||
First Author Name\textsuperscript{\rm 1},
|
||||
Second Author Name\textsuperscript{\rm 2},
|
||||
Third Author Name\textsuperscript{\rm 1}
|
||||
}
|
||||
\affiliations {
|
||||
% Affiliations
|
||||
\textsuperscript{\rm 1}Affiliation 1\\
|
||||
\textsuperscript{\rm 2}Affiliation 2\\
|
||||
firstAuthor@affiliation1.com, secondAuthor@affilation2.com, thirdAuthor@affiliation1.com
|
||||
}
|
||||
\fi
|
||||
|
||||
|
||||
% REMOVE THIS: bibentry
|
||||
% This is only needed to show inline citations in the guidelines document. You should not need it and can safely delete it.
|
||||
\usepackage{bibentry}
|
||||
% END REMOVE bibentry
|
||||
|
||||
\begin{document}
|
||||
|
||||
\maketitle
|
||||
|
||||
\begin{abstract}
|
||||
AAAI creates proceedings, working notes, and technical reports directly from electronic source furnished by the authors. To ensure that all papers in the publication have a uniform appearance, authors must adhere to the following instructions.
|
||||
\end{abstract}
|
||||
|
||||
% Uncomment the following to link to your code, datasets, an extended version or similar.
|
||||
% You must keep this block between (not within) the abstract and the main body of the paper.
|
||||
% \begin{links}
|
||||
% \link{Code}{https://aaai.org/example/code}
|
||||
% \link{Datasets}{https://aaai.org/example/datasets}
|
||||
% \link{Extended version}{https://aaai.org/example/extended-version}
|
||||
% \end{links}
|
||||
|
||||
\section{Preparing an Anonymous Submission}
|
||||
|
||||
This document details the formatting requirements for anonymous submissions. The requirements are the same as for camera ready papers but with a few notable differences:
|
||||
|
||||
\begin{itemize}
|
||||
\item Anonymous submissions must not include the author names and affiliations. Write ``Anonymous Submission'' as the ``sole author'' and leave the affiliations empty.
|
||||
\item The PDF document's metadata should be cleared with a metadata-cleaning tool before submitting it. This is to prevent leaked information from revealing your identity.
|
||||
\item References must be anonymized whenever the reader can infer that they are to the authors' previous work.
|
||||
\item AAAI's copyright notice should not be included as a footer in the first page.
|
||||
\item Only the PDF version is required at this stage. No source versions will be requested, nor any copyright transfer form.
|
||||
\end{itemize}
|
||||
|
||||
You can remove the copyright notice and ensure that your names aren't shown by including \texttt{submission} option when loading the \texttt{aaai2026} package:
|
||||
|
||||
\begin{quote}\begin{scriptsize}\begin{verbatim}
|
||||
\documentclass[letterpaper]{article}
|
||||
\usepackage[submission]{aaai2026}
|
||||
\end{verbatim}\end{scriptsize}\end{quote}
|
||||
|
||||
The remainder of this document are the original camera-
|
||||
ready instructions. Any contradiction of the above points
|
||||
ought to be ignored while preparing anonymous submis-
|
||||
sions.
|
||||
|
||||
\section{Camera-Ready Guidelines}
|
||||
|
||||
Congratulations on having a paper selected for inclusion in an AAAI Press proceedings or technical report! This document details the requirements necessary to get your accepted paper published using PDF\LaTeX{}. If you are using Microsoft Word, instructions are provided in a different document. AAAI Press does not support any other formatting software.
|
||||
|
||||
The instructions herein are provided as a general guide for experienced \LaTeX{} users. If you do not know how to use \LaTeX{}, please obtain assistance locally. AAAI cannot provide you with support and the accompanying style files are \textbf{not} guaranteed to work. If the results you obtain are not in accordance with the specifications you received, you must correct your source file to achieve the correct result.
|
||||
|
||||
These instructions are generic. Consequently, they do not include specific dates, page charges, and so forth. Please consult your specific written conference instructions for details regarding your submission. Please review the entire document for specific instructions that might apply to your particular situation. All authors must comply with the following:
|
||||
|
||||
\begin{itemize}
|
||||
\item You must use the 2026 AAAI Press \LaTeX{} style file and the aaai2026.bst bibliography style files, which are located in the 2026 AAAI Author Kit (aaai2026.sty, aaai2026.bst).
|
||||
\item You must complete, sign, and return by the deadline the AAAI copyright form (unless directed by AAAI Press to use the AAAI Distribution License instead).
|
||||
\item You must read and format your paper source and PDF according to the formatting instructions for authors.
|
||||
\item You must submit your electronic files and abstract using our electronic submission form \textbf{on time.}
|
||||
\item You must pay any required page or formatting charges to AAAI Press so that they are received by the deadline.
|
||||
\item You must check your paper before submitting it, ensuring that it compiles without error, and complies with the guidelines found in the AAAI Author Kit.
|
||||
\end{itemize}
|
||||
|
||||
\section{Copyright}
|
||||
All papers submitted for publication by AAAI Press must be accompanied by a valid signed copyright form. They must also contain the AAAI copyright notice at the bottom of the first page of the paper. There are no exceptions to these requirements. If you fail to provide us with a signed copyright form or disable the copyright notice, we will be unable to publish your paper. There are \textbf{no exceptions} to this policy. You will find a PDF version of the AAAI copyright form in the AAAI AuthorKit. Please see the specific instructions for your conference for submission details.
|
||||
|
||||
\section{Formatting Requirements in Brief}
|
||||
We need source and PDF files that can be used in a variety of ways and can be output on a variety of devices. The design and appearance of the paper is strictly governed by the aaai style file (aaai2026.sty).
|
||||
\textbf{You must not make any changes to the aaai style file, nor use any commands, packages, style files, or macros within your own paper that alter that design, including, but not limited to spacing, floats, margins, fonts, font size, and appearance.} AAAI imposes requirements on your source and PDF files that must be followed. Most of these requirements are based on our efforts to standardize conference manuscript properties and layout. All papers submitted to AAAI for publication will be recompiled for standardization purposes. Consequently, every paper submission must comply with the following requirements:
|
||||
|
||||
\begin{itemize}
|
||||
\item Your .tex file must compile in PDF\LaTeX{} --- (you may not include .ps or .eps figure files.)
|
||||
\item All fonts must be embedded in the PDF file --- including your figures.
|
||||
\item Modifications to the style file, whether directly or via commands in your document may not ever be made, most especially when made in an effort to avoid extra page charges or make your paper fit in a specific number of pages.
|
||||
\item No type 3 fonts may be used (even in illustrations).
|
||||
\item You may not alter the spacing above and below captions, figures, headings, and subheadings.
|
||||
\item You may not alter the font sizes of text elements, footnotes, heading elements, captions, or title information (for references and mathematics, please see the limited exceptions provided herein).
|
||||
\item You may not alter the line spacing of text.
|
||||
\item Your title must follow Title Case capitalization rules (not sentence case).
|
||||
\item \LaTeX{} documents must use the Times or Nimbus font package (you may not use Computer Modern for the text of your paper).
|
||||
\item No \LaTeX{} 209 documents may be used or submitted.
|
||||
\item Your source must not require use of fonts for non-Roman alphabets within the text itself. If your paper includes symbols in other languages (such as, but not limited to, Arabic, Chinese, Hebrew, Japanese, Thai, Russian and other Cyrillic languages), you must restrict their use to bit-mapped figures. Fonts that require non-English language support (CID and Identity-H) must be converted to outlines or 300 dpi bitmap or removed from the document (even if they are in a graphics file embedded in the document).
|
||||
\item Two-column format in AAAI style is required for all papers.
|
||||
\item The paper size for final submission must be US letter without exception.
|
||||
\item The source file must exactly match the PDF.
|
||||
\item The document margins may not be exceeded (no overfull boxes).
|
||||
\item The number of pages and the file size must be as specified for your event.
|
||||
\item No document may be password protected.
|
||||
\item Neither the PDFs nor the source may contain any embedded links or bookmarks (no hyperref or navigator packages).
|
||||
\item Your source and PDF must not have any page numbers, footers, or headers (no pagestyle commands).
|
||||
\item Your PDF must be compatible with Acrobat 5 or higher.
|
||||
\item Your \LaTeX{} source file (excluding references) must consist of a \textbf{single} file (use of the ``input" command is not allowed.
|
||||
\item Your graphics must be sized appropriately outside of \LaTeX{} (do not use the ``clip" or ``trim'' command) .
|
||||
\end{itemize}
|
||||
|
||||
If you do not follow these requirements, your paper will be returned to you to correct the deficiencies.
|
||||
|
||||
\section{What Files to Submit}
|
||||
You must submit the following items to ensure that your paper is published:
|
||||
\begin{itemize}
|
||||
\item A fully-compliant PDF file.
|
||||
\item Your \LaTeX{} source file submitted as a \textbf{single} .tex file (do not use the ``input" command to include sections of your paper --- every section must be in the single source file). (The only allowable exception is .bib file, which should be included separately).
|
||||
\item The bibliography (.bib) file(s).
|
||||
\item Your source must compile on our system, which includes only standard \LaTeX{} 2020 TeXLive support files.
|
||||
\item Only the graphics files used in compiling paper.
|
||||
\item The \LaTeX{}-generated files (e.g. .aux, .bbl file, PDF, etc.).
|
||||
\end{itemize}
|
||||
|
||||
Your \LaTeX{} source will be reviewed and recompiled on our system (if it does not compile, your paper will be returned to you. \textbf{Do not submit your source in multiple text files.} Your single \LaTeX{} source file must include all your text, your bibliography (formatted using aaai2026.bst), and any custom macros.
|
||||
|
||||
Your files should work without any supporting files (other than the program itself) on any computer with a standard \LaTeX{} distribution.
|
||||
|
||||
\textbf{Do not send files that are not actually used in the paper.} Avoid including any files not needed for compiling your paper, including, for example, this instructions file, unused graphics files, style files, additional material sent for the purpose of the paper review, intermediate build files and so forth.
|
||||
|
||||
\textbf{Obsolete style files.} The commands for some common packages (such as some used for algorithms), may have changed. Please be certain that you are not compiling your paper using old or obsolete style files.
|
||||
|
||||
\textbf{Final Archive.} Place your source files in a single archive which should be compressed using .zip. The final file size may not exceed 10 MB.
|
||||
Name your source file with the last (family) name of the first author, even if that is not you.
|
||||
|
||||
|
||||
\section{Using \LaTeX{} to Format Your Paper}
|
||||
|
||||
The latest version of the AAAI style file is available on AAAI's website. Download this file and place it in the \TeX\ search path. Placing it in the same directory as the paper should also work. You must download the latest version of the complete AAAI Author Kit so that you will have the latest instruction set and style file.
|
||||
|
||||
\subsection{Document Preamble}
|
||||
|
||||
In the \LaTeX{} source for your paper, you \textbf{must} place the following lines as shown in the example in this subsection. This command set-up is for three authors. Add or subtract author and address lines as necessary, and uncomment the portions that apply to you. In most instances, this is all you need to do to format your paper in the Times font. The helvet package will cause Helvetica to be used for sans serif. These files are part of the PSNFSS2e package, which is freely available from many Internet sites (and is often part of a standard installation).
|
||||
|
||||
Leave the setcounter for section number depth commented out and set at 0 unless you want to add section numbers to your paper. If you do add section numbers, you must uncomment this line and change the number to 1 (for section numbers), or 2 (for section and subsection numbers). The style file will not work properly with numbering of subsubsections, so do not use a number higher than 2.
|
||||
|
||||
\subsubsection{The Following Must Appear in Your Preamble}
|
||||
\begin{quote}
|
||||
\begin{scriptsize}\begin{verbatim}
|
||||
\documentclass[letterpaper]{article}
|
||||
% DO NOT CHANGE THIS
|
||||
\usepackage[submission]{aaai2026} % DO NOT CHANGE THIS
|
||||
\usepackage{times} % DO NOT CHANGE THIS
|
||||
\usepackage{helvet} % DO NOT CHANGE THIS
|
||||
\usepackage{courier} % DO NOT CHANGE THIS
|
||||
\usepackage[hyphens]{url} % DO NOT CHANGE THIS
|
||||
\usepackage{graphicx} % DO NOT CHANGE THIS
|
||||
\urlstyle{rm} % DO NOT CHANGE THIS
|
||||
\def\UrlFont{\rm} % DO NOT CHANGE THIS
|
||||
\usepackage{graphicx} % DO NOT CHANGE THIS
|
||||
\usepackage{natbib} % DO NOT CHANGE THIS
|
||||
\usepackage{caption} % DO NOT CHANGE THIS
|
||||
\frenchspacing % DO NOT CHANGE THIS
|
||||
\setlength{\pdfpagewidth}{8.5in} % DO NOT CHANGE THIS
|
||||
\setlength{\pdfpageheight}{11in} % DO NOT CHANGE THIS
|
||||
%
|
||||
% Keep the \pdfinfo as shown here. There's no need
|
||||
% for you to add the /Title and /Author tags.
|
||||
\pdfinfo{
|
||||
/TemplateVersion (2026.1)
|
||||
}
|
||||
\end{verbatim}\end{scriptsize}
|
||||
\end{quote}
|
||||
|
||||
\subsection{Preparing Your Paper}
|
||||
|
||||
After the preamble above, you should prepare your paper as follows:
|
||||
\begin{quote}
|
||||
\begin{scriptsize}\begin{verbatim}
|
||||
\begin{document}
|
||||
\maketitle
|
||||
\begin{abstract}
|
||||
%...
|
||||
\end{abstract}\end{verbatim}\end{scriptsize}
|
||||
\end{quote}
|
||||
|
||||
\noindent If you want to add links to the paper's code, dataset(s), and extended version or similar this is the place to add them, within a \emph{links} environment:
|
||||
\begin{quote}%
|
||||
\begin{scriptsize}\begin{verbatim}
|
||||
\begin{links}
|
||||
\link{Code}{https://aaai.org/example/guidelines}
|
||||
\link{Datasets}{https://aaai.org/example/datasets}
|
||||
\link{Extended version}{https://aaai.org/example}
|
||||
\end{links}\end{verbatim}\end{scriptsize}
|
||||
\end{quote}
|
||||
\noindent Make sure that you do not de-anonymize yourself with these links.
|
||||
|
||||
\noindent You should then continue with the body of your paper. Your paper must conclude with the references, which should be inserted as follows:
|
||||
\begin{quote}
|
||||
\begin{scriptsize}\begin{verbatim}
|
||||
% References and End of Paper
|
||||
% These lines must be placed at the end of your paper
|
||||
\bibliography{Bibliography-File}
|
||||
\end{document}
|
||||
\end{verbatim}\end{scriptsize}
|
||||
\end{quote}
|
||||
|
||||
\begin{quote}
|
||||
\begin{scriptsize}\begin{verbatim}
|
||||
\begin{document}\\
|
||||
\maketitle\\
|
||||
...\\
|
||||
\bibliography{Bibliography-File}\\
|
||||
\end{document}\\
|
||||
\end{verbatim}\end{scriptsize}
|
||||
\end{quote}
|
||||
|
||||
\subsection{Commands and Packages That May Not Be Used}
|
||||
\begin{table*}[t]
|
||||
\centering
|
||||
|
||||
\begin{tabular}{l|l|l|l}
|
||||
\textbackslash abovecaption &
|
||||
\textbackslash abovedisplay &
|
||||
\textbackslash addevensidemargin &
|
||||
\textbackslash addsidemargin \\
|
||||
\textbackslash addtolength &
|
||||
\textbackslash baselinestretch &
|
||||
\textbackslash belowcaption &
|
||||
\textbackslash belowdisplay \\
|
||||
\textbackslash break &
|
||||
\textbackslash clearpage &
|
||||
\textbackslash clip &
|
||||
\textbackslash columnsep \\
|
||||
\textbackslash float &
|
||||
\textbackslash input &
|
||||
\textbackslash input &
|
||||
\textbackslash linespread \\
|
||||
\textbackslash newpage &
|
||||
\textbackslash pagebreak &
|
||||
\textbackslash renewcommand &
|
||||
\textbackslash setlength \\
|
||||
\textbackslash text height &
|
||||
\textbackslash tiny &
|
||||
\textbackslash top margin &
|
||||
\textbackslash trim \\
|
||||
\textbackslash vskip\{- &
|
||||
\textbackslash vspace\{- \\
|
||||
\end{tabular}
|
||||
%}
|
||||
\caption{Commands that must not be used}
|
||||
\label{table1}
|
||||
\end{table*}
|
||||
|
||||
\begin{table}[t]
|
||||
\centering
|
||||
%\resizebox{.95\columnwidth}{!}{
|
||||
\begin{tabular}{l|l|l|l}
|
||||
authblk & babel & cjk & dvips \\
|
||||
epsf & epsfig & euler & float \\
|
||||
fullpage & geometry & graphics & hyperref \\
|
||||
layout & linespread & lmodern & maltepaper \\
|
||||
navigator & pdfcomment & pgfplots & psfig \\
|
||||
pstricks & t1enc & titlesec & tocbind \\
|
||||
ulem
|
||||
\end{tabular}
|
||||
\caption{LaTeX style packages that must not be used.}
|
||||
\label{table2}
|
||||
\end{table}
|
||||
|
||||
There are a number of packages, commands, scripts, and macros that are incompatable with aaai2026.sty. The common ones are listed in tables \ref{table1} and \ref{table2}. Generally, if a command, package, script, or macro alters floats, margins, fonts, sizing, linespacing, or the presentation of the references and citations, it is unacceptable. Note that negative vskip and vspace may not be used except in certain rare occurances, and may never be used around tables, figures, captions, sections, subsections, subsubsections, or references.
|
||||
|
||||
|
||||
\subsection{Page Breaks}
|
||||
For your final camera ready copy, you must not use any page break commands. References must flow directly after the text without breaks. Note that some conferences require references to be on a separate page during the review process. AAAI Press, however, does not require this condition for the final paper.
|
||||
|
||||
|
||||
\subsection{Paper Size, Margins, and Column Width}
|
||||
Papers must be formatted to print in two-column format on 8.5 x 11 inch US letter-sized paper. The margins must be exactly as follows:
|
||||
\begin{itemize}
|
||||
\item Top margin: 1.25 inches (first page), .75 inches (others)
|
||||
\item Left margin: .75 inches
|
||||
\item Right margin: .75 inches
|
||||
\item Bottom margin: 1.25 inches
|
||||
\end{itemize}
|
||||
|
||||
|
||||
The default paper size in most installations of \LaTeX{} is A4. However, because we require that your electronic paper be formatted in US letter size, the preamble we have provided includes commands that alter the default to US letter size. Please note that using any other package to alter page size (such as, but not limited to the Geometry package) will result in your final paper being returned to you for correction.
|
||||
|
||||
|
||||
\subsubsection{Column Width and Margins.}
|
||||
To ensure maximum readability, your paper must include two columns. Each column should be 3.3 inches wide (slightly more than 3.25 inches), with a .375 inch (.952 cm) gutter of white space between the two columns. The aaai2026.sty file will automatically create these columns for you.
|
||||
|
||||
\subsection{Overlength Papers}
|
||||
If your paper is too long and you resort to formatting tricks to make it fit, it is quite likely that it will be returned to you. The best way to retain readability if the paper is overlength is to cut text, figures, or tables. There are a few acceptable ways to reduce paper size that don't affect readability. First, turn on \textbackslash frenchspacing, which will reduce the space after periods. Next, move all your figures and tables to the top of the page. Consider removing less important portions of a figure. If you use \textbackslash centering instead of \textbackslash begin\{center\} in your figure environment, you can also buy some space. For mathematical environments, you may reduce fontsize {\bf but not below 6.5 point}.
|
||||
|
||||
|
||||
Commands that alter page layout are forbidden. These include \textbackslash columnsep, \textbackslash float, \textbackslash topmargin, \textbackslash topskip, \textbackslash textheight, \textbackslash textwidth, \textbackslash oddsidemargin, and \textbackslash evensizemargin (this list is not exhaustive). If you alter page layout, you will be required to pay the page fee. Other commands that are questionable and may cause your paper to be rejected include \textbackslash parindent, and \textbackslash parskip. Commands that alter the space between sections are forbidden. The title sec package is not allowed. Regardless of the above, if your paper is obviously ``squeezed" it is not going to to be accepted. Options for reducing the length of a paper include reducing the size of your graphics, cutting text, or paying the extra page charge (if it is offered).
|
||||
|
||||
|
||||
\subsection{Type Font and Size}
|
||||
Your paper must be formatted in Times Roman or Nimbus. We will not accept papers formatted using Computer Modern or Palatino or some other font as the text or heading typeface. Sans serif, when used, should be Courier. Use Symbol or Lucida or Computer Modern for \textit{mathematics only. }
|
||||
|
||||
Do not use type 3 fonts for any portion of your paper, including graphics. Type 3 bitmapped fonts are designed for fixed resolution printers. Most print at 300 dpi even if the printer resolution is 1200 dpi or higher. They also often cause high resolution imagesetter devices to crash. Consequently, AAAI will not accept electronic files containing obsolete type 3 fonts. Files containing those fonts (even in graphics) will be rejected. (Authors using blackboard symbols must avoid packages that use type 3 fonts.)
|
||||
|
||||
Fortunately, there are effective workarounds that will prevent your file from embedding type 3 bitmapped fonts. The easiest workaround is to use the required times, helvet, and courier packages with \LaTeX{}2e. (Note that papers formatted in this way will still use Computer Modern for the mathematics. To make the math look good, you'll either have to use Symbol or Lucida, or you will need to install type 1 Computer Modern fonts --- for more on these fonts, see the section ``Obtaining Type 1 Computer Modern.")
|
||||
|
||||
If you are unsure if your paper contains type 3 fonts, view the PDF in Acrobat Reader. The Properties/Fonts window will display the font name, font type, and encoding properties of all the fonts in the document. If you are unsure if your graphics contain type 3 fonts (and they are PostScript or encapsulated PostScript documents), create PDF versions of them, and consult the properties window in Acrobat Reader.
|
||||
|
||||
The default size for your type must be ten-point with twelve-point leading (line spacing). Start all pages (except the first) directly under the top margin. (See the next section for instructions on formatting the title page.) Indent ten points when beginning a new paragraph, unless the paragraph begins directly below a heading or subheading.
|
||||
|
||||
|
||||
\subsubsection{Obtaining Type 1 Computer Modern for \LaTeX{}.}
|
||||
|
||||
If you use Computer Modern for the mathematics in your paper (you cannot use it for the text) you may need to download type 1 Computer fonts. They are available without charge from the American Mathematical Society:
|
||||
http://www.ams.org/tex/type1-fonts.html.
|
||||
|
||||
\subsubsection{Nonroman Fonts.}
|
||||
If your paper includes symbols in other languages (such as, but not limited to, Arabic, Chinese, Hebrew, Japanese, Thai, Russian and other Cyrillic languages), you must restrict their use to bit-mapped figures.
|
||||
|
||||
\subsection{Title and Authors}
|
||||
Your title must appear centered over both text columns in sixteen-point bold type (twenty-four point leading). The title must be written in Title Case according to the Chicago Manual of Style rules. The rules are a bit involved, but in general verbs (including short verbs like be, is, using, and go), nouns, adverbs, adjectives, and pronouns should be capitalized, (including both words in hyphenated terms), while articles, conjunctions, and prepositions are lower case unless they directly follow a colon or long dash. You can use the online tool \url{https://titlecaseconverter.com/} to double-check the proper capitalization (select the "Chicago" style and mark the "Show explanations" checkbox).
|
||||
|
||||
Author's names should appear below the title of the paper, centered in twelve-point type (with fifteen point leading), along with affiliation(s) and complete address(es) (including electronic mail address if available) in nine-point roman type (the twelve point leading). You should begin the two-column format when you come to the abstract.
|
||||
|
||||
\subsubsection{Formatting Author Information.}
|
||||
Author information has to be set according to the following specification depending if you have one or more than one affiliation. You may not use a table nor may you employ the \textbackslash authorblk.sty package. For one or several authors from the same institution, please separate them with commas and write all affiliation directly below (one affiliation per line) using the macros \textbackslash author and \textbackslash affiliations:
|
||||
|
||||
\begin{quote}\begin{scriptsize}\begin{verbatim}
|
||||
\author{
|
||||
Author 1, ..., Author n\\
|
||||
}
|
||||
\affiliations {
|
||||
Address line\\
|
||||
... \\
|
||||
Address line\\
|
||||
}
|
||||
\end{verbatim}\end{scriptsize}\end{quote}
|
||||
|
||||
|
||||
\noindent For authors from different institutions, use \textbackslash textsuperscript \{\textbackslash rm x \} to match authors and affiliations. Notice that there should not be any spaces between the author name (or comma following it) and the superscript.
|
||||
|
||||
\begin{quote}\begin{scriptsize}\begin{verbatim}
|
||||
\author{
|
||||
AuthorOne\equalcontrib\textsuperscript{\rm 1,\rm 2},
|
||||
AuthorTwo\equalcontrib\textsuperscript{\rm 2},
|
||||
AuthorThree\textsuperscript{\rm 3},\\
|
||||
AuthorFour\textsuperscript{\rm 4},
|
||||
AuthorFive \textsuperscript{\rm 5}}
|
||||
}
|
||||
\affiliations {
|
||||
\textsuperscript{\rm 1}AffiliationOne,\\
|
||||
\textsuperscript{\rm 2}AffiliationTwo,\\
|
||||
\textsuperscript{\rm 3}AffiliationThree,\\
|
||||
\textsuperscript{\rm 4}AffiliationFour,\\
|
||||
\textsuperscript{\rm 5}AffiliationFive\\
|
||||
\{email, email\}@affiliation.com,
|
||||
email@affiliation.com,
|
||||
email@affiliation.com,
|
||||
email@affiliation.com
|
||||
}
|
||||
\end{verbatim}\end{scriptsize}\end{quote}
|
||||
|
||||
You can indicate that some authors contributed equally using the \textbackslash equalcontrib command. This will add a marker after the author names and a footnote on the first page.
|
||||
|
||||
Note that you may want to break the author list for better visualization. You can achieve this using a simple line break (\textbackslash \textbackslash).
|
||||
|
||||
\subsection{\LaTeX{} Copyright Notice}
|
||||
The copyright notice automatically appears if you use aaai2026.sty. It has been hardcoded and may not be disabled.
|
||||
|
||||
\subsection{Credits}
|
||||
Any credits to a sponsoring agency should appear in the acknowledgments section, unless the agency requires different placement. If it is necessary to include this information on the front page, use
|
||||
\textbackslash thanks in either the \textbackslash author or \textbackslash title commands.
|
||||
For example:
|
||||
\begin{quote}
|
||||
\begin{small}
|
||||
\textbackslash title\{Very Important Results in AI\textbackslash thanks\{This work is
|
||||
supported by everybody.\}\}
|
||||
\end{small}
|
||||
\end{quote}
|
||||
Multiple \textbackslash thanks commands can be given. Each will result in a separate footnote indication in the author or title with the corresponding text at the botton of the first column of the document. Note that the \textbackslash thanks command is fragile. You will need to use \textbackslash protect.
|
||||
|
||||
Please do not include \textbackslash pubnote commands in your document.
|
||||
|
||||
\subsection{Abstract}
|
||||
Follow the example commands in this document for creation of your abstract. The command \textbackslash begin\{abstract\} will automatically indent the text block. Please do not indent it further. {Do not include references in your abstract!}
|
||||
|
||||
\subsection{Page Numbers}
|
||||
|
||||
Do not print any page numbers on your paper. The use of \textbackslash pagestyle is forbidden.
|
||||
|
||||
\subsection{Text}
|
||||
The main body of the paper must be formatted in black, ten-point Times Roman with twelve-point leading (line spacing). You may not reduce font size or the linespacing. Commands that alter font size or line spacing (including, but not limited to baselinestretch, baselineshift, linespread, and others) are expressly forbidden. In addition, you may not use color in the text.
|
||||
|
||||
\subsection{Citations}
|
||||
Citations within the text should include the author's last name and year, for example (Newell 1980). Append lower-case letters to the year in cases of ambiguity. Multiple authors should be treated as follows: (Feigenbaum and Engelmore 1988) or (Ford, Hayes, and Glymour 1992). In the case of four or more authors, list only the first author, followed by et al. (Ford et al. 1997).
|
||||
|
||||
\subsection{Extracts}
|
||||
Long quotations and extracts should be indented ten points from the left and right margins.
|
||||
|
||||
\begin{quote}
|
||||
This is an example of an extract or quotation. Note the indent on both sides. Quotation marks are not necessary if you offset the text in a block like this, and properly identify and cite the quotation in the text.
|
||||
|
||||
\end{quote}
|
||||
|
||||
\subsection{Footnotes}
|
||||
Use footnotes judiciously, taking into account that they interrupt the reading of the text. When required, they should be consecutively numbered throughout with superscript Arabic numbers. Footnotes should appear at the bottom of the page, separated from the text by a blank line space and a thin, half-point rule.
|
||||
|
||||
\subsection{Headings and Sections}
|
||||
When necessary, headings should be used to separate major sections of your paper. Remember, you are writing a short paper, not a lengthy book! An overabundance of headings will tend to make your paper look more like an outline than a paper. The aaai2026.sty package will create headings for you. Do not alter their size nor their spacing above or below.
|
||||
|
||||
\subsubsection{Section Numbers.}
|
||||
The use of section numbers in AAAI Press papers is optional. To use section numbers in \LaTeX{}, uncomment the setcounter line in your document preamble and change the 0 to a 1. Section numbers should not be used in short poster papers and/or extended abstracts.
|
||||
|
||||
\subsubsection{Section Headings.}
|
||||
Sections should be arranged and headed as follows:
|
||||
\begin{enumerate}
|
||||
\item Main content sections
|
||||
\item Appendices (optional)
|
||||
\item Ethical Statement (optional, unnumbered)
|
||||
\item Acknowledgements (optional, unnumbered)
|
||||
\item References (unnumbered)
|
||||
\end{enumerate}
|
||||
|
||||
\subsubsection{Appendices.}
|
||||
Any appendices must appear after the main content. If your main sections are numbered, appendix sections must use letters instead of arabic numerals. In \LaTeX{} you can use the \texttt{\textbackslash appendix} command to achieve this effect and then use \texttt{\textbackslash section\{Heading\}} normally for your appendix sections.
|
||||
|
||||
\subsubsection{Ethical Statement.}
|
||||
You can write a statement about the potential ethical impact of your work, including its broad societal implications, both positive and negative. If included, such statement must be written in an unnumbered section titled \emph{Ethical Statement}.
|
||||
|
||||
\subsubsection{Acknowledgments.}
|
||||
The acknowledgments section, if included, appears right before the references and is headed ``Acknowledgments". It must not be numbered even if other sections are (use \texttt{\textbackslash section*\{Acknowledgements\}} in \LaTeX{}). This section includes acknowledgments of help from associates and colleagues, credits to sponsoring agencies, financial support, and permission to publish. Please acknowledge other contributors, grant support, and so forth, in this section. Do not put acknowledgments in a footnote on the first page. If your grant agency requires acknowledgment of the grant on page 1, limit the footnote to the required statement, and put the remaining acknowledgments at the back. Please try to limit acknowledgments to no more than three sentences.
|
||||
|
||||
\subsubsection{References.}
|
||||
The references section should be labeled ``References" and must appear at the very end of the paper (don't end the paper with references, and then put a figure by itself on the last page). A sample list of references is given later on in these instructions. Please use a consistent format for references. Poorly prepared or sloppy references reflect badly on the quality of your paper and your research. Please prepare complete and accurate citations.
|
||||
|
||||
\subsection{Illustrations and Figures}
|
||||
|
||||
\begin{figure}[t]
|
||||
\centering
|
||||
\includegraphics[width=0.9\columnwidth]{figure1} % Reduce the figure size so that it is slightly narrower than the column. Don't use precise values for figure width.This setup will avoid overfull boxes.
|
||||
\caption{Using the trim and clip commands produces fragile layers that can result in disasters (like this one from an actual paper) when the color space is corrected or the PDF combined with others for the final proceedings. Crop your figures properly in a graphics program -- not in LaTeX.}
|
||||
\label{fig1}
|
||||
\end{figure}
|
||||
|
||||
\begin{figure*}[t]
|
||||
\centering
|
||||
\includegraphics[width=0.8\textwidth]{figure2} % Reduce the figure size so that it is slightly narrower than the column.
|
||||
\caption{Adjusting the bounding box instead of actually removing the unwanted data resulted multiple layers in this paper. It also needlessly increased the PDF size. In this case, the size of the unwanted layer doubled the paper's size, and produced the following surprising results in final production. Crop your figures properly in a graphics program. Don't just alter the bounding box.}
|
||||
\label{fig2}
|
||||
\end{figure*}
|
||||
|
||||
% Using the \centering command instead of \begin{center} ... \end{center} will save space
|
||||
% Positioning your figure at the top of the page will save space and make the paper more readable
|
||||
% Using 0.95\columnwidth in conjunction with the
|
||||
|
||||
|
||||
Your paper must compile in PDF\LaTeX{}. Consequently, all your figures must be .jpg, .png, or .pdf. You may not use the .gif (the resolution is too low), .ps, or .eps file format for your figures.
|
||||
|
||||
Figures, drawings, tables, and photographs should be placed throughout the paper on the page (or the subsequent page) where they are first discussed. Do not group them together at the end of the paper. If placed at the top of the paper, illustrations may run across both columns. Figures must not invade the top, bottom, or side margin areas. Figures must be inserted using the \textbackslash usepackage\{graphicx\}. Number figures sequentially, for example, figure 1, and so on. Do not use minipage to group figures.
|
||||
|
||||
If you normally create your figures using pgfplots, please create the figures first, and then import them as pdfs with proper bounding boxes, as the bounding and trim boxes created by pfgplots are fragile and not valid.
|
||||
|
||||
When you include your figures, you must crop them \textbf{outside} of \LaTeX{}. The command \textbackslash includegraphics*[clip=true, viewport 0 0 10 10]{...} might result in a PDF that looks great, but the image is \textbf{not really cropped.} The full image can reappear (and obscure whatever it is overlapping) when page numbers are applied or color space is standardized. Figures \ref{fig1}, and \ref{fig2} display some unwanted results that often occur.
|
||||
|
||||
If your paper includes illustrations that are not compatible with PDF\TeX{} (such as .eps or .ps documents), you will need to convert them. The epstopdf package will usually work for eps files. You will need to convert your ps files to PDF in either case.
|
||||
|
||||
\subsubsection {Figure Captions.}The illustration number and caption must appear \textit{under} the illustration. Labels and other text with the actual illustration must be at least nine-point type. However, the font and size of figure captions must be 10 point roman. Do not make them smaller, bold, or italic. (Individual words may be italicized if the context requires differentiation.)
|
||||
|
||||
\subsection{Tables}
|
||||
|
||||
\subsection{Tables}
|
||||
|
||||
Tables should be presented in 10 point roman type. If necessary, they may be altered to 9 point type. You must not use \texttt{\textbackslash resizebox} or other commands that resize the entire table to make it smaller, because you can't control the final font size this way.
|
||||
If your table is too large you can use \texttt{\textbackslash setlength\{\textbackslash tabcolsep\}\{1mm\}} to compress the columns a bit or you can adapt the content (e.g.: reduce the decimal precision when presenting numbers, use shortened column titles, make some column duble-line to get it narrower).
|
||||
|
||||
Tables that do not fit in a single column must be placed across double columns. If your table won't fit within the margins even when spanning both columns and using the above techniques, you must split it in two separate tables.
|
||||
|
||||
\subsubsection {Table Captions.} The number and caption for your table must appear \textit{under} (not above) the table. Additionally, the font and size of table captions must be 10 point roman and must be placed beneath the figure. Do not make them smaller, bold, or italic. (Individual words may be italicized if the context requires differentiation.)
|
||||
|
||||
|
||||
|
||||
\subsubsection{Low-Resolution Bitmaps.}
|
||||
You may not use low-resolution (such as 72 dpi) screen-dumps and GIF files---these files contain so few pixels that they are always blurry, and illegible when printed. If they are color, they will become an indecipherable mess when converted to black and white. This is always the case with gif files, which should never be used. The resolution of screen dumps can be increased by reducing the print size of the original file while retaining the same number of pixels. You can also enlarge files by manipulating them in software such as PhotoShop. Your figures should be 300 dpi when incorporated into your document.
|
||||
|
||||
\subsubsection{\LaTeX{} Overflow.}
|
||||
\LaTeX{} users please beware: \LaTeX{} will sometimes put portions of the figure or table or an equation in the margin. If this happens, you need to make the figure or table span both columns. If absolutely necessary, you may reduce the figure, or reformat the equation, or reconfigure the table.{ \bf Check your log file!} You must fix any overflow into the margin (that means no overfull boxes in \LaTeX{}). \textbf{Nothing is permitted to intrude into the margin or gutter.}
|
||||
|
||||
|
||||
\subsubsection{Using Color.}
|
||||
Use of color is restricted to figures only. It must be WACG 2.0 compliant. (That is, the contrast ratio must be greater than 4.5:1 no matter the font size.) It must be CMYK, NOT RGB. It may never be used for any portion of the text of your paper. The archival version of your paper will be printed in black and white and grayscale. The web version must be readable by persons with disabilities. Consequently, because conversion to grayscale can cause undesirable effects (red changes to black, yellow can disappear, and so forth), we strongly suggest you avoid placing color figures in your document. If you do include color figures, you must (1) use the CMYK (not RGB) colorspace and (2) be mindful of readers who may happen to have trouble distinguishing colors. Your paper must be decipherable without using color for distinction.
|
||||
|
||||
\subsubsection{Drawings.}
|
||||
We suggest you use computer drawing software (such as Adobe Illustrator or, (if unavoidable), the drawing tools in Microsoft Word) to create your illustrations. Do not use Microsoft Publisher. These illustrations will look best if all line widths are uniform (half- to two-point in size), and you do not create labels over shaded areas. Shading should be 133 lines per inch if possible. Use Times Roman or Helvetica for all figure call-outs. \textbf{Do not use hairline width lines} --- be sure that the stroke width of all lines is at least .5 pt. Zero point lines will print on a laser printer, but will completely disappear on the high-resolution devices used by our printers.
|
||||
|
||||
\subsubsection{Photographs and Images.}
|
||||
Photographs and other images should be in grayscale (color photographs will not reproduce well; for example, red tones will reproduce as black, yellow may turn to white, and so forth) and set to a minimum of 300 dpi. Do not prescreen images.
|
||||
|
||||
\subsubsection{Resizing Graphics.}
|
||||
Resize your graphics \textbf{before} you include them with LaTeX. You may \textbf{not} use trim or clip options as part of your \textbackslash includegraphics command. Resize the media box of your PDF using a graphics program instead.
|
||||
|
||||
\subsubsection{Fonts in Your Illustrations.}
|
||||
You must embed all fonts in your graphics before including them in your LaTeX document.
|
||||
|
||||
\subsubsection{Algorithms.}
|
||||
Algorithms and/or programs are a special kind of figures. Like all illustrations, they should appear floated to the top (preferably) or bottom of the page. However, their caption should appear in the header, left-justified and enclosed between horizontal lines, as shown in Algorithm~\ref{alg:algorithm}. The algorithm body should be terminated with another horizontal line. It is up to the authors to decide whether to show line numbers or not, how to format comments, etc.
|
||||
|
||||
In \LaTeX{} algorithms may be typeset using the {\tt algorithm} and {\tt algorithmic} packages, but you can also use one of the many other packages for the task.
|
||||
|
||||
\begin{algorithm}[tb]
|
||||
\caption{Example algorithm}
|
||||
\label{alg:algorithm}
|
||||
\textbf{Input}: Your algorithm's input\\
|
||||
\textbf{Parameter}: Optional list of parameters\\
|
||||
\textbf{Output}: Your algorithm's output
|
||||
\begin{algorithmic}[1] %[1] enables line numbers
|
||||
\STATE Let $t=0$.
|
||||
\WHILE{condition}
|
||||
\STATE Do some action.
|
||||
\IF {conditional}
|
||||
\STATE Perform task A.
|
||||
\ELSE
|
||||
\STATE Perform task B.
|
||||
\ENDIF
|
||||
\ENDWHILE
|
||||
\STATE \textbf{return} solution
|
||||
\end{algorithmic}
|
||||
\end{algorithm}
|
||||
|
||||
\subsubsection{Listings.}
|
||||
Listings are much like algorithms and programs. They should also appear floated to the top (preferably) or bottom of the page. Listing captions should appear in the header, left-justified and enclosed between horizontal lines as shown in Listing~\ref{lst:listing}. Terminate the body with another horizontal line and avoid any background color. Line numbers, if included, must appear within the text column.
|
||||
|
||||
\begin{listing}[tb]%
|
||||
\caption{Example listing {\tt quicksort.hs}}%
|
||||
\label{lst:listing}%
|
||||
\begin{lstlisting}[language=Haskell]
|
||||
quicksort :: Ord a => [a] -> [a]
|
||||
quicksort [] = []
|
||||
quicksort (p:xs) = (quicksort lesser) ++ [p] ++ (quicksort greater)
|
||||
where
|
||||
lesser = filter (< p) xs
|
||||
greater = filter (>= p) xs
|
||||
\end{lstlisting}
|
||||
\end{listing}
|
||||
|
||||
\subsection{References}
|
||||
The AAAI style includes a set of definitions for use in formatting references with BibTeX. These definitions make the bibliography style fairly close to the ones specified in the Reference Examples appendix below. To use these definitions, you also need the BibTeX style file ``aaai2026.bst," available in the AAAI Author Kit on the AAAI web site. Then, at the end of your paper but before \textbackslash end{document}, you need to put the following lines:
|
||||
|
||||
\begin{quote}
|
||||
\begin{small}
|
||||
\textbackslash bibliography\{bibfile1,bibfile2,...\}
|
||||
\end{small}
|
||||
\end{quote}
|
||||
|
||||
Please note that the aaai2026.sty class already sets the bibliographystyle for you, so you do not have to place any \textbackslash bibliographystyle command in the document yourselves. The aaai2026.sty file is incompatible with the hyperref and navigator packages. If you use either, your references will be garbled and your paper will be returned to you.
|
||||
|
||||
References may be the same size as surrounding text.
|
||||
However, in this section (only), you may reduce the size to {\em \textbackslash small} (9pt) if your paper exceeds the allowable number of pages. Making it any smaller than 9 point with 10 point linespacing, however, is not allowed.
|
||||
|
||||
The list of files in the \textbackslash bibliography command should be the names of your BibTeX source files (that is, the .bib files referenced in your paper).
|
||||
|
||||
The following commands are available for your use in citing references:
|
||||
\begin{quote}
|
||||
{\em \textbackslash cite:} Cites the given reference(s) with a full citation. This appears as ``(Author Year)'' for one reference, or ``(Author Year; Author Year)'' for multiple references.\smallskip\\
|
||||
{\em \textbackslash shortcite:} Cites the given reference(s) with just the year. This appears as ``(Year)'' for one reference, or ``(Year; Year)'' for multiple references.\smallskip\\
|
||||
{\em \textbackslash citeauthor:} Cites the given reference(s) with just the author name(s) and no parentheses.\smallskip\\
|
||||
{\em \textbackslash citeyear:} Cites the given reference(s) with just the date(s) and no parentheses.
|
||||
\end{quote}
|
||||
You may also use any of the \emph{natbib} citation commands.
|
||||
|
||||
|
||||
\section{Proofreading Your PDF}
|
||||
Please check all the pages of your PDF file. The most commonly forgotten element is the acknowledgements --- especially the correct grant number. Authors also commonly forget to add the metadata to the source, use the wrong reference style file, or don't follow the capitalization rules or comma placement for their author-title information properly. A final common problem is text (expecially equations) that runs into the margin. You will need to fix these common errors before submitting your file.
|
||||
|
||||
\section{Improperly Formatted Files }
|
||||
In the past, AAAI has corrected improperly formatted files submitted by the authors. Unfortunately, this has become an increasingly burdensome expense that we can no longer absorb). Consequently, if your file is improperly formatted, it will be returned to you for correction.
|
||||
|
||||
\section{Naming Your Electronic File}
|
||||
We require that you name your \LaTeX{} source file with the last name (family name) of the first author so that it can easily be differentiated from other submissions. Complete file-naming instructions will be provided to you in the submission instructions.
|
||||
|
||||
\section{Submitting Your Electronic Files to AAAI}
|
||||
Instructions on paper submittal will be provided to you in your acceptance letter.
|
||||
|
||||
\section{Inquiries}
|
||||
If you have any questions about the preparation or submission of your paper as instructed in this document, please contact AAAI Press at the address given below. If you have technical questions about implementation of the aaai style file, please contact an expert at your site. We do not provide technical support for \LaTeX{} or any other software package. To avoid problems, please keep your paper simple, and do not incorporate complicated macros and style files.
|
||||
|
||||
\begin{quote}
|
||||
\noindent AAAI Press\\
|
||||
1101 Pennsylvania Ave, NW Suite 300\\
|
||||
Washington, DC 20004 USA\\
|
||||
\textit{Telephone:} 1-202-360-4062\\
|
||||
\textit{E-mail:} See the submission instructions for your particular conference or event.
|
||||
\end{quote}
|
||||
|
||||
\section{Additional Resources}
|
||||
\LaTeX{} is a difficult program to master. If you've used that software, and this document didn't help or some items were not explained clearly, we recommend you read Michael Shell's excellent document (testflow doc.txt V1.0a 2002/08/13) about obtaining correct PS/PDF output on \LaTeX{} systems. (It was written for another purpose, but it has general application as well). It is available at www.ctan.org in the tex-archive.
|
||||
|
||||
\appendix
|
||||
\section{Reference Examples}
|
||||
\label{sec:reference_examples}
|
||||
|
||||
\nobibliography*
|
||||
Formatted bibliographies should look like the following examples. You should use BibTeX to generate the references. Missing fields are unacceptable when compiling references, and usually indicate that you are using the wrong type of entry (BibTeX class).
|
||||
|
||||
\paragraph{Book with multiple authors~\nocite{em:86}} Use the \texttt{@book} class.\\[.2em]
|
||||
\bibentry{em:86}.
|
||||
|
||||
\paragraph{Journal and magazine articles~\nocite{r:80, hcr:83}} Use the \texttt{@article} class.\\[.2em]
|
||||
\bibentry{r:80}.\\[.2em]
|
||||
\bibentry{hcr:83}.
|
||||
|
||||
\paragraph{Proceedings paper published by a society, press or publisher~\nocite{c:83, c:84}} Use the \texttt{@inproceedings} class. You may abbreviate the \emph{booktitle} field, but make sure that the conference edition is clear.\\[.2em]
|
||||
\bibentry{c:84}.\\[.2em]
|
||||
\bibentry{c:83}.
|
||||
|
||||
\paragraph{University technical report~\nocite{r:86}} Use the \texttt{@techreport} class.\\[.2em]
|
||||
\bibentry{r:86}.
|
||||
|
||||
\paragraph{Dissertation or thesis~\nocite{c:79}} Use the \texttt{@phdthesis} class.\\[.2em]
|
||||
\bibentry{c:79}.
|
||||
|
||||
\paragraph{Forthcoming publication~\nocite{c:21}} Use the \texttt{@misc} class with a \texttt{note="Forthcoming"} annotation.
|
||||
\begin{quote}
|
||||
\begin{footnotesize}
|
||||
\begin{verbatim}
|
||||
@misc(key,
|
||||
[...]
|
||||
note="Forthcoming",
|
||||
)
|
||||
\end{verbatim}
|
||||
\end{footnotesize}
|
||||
\end{quote}
|
||||
\bibentry{c:21}.
|
||||
|
||||
\paragraph{ArXiv paper~\nocite{c:22}} Fetch the BibTeX entry from the "Export Bibtex Citation" link in the arXiv website. Notice it uses the \texttt{@misc} class instead of the \texttt{@article} one, and that it includes the \texttt{eprint} and \texttt{archivePrefix} keys.
|
||||
\begin{quote}
|
||||
\begin{footnotesize}
|
||||
\begin{verbatim}
|
||||
@misc(key,
|
||||
[...]
|
||||
eprint="xxxx.yyyy",
|
||||
archivePrefix="arXiv",
|
||||
)
|
||||
\end{verbatim}
|
||||
\end{footnotesize}
|
||||
\end{quote}
|
||||
\bibentry{c:22}.
|
||||
|
||||
\paragraph{Website or online resource~\nocite{c:23}} Use the \texttt{@misc} class. Add the url in the \texttt{howpublished} field and the date of access in the \texttt{note} field:
|
||||
\begin{quote}
|
||||
\begin{footnotesize}
|
||||
\begin{verbatim}
|
||||
@misc(key,
|
||||
[...]
|
||||
howpublished="\url{http://...}",
|
||||
note="Accessed: YYYY-mm-dd",
|
||||
)
|
||||
\end{verbatim}
|
||||
\end{footnotesize}
|
||||
\end{quote}
|
||||
\bibentry{c:23}.
|
||||
|
||||
\vspace{.2em}
|
||||
For the most up to date version of the AAAI reference style, please consult the \textit{AI Magazine} Author Guidelines at \url{https://aaai.org/ojs/index.php/aimagazine/about/submissions#authorGuidelines}
|
||||
|
||||
\section{Acknowledgments}
|
||||
AAAI is especially grateful to Peter Patel Schneider for his work in implementing the original aaai.sty file, liberally using the ideas of other style hackers, including Barbara Beeton. We also acknowledge with thanks the work of George Ferguson for his guide to using the style and BibTeX files --- which has been incorporated into this document --- and Hans Guesgen, who provided several timely modifications, as well as the many others who have, from time to time, sent in suggestions on improvements to the AAAI style. We are especially grateful to Francisco Cruz, Marc Pujol-Gonzalez, and Mico Loretan for the improvements to the Bib\TeX{} and \LaTeX{} files made in 2020.
|
||||
|
||||
The preparation of the \LaTeX{} and Bib\TeX{} files that implement these instructions was supported by Schlumberger Palo Alto Research, AT\&T Bell Laboratories, Morgan Kaufmann Publishers, The Live Oak Press, LLC, and AAAI Press. Bibliography style changes were added by Sunil Issar. \verb+\+pubnote was added by J. Scott Penberthy. George Ferguson added support for printing the AAAI copyright slug. Additional changes to aaai2026.sty and aaai2026.bst have been made by Francisco Cruz and Marc Pujol-Gonzalez.
|
||||
|
||||
\bigskip
|
||||
\noindent Thank you for reading these instructions carefully. We look forward to receiving your electronic files!
|
||||
|
||||
\bibliography{aaai2026}
|
||||
|
||||
\end{document}
|
||||
255
eccv.sty
@@ -1,255 +0,0 @@
|
||||
% ---------------------------------------------------------------
|
||||
%
|
||||
% Formatting Package for ECCV Submissions
|
||||
%
|
||||
% initially created for ECCV 2024
|
||||
% by Stefan Roth
|
||||
%
|
||||
% based on previous ECCV templates:
|
||||
% updated April 2002 by Antje Endemann
|
||||
% Based on CVPR 07 and LNCS, with modifications by DAF, AZ and elle, 2008 and AA, 2010, and CC, 2011; TT, 2014; AAS, 2016; AAS, 2020; TH, 2022
|
||||
%
|
||||
% and the CVPR templates:
|
||||
% https://github.com/cvpr-org/author-kit
|
||||
%
|
||||
% No guarantee is given that the format corresponds perfectly to
|
||||
% LNCS Proceedings, but most features should be ok.
|
||||
%
|
||||
% ---------------------------------------------------------------
|
||||
%
|
||||
% use as
|
||||
% \documentclass[runningheads]{llncs}
|
||||
% \usepackage[options]{eccv}
|
||||
%
|
||||
% "options" include
|
||||
% * "review" for submitting a paper for review and
|
||||
% * "final" for the camera ready (default).
|
||||
% * "mobile" for camera ready on small-screen devices
|
||||
% * "year=20??" allows to specify the conference year (default current year).
|
||||
% * "ID=12345" allows to specify the paper ID (default `none').
|
||||
%
|
||||
% specify references as
|
||||
% \bibliographystyle{splncs04}
|
||||
% \bibliography{...your files...}
|
||||
% ---------------------------------------------------------------
|
||||
|
||||
\NeedsTeXFormat{LaTeX2e}[1999/12/01]
|
||||
\ProvidesPackage{eccv}[LaTeX style for ECCV]
|
||||
|
||||
|
||||
% ---------------------------------------------------------------
|
||||
% Suppress unwanted warnings
|
||||
|
||||
\RequirePackage{silence}
|
||||
\WarningFilter{amsmath}{Unable to redefine math accent \vec}
|
||||
\WarningFilter{caption}{Unknown document class (or package)}
|
||||
\RequirePackage{etoolbox}
|
||||
|
||||
|
||||
% ---------------------------------------------------------------
|
||||
% Basic packages
|
||||
|
||||
\RequirePackage[T1]{fontenc} % Required to avoid font issues
|
||||
\RequirePackage[left,mathlines]{lineno} % Support for line numbers
|
||||
\RequirePackage[dvipsnames]{xcolor} % Color for line numbers
|
||||
\RequirePackage{amsmath} % Need AMS packages to bug fix
|
||||
\RequirePackage{amssymb} % line numbers in equations
|
||||
\RequirePackage{cite} % Sort citations
|
||||
\RequirePackage{xspace}
|
||||
|
||||
% Breaking lines for URLs in the bib
|
||||
\RequirePackage[hyphens]{url}
|
||||
\Urlmuskip=0mu plus 1mu\relax
|
||||
|
||||
% Color for links and line numbers
|
||||
\definecolor{eccvblue}{rgb}{0.12,0.49,0.85}
|
||||
|
||||
% ---------------------------------------------------------------
|
||||
% Use modern caption package to allow for sub-figures etc.
|
||||
% Reproduces the original LNCS style as closely as possible.
|
||||
|
||||
\RequirePackage[labelfont=bf,font=small,tableposition=bottom]{caption}
|
||||
\RequirePackage[skip=3pt]{subcaption}
|
||||
|
||||
|
||||
% ---------------------------------------------------------------
|
||||
% Process ECCV package options
|
||||
|
||||
% Key value options
|
||||
\RequirePackage{kvoptions}
|
||||
\SetupKeyvalOptions{
|
||||
family=eccv,
|
||||
prefix=eccv@
|
||||
}
|
||||
|
||||
\DeclareBoolOption{review}
|
||||
\DeclareComplementaryOption{final}{review}
|
||||
\DeclareBoolOption{mobile}
|
||||
\DeclareStringOption[\the\year]{year}
|
||||
\DeclareStringOption[none]{ID}
|
||||
\DeclareDefaultOption{\PackageWarning{eccv}{Unkown option `\CurrentOption'}}
|
||||
\ProcessKeyvalOptions*
|
||||
|
||||
% Enable processing options also in main paper with \eccvsetup{ key=value, ... }
|
||||
\newcommand*{\eccvsetup}
|
||||
{\setkeys{eccv}%
|
||||
}
|
||||
|
||||
% Warn if ECCV package for review version is not loaded with paper ID option
|
||||
\ifeccv@review
|
||||
\ifdefstring{\eccv@ID}{none}{%
|
||||
\PackageWarningNoLine{eccv}{Review version requires a paper ID. Please load `eccv' package with `ID=*****' option and replace `*****' with your paper ID}
|
||||
}{}
|
||||
\fi
|
||||
|
||||
|
||||
% ---------------------------------------------------------------
|
||||
% Basic error handling
|
||||
|
||||
\AtBeginDocument{%
|
||||
% Print an error if document class other than llncs is used
|
||||
\@ifclassloaded{llncs}{}{%
|
||||
\PackageError{eccv}{Package only meant to be used with document class `llncs'}{Change document class to `llncs'.}
|
||||
}
|
||||
% Print a warning if incorrect options for llncs are specified
|
||||
\@ifclasswith{llncs}{runningheads}{}{%
|
||||
\PackageWarningNoLine{eccv}{Running heads incorrectly suppressed - ECCV requires running heads. Please load document class `llncs' with `runningheads' option}
|
||||
}
|
||||
% Print a warning if hyperref is not loaded and/or if the pagebackref option is missing
|
||||
\ifeccv@review
|
||||
\@ifpackageloaded{hyperref}{%
|
||||
\@ifpackagewith{hyperref}{pagebackref}{}{%
|
||||
\PackageWarningNoLine{eccv}{Package `hyperref' is not loaded with option `pagebackref', which is strongly recommended for review version}
|
||||
}
|
||||
}{%
|
||||
\PackageWarningNoLine{eccv}{Package `hyperref' is not loaded, but strongly recommended for review version}
|
||||
}
|
||||
\else
|
||||
\@ifpackageloaded{hyperref}{%
|
||||
\@ifpackagewith{hyperref}{pagebackref}{%
|
||||
\PackageWarningNoLine{eccv}{Package `hyperref' is loaded with option `pagebackref', which is *not* recommended for camera-ready version}{}
|
||||
}{}
|
||||
}{%
|
||||
\PackageWarningNoLine{eccv}{Package `hyperref' is not loaded, but highly recommended for camera-ready version}
|
||||
}
|
||||
\fi
|
||||
}
|
||||
|
||||
|
||||
% ---------------------------------------------------------------
|
||||
% Line number support for the review version
|
||||
|
||||
% NUMBER with left flushed zeros \fillzeros[<WIDTH>]<NUMBER>
|
||||
% from CVPR template
|
||||
\newcount\cv@tmpc@ \newcount\cv@tmpc
|
||||
\def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi
|
||||
\cv@tmpc=1 %
|
||||
\loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi
|
||||
\ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat
|
||||
\ifnum#2<0\advance\cv@tmpc1\relax-\fi
|
||||
\loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat
|
||||
\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}%
|
||||
|
||||
|
||||
% colored, bold, sans serif line numbers
|
||||
\renewcommand\thelinenumber{\color{eccvblue}\normalfont\sffamily\scriptsize\fillzeros[3]{\arabic{linenumber}}\color[rgb]{0,0,0}}
|
||||
% on both sides
|
||||
\renewcommand\makeLineNumber{\hss\thelinenumber\ \hspace{4.5mm} \rlap{\hskip\textwidth\ \hspace{5mm}\thelinenumber}}
|
||||
|
||||
|
||||
% Bug: An equation with $$ ... $$ isn't numbered, nor is the previous line.
|
||||
% Patch amsmath commands so that the previous line and the equation itself
|
||||
% are numbered. Bug: multiline has an extra line number.
|
||||
% https://tex.stackexchange.com/questions/461186/how-to-use-lineno-with-amsmath-align
|
||||
|
||||
%% Patch 'normal' math environments:
|
||||
\newcommand*\linenomathpatch[1]{%
|
||||
\cspreto{#1}{\linenomath}%
|
||||
\cspreto{#1*}{\linenomath}%
|
||||
\csappto{end#1}{\endlinenomath}%
|
||||
\csappto{end#1*}{\endlinenomath}%
|
||||
}
|
||||
%% Patch AMS math environments:
|
||||
\newcommand*\linenomathpatchAMS[1]{%
|
||||
\cspreto{#1}{\linenomathAMS}%
|
||||
\cspreto{#1*}{\linenomathAMS}%
|
||||
\csappto{end#1}{\endlinenomath}%
|
||||
\csappto{end#1*}{\endlinenomath}%
|
||||
}
|
||||
|
||||
%% Definition of \linenomathAMS depends on whether the mathlines option is provided
|
||||
\expandafter\ifx\linenomath\linenomathWithnumbers
|
||||
\let\linenomathAMS\linenomathWithnumbers
|
||||
%% The following line gets rid of an extra line numbers at the bottom:
|
||||
\patchcmd\linenomathAMS{\advance\postdisplaypenalty\linenopenalty}{}{}{}
|
||||
\else
|
||||
\let\linenomathAMS\linenomathNonumbers
|
||||
\fi
|
||||
|
||||
\linenomathpatch{equation}
|
||||
\linenomathpatchAMS{gather}
|
||||
\linenomathpatchAMS{multline}
|
||||
\linenomathpatchAMS{align}
|
||||
\linenomathpatchAMS{alignat}
|
||||
\linenomathpatchAMS{flalign}
|
||||
|
||||
% Disable line numbering during measurement step of multline
|
||||
\makeatletter
|
||||
\patchcmd{\mmeasure@}{\measuring@true}{
|
||||
\measuring@true
|
||||
\ifnum-\linenopenaltypar>\interdisplaylinepenalty
|
||||
\advance\interdisplaylinepenalty-\linenopenalty
|
||||
\fi
|
||||
}{}{}
|
||||
\makeatother
|
||||
|
||||
|
||||
% ---------------------------------------------------------------
|
||||
% Modifications to LNCS template for review version
|
||||
|
||||
\makeatletter
|
||||
\ifeccv@review
|
||||
% Display line numbers
|
||||
\AtBeginDocument{%
|
||||
\linenumbers
|
||||
\linenomathpatch{equation}%
|
||||
\linenomathpatchAMS{gather}%
|
||||
\linenomathpatchAMS{multline}%
|
||||
\linenomathpatchAMS{align}%
|
||||
\linenomathpatchAMS{alignat}%
|
||||
\linenomathpatchAMS{flalign}%
|
||||
}
|
||||
|
||||
% Crop the page for review version
|
||||
\RequirePackage[width=122mm,left=12mm,paperwidth=146mm,height=193mm,top=12mm,paperheight=217mm]{geometry}
|
||||
|
||||
% Replace authors, institute, and running title with review placeholders
|
||||
\let\maketitleold\maketitle
|
||||
\renewcommand{\maketitle}{\author{Anonymous ECCV \eccv@year{} Submission}%
|
||||
\titlerunning{ECCV \eccv@year{} Submission \#\eccv@ID}%
|
||||
\authorrunning{ECCV \eccv@year{} Submission \#\eccv@ID}%
|
||||
\institute{Paper ID \#\eccv@ID}%
|
||||
\maketitleold}
|
||||
\fi
|
||||
|
||||
\ifeccv@mobile
|
||||
% Crop the page for mobile version
|
||||
\RequirePackage[width=122mm,left=12mm,paperwidth=146mm,height=193mm,top=12mm,paperheight=217mm]{geometry}
|
||||
\fi
|
||||
|
||||
% Macro for ECCV year in main text
|
||||
\newcommand{\ECCVyear}{\eccv@year\xspace}
|
||||
\makeatother
|
||||
|
||||
|
||||
% ---------------------------------------------------------------
|
||||
% Support for easy cross-referencing (e.g., \cref{eq:loss}, \cref{sec:intro})
|
||||
% configured with \AtEndPreamble as it needs to be called after hyperref
|
||||
|
||||
\AtEndPreamble{
|
||||
\usepackage[capitalize]{cleveref}
|
||||
\crefname{section}{Sec.}{Secs.}
|
||||
\Crefname{section}{Section}{Sections}
|
||||
\crefname{table}{Tab.}{Tabs.}
|
||||
\Crefname{table}{Table}{Tables}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
% ---------------------------------------------------------------
|
||||
%
|
||||
% Formatting Package for ECCV Submissions
|
||||
%
|
||||
% initially created for ECCV 2024
|
||||
% by Stefan Roth
|
||||
%
|
||||
% based on previous ECCV templates:
|
||||
% updated April 2002 by Antje Endemann
|
||||
% Based on CVPR 07 and LNCS, with modifications by DAF, AZ and elle, 2008 and AA, 2010, and CC, 2011; TT, 2014; AAS, 2016; AAS, 2020; TH, 2022
|
||||
%
|
||||
% and the CVPR templates:
|
||||
% https://github.com/cvpr-org/author-kit
|
||||
%
|
||||
% No guarantee is given that the format corresponds perfectly to
|
||||
% LNCS Proceedings, but most features should be ok.
|
||||
%
|
||||
% ---------------------------------------------------------------
|
||||
|
||||
\NeedsTeXFormat{LaTeX2e}[1999/12/01]
|
||||
\ProvidesPackage{eccvabbrv}[Common abbreviations for ECCV]
|
||||
|
||||
% Add a period to the end of an abbreviation unless there's one
|
||||
% already, then \xspace.
|
||||
\RequirePackage{xspace}
|
||||
\makeatletter
|
||||
\DeclareRobustCommand\onedot{\futurelet\@let@token\@onedot}
|
||||
\def\@onedot{\ifx\@let@token.\else.\null\fi\xspace}
|
||||
|
||||
\def\eg{\emph{e.g}\onedot}
|
||||
\def\Eg{\emph{E.g}\onedot}
|
||||
\def\ie{\emph{i.e}\onedot}
|
||||
\def\Ie{\emph{I.e}\onedot}
|
||||
\def\cf{\emph{cf}\onedot}
|
||||
\def\Cf{\emph{Cf}\onedot}
|
||||
\def\etc{\emph{etc}\onedot}
|
||||
\def\vs{\emph{vs}\onedot}
|
||||
\def\wrt{w.r.t\onedot}
|
||||
\def\dof{d.o.f\onedot}
|
||||
\def\iid{i.i.d\onedot}
|
||||
\def\wolog{w.l.o.g\onedot}
|
||||
\def\etal{\emph{et al}\onedot}
|
||||
\makeatother
|
||||
493
eijkel2.eps
@@ -1,493 +0,0 @@
|
||||
%!PS-Adobe-2.0 EPSF-1.2
|
||||
%%Creator: MATLAB, The Mathworks, Inc.
|
||||
%%Title: parz_sym.eps
|
||||
%%CreationDate: 03/13/96 12:46:22
|
||||
%%DocumentNeededFonts: Helvetica
|
||||
%%DocumentProcessColors: Cyan Magenta Yellow Black
|
||||
%%Pages: 1
|
||||
%%BoundingBox: 59 192 549 590
|
||||
%%EndComments
|
||||
|
||||
%%BeginProlog
|
||||
|
||||
% MathWorks dictionary
|
||||
/MathWorks 150 dict begin
|
||||
|
||||
% definition operators
|
||||
/bdef {bind def} bind def
|
||||
/ldef {load def} bind def
|
||||
/xdef {exch def} bdef
|
||||
/xstore {exch store} bdef
|
||||
|
||||
% operator abbreviations
|
||||
/c /clip ldef
|
||||
/cc /concat ldef
|
||||
/cp /closepath ldef
|
||||
/gr /grestore ldef
|
||||
/gs /gsave ldef
|
||||
/mt /moveto ldef
|
||||
/np /newpath ldef
|
||||
/cm /currentmatrix ldef
|
||||
/sm /setmatrix ldef
|
||||
/rc {rectclip} bdef
|
||||
/rf {rectfill} bdef
|
||||
/rm /rmoveto ldef
|
||||
/rl /rlineto ldef
|
||||
/s /show ldef
|
||||
/sc {setcmykcolor} bdef
|
||||
/sr /setrgbcolor ldef
|
||||
/w /setlinewidth ldef
|
||||
/j /setlinejoin ldef
|
||||
/cap /setlinecap ldef
|
||||
|
||||
% page state control
|
||||
/pgsv () def
|
||||
/bpage {/pgsv save def} bdef
|
||||
/epage {pgsv restore} bdef
|
||||
/bplot /gsave ldef
|
||||
/eplot {stroke grestore} bdef
|
||||
|
||||
% orientation switch
|
||||
/portraitMode 0 def
|
||||
/landscapeMode 1 def
|
||||
|
||||
% coordinate system mappings
|
||||
/dpi2point 0 def
|
||||
|
||||
% font control
|
||||
/FontSize 0 def
|
||||
/FMS {
|
||||
/FontSize xstore %save size off stack
|
||||
findfont
|
||||
[FontSize 0 0 FontSize neg 0 0]
|
||||
makefont
|
||||
setfont
|
||||
}bdef
|
||||
|
||||
/reencode {
|
||||
exch dup where
|
||||
{pop load} {pop StandardEncoding} ifelse
|
||||
exch
|
||||
dup 3 1 roll
|
||||
findfont dup length dict begin
|
||||
{ 1 index /FID ne {def}{pop pop} ifelse } forall
|
||||
/Encoding exch def
|
||||
currentdict
|
||||
end
|
||||
definefont pop
|
||||
} bdef
|
||||
|
||||
/isroman {
|
||||
findfont /CharStrings get
|
||||
/Agrave known
|
||||
} bdef
|
||||
|
||||
/FMSR {
|
||||
3 1 roll 1 index
|
||||
dup isroman
|
||||
{reencode} {pop pop} ifelse
|
||||
exch FMS
|
||||
} bdef
|
||||
|
||||
/csm {
|
||||
1 dpi2point div -1 dpi2point div scale
|
||||
neg translate
|
||||
landscapeMode eq {90 rotate} if
|
||||
} bdef
|
||||
|
||||
% line types: solid, dotted, dashed, dotdash
|
||||
/SO { [] 0 setdash } bdef
|
||||
/DO { [.5 dpi2point mul 4 dpi2point mul] 0 setdash } bdef
|
||||
/DA { [6 dpi2point mul] 0 setdash } bdef
|
||||
/DD { [.5 dpi2point mul 4 dpi2point mul 6 dpi2point mul 4 dpi2point mul] 0 setdash } bdef
|
||||
|
||||
% macros for lines and objects
|
||||
/L {
|
||||
lineto
|
||||
stroke
|
||||
} bdef
|
||||
/MP {
|
||||
3 1 roll moveto
|
||||
1 sub {rlineto} repeat
|
||||
} bdef
|
||||
/AP {
|
||||
{rlineto} repeat
|
||||
} bdef
|
||||
/PP {
|
||||
closepath fill
|
||||
} bdef
|
||||
/DP {
|
||||
closepath stroke
|
||||
} bdef
|
||||
/MR {
|
||||
4 -2 roll moveto
|
||||
dup 0 exch rlineto
|
||||
exch 0 rlineto
|
||||
neg 0 exch rlineto
|
||||
closepath
|
||||
} bdef
|
||||
/FR {
|
||||
MR stroke
|
||||
} bdef
|
||||
/PR {
|
||||
MR fill
|
||||
} bdef
|
||||
/L1i {
|
||||
{ currentfile picstr readhexstring pop } image
|
||||
} bdef
|
||||
|
||||
/tMatrix matrix def
|
||||
/MakeOval {
|
||||
newpath
|
||||
tMatrix currentmatrix pop
|
||||
translate scale
|
||||
0 0 1 0 360 arc
|
||||
tMatrix setmatrix
|
||||
} bdef
|
||||
/FO {
|
||||
MakeOval
|
||||
stroke
|
||||
} bdef
|
||||
/PO {
|
||||
MakeOval
|
||||
fill
|
||||
} bdef
|
||||
|
||||
/PD {
|
||||
2 copy moveto lineto stroke
|
||||
} bdef
|
||||
|
||||
|
||||
currentdict end def
|
||||
%%EndProlog
|
||||
|
||||
%%BeginSetup
|
||||
MathWorks begin
|
||||
|
||||
0 cap
|
||||
|
||||
end
|
||||
%%EndSetup
|
||||
|
||||
%%Page: 1 1
|
||||
%%BeginPageSetup
|
||||
%%PageBoundingBox: 59 192 549 590
|
||||
MathWorks begin
|
||||
bpage
|
||||
%%EndPageSetup
|
||||
|
||||
%%BeginObject: graph1 1
|
||||
bplot
|
||||
|
||||
/dpi2point 12 def
|
||||
portraitMode 0216 7344 csm
|
||||
|
||||
501 259 5882 4776 MR c np
|
||||
76 dict begin %Colortable dictionary
|
||||
/c0 { 0 0 0 sr} bdef
|
||||
/c1 { 1 1 1 sr} bdef
|
||||
/c2 { 1 0 0 sr} bdef
|
||||
/c3 { 0 1 0 sr} bdef
|
||||
/c4 { 0 0 1 sr} bdef
|
||||
/c5 { 1 1 0 sr} bdef
|
||||
/c6 { 1 0 1 sr} bdef
|
||||
/c7 { 0 1 1 sr} bdef
|
||||
%%IncludeResource: font Helvetica
|
||||
/Helvetica /ISOLatin1Encoding 144 FMSR
|
||||
|
||||
1 j
|
||||
c1
|
||||
0 0 6912 5184 PR
|
||||
6 w
|
||||
DO
|
||||
4 w
|
||||
SO
|
||||
6 w
|
||||
c0
|
||||
898 4612 mt 6254 4612 L
|
||||
898 388 mt 6254 388 L
|
||||
6254 4612 mt 6254 388 L
|
||||
898 4612 mt 898 388 L
|
||||
6254 4612 mt 6254 4612 L
|
||||
898 4612 mt 898 4612 L
|
||||
898 4612 mt 6254 4612 L
|
||||
898 4612 mt 898 388 L
|
||||
898 4612 mt 898 4612 L
|
||||
898 4612 mt 898 4558 L
|
||||
898 388 mt 898 442 L
|
||||
734 4781 mt
|
||||
(-25) s
|
||||
1663 4612 mt 1663 4558 L
|
||||
1663 388 mt 1663 442 L
|
||||
1499 4781 mt
|
||||
(-20) s
|
||||
2428 4612 mt 2428 4558 L
|
||||
2428 388 mt 2428 442 L
|
||||
2264 4781 mt
|
||||
(-15) s
|
||||
3193 4612 mt 3193 4558 L
|
||||
3193 388 mt 3193 442 L
|
||||
3029 4781 mt
|
||||
(-10) s
|
||||
3959 4612 mt 3959 4558 L
|
||||
3959 388 mt 3959 442 L
|
||||
3835 4781 mt
|
||||
(-5) s
|
||||
4724 4612 mt 4724 4558 L
|
||||
4724 388 mt 4724 442 L
|
||||
4684 4781 mt
|
||||
(0) s
|
||||
5489 4612 mt 5489 4558 L
|
||||
5489 388 mt 5489 442 L
|
||||
5449 4781 mt
|
||||
(5) s
|
||||
6254 4612 mt 6254 4558 L
|
||||
6254 388 mt 6254 442 L
|
||||
6174 4781 mt
|
||||
(10) s
|
||||
898 4612 mt 952 4612 L
|
||||
6254 4612 mt 6200 4612 L
|
||||
783 4665 mt
|
||||
(0) s
|
||||
898 3767 mt 952 3767 L
|
||||
6254 3767 mt 6200 3767 L
|
||||
503 3820 mt
|
||||
(0.005) s
|
||||
898 2922 mt 952 2922 L
|
||||
6254 2922 mt 6200 2922 L
|
||||
583 2975 mt
|
||||
(0.01) s
|
||||
898 2078 mt 952 2078 L
|
||||
6254 2078 mt 6200 2078 L
|
||||
503 2131 mt
|
||||
(0.015) s
|
||||
898 1233 mt 952 1233 L
|
||||
6254 1233 mt 6200 1233 L
|
||||
583 1286 mt
|
||||
(0.02) s
|
||||
898 388 mt 952 388 L
|
||||
6254 388 mt 6200 388 L
|
||||
503 441 mt
|
||||
(0.025) s
|
||||
898 388 mt 6254 388 L
|
||||
898 4612 mt 6254 4612 L
|
||||
898 4612 mt 898 388 L
|
||||
6254 4612 mt 6254 388 L
|
||||
898 388 mt 898 388 L
|
||||
6254 388 mt 6254 388 L
|
||||
gs 898 388 5357 4225 MR c np
|
||||
DA
|
||||
16 0 15 0 15 0 16 0 15 0 15 0 15 0 16 0
|
||||
15 0 15 0 16 0 15 0 15 0 16 0 15 0 15 0
|
||||
15 0 16 0 15 0 15 0 16 0 15 0 15 0 16 0
|
||||
15 0 15 0 15 0 16 0 15 0 15 0 16 0 15 0
|
||||
15 0 16 0 15 0 15 1 16 0 15 0 15 0 15 0
|
||||
16 0 15 0 15 1 16 0 15 0 15 1 16 0 15 1
|
||||
15 0 15 1 16 1 15 0 15 1 16 1 15 2 15 1
|
||||
16 1 15 2 15 2 15 2 16 3 15 3 15 3 16 3
|
||||
15 4 15 4 16 5 15 5 15 6 16 6 15 7 15 8
|
||||
15 8 16 9 15 10 15 11 16 12 15 13 15 14 16 16
|
||||
15 16 15 18 15 19 16 21 15 22 15 24 16 25 15 27
|
||||
15 29 16 31 15 32 15 35 15 36 16 39 15 40 15 43
|
||||
16 45 15 47 15 49 4724 3846 100 MP stroke
|
||||
16 51 15 53 15 55 15 58 16 59 15 61 15 63 16 65
|
||||
15 67 15 68 16 70 15 71 15 72 16 74 15 74 15 75
|
||||
15 77 16 76 15 77 15 77 16 77 15 77 15 77 16 76
|
||||
15 76 15 75 15 73 16 73 15 71 15 70 16 68 15 66
|
||||
15 65 16 63 15 60 15 59 15 56 16 54 15 52 15 49
|
||||
16 47 15 44 15 42 16 39 15 37 15 34 16 32 15 29
|
||||
15 27 15 24 16 22 15 20 15 17 16 15 15 12 15 11
|
||||
16 8 15 5 15 4 15 1 16 -1 15 -4 15 -5 16 -8
|
||||
15 -11 15 -12 16 -15 15 -17 15 -20 15 -22 16 -24 15 -27
|
||||
15 -29 16 -32 15 -34 15 -37 16 -39 15 -42 15 -44 16 -47
|
||||
15 -49 15 -52 15 -54 16 -56 15 -59 15 -60 16 -63 15 -65
|
||||
15 -66 16 -68 15 -70 15 -71 15 -73 16 -73 15 -75 15 -76
|
||||
16 -76 15 -77 15 -77 3209 2426 100 MP stroke
|
||||
16 -77 15 -77 15 -77 15 -76 16 -77 15 -75 15 -74 16 -74
|
||||
15 -72 15 -71 16 -70 15 -68 15 -67 16 -65 15 -63 15 -61
|
||||
15 -59 16 -58 15 -55 15 -53 16 -51 15 -49 15 -47 16 -45
|
||||
15 -43 15 -40 15 -39 16 -36 15 -35 15 -32 16 -31 15 -29
|
||||
15 -27 16 -25 15 -24 15 -22 15 -21 16 -19 15 -18 15 -16
|
||||
16 -16 15 -14 15 -13 16 -12 15 -11 15 -10 16 -9 15 -8
|
||||
15 -8 15 -7 16 -6 15 -6 15 -5 16 -5 15 -4 15 -4
|
||||
16 -3 15 -3 15 -3 15 -3 16 -2 15 -2 15 -2 16 -1
|
||||
15 -1 15 -2 16 -1 15 -1 15 0 15 -1 16 -1 15 0
|
||||
15 -1 16 0 15 -1 15 0 16 0 15 -1 15 0 15 0
|
||||
16 0 15 0 15 0 16 0 15 -1 15 0 16 0 15 0
|
||||
15 0 16 0 15 0 15 0 15 0 16 0 15 0 15 0
|
||||
16 0 15 0 15 0 1694 4612 100 MP stroke
|
||||
16 0 15 0 15 0 1648 4612 4 MP stroke
|
||||
SO
|
||||
16 0 15 0 15 0 16 0 15 0 15 0 15 0 16 0
|
||||
15 0 15 0 16 0 15 0 15 0 16 0 15 0 15 0
|
||||
15 0 16 0 15 0 15 0 16 0 15 0 15 0 16 0
|
||||
15 0 15 0 15 0 16 0 15 0 15 0 16 0 15 0
|
||||
15 0 16 0 15 0 15 0 16 0 15 0 15 0 15 0
|
||||
16 0 15 0 15 0 16 0 15 0 15 0 16 0 15 0
|
||||
15 0 15 0 16 0 15 0 15 0 16 0 15 0 15 0
|
||||
16 0 15 0 15 0 15 0 16 0 15 0 15 0 16 0
|
||||
15 0 15 0 16 0 15 0 15 0 16 0 15 0 15 0
|
||||
15 0 16 0 15 0 15 1 16 0 15 0 15 0 16 0
|
||||
15 0 15 0 15 1 16 0 15 0 15 1 16 0 15 1
|
||||
15 0 16 1 15 0 15 1 15 1 16 1 15 2 15 1
|
||||
16 1 15 2 15 2 4724 4596 100 MP stroke
|
||||
16 2 15 3 15 2 15 4 16 3 15 4 15 4 16 5
|
||||
15 5 15 5 16 7 15 7 15 7 16 9 15 9 15 10
|
||||
15 11 16 12 15 12 15 14 16 15 15 17 15 17 16 19
|
||||
15 21 15 22 15 23 16 25 15 27 15 28 16 30 15 32
|
||||
15 34 16 35 15 38 15 39 15 41 16 43 15 46 15 47
|
||||
16 49 15 50 15 53 16 54 15 56 15 57 16 59 15 60
|
||||
15 62 15 62 16 64 15 64 15 65 16 65 15 65 15 66
|
||||
16 65 15 65 15 64 15 63 16 62 15 61 15 59 16 57
|
||||
15 55 15 53 16 50 15 48 15 44 15 42 16 38 15 35
|
||||
15 31 16 27 15 23 15 19 16 15 15 11 15 6 16 2
|
||||
15 -2 15 -6 15 -11 16 -15 15 -19 15 -23 16 -27 15 -31
|
||||
15 -35 16 -38 15 -42 15 -44 15 -48 16 -50 15 -53 15 -55
|
||||
16 -57 15 -59 15 -61 3209 2592 100 MP stroke
|
||||
16 -62 15 -63 15 -64 15 -65 16 -65 15 -66 15 -65 16 -65
|
||||
15 -65 15 -64 16 -64 15 -62 15 -62 16 -60 15 -59 15 -57
|
||||
15 -56 16 -54 15 -53 15 -50 16 -49 15 -47 15 -46 16 -43
|
||||
15 -41 15 -39 15 -38 16 -35 15 -34 15 -32 16 -30 15 -28
|
||||
15 -27 16 -25 15 -23 15 -22 15 -21 16 -19 15 -17 15 -17
|
||||
16 -15 15 -14 15 -12 16 -12 15 -11 15 -10 16 -9 15 -9
|
||||
15 -7 15 -7 16 -7 15 -5 15 -5 16 -5 15 -4 15 -4
|
||||
16 -3 15 -4 15 -2 15 -3 16 -2 15 -2 15 -2 16 -1
|
||||
15 -1 15 -2 16 -1 15 -1 15 -1 15 0 16 -1 15 0
|
||||
15 -1 16 0 15 -1 15 0 16 0 15 -1 15 0 15 0
|
||||
16 0 15 0 15 0 16 0 15 -1 15 0 16 0 15 0
|
||||
15 0 16 0 15 0 15 0 15 0 16 0 15 0 15 0
|
||||
16 0 15 0 15 0 1694 4612 100 MP stroke
|
||||
16 0 15 0 15 0 1648 4612 4 MP stroke
|
||||
16 0 15 0 15 0 16 0 15 0 15 0 15 0 16 0
|
||||
15 0 15 0 16 0 15 0 15 0 16 0 15 0 15 0
|
||||
15 0 16 0 15 0 15 0 16 0 15 0 15 0 16 0
|
||||
15 0 15 0 15 0 16 0 15 0 15 0 16 0 15 0
|
||||
15 0 16 0 15 0 15 1 16 0 15 0 15 0 15 0
|
||||
16 0 15 0 15 1 16 0 15 0 15 1 16 0 15 1
|
||||
15 0 15 1 16 0 15 1 15 1 16 1 15 2 15 1
|
||||
16 1 15 2 15 2 15 2 16 3 15 2 15 4 16 3
|
||||
15 4 15 4 16 5 15 5 15 5 16 7 15 7 15 7
|
||||
15 9 16 9 15 10 15 11 16 12 15 12 15 14 16 15
|
||||
15 17 15 17 15 19 16 21 15 22 15 23 16 25 15 27
|
||||
15 28 16 30 15 32 15 34 15 35 16 38 15 39 15 41
|
||||
16 43 15 46 15 47 4724 3862 100 MP stroke
|
||||
16 49 15 50 15 53 15 54 16 56 15 57 15 59 16 60
|
||||
15 62 15 62 16 64 15 64 15 65 16 65 15 65 15 66
|
||||
15 65 16 65 15 64 15 63 16 62 15 61 15 59 16 57
|
||||
15 55 15 53 15 50 16 48 15 44 15 42 16 38 15 35
|
||||
15 31 16 27 15 23 15 19 15 15 16 11 15 6 15 2
|
||||
16 -2 15 -6 15 -11 16 -15 15 -19 15 -23 16 -27 15 -31
|
||||
15 -35 15 -38 16 -42 15 -44 15 -48 16 -50 15 -53 15 -55
|
||||
16 -57 15 -59 15 -61 15 -62 16 -63 15 -64 15 -65 16 -65
|
||||
15 -66 15 -65 16 -65 15 -65 15 -64 15 -64 16 -62 15 -62
|
||||
15 -60 16 -59 15 -57 15 -56 16 -54 15 -53 15 -50 16 -49
|
||||
15 -47 15 -46 15 -43 16 -41 15 -39 15 -38 16 -35 15 -34
|
||||
15 -32 16 -30 15 -28 15 -27 15 -25 16 -23 15 -22 15 -21
|
||||
16 -19 15 -17 15 -17 3209 4446 100 MP stroke
|
||||
16 -15 15 -14 15 -12 15 -12 16 -11 15 -10 15 -9 16 -9
|
||||
15 -7 15 -7 16 -7 15 -5 15 -5 16 -5 15 -4 15 -4
|
||||
15 -3 16 -4 15 -2 15 -3 16 -2 15 -2 15 -2 16 -1
|
||||
15 -1 15 -2 15 -1 16 -1 15 -1 15 0 16 -1 15 0
|
||||
15 -1 16 0 15 -1 15 0 15 0 16 -1 15 0 15 0
|
||||
16 0 15 0 15 0 16 0 15 -1 15 0 16 0 15 0
|
||||
15 0 15 0 16 0 15 0 15 0 16 0 15 0 15 0
|
||||
16 0 15 0 15 0 15 0 16 0 15 0 15 0 16 0
|
||||
15 0 15 0 16 0 15 0 15 0 15 0 16 0 15 0
|
||||
15 0 16 0 15 0 15 0 16 0 15 0 15 0 15 0
|
||||
16 0 15 0 15 0 16 0 15 0 15 0 16 0 15 0
|
||||
15 0 16 0 15 0 15 0 15 0 16 0 15 0 15 0
|
||||
16 0 15 0 15 0 1694 4612 100 MP stroke
|
||||
16 0 15 0 15 0 1648 4612 4 MP stroke
|
||||
DO
|
||||
16 0 15 0 15 0 16 0 15 0 15 0 15 0 16 0
|
||||
15 0 15 0 16 0 15 0 15 0 16 0 15 0 15 0
|
||||
15 0 16 0 15 0 15 0 16 0 15 0 15 0 16 0
|
||||
15 0 15 0 15 0 16 0 15 0 15 0 16 0 15 0
|
||||
15 0 16 0 15 0 15 0 16 0 15 0 15 0 15 0
|
||||
16 0 15 0 15 0 16 0 15 0 15 1 16 0 15 0
|
||||
15 0 15 0 16 0 15 0 15 1 16 0 15 0 15 1
|
||||
16 0 15 1 15 0 15 1 16 0 15 1 15 1 16 1
|
||||
15 2 15 1 16 1 15 2 15 2 16 2 15 3 15 2
|
||||
15 4 16 3 15 4 15 4 16 5 15 5 15 5 16 7
|
||||
15 7 15 7 15 9 16 9 15 10 15 11 16 12 15 12
|
||||
15 14 16 15 15 17 15 17 15 19 16 21 15 22 15 23
|
||||
16 25 15 27 15 28 4724 4247 100 MP stroke
|
||||
16 30 15 32 15 34 15 35 16 38 15 39 15 41 16 43
|
||||
15 46 15 47 16 49 15 50 15 53 16 54 15 56 15 57
|
||||
15 59 16 60 15 62 15 62 16 64 15 64 15 65 16 65
|
||||
15 65 15 66 15 65 16 65 15 64 15 63 16 62 15 61
|
||||
15 59 16 57 15 55 15 53 15 50 16 48 15 44 15 42
|
||||
16 38 15 35 15 31 16 27 15 23 15 19 16 15 15 11
|
||||
15 6 15 2 16 -2 15 -6 15 -11 16 -15 15 -19 15 -23
|
||||
16 -27 15 -31 15 -35 15 -38 16 -42 15 -44 15 -48 16 -50
|
||||
15 -53 15 -55 16 -57 15 -59 15 -61 15 -62 16 -63 15 -64
|
||||
15 -65 16 -65 15 -66 15 -65 16 -65 15 -65 15 -64 16 -64
|
||||
15 -62 15 -62 15 -60 16 -59 15 -57 15 -56 16 -54 15 -53
|
||||
15 -50 16 -49 15 -47 15 -46 15 -43 16 -41 15 -39 15 -38
|
||||
16 -35 15 -34 15 -32 3209 4217 100 MP stroke
|
||||
16 -30 15 -28 15 -27 15 -25 16 -23 15 -22 15 -21 16 -19
|
||||
15 -17 15 -17 16 -15 15 -14 15 -12 16 -12 15 -11 15 -10
|
||||
15 -9 16 -9 15 -7 15 -7 16 -7 15 -5 15 -5 16 -5
|
||||
15 -4 15 -4 15 -3 16 -4 15 -2 15 -3 16 -2 15 -2
|
||||
15 -2 16 -1 15 -1 15 -2 15 -1 16 -1 15 -1 15 0
|
||||
16 -1 15 0 15 -1 16 0 15 -1 15 0 16 0 15 -1
|
||||
15 0 15 0 16 0 15 0 15 0 16 0 15 -1 15 0
|
||||
16 0 15 0 15 0 15 0 16 0 15 0 15 0 16 0
|
||||
15 0 15 0 16 0 15 0 15 0 15 0 16 0 15 0
|
||||
15 0 16 0 15 0 15 0 16 0 15 0 15 0 15 0
|
||||
16 0 15 0 15 0 16 0 15 0 15 0 16 0 15 0
|
||||
15 0 16 0 15 0 15 0 15 0 16 0 15 0 15 0
|
||||
16 0 15 0 15 0 1694 4612 100 MP stroke
|
||||
16 0 15 0 15 0 1648 4612 4 MP stroke
|
||||
0 -2703 4112 4612 2 MP stroke
|
||||
0 -2703 3499 4612 2 MP stroke
|
||||
0 -3823 3959 4612 2 MP stroke
|
||||
SO
|
||||
|
||||
gr
|
||||
3463 3236 mt 3535 3236 L
|
||||
3499 3200 mt 3499 3272 L
|
||||
gs 898 388 5357 4225 MR c np
|
||||
|
||||
gr
|
||||
3923 3236 mt 3995 3236 L
|
||||
3959 3200 mt 3959 3272 L
|
||||
gs 898 388 5357 4225 MR c np
|
||||
|
||||
gr
|
||||
3923 789 mt 3995 789 L
|
||||
3959 753 mt 3959 825 L
|
||||
3923 753 mt 3995 825 L
|
||||
3995 753 mt 3923 825 L
|
||||
gs 898 388 5357 4225 MR c np
|
||||
|
||||
gr
|
||||
4076 2129 mt 4148 2201 L
|
||||
4148 2129 mt 4076 2201 L
|
||||
gs 898 388 5357 4225 MR c np
|
||||
|
||||
gr
|
||||
3923 2129 mt 3995 2201 L
|
||||
3995 2129 mt 3923 2201 L
|
||||
gs 898 388 5357 4225 MR c np
|
||||
|
||||
gr
|
||||
3423 5003 mt
|
||||
(Xi) s
|
||||
3867 5003 mt
|
||||
(Xs) s
|
||||
4050 5003 mt
|
||||
(Xj) s
|
||||
|
||||
end
|
||||
|
||||
eplot
|
||||
%%EndObject graph 1
|
||||
|
||||
epage
|
||||
end
|
||||
|
||||
showpage
|
||||
|
||||
%%Trailer
|
||||
%%EOF
|
||||