1 Commits

Author SHA1 Message Date
Tobias Christian Nauen
ff34712155 AAAI Version 2026-02-24 12:22:44 +01:00
377 changed files with 20390 additions and 5022 deletions

Binary file not shown.

View File

@@ -0,0 +1,73 @@
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]
# Same as Black.
line-length = 120
indent-width = 4
[lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["E", "F", "D", "B", "NPY", "PD", "TD005", "TD006", "TD007", "SIM", "RET", "Q", "ICN", "I"]
ignore = ["D203", "D213", "E501", "D100", "NPY002", "D102", "B008", "PD011", "D105", "SIM118"]
# Allow fix for all enabled rules (when `--fix`) is provided.
unfixable = ["B"]
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = true
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
# Enable auto-formatting of code examples in docstrings. Markdown,
# reStructuredText code/literal blocks and doctests are all supported.
#
# This is currently disabled by default, but it is planned for this
# to be opt-out in the future.
docstring-code-format = true
# Set the line length limit used when formatting code snippets in
# docstrings.
#
# This only has an effect when the `docstring-code-format` setting is
# enabled.
docstring-code-line-length = "dynamic"

View File

@@ -0,0 +1,61 @@
# Creating the ForNet Dataset
We can't just provide the ForNet dataset here, as it's too large to be part of the appendix and using a link will go against double-blind review rules.
After acceptance, the dataset will be downloadable online.
For now, we provide the scripts and steps to recreate the dataset.
In general, if you are unsure what arguments each script allows, run it using the `--help` flag.
## 1. Setup paths
Fill in the paths in `experiments/general_srun` in the `Model Training Code` folder, as well as in `srun-general.sh`, `slurm-segment-imnet.sh` and all the `sbatch-segment-...` files.
In particular the `--container-image`, `--container-mounts`, `--output` and `NLTK_DATA` and `HF_HOME` paths in `--export`.
## 2. Pretrain Filtering Models
Use the `Model Trainig Code` to pretrain an ensemble of models to use for filtering in a later step.
Train those models on either `TinyImageNet` or `ImageNet`, depending on if you want to create `TinyForNet` or `ForNet`.
The fill in the relevant paths to the pretrained weights in `experiments/filter_segmentation_versions.py` lines 96/98.
## 3. Create the dataset
### Automatically: using slurm
You may just run the `create_dataset.py` file (on a slurm head node). That file will automatically run all the necessary steps one after another.
### Manually and step-by-step
If you want to run each step of the pipeline manually, follow these steps.
For default arguments and settings, see the `create_dataset.py` script, even though you may not want to run it directly, it can tell you how to run all the other scripts.
#### 3.1 Segment Objects and Backgrounds
Use the segementation script (`segment_imagenet.py`) to segment each of the dataset images.
Watch out, as this script uses `datadings` for image loading, so you need to provide a `datadings` variant of your dataset.
You need to provide the root folder of the dataset.
Choose your segmentation model using the `-model` argument (LaMa or AttErase).
If you want to use the >general< prompting strategy, set the `--parent_in_promt` flag.
Use `--output`/`-o` to set the output directory.
Use `--processes` and `-id` for splitting the task up into multiple parallelizable processes.
#### 3.2 Filter the segmented images
In this step, you use the pretrained ensemble of models (from step 2) for filtering the segmented images.
As this step is based on the training and model code, it's in the `Model Training Code` directory.
After setting the relevant paths to the pretrained weights (see step 2), you may run the `experiments/filter_segmentation_versions.py` script using that directory as the PWD.
#### 3.3 Zip the dataset
In distributed storage settings it might be useful to read from one large (unclompressed) zip file instead of reading millions of small single files.
To do this, run
```commandline
zip -r -0 backgrounds_train.zip train/backgrounds > /dev/null 2>&1
```
for the train and val backgrounds and foregrounds
#### 3.4 Compute the foreground size ratios
For the resizing step during recombination, the relative size of each object in each image is needed.
To compute it, run the `foreground_size_ratio.py` script on your filtered dataset.
It expects the zipfiled in the folder you provide as `-ds`.

View File

@@ -0,0 +1,98 @@
from copy import copy
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DDIMScheduler, DiffusionPipeline
from PIL import Image, ImageFilter
from torchvision.transforms.functional import gaussian_blur, to_tensor
def _preprocess_image(image, device, dtype=torch.float16):
image = to_tensor(image)
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
if image.shape[1] != 3:
image = image.expand(-1, 3, -1, -1)
image = F.interpolate(image, (1024, 1024))
return image.to(dtype).to(device)
def _preprocess_mask(mask, device, dtype=torch.float16):
mask = to_tensor(mask.convert("L"))
mask = mask.unsqueeze_(0).float() # 0 or 1
mask = F.interpolate(mask, (1024, 1024))
mask = gaussian_blur(mask, kernel_size=(77, 77))
mask[mask < 0.1] = 0
mask[mask >= 0.1] = 1
return mask.to(dtype).to(device)
class AttentiveEraser:
"""Attentive Eraser Pipeline + Pre- and Post-Processing."""
prompt = ""
dtype = torch.float16
def __init__(self, num_steps=50, device=None):
"""Create the attentive eraser.
Args:
num_steps (int, optional): Number of steps in the diffusion process. Will start at 20%. Defaults to 100.
device (_type_, optional): Device to run on. Defaults to 'cuda' if available, else 'cpu'.
"""
self.num_steps = num_steps
self.device = device if device else torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
)
self.model_path = "stabilityai/stable-diffusion-xl-base-1.0"
self.pipeline = DiffusionPipeline.from_pretrained(
self.model_path,
custom_pipeline="./pipelines/pipeline_stable_diffusion_xl_attentive_eraser.py",
scheduler=scheduler,
variant="fp16",
use_safetensors=True,
torch_dtype=self.dtype,
).to(self.device)
self.pipeline.enable_attention_slicing()
self.pipeline.enable_model_cpu_offload()
def __call__(self, image, mask):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if isinstance(mask, np.ndarray):
mask = Image.fromarray(mask)
prep_img = _preprocess_image(image, self.device, self.dtype)
prep_mask = _preprocess_mask(mask, self.device, self.dtype)
orig_shape = image.size
diff_img = (
self.pipeline(
prompt=self.prompt,
image=prep_img,
mask_image=prep_mask,
height=1024,
width=1024,
AAS=True, # enable AAS
strength=0.8, # inpainting strength
rm_guidance_scale=9, # removal guidance scale
ss_steps=9, # similarity suppression steps
ss_scale=0.3, # similarity suppression scale
AAS_start_step=0, # AAS start step
AAS_start_layer=34, # AAS start layer
AAS_end_layer=70, # AAS end layer
num_inference_steps=self.num_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
guidance_scale=1,
)
.images[0]
.resize(orig_shape)
)
paste_mask = Image.fromarray(np.array(mask) > 0 * 255)
paste_mask = paste_mask.convert("RGB").filter(ImageFilter.GaussianBlur(radius=10)).convert("L")
out_img = copy(image)
out_img.paste(diff_img, (0, 0), paste_mask)
return out_img

View File

@@ -0,0 +1,195 @@
import argparse
import os
import re
import subprocess
parser = argparse.ArgumentParser(description="Create Segment & Recombine dataset")
parser.add_argument("-d", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
parser.add_argument(
"--out_root", type=str, required=True, help="Root directory where the output directory will be created."
)
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
parser.add_argument(
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
)
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
parser.add_argument("-infill_model", choices=["LaMa", "AttErase"], default="LaMa", help="Infilling model to use")
parser.add_argument("--continue", dest="continue_", action="store_true", help="Continue from previous run")
args = parser.parse_args()
out_root = args.out_root
base_folder = os.path.dirname(__file__)
training_code_folder = os.path.join(base_folder, os.pardir, "Model Training Code")
ds_name_re = re.compile(r"(Tiny)?INSegment_v(\d*)(_f\d*)?")
name_match = ds_name_re.match(args.output)
assert name_match, f"Output name {args.output} does not match the expected format: {ds_name_re.pattern}"
assert (
args.output_ims == "best" or name_match.group(3) is None
), "For output_ims == 'all', the filter subversions will be automatically created."
assert args.continue_ or not os.path.exists(out_root + args.output), f"Output directory {args.output} already exists."
settings_file = {
"dataset": args.dataset,
"threshold": args.threshold,
"output_ims": args.output_ims,
"mask_merge_threshold": args.mask_merge_threshold,
"parent_labels": args.parent_labels,
"parent_in_prompt": args.parent_in_prompt,
"infill_model": args.infill_model,
}
settings_file = [f"{k} = {str(settings_file[k])}" for k in sorted(list(settings_file.keys()))]
if os.path.exists(out_root + args.output) and os.path.exists(out_root + args.output + "/settings.txt"):
with open(out_root + args.output + "/settings.txt", "r") as f:
old_settings = f.read().split("\n")
old_settings = [line.strip() for line in old_settings if line if len(line.strip()) > 0]
assert old_settings == settings_file, (
f"Settings file {out_root + args.output}/settings.txt does not match current settings: old: {old_settings} vs"
f" new: {settings_file}"
)
else:
os.makedirs(out_root + args.output, exist_ok=True)
with open(out_root + args.output + "/settings.txt", "w") as f:
f.write("\n".join(settings_file) + "\n")
general_args = [
"sbatch",
"sbatch-segment-tinyimnet-wait" if args.dataset == "tinyimagenet" else "sbatch-segment-imagenet-wait",
"-r",
args.dataset_root,
"-o",
out_root + args.output,
"--parent_labels",
str(args.parent_labels),
"--output_ims",
args.output_ims,
"--mask_merge_threshold",
str(args.mask_merge_threshold),
"-t",
str(args.threshold),
"-model",
args.infill_model,
]
if args.parent_in_prompt:
general_args.append("--parent_in_prompt")
print(f"Starting segmentation: {' '.join(general_args)} for {args.dataset}-val and {args.dataset}")
p_train = subprocess.Popen(general_args + ["-d", args.dataset], cwd=base_folder)
if args.dataset == "imagenet":
general_args[1] = "sbatch-segment-imagenet-val-wait"
p_val = subprocess.Popen(general_args + ["-d", args.dataset + "-val"], cwd=base_folder)
# detect if exit in error
p_val.wait()
p_train.wait()
rcodes = (p_val.returncode, p_train.returncode)
if any(rcode != 0 for rcode in rcodes):
print(f"Error in segmentation (val, train): {rcodes}")
exit(1)
print("Segmentation done.")
if args.output_ims == "all":
print("copy to subversions for filtering")
p_1 = subprocess.Popen(
[
"cp",
"-rl",
os.path.join(out_root, args.output),
out_root + f"{args.output}_f1/",
]
)
p_2 = subprocess.Popen(
[
"cp",
"-rl",
os.path.join(out_root, args.output),
out_root + f"{args.output}_f2/",
]
)
p_1.wait()
p_2.wait()
print("Filtering subversions copied over.")
print("Starting filtering")
filtering_args = ["./experiments/general_srun.sh", "experiments/filter_segmentation_versions.py"]
p_val_f1 = subprocess.Popen(
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/val", "-d", args.dataset], cwd=training_code_folder
)
p_train_f1 = subprocess.Popen(
filtering_args + ["-f", out_root + f"{args.output}_f1" + "/train", "-d", args.dataset], cwd=training_code_folder
)
p_val_f2 = subprocess.Popen(
filtering_args
+ ["-f", out_root + f"{args.output}_f2" + "/val", "-score_f_weights", "automatic", "-d", args.dataset],
cwd=training_code_folder,
)
p_train_f2 = subprocess.Popen(
filtering_args
+ ["-f", out_root + f"{args.output}_f2" + "/train", "-score_f_weights", "automatic", "-d", args.dataset],
cwd=training_code_folder,
)
p_val_f1.wait()
p_train_f1.wait()
p_val_f2.wait()
p_train_f2.wait()
ds_folders = [f"{args.output}_f1", f"{args.output}_f2"]
else:
if name_match.group(2) is None:
new_name = args.output + "_f1"
print(f"Renaming output folder: {args.output} -> {new_name}")
os.rename(out_root + args.output, out_root + new_name)
else:
new_name = args.output
ds_folders = [new_name]
for folder in ds_folders:
print(f"Zipping up {folder}")
p_val_fg = subprocess.Popen(
["zip", "-r", "-0", "foregrounds_val.zip", "val/foregrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
)
p_train_fg = subprocess.Popen(
["zip", "-r", "-0", "foregrounds_train.zip", "train/foregrounds", ">", "/dev/null", "2>&1"],
cwd=out_root + folder,
)
p_val_bg = subprocess.Popen(
["zip", "-r", "-0", "backgrounds_val.zip", "val/backgrounds", ">", "/dev/null", "2>&1"], cwd=out_root + folder
)
p_train_bg = subprocess.Popen(
["zip", "-r", "-0", "backgrounds_train.zip", "train/backgrounds", ">", "/dev/null", "2>&1"],
cwd=out_root + folder,
)
p_val_fg.wait()
p_train_fg.wait()
p_val_bg.wait()
p_train_bg.wait()
print(f"Gathering foreground size ratios for {folder}")
p_val = subprocess.Popen(
[
"./srun-general.sh",
"python",
"foreground_size_ratio.py",
"--root",
args.dataset_root,
"-ds",
folder,
"-mode",
"val",
],
cwd=base_folder,
)
p_train = subprocess.Popen(
["./srun-general.sh", "python", "foreground_size_ratio.py", "--root", args.dataset_root, "-ds", folder],
cwd=base_folder,
)
p_val.wait()
p_train.wait()

Binary file not shown.

After

Width:  |  Height:  |  Size: 402 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 148 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

View File

@@ -0,0 +1,69 @@
import argparse
import json
import os
import zipfile
from io import BytesIO
import numpy as np
import PIL
from PIL import Image
from tqdm.auto import tqdm
parser = argparse.ArgumentParser(description="Foreground size ratio")
parser.add_argument("-mode", choices=["train", "val"], default="train", help="Train or val data?")
parser.add_argument("-ds", "--dataset", type=str, required=True, help="Dataset to use")
parser.add_argument("--root", type=str, required=True, help="Root folder of the dataset")
args = parser.parse_args()
# if not os.path.exists("foreground_size_ratios.json"):
train = args.mode == "train"
root = os.path.join(args.root, args.dataset) if not args.dataset.startswith("/") else args.dataset
cls_bg_ratios = {}
fg_bg_ratio_map = {}
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip, zipfile.ZipFile(
f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r"
) as fg_zip:
backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
print(f"Bgs: {backgrounds[:5]}, ...\nFgs: {foregrounds[:5]}, ...")
for bg_name in tqdm(backgrounds):
fg_name = bg_name.replace("backgrounds", "foregrounds").replace("JPEG", "WEBP")
if fg_name not in foregrounds:
tqdm.write(f"Skipping {bg_name} as it has no corresponding foreground")
fg_bg_ratio_map[bg_name] = 0.0
continue
with bg_zip.open(bg_name) as f:
bg_data = BytesIO(f.read())
try:
bg_img = Image.open(bg_data)
except PIL.UnidentifiedImageError as e:
print(f"Error with file={bg_name}")
raise e
bg_img_size = bg_img.size
with fg_zip.open(fg_name) as f:
fg_data = BytesIO(f.read())
try:
fg_img = Image.open(fg_data)
except PIL.UnidentifiedImageError as e:
print(f"Error with file={fg_name}")
raise e
fg_img_size = fg_img.size
fg_img_pixel_size = int(np.sum(fg_img.split()[-1]) / 255)
img_cls = bg_name.split("/")[-2]
if img_cls not in cls_bg_ratios:
cls_bg_ratios[img_cls] = []
cls_bg_ratios[img_cls].append(fg_img_pixel_size / (bg_img_size[0] * bg_img_size[1]))
fg_bg_ratio_map[bg_name] = fg_img_pixel_size / (bg_img_size[0] * bg_img_size[1])
with open(f"{root}/fg_bg_ratios_{args.mode}.json", "w") as f:
json.dump(fg_bg_ratio_map, f)
print(f"Saved fg_bg_ratios_{args.mode} to disk. Exiting.")

View File

@@ -0,0 +1,42 @@
import argparse
import json
from datadings.reader import MsgpackReader
from datadings.torch import Dataset
from tqdm.auto import tqdm
parser = argparse.ArgumentParser(description="Foreground size ratio")
parser.add_argument("-mode", choices=["train", "val"], default="train", help="Train or val data?")
parser.add_argument("-ds", "--dataset", choices=["imagenet", "tinyimagenet"], required=True, help="Dataset to use")
parser.add_argument("-r", "--dataset_root", required=True, type=str, help="Root directory of the dataset")
parser.add_argument("-s", "--segment_root", required=True, type=str, help="Root directory of the segmentation dataset")
args = parser.parse_args()
if args.dataset == "imagenet":
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/{args.mode}.msgpack")
dataset = Dataset(reader)
elif args.dataset == "tinyimagenet":
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_{args.mode}.msgpack")
dataset = Dataset(reader)
else:
raise ValueError(f"Unknown dataset: {args.dataset}")
if args.dataset.startswith("imagenet"):
with open("wordnet_data/imagenet1k_synsets.json", "r") as f:
id_to_synset = json.load(f)
id_to_synset = {int(k): v for k, v in id_to_synset.items()}
elif args.dataset.startswith("tinyimagenet"):
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
synsets = f.readlines()
id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets]
id_to_synset = sorted(id_to_synset)
key_to_idx = {
f"n{id_to_synset[data['label']]:08d}/{data['key'].split('.')[0]}": i
for i, data in enumerate(tqdm(dataset, leave=True))
}
print(len(key_to_idx))
with open(f"{args.segment_root}/{args.mode}_indices.json", "w") as f:
json.dump(key_to_idx, f)

View File

@@ -0,0 +1,223 @@
"""Use grounding DINO + Segment Anything (SAM) to perform grounded segmentation on an image.
Based on: https://github.com/IDEA-Research/Grounded-Segment-Anything
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import requests
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
@dataclass
class BoundingBox:
"""Bounding box representation."""
xmin: int
ymin: int
xmax: int
ymax: int
@property
def xyxy(self) -> List[float]:
"""Return bounding box coordinates.
Returns:
List[float]: coodinates: [xmin, ymin, xmax, ymax]
"""
return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass
class DetectionResult:
"""Detection result from Grounding DINO + Mask from SAM."""
score: float
label: str
box: BoundingBox
mask: Optional[np.array] = None
@classmethod
def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
"""Create a DetectionResult from a dictionary.
Args:
detection_dict (Dict): Detection result dictionary.
Returns:
DetectionResult: Detection result object.
"""
return cls(
score=detection_dict["score"],
label=detection_dict["label"],
box=BoundingBox(
xmin=detection_dict["box"]["xmin"],
ymin=detection_dict["box"]["ymin"],
xmax=detection_dict["box"]["xmax"],
ymax=detection_dict["box"]["ymax"],
),
)
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
"""Use OpenCV to refine a mask by turning it into a polygon.
Args:
mask (np.ndarray): Segmentation mask.
Returns:
List[List[int]]: List of (x, y) coordinates representing the vertices of the polygon.
"""
# Find contours in the binary mask
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find the contour with the largest area
largest_contour = max(contours, key=cv2.contourArea)
# Extract the vertices of the contour
return largest_contour.reshape(-1, 2).tolist()
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
"""Convert a polygon to a segmentation mask.
Args:
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
image_shape (tuple): Shape of the image (height, width) for the mask.
Returns:
np.ndarray: Segmentation mask with the polygon filled.
"""
# Create an empty mask
mask = np.zeros(image_shape, dtype=np.uint8)
# Convert polygon to an array of points
pts = np.array(polygon, dtype=np.int32)
# Fill the polygon with white color (255)
cv2.fillPoly(mask, [pts], color=(255,))
return mask
def load_image(image_str: str) -> Image.Image:
"""Load an image from a URL or file path.
Args:
image_str (str): URL or file path to the image.
Returns:
PIL.Image: Image object.
"""
if image_str.startswith("http"):
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
else:
image = Image.open(image_str).convert("RGB")
return image
def _get_boxes(results: DetectionResult) -> List[List[List[float]]]:
boxes = []
for result in results:
xyxy = result.box.xyxy
boxes.append(xyxy)
return [boxes]
def _refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
masks = masks.cpu().float()
masks = masks.permute(0, 2, 3, 1)
masks = masks.mean(axis=-1)
masks = (masks > 0).int()
masks = masks.numpy().astype(np.uint8)
masks = list(masks)
if polygon_refinement:
for idx, mask in enumerate(masks):
shape = mask.shape
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
masks[idx] = mask
return masks
device = "cuda" if torch.cuda.is_available() else "cpu"
detector_id = "IDEA-Research/grounding-dino-tiny"
print(f"load object detector pipeline: {detector_id}")
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
segmenter_id = "facebook/sam-vit-base"
print(f"load segmentator: {segmenter_id}")
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
print(f"load processor: {segmenter_id}")
processor = AutoProcessor.from_pretrained(segmenter_id)
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3) -> List[Dict[str, Any]]:
"""Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion."""
global object_detector, device
labels = [label if label.endswith(".") else label + "." for label in labels]
results = object_detector(image, candidate_labels=labels, threshold=threshold)
return [DetectionResult.from_dict(result) for result in results]
def segment(
image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False
) -> List[DetectionResult]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
global segmentator, processor, device
boxes = _get_boxes(detection_results)
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
outputs = segmentator(**inputs)
masks = processor.post_process_masks(
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes
)[0]
masks = _refine_masks(masks, polygon_refinement)
for detection_result, mask in zip(detection_results, masks):
detection_result.mask = mask
return detection_results
def grounded_segmentation(
image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False
) -> Tuple[torch.Tensor, List[DetectionResult]]:
"""Segment out the objects in an image given a set of labels.
Args:
image (Union[Image.Image, str]): Image to load/work on.
labels (List[str]): Object labels to segment.
threshold (float, optional): Segmentation threshold. Defaults to 0.3.
polygon_refinement (bool, optional): Use polygon refinement on the segmented mask? Defaults to False.
Returns:
Tuple[torch.Tensor, List[DetectionResult]]: Image tensor and list of detection results.
"""
if isinstance(image, str):
image = load_image(image)
detections = detect(image, labels, threshold)
if len(detections) == 0:
return ToTensor()(image), []
detections = segment(image, detections, polygon_refinement)
return ToTensor()(image), detections

View File

@@ -0,0 +1,271 @@
"""Use the LaMa (*LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions*) model to infill the background of an image.
Based on: https://github.com/Sanster/IOPaint/blob/main/iopaint/model/lama.py
"""
import abc
import hashlib
import os
import sys
from enum import Enum
from typing import Optional
from urllib.parse import urlparse
import numpy as np
import torch
from torch.hub import download_url_to_file
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
)
LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e500")
class LDMSampler(str, Enum):
ddim = "ddim"
plms = "plms"
class SDSampler(str, Enum):
ddim = "ddim"
pndm = "pndm"
k_lms = "k_lms"
k_euler = "k_euler"
k_euler_a = "k_euler_a"
dpm_plus_plus = "dpm++"
uni_pc = "uni_pc"
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
def pad_img_to_modulo(img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None):
"""
Args:
img: [H, W, C]
mod:
square: 是否为正方形
min_size:
Returns:
"""
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
height, width = img.shape[:2]
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
if min_size is not None:
assert min_size % mod == 0
out_width = max(min_size, out_width)
out_height = max(min_size, out_height)
if square:
max_size = max(out_height, out_width)
out_height = max_size
out_width = max_size
return np.pad(
img,
((0, out_height - height), (0, out_width - width), (0, 0)),
mode="symmetric",
)
def undo_pad_to_mod(img: np.ndarray, height: int, width: int):
return img[:height, :width, :]
class InpaintModel:
name = "base"
min_size: Optional[int] = None
pad_mod = 8
pad_to_square = False
def __init__(self, device, **kwargs):
"""
Args:
device:
"""
self.device = device
self.init_model(device, **kwargs)
@abc.abstractmethod
def init_model(self, device, **kwargs): ...
@staticmethod
@abc.abstractmethod
def is_downloaded() -> bool: ...
@abc.abstractmethod
def forward(self, image, mask):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W, 1] 255 为 masks 区域
return: BGR IMAGE
"""
...
def _pad_forward(self, image, mask):
origin_height, origin_width = image.shape[:2]
pad_image = pad_img_to_modulo(image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
pad_mask = pad_img_to_modulo(mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
result = self.forward(pad_image, pad_mask)
# result = result[0:origin_height, 0:origin_width, :]
# result, image, mask = self.forward_post_process(result, image, mask)
mask = mask[:, :, np.newaxis]
result = result[0].permute(1, 2, 0).cpu().numpy() # 400 x 504 x 3
result = undo_pad_to_mod(result, origin_height, origin_width)
assert result.shape[:2] == image.shape[:2], f"{result.shape[:2]} != {image.shape[:2]}"
# result = result * 255
mask = (mask > 0) * 255
result = result * mask + image * (1 - (mask / 255))
result = np.clip(result, 0, 255).astype("uint8")
return result
def forward_post_process(self, result, image, mask):
return result, image, mask
def resize_np_img(np_img, size, interpolation="bicubic"):
assert interpolation in [
"nearest",
"bilinear",
"bicubic",
], f"Unsupported interpolation: {interpolation}, use nearest, bilinear or bicubic."
torch_img = torch.from_numpy(np_img).permute(2, 0, 1).unsqueeze(0).float() / 255
interp_img = torch.nn.functional.interpolate(torch_img, size=size, mode=interpolation, align_corners=True)
return (interp_img.squeeze(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
def resize_max_size(np_img, size_limit: int, interpolation="bicubic") -> np.ndarray:
# Resize image's longer size to size_limit if longer size larger than size_limit
h, w = np_img.shape[:2]
if max(h, w) > size_limit:
ratio = size_limit / max(h, w)
new_w = int(w * ratio + 0.5)
new_h = int(h * ratio + 0.5)
return resize_np_img(np_img, size=(new_w, new_h), interpolation=interpolation)
else:
return np_img
def norm_img(np_img):
if len(np_img.shape) == 2:
np_img = np_img[:, :, np.newaxis]
np_img = np.transpose(np_img, (2, 0, 1))
np_img = np_img.astype("float32") / 255
return np_img
def get_cache_path_by_url(url):
parts = urlparse(url)
hub_dir = "~/.TORCH_HUB_CACHE"
model_dir = os.path.join(hub_dir, "checkpoints")
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
return cached_file
def md5sum(filename):
md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
md5.update(chunk)
return md5.hexdigest()
def download_model(url, model_md5: str = None):
cached_file = get_cache_path_by_url(url)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
download_url_to_file(url, cached_file, hash_prefix, progress=True)
if model_md5:
_md5 = md5sum(cached_file)
if model_md5 == _md5:
print(f"Download model success, md5: {_md5}")
else:
try:
os.remove(cached_file)
print(
f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart"
" lama-cleaner.If you still have errors, please try download model manually first"
" https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
)
except:
print(
f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart"
" lama-cleaner."
)
exit(-1)
return cached_file
def load_jit_model(url_or_path, device, model_md5: str):
if os.path.exists(url_or_path):
model_path = url_or_path
else:
model_path = download_model(url_or_path, model_md5)
print(f"Loading model from: {model_path}")
try:
model = torch.jit.load(model_path, map_location="cpu").to(device)
except Exception as e:
print(f"Error loading model: {e}")
raise
model.eval()
return model
class LaMa(InpaintModel):
name = "lama"
pad_mod = 8
@torch.no_grad()
def __call__(self, image, mask):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
# boxes = boxes_from_mask(mask)
inpaint_result = self._pad_forward(image, mask)
return inpaint_result
def init_model(self, device, **kwargs):
self.model = load_jit_model(LAMA_MODEL_URL, device, LAMA_MODEL_MD5).eval()
@staticmethod
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL))
def forward(self, image, mask):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
return: BGR IMAGE
"""
image = norm_img(image)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
inpainted_image = self.model(image, mask)
return inpainted_image

View File

@@ -0,0 +1,65 @@
import argparse
import os
from random import shuffle
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
parser = argparse.ArgumentParser(description="Inspect image versions")
parser.add_argument("-f", "--base_folder", type=str, required=True, help="Base folder to inspect")
parser.add_argument("-opt_fg_ratio", type=float, default=0.3, help="Optimal foreground ratio")
args = parser.parse_args()
bg_folder = os.path.join(args.base_folder, "backgrounds")
fg_folder = os.path.join(args.base_folder, "foregrounds")
classes = os.listdir(fg_folder)
classes = sorted(classes, key=lambda x: int(x[1:]))
assert len(classes) in [200, 1_000], f"Expected 200 or 1_000 classes, got {len(classes)}"
total_images = set()
for in_cls in classes:
cls_images = {
os.path.join(in_cls, "_".join(img.split(".")[0].split("_")[:3]))
for img in os.listdir(os.path.join(fg_folder, in_cls))
}
total_images.update(cls_images)
total_images = list(total_images)
shuffle(total_images)
print(total_images[:5], "...")
# in_cls to print name/lemma
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
in_cls_to_name = {line.split(":")[0].strip(): line.split(":")[1].strip() for line in f.readlines() if len(line) > 2}
for image_name in total_images:
in_cls, img_name = image_name.split("/")
versions = set()
for img in os.listdir(os.path.join(fg_folder, in_cls)):
if img.startswith(img_name):
versions.add(img)
if len(versions) <= 1:
print(f"Image {image_name} has only one version")
continue
versions = sorted(list(versions))
bg_img = Image.open(os.path.join(bg_folder, in_cls, f"{versions[0].split('.')[0]}.JPEG"))
# plot all versions at once
fig, axs = plt.subplots(1, len(versions) + 1, figsize=(15, 5))
axs[0].imshow(bg_img)
axs[0].axis("off")
axs[0].set_title("Background")
for version, ax in zip(versions, axs[1:]):
img = Image.open(os.path.join(fg_folder, in_cls, version))
img_mask = np.array(img.convert("RGBA").split()[-1])
fg_dens_fact = np.sum(img_mask) / (255 * img_mask.size)
fg_ratio = np.sum(img_mask) / (255 * bg_img.size[0] * bg_img.size[1])
ax.imshow(img)
ax.axis("off")
ax.set_title(f"{version}\nfg_rat: {fg_ratio:.2f} dist to optimal: {abs(fg_ratio - args.opt_fg_ratio):.2f}")
fig.suptitle(in_cls_to_name[in_cls])
plt.show()
plt.close()

View File

@@ -0,0 +1,11 @@
tqdm
einops
omegaconf
diffusers
opencv-python
transformers
accelerate
torchvision
datadings
numpy
nltk

View File

@@ -0,0 +1,20 @@
#!/bin/bash
#SBATCH --array=0-29%30
#SBATCH --time=1-0
#SBATCH --mem=64G
#SBATCH --gpus=1
#SBATCH --partition=H200,H100,H100-PCI,A100-40GB,A100-80GB
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
#SBATCH --job-name="Segment ImageNet (val)"
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
#SBATCH --wait
srun -K \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py -p 30 -id $SLURM_ARRAY_TASK_ID -o /OUTPUT/FOLDER/PATH/ "$@"

View File

@@ -0,0 +1,20 @@
#!/bin/bash
#SBATCH --array=0-319%60
#SBATCH --time=1-0
#SBATCH --mem=64G
#SBATCH --gpus=1
#SBATCH --partition=H200,H100,H100-PCI,A100-PCI,A100-40GB,A100-80GB
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
#SBATCH --job-name="Segment ImageNet (train)"
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
#SBATCH --wait
srun -K \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py -p 320 -id $SLURM_ARRAY_TASK_ID -o /OUTPUT_PATH/INSegment/ "$@"

View File

@@ -0,0 +1,19 @@
#!/bin/bash
#SBATCH --array=0-4%10
#SBATCH --time=1-0
#SBATCH --mem=64G
#SBATCH --gpus=1
#SBATCH --partition=RTX3090,RTXA6000,A100-PCI,A100-40GB,H100-PCI
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
#SBATCH --job-name="Segment ImageNet"
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
srun -K \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py -p 5 -id $SLURM_ARRAY_TASK_ID -o /PATH/TO/OUT/FOLDER/INSegment/ "$@"

View File

@@ -0,0 +1,19 @@
#!/bin/bash
#SBATCH --array=0-19%20
#SBATCH --time=1-0
#SBATCH --mem=64G
#SBATCH --gpus=1
#SBATCH --partition=RTX3090,RTXA6000,A100-PCI,A100-40GB,H100-PCI
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
#SBATCH --job-name="Segment TinyImageNet"
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
srun -K \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py -p 20 -id $SLURM_ARRAY_TASK_ID "$@"

View File

@@ -0,0 +1,20 @@
#!/bin/bash
#SBATCH --array=0-29%30
#SBATCH --time=1-0
#SBATCH --mem=64G
#SBATCH --gpus=1
#SBATCH --partition=H200,H100,H100-PCI,A100-PCI,A100-40GB,A100-80GB
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --export="HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,TQDM_DISABLE=1"
#SBATCH --job-name="Segment TinyImageNet"
#SBATCH --output=/SBATCH/OUT/FOLDER/%x-%j-%N-%a.out
#SBATCH --wait
srun -K \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py -p 30 -id $SLURM_ARRAY_TASK_ID "$@"

View File

@@ -0,0 +1,338 @@
import argparse
import json
import os
from datetime import datetime
from math import ceil
import numpy as np
import torch
from datadings.reader import MsgpackReader
from datadings.torch import Compose, CompressedToPIL, Dataset
from PIL import Image, ImageFilter
from torchvision.transforms import ToPILImage, ToTensor
from tqdm.auto import tqdm
from attentive_eraser import AttentiveEraser
from grounded_segmentation import grounded_segmentation
from infill_lama import LaMa
from utils import already_segmented, save_img
from wordnet_tree import WNTree
def _collate_single(data):
return data[0]
def _collate_multiple(datas):
return {
"images": [data["image"] for data in datas],
"keys": [data["key"] for data in datas],
"labels": [data["label"] for data in datas],
}
def smallest_crop(image: torch.Tensor, mask: torch.Tensor):
"""Crops the image to just so fit the mask given.
Mask and image have to be of the same size.
Args:
image (torch.Tensor): image to crop
mask (torch.Tensor): cropping to mask
Returns:
torch.Tensor: cropped image
"""
if len(mask.shape) == 3:
assert mask.shape[0] == 1
mask = mask[0]
assert (
len(image.shape) == 3 and len(mask.shape) == 2 and image.shape[1:] == mask.shape
), f"Invalid shapes: {image.shape}, {mask.shape}"
dim0_indices = mask.sum(dim=0).nonzero()
dim1_indices = mask.sum(dim=1).nonzero()
return image[
:,
dim1_indices.min().item() : dim1_indices.max().item() + 1,
dim0_indices.min().item() : dim0_indices.max().item() + 1,
]
def _synset_to_prompt(synset, tree, parent_in_prompt):
prompt = f"an {synset.print_name}." if synset.print_name[0] in "aeiou" else f"a {synset.print_name}."
if synset.parent_id is None or not parent_in_prompt:
return prompt
parent = tree[synset.parent_id]
return prompt[:-1] + f", a type of {parent.print_name}."
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Grounded Segmentation")
parser.add_argument(
"-d",
"--dataset",
choices=["imagenet", "imagenet-val", "imagenet21k", "tinyimagenet", "tinyimagenet-val"],
default="imagenet",
help="Dataset to use",
)
parser.add_argument("-r", "--dataset_root", required=True, help="Root directory of the dataset")
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("-t", "--threshold", type=float, default=0.3, help="Detection threshold")
parser.add_argument("--debug", action="store_true", help="Debug mode")
parser.add_argument("-o", "--output", type=str, required=True, help="Output directory")
parser.add_argument(
"-p", "--processes", type=int, default=1, help="Number of processes that are used to process the data"
)
parser.add_argument("-id", type=int, default=0, help="ID of this process)")
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files, instead of skipping them")
parser.add_argument(
"--parent_labels", type=int, default=2, help="Number of parent labels to use; steps to go up the tree"
)
parser.add_argument("--output_ims", choices=["best", "all"], default="best", help="Output all or best masks")
parser.add_argument("--mask_merge_threshold", type=float, default=0.9, help="Threshold on IoU for merging masks")
parser.add_argument("--parent_in_prompt", action="store_true", help="Include parent label in the prompt")
parser.add_argument(
"-model",
choices=["LaMa", "AttErase"],
default="LaMa",
help="Model to use for erasing/infilling. Defaults to LaMa.",
)
args = parser.parse_args()
dataset = args.dataset.lower()
part = "train"
if dataset == "imagenet21k":
reader = MsgpackReader(f"{args.dataset_root}imagenet21k/train.msgpack")
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
elif dataset == "imagenet-val":
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/val.msgpack")
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
part = "val"
elif dataset == "imagenet":
reader = MsgpackReader(f"{args.dataset_root}imagenet/msgpack/train.msgpack")
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
elif dataset == "tinyimagenet":
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_train.msgpack")
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
elif dataset == "tinyimagenet-val":
reader = MsgpackReader(f"{args.dataset_root}TinyINSegment/TinyIN_val.msgpack")
dataset = Dataset(reader, transforms={"image": Compose([CompressedToPIL()])})
part = "val"
else:
raise ValueError(f"Unknown dataset: {dataset}")
assert 0 <= args.id < args.processes, "ID must be in the range [0, processes)"
if args.processes > 1:
partlen = ceil(len(dataset) / args.processes)
start_id = args.id * partlen
end_id = min((args.id + 1) * partlen, len(dataset))
dataset = torch.utils.data.Subset(dataset, list(range(start_id, end_id)))
assert args.batch_size == 1, "Batch size must be 1 for grounded segmentation (for now)"
dataset = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=args.debug,
drop_last=False,
num_workers=10,
collate_fn=_collate_single if args.batch_size == 1 else _collate_multiple,
)
infill_model = (
LaMa(device="cuda" if torch.cuda.is_available() else "cpu") if args.model == "LaMa" else AttentiveEraser()
)
if args.dataset.startswith("imagenet"):
with open("wordnet_data/imagenet1k_synsets.json", "r") as f:
id_to_synset = json.load(f)
id_to_synset = {int(k): v for k, v in id_to_synset.items()}
elif args.dataset.startswith("tinyimagenet"):
with open("wordnet_data/tinyimagenet_synset_names.txt", "r") as f:
synsets = f.readlines()
id_to_synset = [int(synset.split(":")[0].strip()[1:]) for synset in synsets]
id_to_synset = sorted(id_to_synset)
wordnet = WNTree.load("wordnet_data/imagenet21k+1k_masses_tree.json")
# create the necessary folders
os.makedirs(os.path.join(args.output, part, "foregrounds"), exist_ok=True)
os.makedirs(os.path.join(args.output, part, "backgrounds"), exist_ok=True)
os.makedirs(os.path.join(args.output, part, "no_detect"), exist_ok=True)
os.makedirs(os.path.join(args.output, part, "error"), exist_ok=True)
man_print_progress = int(os.environ.get("TQDM_DISABLE", "0")) == 1
for i, data in tqdm(enumerate(dataset), total=len(dataset)):
if args.debug and i > 10:
break
if man_print_progress and i % 200 == 0:
print(f"{datetime.now()} \t process {args.id}/{args.processes}: \t sample: {i}/{len(dataset)}", flush=True)
image = data["image"]
key = data["key"]
if args.dataset.endswith("-val"):
img_class = f"n{id_to_synset[data['label']]:08d}" # overwrite the synset id on the val set
else:
img_class = key.split("/")[-1].split("_")[0]
if (
already_segmented(key, os.path.join(args.output, part, "foregrounds"), img_class=img_class)
and not args.overwrite
):
continue
# get the next 3 labels up in the wordnet tree
label_id = data["label"]
offset = id_to_synset[label_id] # int(data["key"].split("_")[0][1:])
synset = wordnet[offset]
labels = [_synset_to_prompt(synset, wordnet, args.parent_in_prompt)]
while len(labels) < 1 + args.parent_labels and synset.parent_id is not None:
synset = wordnet[synset.parent_id]
_label = _synset_to_prompt(synset, wordnet, args.parent_in_prompt)
_label = _label.replace("_", " ")
labels.append(_label)
assert len(labels) > 0, f"No labels for {key}"
# 1. detect and segment
image_tensor, detections = grounded_segmentation(
image, labels, threshold=args.threshold, polygon_refinement=True
)
# merge all detections for the same label
masks = {}
for detection in detections:
if detection.label not in masks:
masks[detection.label] = detection.mask
else:
masks[detection.label] |= detection.mask
labels = list(masks.keys())
masks = [masks[lbl] for lbl in labels]
assert len(masks) == len(
labels
), f"Number of masks ({len(masks)}) != number of labels ({len(labels)}); detections: {detections}"
if args.debug:
for detected_mask, label in zip(masks, labels):
detected_mask = torch.from_numpy(detected_mask).unsqueeze(0) / 255
mask_foreground = torch.cat((image_tensor * detected_mask, detected_mask), dim=0)
ToPILImage()(mask_foreground).save(f"example_images/{key}_{label}_masked.png")
if len(masks) == 0:
tqdm.write(f"No detections for {key}; skipping")
save_img(image, key, os.path.join(args.output, part, "no_detect"), img_class=img_class, format="JPEG")
continue
elif len(masks) == 1:
# max_overlap_mask = masks[0]
masks = [masks[0]]
elif args.output_ims == "best":
# find the 2 masks with the largest overlap
max_overlap = 0
max_overlap_mask = None
using_labels = None
for i, mask1 in enumerate(masks):
for j, mask2 in enumerate(masks):
if i >= j:
continue
iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
if iou > max_overlap:
max_overlap = iou
max_overlap_mask = mask1 | mask2
using_labels = f"{labels[i]} & {labels[j]}"
if args.debug:
tqdm.write(f"{key}:\tMax overlap: {max_overlap}, using {using_labels}")
if max_overlap_mask is None:
# assert len(masks) > 0 and len(masks) == len(labels), f"No detections for {key}"
max_overlap_mask = masks[0]
masks = [max_overlap_mask]
else:
# merge masks that are too similar
has_changed = True
while has_changed:
has_changed = False
for i, mask1 in enumerate(masks):
for j, mask2 in enumerate(masks):
if i >= j:
continue
iou = (mask1 & mask2).sum() / (mask1 | mask2).sum()
if iou > args.mask_merge_threshold:
masks[i] |= masks[j]
masks.pop(j)
labels.pop(j)
has_changed = True
break
if has_changed:
break
for mask_idx, mask_array in enumerate(masks):
mask = torch.from_numpy(mask_array).unsqueeze(0) / 255
if args.debug:
mask_image = ToPILImage()(mask)
mask_image.save(f"example_images/{key}_mask.png")
# foreground = image_tensor
foreground = torch.cat((image_tensor * mask, mask), dim=0)
foreground = ToPILImage()(foreground)
mask_img = foreground.split()[-1]
foreground = smallest_crop(ToTensor()(foreground), ToTensor()(mask_img)) # TODO: fix smallest crop
fg_img = ToPILImage()(foreground)
background = (1 - mask) * image_tensor
bg_image = ToPILImage()(background)
# 2. infill background
mask_image = Image.fromarray(mask_array)
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=7))
try:
infilled_bg = infill_model(np.array(bg_image), np.array(mask_image))
infilled_bg = Image.fromarray(np.uint8(infilled_bg))
except RuntimeError as e:
tqdm.write(
f"Error infilling {key}: bg_image.shape={np.array(bg_image).shape},"
f" mask_image.shape={np.array(mask_image).shape}\n{e}"
)
save_img(
image,
key,
os.path.join(args.output, part, "error"),
img_class=img_class,
img_version=mask_idx if len(masks) > 1 else None,
format="JPEG",
)
infilled_bg = None
if args.debug:
data["image"].save(f"example_images/{key}_orig.png")
fg_img.save(f"example_images/{key}_fg.png")
infilled_bg.save(f"example_images/{key}_bg.png")
else:
# save files
class_name = key.split("_")[0]
save_img(
fg_img,
key,
os.path.join(args.output, part, "foregrounds"),
img_class=img_class,
img_version=mask_idx if len(masks) > 1 else None,
format="WEBP",
)
if infilled_bg is not None:
save_img(
infilled_bg,
key,
os.path.join(args.output, part, "backgrounds"),
img_class=img_class,
img_version=mask_idx if len(masks) > 1 else None,
format="JPEG",
)
if man_print_progress:
print(
f"{datetime.now()} \t process {args.id}/{args.processes}: \t done with all {len(dataset)} samples",
flush=True,
)

View File

@@ -0,0 +1,15 @@
#!/bin/bash
srun -K \
--partition=H200,H100,A100-40GB,A100-80GB \
--job-name="Segment ImageNet" \
--nodes=1 \
--gpus=1 \
--cpus-per-task=16 \
--mem=64G \
--time=1-0 \
--export=HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/ \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
python3 segment_imagenet.py "$@"

View File

@@ -0,0 +1,13 @@
#!/bin/bash
srun -K \
--partition=RTX3090,A100-40GB,H100 \
--job-name="Segment ImageNet" \
--nodes=1 \
--mem=64G \
--time=1-0 \
--export=HF_HOME=/PATH/TO/HF_HOME/,NLTK_DATA=/PATH/TO/NLTK_DATA/,MAX_WORKERS=16 \
--container-image=/PATH/TO.sqsh \
--container-workdir="$(pwd)" \
--container-mounts=/SET/CONTAINER/MOUNTS,"$(pwd)":"$(pwd)" \
"$@"

View File

@@ -0,0 +1,49 @@
import os
from PIL import Image
def save_img(img: Image, img_name: str, base_dir: str, img_class: str = None, format="PNG", img_version=None):
"""Save an image to a directory.
Args:
img (PIL.Image): Image to save.
img_name (str): Relative path to the image.
base_dir (str): Base directory to save images in.
img_class (str, optional): Image class, if not given try to extract it from the image name in ImageNet train format. Defaults to None.
format (str, optional): Format to save the image in. Defaults to "PNG".
img_version (int, optional): Version of the image. Will be appended to the path. Defaults to None.
"""
if not img_name.endswith(f".{format}"):
img_name = f"{img_name.split('.')[0]}.{format}"
if img_class is None:
img_class = img_name.split("_")[0]
if not os.path.exists(os.path.join(base_dir, img_class)):
os.makedirs(os.path.join(base_dir, img_class), exist_ok=True)
if img_version is not None:
img_name = f"{img_name.split('.')[0]}_v{img_version}.{format}"
img.save(os.path.join(base_dir, img_class, img_name), format.lower())
def already_segmented(img_name: str, base_dir: str, img_class: str = None):
"""Check if an image was already segmented.
Args:
img_name (str): Relative path to the image.
base_dir (str): Base directory to save images in.
img_class (str, optional): Image class, if not given try to extract it from the image name in ImageNet train format. Defaults to None.
Returns:
bool: Image was segmented already.
"""
img_base_name = ".".join(img_name.split(".")[:-1]) if "." in img_name else img_name
if img_class is None:
img_class = img_name.split("_")[0]
if not os.path.exists(os.path.join(base_dir, img_class)):
return False
return any(
file.startswith(img_base_name + "_v") or file.startswith(img_base_name + ".")
for file in os.listdir(os.path.join(base_dir, img_class))
)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,200 @@
n07695742: pretzel
n03902125: pay-phone, pay-station
n03980874: poncho
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
n02730930: apron
n02699494: altar
n03201208: dining table, board
n02056570: king penguin, Aptenodytes patagonica
n04099969: rocking chair, rocker
n04366367: suspension bridge
n04067472: reel
n02808440: bathtub, bathing tub, bath, tub
n04540053: volleyball
n02403003: ox
n03100240: convertible
n04562935: water tower
n02788148: bannister, banister, balustrade, balusters, handrail
n02988304: CD player
n02423022: gazelle
n03637318: lampshade, lamp shade
n01774384: black widow, Latrodectus mactans
n01768244: trilobite
n07614500: ice cream, icecream
n04254777: sock
n02085620: Chihuahua
n01443537: goldfish, Carassius auratus
n01629819: European fire salamander, Salamandra salamandra
n02099601: golden retriever
n02321529: sea cucumber, holothurian
n03837869: obelisk
n02002724: black stork, Ciconia nigra
n02841315: binoculars, field glasses, opera glasses
n04560804: water jug
n02364673: guinea pig, Cavia cobaya
n03706229: magnetic compass
n09256479: coral reef
n09332890: lakeside, lakeshore
n03544143: hourglass
n02124075: Egyptian cat
n02948072: candle, taper, wax light
n01950731: sea slug, nudibranch
n02791270: barbershop
n03179701: desk
n02190166: fly
n04275548: spider web, spider's web
n04417672: thatch, thatched roof
n03930313: picket fence, paling
n02236044: mantis, mantid
n03976657: pole
n01774750: tarantula
n04376876: syringe
n04133789: sandal
n02099712: Labrador retriever
n04532670: viaduct
n04487081: trolleybus, trolley coach, trackless trolley
n09428293: seashore, coast, seacoast, sea-coast
n03160309: dam, dike, dyke
n03250847: drumstick
n02843684: birdhouse
n07768694: pomegranate
n03670208: limousine, limo
n03085013: computer keyboard, keypad
n02892201: brass, memorial tablet, plaque
n02233338: cockroach, roach
n03649909: lawn mower, mower
n03388043: fountain
n02917067: bullet train, bullet
n02486410: baboon
n04596742: wok
n03255030: dumbbell
n03937543: pill bottle
n02113799: standard poodle
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
n02906734: broom
n07920052: espresso
n01698640: American alligator, Alligator mississipiensis
n02123394: Persian cat
n03424325: gasmask, respirator, gas helmet
n02129165: lion, king of beasts, Panthera leo
n04008634: projectile, missile
n03042490: cliff dwelling
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
n02815834: beaker
n02395406: hog, pig, grunter, squealer, Sus scrofa
n01784675: centipede
n03126707: crane
n04399382: teddy, teddy bear
n07875152: potpie
n03733131: maypole
n02802426: basketball
n03891332: parking meter
n01910747: jellyfish
n03838899: oboe, hautboy, hautbois
n03770439: miniskirt, mini
n02281406: sulphur butterfly, sulfur butterfly
n03970156: plunger, plumber's helper
n09246464: cliff, drop, drop-off
n02206856: bee
n02074367: dugong, Dugong dugon
n03584254: iPod
n04179913: sewing machine
n04328186: stopwatch, stop watch
n07583066: guacamole
n01917289: brain coral
n03447447: gondola
n02823428: beer bottle
n03854065: organ, pipe organ
n02793495: barn
n04285008: sports car, sport car
n02231487: walking stick, walkingstick, stick insect
n04465501: tractor
n02814860: beacon, lighthouse, beacon light, pharos
n02883205: bow tie, bow-tie, bowtie
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
n04149813: scoreboard
n04023962: punching bag, punch bag, punching ball, punchball
n02226429: grasshopper, hopper
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
n02669723: academic gown, academic robe, judge's robe
n04486054: triumphal arch
n04070727: refrigerator, icebox
n03444034: go-kart
n02666196: abacus
n01945685: slug
n04251144: snorkel
n03617480: kimono
n03599486: jinrikisha, ricksha, rickshaw
n02437312: Arabian camel, dromedary, Camelus dromedarius
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
n04118538: rugby ball
n01770393: scorpion
n04356056: sunglasses, dark glasses, shades
n03804744: nail
n02132136: brown bear, bruin, Ursus arctos
n03400231: frying pan, frypan, skillet
n03983396: pop bottle, soda bottle
n07734744: mushroom
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
n02410509: bison
n03404251: fur coat
n04456115: torch
n02123045: tabby, tabby cat
n03026506: Christmas stocking
n07715103: cauliflower
n04398044: teapot
n02927161: butcher shop, meat market
n07749582: lemon
n07615774: ice lolly, lolly, lollipop, popsicle
n02795169: barrel, cask
n04532106: vestment
n02837789: bikini, two-piece
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
n04265275: space heater
n02481823: chimpanzee, chimp, Pan troglodytes
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
n06596364: comic book
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
n02504458: African elephant, Loxodonta africana
n03014705: chest
n01944390: snail
n04146614: school bus
n01641577: bullfrog, Rana catesbeiana
n07720875: bell pepper
n02999410: chain
n01855672: goose
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
n07753592: banana
n07871810: meat loaf, meatloaf
n04501370: turnstile
n04311004: steel arch bridge
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
n04074963: remote control, remote
n03662601: lifeboat
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
n03089624: confectionery, confectionary, candy store
n04259630: sombrero
n03393912: freight car
n04597913: wooden spoon
n07711569: mashed potato
n03355925: flagpole, flagstaff
n02963159: cardigan
n07579787: plate
n02950826: cannon
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
n02094433: Yorkshire terrier
n02909870: bucket, pail
n02058221: albatross, mollymawk
n01742172: boa constrictor, Constrictor constrictor
n09193705: alp
n04371430: swimming trunks, bathing trunks
n07747607: orange
n03814639: neck brace
n04507155: umbrella
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
n03763968: military uniform
n07873807: pizza, pizza pie
n03992509: potter's wheel
n03796401: moving van
n12267677: acorn

View File

@@ -0,0 +1,400 @@
import argparse
import contextlib
import json
import os
from copy import copy
from nltk.corpus import wordnet as wn
class bcolors:
"""Colors for terminal output."""
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def _lemmas_str(synset):
return ", ".join([lemma.name() for lemma in synset.lemmas()])
class WNEntry:
"""One wordnet synset."""
def __init__(
self,
name: str,
id: int,
lemmas: str,
parent_id: int,
depth: int = None,
in_image_net: bool = False,
child_ids: list = None,
in_main_tree: bool = True,
_n_images: int = 0,
_description: str = None,
_name: str = None,
_pruned: bool = False,
):
self.name = name
self.id = id
self.lemmas = lemmas
self.parent_id = parent_id
self.depth = depth
self.in_image_net = in_image_net
self.child_ids = child_ids
self.in_main_tree = in_main_tree
self._n_images = _n_images
self._description = _description
self._name = _name
self._pruned = _pruned
def __str__(self, tree=None, accumulate=True, colors=True, max_depth=0, max_children=None):
green = f"{bcolors.OKGREEN}" if colors else ""
red = f"{bcolors.FAIL}" if colors else ""
end = f"{bcolors.ENDC}" if colors else ""
start_symb = f"{green}+{end}" if self.in_image_net else f"{red}-{end}"
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
if self.child_ids is None or tree is None or max_depth == 0:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
children = self.child_ids
if max_children is not None and len(children) > max_children:
children = children[:max_children]
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
[
"\n ".join(
tree.nodes[child_id]
.__str__(tree=tree, accumulate=accumulate, colors=colors, max_depth=max_depth - 1)
.split("\n")
)
for child_id in children
]
)
def tree_diff(self, tree_1, tree_2):
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
else:
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
n_ims = (
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
)
if self.child_ids is None:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
)
def prune(self, tree):
if self._pruned or self.parent_id is None:
return
if self.child_ids is not None:
for child_id in self.child_ids:
tree[child_id].prune(tree)
self._pruned = True
parent_node = tree.nodes[self.parent_id]
try:
parent_node.child_ids.remove(self.id)
except ValueError as e:
print(
f"Error removing {self.name} from"
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
)
while parent_node._pruned:
parent_node = tree.nodes[parent_node.parent_id]
parent_node._n_images += self._n_images
self._n_images = 0
@property
def description(self):
if not self._description:
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
return self._description
@property
def print_name(self):
return self.name.split(".")[0]
@property
def is_leaf(self):
return self.child_ids is None or len(self.child_ids) == 0
def get_branch(self, tree=None):
if self.parent_id is None or tree is None:
return self.print_name
parent = tree.nodes[self.parent_id]
return parent.get_branch(tree) + " > " + self.print_name
def get_branch_list(self, tree):
if self.parent_id is None:
return [self]
parent = tree.nodes[self.parent_id]
return parent.get_branch_list(tree) + [self]
def to_dict(self):
return {
"name": self.name,
"id": self.id,
"lemmas": self.lemmas,
"parent_id": self.parent_id,
"depth": self.depth,
"in_image_net": self.in_image_net,
"child_ids": self.child_ids,
"in_main_tree": self.in_main_tree,
"_n_images": self._n_images,
"_description": self._description,
"_name": self._name,
"_pruned": self._pruned,
}
@classmethod
def from_dict(cls, d):
return cls(**d)
def n_images(self, tree=None):
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
return self._n_images
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
def n_children(self, tree=None):
if self.child_ids is None:
return 0
if tree is None or len(self.child_ids) == 0:
return len(self.child_ids)
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
def get_examples(self, tree, n_examples=3):
if self.child_ids is None or len(self.child_ids) == 0:
return ""
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
max_images = max(child_images.values())
if max_images == 0:
# go on number of child nodes
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
# sorted childids by number of images
top_children = [
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
]
top_children = top_children[: min(n_examples, len(top_children))]
return ", ".join(
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
)
class WNTree:
def __init__(self, root=1740, nodes=None):
if isinstance(root, int):
root_id = root
root_synset = wn.synset_from_pos_and_offset("n", root)
root_node = WNEntry(
root_synset.name(),
root_id,
_lemmas_str(root_synset),
parent_id=None,
depth=0,
)
else:
assert isinstance(root, WNEntry)
root_id = root.id
root_node = root
self.root = root_node
self.nodes = {root_id: self.root} if nodes is None else nodes
self.parentless = []
self.label_index = None
self.pruned = set()
def to_dict(self):
return {
"root": self.root.to_dict(),
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
"parentless": self.parentless,
"pruned": list(self.pruned),
}
def prune(self, min_images):
pruned_nodes = set()
# prune all nodes that have fewer than min_images below them
for node_id, node in self.nodes.items():
if node.n_images(self) < min_images:
pruned_nodes.add(node_id)
node.prune(self)
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
node_stack = [self.root]
node_idx = 0
while node_idx < len(node_stack):
node = node_stack[node_idx]
if node.child_ids is not None:
for child_id in node.child_ids:
child = self.nodes[child_id]
node_stack.append(child)
node_idx += 1
# now prune the stack from the bottom up
for node in node_stack[::-1]:
# only look at images of that class, not of additional children
if node.n_images() < min_images:
pruned_nodes.add(node.id)
node.prune(self)
self.pruned = pruned_nodes
return pruned_nodes
@classmethod
def from_dict(cls, d):
tree = cls()
tree.root = WNEntry.from_dict(d["root"])
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
tree.parentless = d["parentless"]
if "pruned" in d:
tree.pruned = set(d["pruned"])
return tree
def add_node(self, node_id, in_in=True):
if node_id in self.nodes:
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
return
synset = wn.synset_from_pos_and_offset("n", node_id)
# print(f"adding node {synset.name()} with id {node_id}")
hypernyms = synset.hypernyms()
if len(hypernyms) == 0:
parent_id = None
self.parentless.append(node_id)
main_tree = False
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
else:
parent_id = synset.hypernyms()[0].offset()
if parent_id not in self.nodes:
self.add_node(parent_id, in_in=False)
parent = self.nodes[parent_id]
if parent.child_ids is None:
parent.child_ids = []
parent.child_ids.append(node_id)
main_tree = parent.in_main_tree
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
node = WNEntry(
synset.name(),
node_id,
_lemmas_str(synset),
parent_id=parent_id,
in_image_net=in_in,
depth=depth,
in_main_tree=main_tree,
)
self.nodes[node_id] = node
def __len__(self):
return len(self.nodes)
def image_net_len(self, only_main_tree=False):
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def max_depth(self, only_main_tree=False):
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def __str__(self, colors=True):
return (
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self, colors=colors)}\nParentless:\n"
+ "\n".join([self.nodes[node_id].__str__(tree=self, colors=colors) for node_id in self.parentless])
)
def save(self, path):
with open(path, "w") as f:
json.dump(self.to_dict(), f)
@classmethod
def load(cls, path):
with open(path, "r") as f:
tree_dict = json.load(f)
return cls.from_dict(tree_dict)
def subtree(self, node):
node_id = self[node].id
if node_id not in self.nodes:
return None
node_queue = [self.nodes[node_id]]
subtree_ids = set()
while len(node_queue) > 0:
node = node_queue.pop(0)
subtree_ids.add(node.id)
if node.child_ids is not None:
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
subtree_root = subtree_nodes[node_id]
subtree_root.parent_id = None
depth_diff = subtree_root.depth
for node in subtree_nodes.values():
node.depth -= depth_diff
return WNTree(root=subtree_root, nodes=subtree_nodes)
def _make_label_index(self):
self.label_index = sorted(
[node_id for node_id, node in self.nodes.items() if node.n_images(self) > 0 and not node._pruned]
)
def get_label(self, node_id):
if self.label_index is None:
self._make_label_index()
while self.nodes[node_id]._pruned:
node_id = self.nodes[node_id].parent_id
return self.label_index.index(node_id)
def n_labels(self):
if self.label_index is None:
self._make_label_index()
return len(self.label_index)
def __contains__(self, item):
if isinstance(item, str):
if item[0] == "n":
item = int(item[1:])
else:
return False
if isinstance(item, int):
return item in self.nodes
if isinstance(item, WNEntry):
return item.id in self.nodes
return False
def __getitem__(self, item):
if isinstance(item, str) and item[0].startswith("n"):
with contextlib.suppress(ValueError):
item = int(item[1:])
if isinstance(item, str) and ".n." in item:
for node in self.nodes.values():
if item == node.name:
return node
raise KeyError(f"Item {item} not found in tree")
if isinstance(item, int):
return self.nodes[item]
if isinstance(item, WNEntry):
return self.nodes[item.id]
raise KeyError(f"Item {item} not found in tree")
def __iter__(self):
return iter(self.nodes.keys())

View File

@@ -0,0 +1,76 @@
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]
# Same as Black.
line-length = 120
indent-width = 4
[lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["E", "F", "D", "B", "NPY", "PD", "TD005", "TD006", "TD007", "SIM", "RET", "Q", "ICN", "I"]
ignore = ["D203", "D213", "E501", "D100", "NPY002", "D102", "B008", "PD011", "D105", "SIM118", "D417"]
# Allow fix for all enabled rules (when `--fix`) is provided.
unfixable = ["B"]
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[lint.pydocstyle]
convention = "google"
[format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = true
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
# Enable auto-formatting of code examples in docstrings. Markdown,
# reStructuredText code/literal blocks and doctests are all supported.
#
# This is currently disabled by default, but it is planned for this
# to be opt-out in the future.
docstring-code-format = true
# Set the line length limit used when formatting code snippets in
# docstrings.
#
# This only has an effect when the `docstring-code-format` setting is
# enabled.
docstring-code-line-length = "dynamic"

View File

@@ -0,0 +1,107 @@
# ForNet
This is the training code for the ForNet paper.
All our experiments and evaluations were run using this codebase.
## Requirements
This project heavily builds on [timm](https://github.com/huggingface/pytorch-image-models) and open source implementations of the models that are tested.
All requirements are listed in [requirements.txt](./requirements.txt).
To install those, run
```commandline
pip install -r requirements.txt
```
## Usage
After **cloning this repository**, you can train and test a lot of different models.
By default, a `srun` command is executed to run the code on a slurm cluster.
To run on the local machine, append the `-local` flag to the command.
### General Preparation
After cloning the repository on a slurm cluster, make sure main.py is executable (by using "chmod a+x main.py").
To run the project on a slurm cluster, you need to create a docker image from the requirements file.
You will also want to adapt the default slurm parameters in `config.py`.
Next, adjust the paths in paths_config.py for your system, specifically results_folder, slurm_output_folder and dataset folders.
Finally, if you want to use Weights and Biases for Tracking, create the file ".wandb.apikey" in this folder and paste your API Key into it.
### Training
#### Pretraining
To pretrain a `ViT-S` on a given dataset, run
```commandline
./main.py --task pre-train --model ViT-S/16 --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
```
This will save a checkpoint (`.pt` file) every `<save_epochs>` epochs (the default is 10), which contains all the model weights, along with the optimizer and scheduler state, and the current training stats.
#### Finetuning
A model (checkpoint) can be finetuned on another dataset using the following command:
```commandline
./main.py --task fine-tune --model <model_checkpoint.pt> --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
```
This will also save new checkpoints during training.
### Evaluation
It is also possible to evaluate the models.
To evaluate the model's accuracy on a specific dataset, run
```commandline
./main.py -t eval -ds <dataset name> -m <model_checkpoint.pt> --ntasks 1 -bs 512 --num-workers 10 --cpus-per-task 10 --time 10:00 (--local)
```
You can run our center-bias, size-bias, and foreground-focus evaluations using the `eval-attr`, `eval-center-bias`, and `eval-size-bias` tasks (`-t` or `--task` argument).
### Further Arguments
There can be multiple further arguments and flags given to the scripts.
The most important ones are
| Arg | Description |
| :------------------------------ | :----------------------------------------------------- |
| `--model <model>` | Model name or checkpoint. |
| `--run_name <name for the run>` | Name or description of this training run. |
| `--dataset <dataset>` | Specifies a dataset to use. |
| `--task <task>` | Specifies a task. The default is `pre-train`. |
| `--local` | Run on the local machine, not on a slurm cluster. |
| `--epochs <epochs>` | Epochs to train. |
| `--lr <lr>` | Learning rate. Default is 3e-3. |
| `--batch_size <bs>` | Batch size. Default is 2048. |
| `--weight_decay <wd>` | Weight decay. Default is 0.02. |
| `--imsize <image resolution>` | Resulution of the image to train with. Default is 224. |
For a list of all arguments, run
```commandline
./main.py --help
```
## Supported Models
These are the models we support. Links are to original code sources. If no link is provided, we implemented the architecture from scratch, following the specific paper.
| Architecture | Versions |
| :----------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [DeiT](https://github.com/facebookresearch/deit) | `deit_tiny_patch16_LS`, `deit_small_patch16_LS`, `deit_medium_patch16_LS`, `deit_base_patch16_LS`, `deit_large_patch16_LS`, `deit_huge_patch14_LS`, `deit_huge_patch14_52_LS`, `deit_huge_patch14_26x2_LS`, `deit_Giant_48_patch14_LS`, `deit_giant_40_patch14_LS`, `deit_small_patch16_36_LS`, `deit_small_patch16_36`, `deit_small_patch16_18x2_LS`, `deit_small_patch16_18x2`, `deit_base_patch16_18x2_LS`, `deit_base_patch16_18x2`, `deit_base_patch16_36x1_LS`, `deit_base_patch16_36x1` |
| [ResNet](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnet.py) | `resnet18`, `resnet34`, `resnet26`, `resnet50`, `resnet101`, `wide_resnet50_2` |
| [Swin](https://github.com/microsoft/Swin-Transformer) | `swin_tiny_patch4_window7`, `swin_small_patch4_window7`, `swin_base_patch4_window7`, `swin_large_patch4_window7` |
| [ViT](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py) | `ViT-{Ti,S,B,L}/<patch_size>` |
## License
We release this code under the [MIT license](./LICENSE).
## Citation
If you use this codebase in your project, please cite:

View File

@@ -0,0 +1,3 @@
albumentations==2.0.5
datasets==3.5.0
nvidia-dali-cuda120==1.47.0

View File

@@ -0,0 +1,238 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
from architectures.vit import TimmViT
__all__ = [
"deit_tiny_patch16_224",
"deit_small_patch16_224",
"deit_base_patch16_224",
"deit_tiny_distilled_patch16_224",
"deit_small_distilled_patch16_224",
"deit_base_distilled_patch16_224",
"deit_base_patch16_384",
"deit_base_distilled_patch16_384",
]
class DistilledVisionTransformer(TimmViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
trunc_normal_(self.dist_token, std=0.02)
trunc_normal_(self.pos_embed, std=0.02)
self.head_dist.apply(self._init_weights)
def forward_features(self, x):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add the dist_token
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0], x[:, 1]
def forward(self, x):
x, x_dist = self.forward_features(x)
x = self.head(x)
x_dist = self.head_dist(x_dist)
if self.training:
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
def _clean_kwargs(kwargs):
allowed_keys = {key for key in kwargs.keys() if not key.startswith("pretrain")}
allowed_keys = {key for key in allowed_keys if not key.startswith("cache")}
return {key: kwargs[key] for key in allowed_keys}
@register_model
def deit_tiny_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
kwargs = _clean_kwargs(kwargs)
model = TimmViT(
patch_size=16,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_small_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
kwargs = _clean_kwargs(kwargs)
model = TimmViT(
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
kwargs = _clean_kwargs(kwargs)
model = TimmViT(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_tiny_distilled_patch16(
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
):
kwargs = _clean_kwargs(kwargs)
model = DistilledVisionTransformer(
patch_size=16,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_small_distilled_patch16(
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
):
kwargs = _clean_kwargs(kwargs)
model = DistilledVisionTransformer(
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_distilled_patch16(
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
):
kwargs = _clean_kwargs(kwargs)
model = DistilledVisionTransformer(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model

View File

@@ -0,0 +1,850 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Taken from https://github.com/facebookresearch/deit with slight modifications
from loguru import logger
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.vision_transformer import Mlp, PatchEmbed, _cfg
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from resizing_interface import ResizingInterface
class Attention(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Layer_scale_init_Block(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class Layer_scale_init_Block_paralx2(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm11 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.attn1 = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm21 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.mlp1 = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, x):
x = (
x
+ self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
+ self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
)
x = (
x
+ self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
)
return x
class Block_paralx2(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm11 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.attn1 = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm21 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.mlp1 = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.attn1(self.norm11(x)))
x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.mlp1(self.norm21(x)))
return x
class hMLP_stem(nn.Module):
"""hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=nn.SyncBatchNorm,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = torch.nn.Sequential(
*[
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=4, stride=4),
norm_layer(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim // 4, kernel_size=2, stride=2),
norm_layer(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),
norm_layer(embed_dim),
]
)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class vit_models(nn.Module, ResizingInterface):
"""Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
global_pool=None,
block_layers=Block,
Patch_layer=PatchEmbed,
act_layer=nn.GELU,
Attention_block=Attention,
Mlp_block=Mlp,
dpr_constant=True,
init_scale=1e-4,
mlp_ratio_clstk=4.0,
**kwargs,
):
super().__init__()
self.dropout_rate = drop_rate
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.img_size = img_size
self.patch_embed = Patch_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.embed_layer = Patch_layer
self.pre_norm = False
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.no_embed_class = True
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList(
[
block_layers(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=0.0,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
Attention_block=Attention_block,
Mlp_block=Mlp_block,
init_values=init_scale,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")]
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
def set_num_classes(self, n_classes):
super().set_num_classes(n_classes)
self._init_weights(self.head)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def get_classifier(self):
return self.head
def get_num_layers(self):
return len(self.blocks)
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, test=False):
B = x.shape[0]
x = self.patch_embed(x)
if test and x.isnan().any().item():
logger.error("patch embedded input has nan value")
cls_tokens = self.cls_token.expand(B, -1, -1)
x = x + self.pos_embed
if test and x.isnan().any().item():
logger.error("position embedded input has a nan value")
x = torch.cat((cls_tokens, x), dim=1)
if test and x.isnan().any().item():
logger.error("input with [CLS] has a nan value")
for i, blk in enumerate(self.blocks):
x = blk(x)
if test and x.isnan().any().item():
logger.error(f"output of block {i} has a nan value")
x = self.norm(x)
return x[:, 0]
def forward(self, x, test=False):
x = self.forward_features(x, test=test)
if self.dropout_rate:
x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
x = self.head(x)
return x
# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
@register_model
def deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
model.default_cfg = _cfg()
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_small_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
patch_size=16,
embed_dim=512,
depth=12,
num_heads=8,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
model.default_cfg = _cfg()
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_medium_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_base_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_large_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_huge_" + str(img_size) + "_"
if pretrained_21k:
name += "21k_v1.pth"
else:
name += "1k_v1.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1280,
depth=52,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1280,
depth=26,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block_paralx2,
**kwargs,
)
return model
# @register_model
# def deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
# model = vit_models(
# img_size=img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4,
# norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
#
# return model
# @register_model
# def deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
# model = vit_models(
# img_size=img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4,
# norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
# return model
@register_model
def deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1664,
depth=48,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1408,
depth=40,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
# model.default_cfg = _cfg()
return model
# Models from Three things everyone should know about Vision Transformers (https://arxiv.org/pdf/2203.09795.pdf)
@register_model
def deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=36,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=36,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
@register_model
def deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=18,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=18,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=18,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=18,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=36,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=36,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model

View File

@@ -0,0 +1,111 @@
from timm.models import register_model
from timm.models.resnet import BasicBlock, Bottleneck
from timm.models.resnet import ResNet as ResNetTimm
from torch import nn
from resizing_interface import ResizingInterface
class ResNet(ResNetTimm, ResizingInterface):
"""The popular ResNet model with a ResizingInterface."""
def __init__(self, *args, global_pool="avg", **kwargs):
"""Create ResNet model with resizing capabilities.
Args:
*args: Arguments for the ResNet model.
global_pool (str, optional): _description_. Defaults to "avg".
**kwargs: Keyword arguments for the ResNet model (from Timm).
Keyword Args:
block, layers, num_classes, in_chans, output_stride, cardinality, base_width, stem_width, stem_type, replace_stem_pool, block_reduce_first, down_kernel_size, avg_down, act_layer, norm_layer, aa_layer, drop_rate, drop_path_rate, drop_block_rate, zero_init_last, block_args
"""
admissible_kwargs = [
"block",
"layers",
"num_classes",
"in_chans",
"output_stride",
"cardinality",
"base_width",
"stem_width",
"stem_type",
"replace_stem_pool",
"block_reduce_first",
"down_kernel_size",
"avg_down",
"act_layer",
"norm_layer",
"aa_layer",
"drop_rate",
"drop_path_rate",
"drop_block_rate",
"zero_init_last",
"block_args",
]
for key in list(kwargs.keys()):
if key not in admissible_kwargs:
kwargs.pop(key)
super().__init__(*args, global_pool=global_pool, **kwargs)
self.global_pool_str = global_pool
def set_image_res(self, res):
# resizing not needed for CNNs with pooling
return
def set_num_classes(self, n_classes):
if self.num_classes == n_classes:
return
if n_classes > 0:
self.reset_classifier(num_classes=n_classes, global_pool=self.global_pool_str)
else:
self.fc = nn.Identity()
@register_model
def resnet18(pretrained=False, **kwargs):
"""Construct a ResNet-18 model."""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
return ResNet(**model_args, **kwargs)
@register_model
def resnet34(pretrained=False, **kwargs):
"""Construct a ResNet-34 model."""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
return ResNet(**model_args, **kwargs)
@register_model
def resnet26(pretrained=False, **kwargs):
"""Construct a ResNet-26 model."""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
return ResNet(**model_args, **kwargs)
@register_model
def resnet50(pretrained=False, **kwargs):
"""Construct a ResNet-50 model."""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3])
return ResNet(**model_args, **kwargs)
@register_model
def resnet101(pretrained=False, **kwargs):
"""Construct a ResNet-101 model."""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
return ResNet(**model_args, **kwargs)
@register_model
def wide_resnet50_2(pretrained=False, **kwargs):
"""Construct a Wide ResNet-50-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
return ResNet(**model_args, **kwargs)

View File

@@ -0,0 +1,893 @@
# Taken from https://github.com/microsoft/Swin-Transformer with slight modifications
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
from copy import copy
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from loguru import logger
from timm.models import register_model
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from resizing_interface import ResizingInterface
try:
import os
import sys
kernel_path = os.path.abspath(os.path.join(".."))
sys.path.append(kernel_path)
from kernels.window_process.window_process import (
WindowProcess,
WindowProcessReverse,
)
except ImportError:
WindowProcess = None
WindowProcessReverse = None
logger.warning("Fused window process have not been installed. Please refer to get_started.md for installation.")
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
fused_window_process=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.fused_window_process = fused_window_process
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
)
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
fused_window_process=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=(drop_path[i] if isinstance(drop_path, list) else drop_path),
norm_layer=norm_layer,
fused_window_process=fused_window_process,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
r"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module, ResizingInterface):
r"""Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
fused_window_process=False,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
self.img_size = img_size
self.fused_window_process = fused_window_process
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
self.embed_layer = PatchEmbed
self.patch_size = patch_size
self.in_chans = in_chans
self.norm_layer = norm_layer
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
input_resolution=(
patches_resolution[0] // (2**i_layer),
patches_resolution[1] // (2**i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
fused_window_process=fused_window_process,
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def set_num_classes(self, n_classes):
"""Reset the classification head with a new number of classes.
Parameters
----------
n_classes : int
new number of classes
"""
if n_classes == self.num_classes:
return
self.head = nn.Linear(self.num_features, n_classes) if n_classes > 0 else nn.Identity()
self.num_classes = n_classes
nn.init.trunc_normal_(self.head.weight, std=0.02)
nn.init.constant_(self.head.bias, 0)
def set_image_res(self, res):
if res == self.img_size:
return
old_patch_embed_state = copy(self.patch_embed.state_dict())
self.patch_embed = self.embed_layer(
img_size=res,
patch_size=self.patch_size,
in_chans=self.in_chans,
embed_dim=self.embed_dim,
norm_layer=self.norm_layer if self.patch_norm else None,
)
self.patch_embed.load_state_dict(old_patch_embed_state)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
for i_layer, layer in enumerate(self.layers):
input_resolution = (
patches_resolution[0] // (2**i_layer),
patches_resolution[1] // (2**i_layer),
)
layer.input_resolution = input_resolution
downsample = PatchMerging if (i_layer < self.num_layers - 1) else None
if downsample is not None:
layer.downsample = downsample(input_resolution, dim=layer.dim, norm_layer=self.norm_layer)
for block in layer.blocks:
block.input_resolution = input_resolution
if min(input_resolution) <= block.window_size:
# if window size is larger than input resolution, we don't partition windows
block.shift_size = 0
block.window_size = min(block.input_resolution)
assert 0 <= block.shift_size < block.window_size, "shift_size must in 0-window_size"
if block.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = block.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -block.window_size),
slice(-block.window_size, -block.shift_size),
slice(-block.shift_size, None),
)
w_slices = (
slice(0, -block.window_size),
slice(-block.window_size, -block.shift_size),
slice(-block.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, block.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, block.window_size * block.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
else:
attn_mask = None
block.register_buffer("attn_mask", attn_mask)
if self.ape:
orig_size = int((self.absolute_pos_embed.shape[-2]) ** 0.5)
new_size = int(self.patch_embed.num_patches**0.5)
pos_tokens = self.absolute_pos_embed[:, :]
# make it shape rest x embed_dim x orig_size x orig_size
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
pos_tokens = nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
# make it shape rest x new_size^2 x embed_dim
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
self.absolute_pos_embed = nn.Parameter(pos_tokens.contiguous())
self.img_size = res
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {"absolute_pos_embed"}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"relative_position_bias_table"}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
return x
def forward(self, x):
x = self.forward_features(x)
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
x = self.head(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2**self.num_layers)
flops += self.num_features * self.num_classes
return flops
swin_sizes = {
"Ti": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]),
"S": dict(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24]),
"B": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
"L": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48]),
}
@register_model
def swin_tiny_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["Ti"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_small_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["S"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_base_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["B"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_large_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["L"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_ashim(pretrained=False, img_size=112, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = dict(embed_dim=384, depths=[12], num_heads=[12])
if "num_heads" in kwargs:
kwargs["num_heads"] = [kwargs["num_heads"]]
return SwinTransformer(img_size=img_size, in_chans=3, patch_size=2, window_size=7, **{**size, **kwargs})

View File

@@ -0,0 +1,225 @@
from functools import partial
import torch
from loguru import logger
from timm.models.vision_transformer import (
Attention,
Block,
PatchEmbed,
VisionTransformer,
)
from resizing_interface import ResizingInterface
class _MatrixSaveAttn(Attention):
attn_mat = None
@classmethod
def cast(cls, attn: Attention):
assert isinstance(attn, Attention), "Can only save attention from Timms attention class"
attn.__class__ = cls
assert isinstance(attn, _MatrixSaveAttn)
return attn
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
self.attn_mat = attn.detach()
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return self.proj_drop(x)
def _index_picker(tensor, idx=-1):
"""Pick a specific index from a tensor.
tensor: B x N x D -> B x D, by picking idx from N.
Args:
tensor (toch.tensor): tensor to pick from.
idx (int, optional): index to pick. Defaults to -1.
Returns:
torch.tensor: index from tensor
"""
return tensor[..., idx, :] # B x N x D -> B x D
class TimmViT(VisionTransformer, ResizingInterface):
"""Wrapper for *VisionTransformer* from *timm* library (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)."""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool="token",
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
qk_norm=False,
init_values=None,
class_token=True,
no_embed_class=True,
pre_norm=False,
fc_norm=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
weight_init="",
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
save_attention_maps=False,
fused_attn=True,
**kwargs,
):
"""Create a Vision Transformer model.
Parameters
----------
img_size : int
image dimensions -> img_size x img_size
patch_size : int
patch_size
in_chans : int
number of image channels
num_classes : int
number of classes for classification head
global_pool : str,int
type of global pooling for final sequence (default: 'token'), or index for token to be taken
embed_dim : int
embedding dimension
depth : int
number of transformer layers
num_heads : int
number of transformer heads
mlp_ratio : float
ratio of feed forward (mlp) hidden dimension to embedding dimension
qkv_bias : bool
enable bias for query, key, and value (qkv) embeddings
qk_norm : bool
normalize query and key embeddings
init_values : float
layer scale initial values
class_token : bool
use a class token [CLS]
no_embed_class : bool
no positional embedding for the class token
pre_norm : bool
use pre-norm architecture (norm before the blocks, not after)
fc_norm : bool
norm after pool (used when global_pool == 'avg')
drop_rate : float
dropout rate
attn_drop_rate : float
dropout rate in the attention module
drop_path_rate : float
drop path rate (stochastic depth)
weight_init : str
scheme for weight initialization
embed_layer : nn.Module
patch embedding layer
norm_layer : nn.Module
normalization layer
act_layer : nn.Module
activation function
block_fn : nn.Module
which block structure to use; for parallel attention layers, ...
save_attention_maps : bool
save attention maps for each block
fused_attn : bool
use fused attention
kwargs : dict
additional arguments (will be ignored)
"""
init_kwargs = dict(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
num_classes=num_classes,
global_pool=global_pool if isinstance(global_pool, str) else "avg",
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
class_token=class_token,
no_embed_class=no_embed_class,
pre_norm=pre_norm,
fc_norm=fc_norm,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
weight_init=weight_init,
embed_layer=embed_layer,
norm_layer=norm_layer,
act_layer=act_layer,
block_fn=block_fn,
)
self.new_version = False # TODO: check based on the timm version
self.global_pool = global_pool
if isinstance(global_pool, int):
self.attn_pool = partial(_index_picker, idx=global_pool)
if self.new_version:
init_kwargs["qk_norm"] = qk_norm
init_kwargs["proj_drop_rate"] = drop_rate
super(TimmViT, self).__init__(**init_kwargs)
self.embed_layer = embed_layer
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.pre_norm = pre_norm
self.class_token = class_token
self.no_embed_class = no_embed_class
self.num_classes = num_classes
self.num_heads = num_heads
self.depth = depth
self.save_attention_maps = save_attention_maps
if save_attention_maps:
self.do_save_attention_maps()
try:
for block in self.blocks:
block.attn.fused_attn = fused_attn and self.blocks[0].attn.fused_attn
use_fused = self.blocks[0].attn.fused_attn
logger.info(f"Use fused attention: {use_fused}")
except: # I'm lazy for now # noqa: E722
pass
def do_save_attention_maps(self):
self.save_attention_maps = True
for block in self.blocks:
block.attn = _MatrixSaveAttn.cast(block.attn)
def attention_maps(self):
assert self.save_attention_maps, "Have to save attention maps first"
return [getattr(block.attn, "attn_mat", None) for block in self.blocks]
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
if self.new_version:
x = self.patch_drop(x)
x = self.norm_pre(x)
x = self.blocks(x)
return self.norm(x)
def forward(self, x):
x = self.forward_features(x)
return self.forward_head(x)

View File

@@ -0,0 +1,118 @@
"""Default configuration parameters.
Attributes:
default_kwargs (dict): Default hyperparameters for the training process.
slurm_defaults (dict): Default values for SLURM batch job settings.
"""
from paths_config import user
default_kwargs = {
"amp": True,
"aug_color_jitter_factor": 0.3,
"aug_crop": True,
"aug_cutmix_alpha": 1.0,
"aug_flip": True,
"aug_gauss_blur": True,
"aug_grayscale": True,
"aug_mixup_alpha": 0.0,
"aug_normalize": True,
"aug_rand_rot": 0,
"aug_random_erase_count": 1,
"aug_random_erase_mode": "pixel",
"aug_random_erase_prob": 0.0,
"aug_repeated_augment_repeats": 1,
"aug_resize": True,
"aug_solarize": True,
"augment_engine": "torchvision",
"augment_strategy": "3-augment",
"auto_augment_strategy": "rand-m9-mstd0.5-inc1",
"batch_size": 2048,
"compile_model": False,
"cuda": True,
"custom_dataset_path": None,
"debug": False,
"drop_path_rate": 0.05,
"dropout": 0.0,
"eval_amp": True,
"experiment_name": "none",
"fused_attn": True,
"gather_stats_during_training": True,
"imsize": 224,
"input_dim": None,
"keep_interm_states": 2,
"label_smoothing": 0.1,
"layer_scale": True,
"layer_scale_init_values": 1e-4,
"log_level": "info",
"loss": "ce",
"loss_weight": "none",
"lr": 3e-3,
"max_grad_norm": 1.0,
"max_seq_len": None,
"min_lr": 1e-5,
"momentum": 0.0,
"num_heads": None,
"num_workers": 44,
"opt": "fusedlamb",
"opt_eps": 1e-7,
"pin_memory": False,
"pre_norm": False,
"prefetch_factor": 2,
"qkv_bias": True,
"run_name": None,
"save_epochs": 10,
"sched": "cosine",
"seed": None,
"shuffle": True,
"tqdm": True,
"wandb": True,
"warmup_epochs": 5,
"warmup_lr": 1e-6,
"warmup_sched": "linear",
"weight_decay": 0.02,
"weighted_sampler": False,
}
# , 'model_ema': True, 'model_ema_decay': 0.99996}
deit_kwargs = {
"aug_mixup_alpha": 0.8,
"aug_repeated_augment_repeats": 3,
"augment_strategy": "deit",
"aug_random_erase_prob": 0.25,
"batch_size": 1024,
"lr": 1e-3,
"max_grad_norm": 0.0,
"num_workers": 10,
"opt": "adamw",
"opt_eps": 1e-8,
"weight_decay": 0.05,
}
def get_default_kwargs(settings="deitiii"):
if settings.lower() == "deitiii":
return default_kwargs
if settings.lower() == "deit":
return {**default_kwargs, **deit_kwargs}
raise NotImplementedError(f"No such defaults setting: {settings}")
slurm_defaults = {
"after_job": None,
"container_image": f"PATH/TO/ENROOT/IMAGE",
"container_mounts": f'MOUNT_ALL_IMPORTANT_STORAGE_SERVERS_HERE,"`pwd`":"`pwd`"',
"container_workdir": '"`pwd`"',
"cpus_per_task": 24,
"exclude": None,
"export": "ALL,TQDM_DISABLE=1",
"job_name": None,
"mem_per_gpu": 90,
"nodes": 1,
"ntasks": 4,
"partition": ["A100", "H100", "H200"],
"task_prolog": None,
"time": "1-0",
}

View File

@@ -0,0 +1,142 @@
import albumentations as A
import cv2
import numpy as np
from albumentations.pytorch import ToTensorV2
from datadings.torch import CompressedToPIL
class AlbumTorchCompose(A.Compose):
"""Compose albumentation augmentations in a way that works with PIL images and datadings."""
def __init__(self, *args, **kwargs):
"""Pass to A.Compose."""
super().__init__(*args, **kwargs)
self.to_pil = CompressedToPIL()
def __call__(self, image, mask=None, **kwargs):
if isinstance(image, bytes):
image = self.to_pil(image)
if mask is not None and len(mask) == 0:
mask = None
if not isinstance(image, np.ndarray):
image = np.array(image)
if mask is not None and not isinstance(mask, np.ndarray):
mask = np.array(mask)
if mask is None:
return super().__call__(image=image, **kwargs)["image"]
return super().__call__(image=image, mask=mask, **kwargs)
class PILToNP(A.DualTransform):
"""Convert PIL image to numpy array."""
def apply(self, image, **params):
return np.array(image)
def apply_to_mask(self, image, **params):
return np.array(image)
def get_transform_init_args_names(self):
return ()
class AlbumCompressedToPIL(A.DualTransform):
"""Convert compressed image to PIL image."""
def apply(self, img, **params):
return self.to_pil(img)
def apply_to_mask(self, img, **params):
return self.to_pil(img)
def get_transform_init_args_names(self):
return ()
def minimal_augment(args, test=False):
"""Get minimal augmentations for training or testing.
Args:
args (argparse.Namespace): arguments
test (bool, optional): if True, return test augmentations. Defaults to False.
Returns:
List: Augmentation list
"""
augs = []
if args.aug_resize:
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
if test and args.aug_crop:
augs.append(A.CenterCrop(args.imsize, args.imsize))
elif args.aug_crop:
augs.append(A.RandomCrop(args.imsize, args.imsize))
if not test and args.aug_flip:
augs.append(A.HorizontalFlip(p=0.5))
if args.aug_normalize:
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
augs.append(ToTensorV2())
return augs
def three_augment(args, as_list=False, test=False):
"""Create the data augmentation.
Args:
args (Namespace): arguments
as_list (bool): return list of transformations, not composed transformation
test (bool): In eval mode? If False => train mode
Returns:
torch.nn.Module | list[torch.nn.Module]: composed transformation or list of transformations
"""
augs = []
if args.aug_resize:
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
if test and args.aug_crop:
augs.append(A.CenterCrop(args.imsize, args.imsize))
elif args.aug_crop:
augs.append(A.RandomCrop(args.imsize, args.imsize, pad_if_needed=True, border_mode=cv2.BORDER_REFLECT))
if not test:
if args.aug_flip:
augs.append(A.HorizontalFlip(p=0.5))
augs_choice = []
if args.aug_grayscale:
augs_choice.append(A.ToGray(p=1, num_output_channels=3))
if args.aug_solarize:
augs_choice.append(A.Solarize(p=1, threshold_range=(0.5, 0.5)))
if args.aug_gauss_blur:
augs_choice.append(A.GaussianBlur(p=1, sigma_limit=(0.2, 2.0), blur_limit=(7, 7)))
if len(augs_choice) > 0:
augs.append(A.OneOf(augs_choice))
if args.aug_color_jitter_factor > 0.0:
augs.append(
A.ColorJitter(
brightness=args.aug_color_jitter_factor,
contrast=args.aug_color_jitter_factor,
saturation=args.aug_color_jitter_factor,
hue=0.0,
)
)
if args.aug_normalize:
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
augs.append(ToTensorV2())
if as_list:
return augs
return AlbumTorchCompose(augs)

View File

@@ -0,0 +1,63 @@
import os
from loguru import logger
from PIL import Image
from torch.utils.data import Dataset
class CounterAnimal(Dataset):
"""Dataset to load the CounterAnimal dataset with ImageNet labels."""
def __init__(self, base_path, mode, transform=None, target_transform=None, train=False):
"""Create the dataset.
Args:
base_path (str): path to the base folder (the one where the class folders are in)
mode (str): mode/variant of the dataset (common/counter)
transform: Image augmentation
target_transform: label augmentation
train: train or test set. Train set is not supported
"""
super().__init__()
self.base = base_path
assert mode in ["counter", "common"], f"Supported modes are counter and common, but got '{mode}'"
assert not train, "CounterAnimal only consists of test data, not training data."
self.transform = transform
self.target_transform = target_transform
self.index = []
for class_folder in os.listdir(self.base):
if not os.path.isdir(os.path.join(self.base, class_folder)):
continue
# print(f"looking in folder {class_folder}")
class_idx = int(class_folder.split(" ")[0])
for variant_folder in os.listdir(os.path.join(self.base, class_folder)):
# print(f"\tlooking in variant {variant_folder}")
if not variant_folder.startswith(mode):
# print("\t\tskip")
continue
_folder = os.path.join(self.base, class_folder, variant_folder)
# print(f"\t\tadding {len(os.listdir(_folder))} files to index")
for file in os.listdir(_folder):
if file.lower().split(".")[-1] in ["jpg", "jpeg", "png"]:
self.index.append((os.path.join(_folder, file), class_idx))
# print(f"loaded {len(self.index)} images into the index: {self.index[:5]}...")
assert len(self.index) > 0, "did not find any images :("
def __len__(self):
return len(self.index)
def __getitem__(self, idx):
path, label = self.index[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
if self.target_transform:
label = self.target_transform(label)
return img, label

View File

@@ -0,0 +1,145 @@
from nvidia.dali import fn, pipeline_def, types
# see https://docs.nvidia.com/deeplearning/dali/user-guide/docs/plugins/pytorch_dali_proxy.html
@pipeline_def
def minimal_augment(args, test=False):
"""Minimal Augmentation set for images.
Contains only resize, crop, flip, to tensor and normalize.
Args:
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
test (bool, optional): On the test set? Defaults to False.
Returns:
images: augmented images.
"""
images = fn.external_source(name="images", no_copy=True)
if args.aug_resize:
images = fn.resize(images, size=args.imsize, mode="not_smaller")
if test and args.aug_crop:
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
elif args.aug_crop:
images = fn.crop(
images,
crop=(args.imsize, args.imsize),
crop_pos_x=fn.random.uniform(range=(0, 1)),
crop_pos_y=fn.random.uniform(range=(0, 1)),
)
# if not test and args.aug_flip:
# images = fn.flip(images, horizontal=fn.random.coin_flip())
# if args.aug_normalize:
# images = fn.normalize(
# images,
# mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
# std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
# dtype=types.FLOAT,
# )
return fn.crop_mirror_normalize(
images,
dtype=types.FLOAT,
output_layout="CHW",
crop=(args.imsize, args.imsize),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
)
def dali_solarize(images, threshold=128):
"""Solarize implementation for nvidia DALI.
Args:
images (DALI Tensor): Images to solarize.
threshold (int, optional): Threshold for solarization. Defaults to 128.
Returns:
images: solarized images.
"""
inv_images = types.Constant(255).uint8() - images
mask = (images >= threshold) * types.Constant(1).uint8()
return mask * inv_images + (types.Constant(1).uint8() ^ mask) * images
@pipeline_def(enable_conditionals=True)
def three_augment(args, test=False):
"""3-augment data augmentation pipeline for nvidia DALI.
Args:
args (namespace): augmentation arguments.
test (bool, optional): Test (or train) split. Defaults to False.
Returns:
images: augmented images.
"""
images = fn.external_source(name="images", no_copy=True)
if args.aug_resize:
images = fn.resize(images, size=args.imsize, mode="not_smaller")
if test and args.aug_crop:
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
elif args.aug_crop:
images = fn.crop(
images,
crop=(args.imsize, args.imsize),
crop_pos_x=fn.random.uniform(range=(0, 1)),
crop_pos_y=fn.random.uniform(range=(0, 1)),
)
if not test:
choices = []
# choice = fn.random.choice(3)
# print(images.layout())
choice_ps = [1 * args.aug_grayscale, 1 * args.aug_solarize, 1 * args.aug_gauss_blur]
choice_ps = [c / sum(choice_ps) for c in choice_ps]
choice = fn.random.choice(
[0, 1, 2],
p=choice_ps,
)
if choice == 0:
images = fn.color_space_conversion(
fn.color_space_conversion(images, image_type=types.RGB, output_type=types.GRAY),
image_type=types.GRAY,
output_type=types.RGB,
)
elif choice == 1:
images = dali_solarize(images, threshold=128)
elif choice == 2:
images = fn.gaussian_blur(images, window_size=7, sigma=fn.random.uniform(range=(0.2, 2.0)))
if len(choices) > 0:
images = fn.random.choice(choices)
if args.aug_color_jitter_factor > 0.0:
images = fn.color_twist(
images,
brightness=fn.random.uniform(
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
),
contrast=fn.random.uniform(range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)),
saturation=fn.random.uniform(
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
),
)
return fn.crop_mirror_normalize(
images,
dtype=types.FLOAT,
output_layout="CHW",
crop=(args.imsize, args.imsize),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
)

View File

@@ -0,0 +1,407 @@
from random import uniform
import msgpack
import torch
import torchvision
from datadings.torch import CompressedToPIL
from datadings.torch import Dataset as DDDataset
from PIL import ImageFilter
from torchvision.transforms import (
CenterCrop,
ColorJitter,
Compose,
GaussianBlur,
Grayscale,
InterpolationMode,
Normalize,
RandomChoice,
RandomCrop,
RandomHorizontalFlip,
RandomResizedCrop,
RandomSolarize,
Resize,
ToTensor,
)
from torchvision.transforms import functional as F
_image_and_target_transforms = [
torchvision.transforms.RandomCrop,
torchvision.transforms.RandomHorizontalFlip,
torchvision.transforms.CenterCrop,
torchvision.transforms.RandomRotation,
torchvision.transforms.RandomAffine,
torchvision.transforms.RandomResizedCrop,
torchvision.transforms.RandomRotation,
]
def apply_dense_transforms(x, y, transforms: torchvision.transforms.transforms.Compose):
"""Apply some transfomations to both image and target.
Args:
x (torch.Tensor): image
y (torch.Tensor): target (image)
transforms (torchvision.transforms.transforms.Compose): transformations to apply
Returns:
tuple[torch.Tensor, torch.Tensor]: (x, y) with applyed transformations
"""
for trans in transforms.transforms:
if isinstance(trans, torchvision.transforms.RandomResizedCrop):
params = trans.get_params(x, trans.scale, trans.ratio)
x = F.resized_crop(x, *params, trans.size, trans.interpolation, antialias=trans.antialias)
y = F.resized_crop(y.unsqueeze(0), *params, trans.size, 0).squeeze(0) # nearest neighbor interpolation
elif isinstance(trans, Resize):
pre_shape = x.shape
x = trans(x)
if x.shape != pre_shape:
y = F.resize(y.unsqueeze(0), trans.size, 0, trans.max_size, trans.antialias).squeeze(
0
) # nearest neighbor interpolation
elif any(isinstance(trans, simul_transform) for simul_transform in _image_and_target_transforms):
xy = torch.cat([x, y.unsqueeze(0).float()], dim=0)
xy = trans(xy)
x, y = xy[:-1], xy[-1].long()
elif isinstance(trans, torchvision.transforms.ToTensor):
if not isinstance(x, torch.Tensor):
x = trans(x)
else:
x = trans(x)
return x, y
def get_hf_transform(transform_f, trgt_transform_f=None, image_key="image"):
"""Convert the transform function to a huggingface compatible transform function.
Args:
transform_f (callable): Image transform.
trgt_transform (callable, optional): Target transform. Defaults to None.
image_key (str, optional): Key for the image in the hf ds return dict. Defaults to "image".
"""
def _transform(samples):
try:
samples[image_key] = [transform_f(im) for im in samples[image_key]]
if trgt_transform_f is not None:
samples["label"] = [trgt_transform_f(tgt) for tgt in samples["label"]]
except TypeError as e:
print(f"Type error when transforming samples: {samples}")
raise e
return samples
return _transform
class DDDecodeDataset(DDDataset):
"""Datadings dataset with image decoding before transform."""
def __init__(self, *args, transform=None, target_transform=None, transforms=None, **kwargs):
"""Create datadings dataset.
Args:
transform (callable, optional): Image transform. Overrides transforms['image']. Defaults to None.
target_transform (callable, optional): Label transform. Overrides transforms['label']. Defaults to None.
transforms (dict[str, callable], optional): Dict of transforms for each key. Defaults to None.
"""
super().__init__(*args, **kwargs)
if transforms is None:
transforms = {}
self._decode_transform = transform if transform is not None else transforms.get("image", None)
self._decode_target_transform = (
target_transform if target_transform is not None else transforms.get("label", None)
)
self.ctp = CompressedToPIL()
def __getitem__(self, idx):
sample = super().__getitem__(idx)
img, lbl = sample["image"], sample["label"]
if isinstance(img, bytes):
img = self.ctp(img)
if self._decode_transform is not None:
img = self._decode_transform(img)
if self._decode_target_transform is not None:
lbl = self._decode_target_transform(lbl)
return img, lbl
def minimal_augment(args, test=False):
"""Minimal Augmentation set for images.
Contains only resize, crop, flip, to tensor and normalize.
Args:
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
test (bool, optional): On the test set? Defaults to False.
Returns:
List: Augmentation list
"""
augs = []
augs.append(ToTensor())
if args.aug_resize:
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
if test and args.aug_crop:
augs.append(CenterCrop(args.imsize))
elif args.aug_crop:
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
if not test and args.aug_flip:
augs.append(RandomHorizontalFlip(p=0.5))
if args.aug_normalize:
augs.append(
Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225]),
)
)
return augs
def three_augment(args, as_list=False, test=False):
"""Create the data augmentation.
Parameters
----------
Args:
arguments
as_list : bool
return list of transformations, not composed transformation
test : bool
In eval mode? If False => train mode
Returns:
-------
torch.nn.Module | list[torch.nn.Module]
composed transformation of list of transformations
"""
augs = []
augs.append(ToTensor())
if args.aug_resize:
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
if test and args.aug_crop:
augs.append(CenterCrop(args.imsize))
elif args.aug_crop:
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
if not test:
if args.aug_flip:
augs.append(RandomHorizontalFlip(p=0.5))
augs_choice = []
if args.aug_grayscale:
augs_choice.append(Grayscale(num_output_channels=3))
if args.aug_solarize:
augs_choice.append(RandomSolarize(threshold=0.5, p=1.0))
if args.aug_gauss_blur:
# TODO: check kernel size?
augs_choice.append(GaussianBlur(kernel_size=7, sigma=(0.2, 2.0)))
# augs_choice.append(QuickGaussBlur())
if len(augs_choice) > 0:
augs.append(RandomChoice(augs_choice))
if args.aug_color_jitter_factor > 0.0:
augs.append(
ColorJitter(
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
)
)
if args.aug_normalize:
augs.append(
Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225]),
)
)
if as_list:
return augs
return Compose(augs)
def segment_augment(args, test=False):
"""Create the data augmentation for segmentation.
No cropping in this part, as cropping has to be done for the image and labels simultaneously.
Args:
args (DotDict): arguments
test (bool, optional): In eval mode? If False => train mode. Defaults to False.
Returns:
list[torch.nn.Module]: list of transformations
"""
augs = []
if test:
augs.append(ResizeUp(args.imsize))
augs.append(CenterCrop(args.imsize))
else:
augs.append(RandomResizedCrop(args.imsize, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)))
if not test and args.aug_flip:
augs.append(RandomHorizontalFlip(p=0.5))
if args.aug_color_jitter_factor > 0.0:
augs.append(
ColorJitter(
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
)
)
if args.aug_normalize:
augs.append(
Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225]),
)
)
return augs
class QuickGaussBlur:
"""Gaussian blur transformation using PIL ImageFilter."""
def __init__(self, sigma=(0.2, 2.0)):
"""Create Gaussian blur operator.
Args:
-----
sigma : tuple[float, float]
range of sigma for blur
"""
self.sigma_min, self.sigma_max = sigma
def __call__(self, img):
return img.filter(ImageFilter.GaussianBlur(radius=uniform(self.sigma_min, self.sigma_max)))
class RemoveTransform:
"""Remove data from transformation.
To use with default collate function.
"""
def __call__(self, x, y=None):
if y is None:
return [1]
return [1], y
def __repr__(self):
return f"{self.__class__.__name__}()"
def collate_imnet(data, image_key="image"):
"""Collate function for imagenet(1k / 21k) with datadings.
Args:
----
data : list[dict[str, Any]]
images for a batch
Returns:
-------
tuple[torch.Tensor, torch.Tensor]
images, labels
"""
if isinstance(data[0][image_key], torch.Tensor):
ims = torch.stack([d[image_key] for d in data], dim=0)
else:
ims = [d[image_key] for d in data]
labels = torch.tensor([d["label"] for d in data])
# keys = [d['key'] for d in data]
return ims, labels # , keys
def collate_listops(data):
"""Collate function for ListOps with datadings.
Args:
----
data : list[tuple[torch.Tensor, torch.Tensor]]
list of samples
Returns:
-------
tuple[torch.Tensor, torch.Tensor]
images, labels
"""
return data[0][0], data[0][1]
def no_param_transf(self, sample):
"""Call transformation without extra parameter.
To use with datadings QuasiShuffler.
Args:
----
self : object
use this as a method ( <obj>.<method_name> = MethodType(no_param_transf, <obj>) )
sample : Any
sample to transform
Returns:
-------
Any
transformed sample
"""
if isinstance(sample, tuple):
# sample of type (name (str), data (bytes encoded))
sample = sample[1]
if isinstance(sample, bytes):
# decode msgpack bytes
sample = msgpack.loads(sample)
params = self._rng(sample)
for k, f in self._transforms.items():
sample[k] = f(sample[k], params)
return sample
class ToOneHotSequence:
"""Convert a sequence of grayscale values (range 0 to 1) to a one-hot encoded sequence based on 8-bit rounded values."""
def __call__(self, x, y=None):
# x is 1 x 32 x 32
x = (255 * x).round().to(torch.int64).view(-1)
assert x.max() < 256, f"Found max value {x.max()} in {x}."
x = torch.nn.functional.one_hot(x, num_classes=256).float()
if y is None:
return x
return x, y
def __repr__(self):
return f"{self.__class__.__name__}()"
class ResizeUp(Resize):
"""Resize up if image is smaller than target size."""
def forward(self, img):
w, h = img.shape[-2], img.shape[-1]
if w < self.size or h < self.size:
img = super().forward(img)
return img

View File

@@ -0,0 +1,484 @@
import json
import os
import zipfile
from io import BytesIO
from math import floor
import numpy as np
import PIL
from datadings.torch import Compose
from loguru import logger
from PIL import Image, ImageFilter
from torch.utils.data import Dataset, get_worker_info
from torchvision import transforms as T
from tqdm.auto import tqdm
from data.data_utils import apply_dense_transforms
class ForNet(Dataset):
"""Recombine ImageNet forgrounds and backgrounds.
Note:
This dataset has exactly the ImageNet classes.
"""
_back_combs = ["same", "all", "original"]
_bg_transforms = {T.RandomCrop, T.CenterCrop, T.Resize, T.RandomResizedCrop}
def __init__(
self,
root,
transform=None,
train=True,
target_transform=None,
background_combination="all",
fg_scale_jitter=0.3,
fg_transform=None,
pruning_ratio=0.8,
return_fg_masks=False,
fg_size_mode="range",
fg_bates_n=1,
paste_pre_transform=True,
mask_smoothing_sigma=4.0,
rel_jut_out=0.0,
fg_in_nonant=None,
size_fact=1.0,
orig_img_prob=0.0,
orig_ds=None,
_orig_ds_file_type="JPEG",
epochs=0,
_album_compose=False,
):
"""Create RecombinationNet dataset.
Args:
root (str): Root folder for the dataset.
transform (T.Collate | list, optional): Transform to apply to the image. Defaults to None.
train (bool, optional): On the train set (False -> val set). Defaults to True.
target_transform (T.Collate | list, optional): Transform to apply to the target values. Defaults to None.
background_combination (str, optional): Which backgrounds to combine with foregrounds. Defaults to "same".
fg_scale_jitter (tuple, optional): How much should the size of the foreground be changed (random ratio). Defaults to (0.1, 0.8).
fg_transform (_type_, optional): Transform to apply to the foreground before applying to the background. This is supposed to be a random rotation, mainly. Defaults to None.
pruning_ratio (float, optional): For pruning backgrounds, with (foreground size/background size) >= <pruning_ratio>. Backgrounds from images that contain very large foreground objects are mostly computer generated and therefore relatively unnatural. Defaults to full dataset.
return_fg_masks (bool, optional): Return the foreground masks. Defaults to False.
fg_size_mode (str, optional): How to determine the size of the foreground, based on the foreground sizes of the foreground and background images. Defaults to "max".
fg_bates_n (int, optional): Bates parameter for the distribution of the object position in the foreground. Defaults to 1 (uniform distribution). The higher the value, the more likely the object is in the center. For fg_bates_n = 0, the object is always in the center.
paste_pre_transform (bool, optional): Paste the foreground onto the background before applying the transform. If false, the background will be cropped and resized before pasting the foreground. Defaults to False.
mask_smoothing_sigma (float, optional): Sigma for the Gaussian blur of the mask edge. Defaults to 0.0. Try 2.0 or 4.0?
rel_jut_out (float, optional): How much is the foreground allowed to stand/jut out of the background (and then cut off). Defaults to 0.0.
fg_in_nonant (int, optional): If not None, the foreground will be placed in a specific nonant (0-8) of the image. Defaults to None.
size_fact (float, optional): Factor to multiply the size of the foreground with. Defaults to 1.0.
orig_img_prob (float | str, optional): Probability to use the original image, instead of the fg-bg recombinations. Defaults to 0.0.
orig_ds (Dataset, optional): Original dataset (without transforms) to use for the original images. Defaults to None.
_orig_ds_file_type (str, optional): File type of the original dataset. Defaults to "JPEG".
epochs (int, optional): Number of epochs to train on. Used for linear increase of orig_img_prob.
Note:
For more information on the bates distribution, see https://en.wikipedia.org/wiki/Bates_distribution.
For fg_bats_n < 0, we take extend the bates dirstribution to focus more and more on the edges. This is done by sampling B ~ Bates(|fg_bates_n|) and then passing through f(x) = x + 0.5 - floor(x + 0.5).
For the list of transformations that will be applied to the background only (if paste_pre_transform=False), see RecombinationNet._bg_transforms.
A nonant in this case refers to a square in a 3x3 grid dividing the image.
"""
assert (
background_combination in self._back_combs
), f"background_combination={background_combination} is not supported. Use one of {self._back_combs}"
if (not os.path.exists(f"{root}/backgrounds_{'train' if train else 'val'}.zip")) and os.path.exists(
os.path.join(root, "train" if train else "val", "backgrounds")
):
self._mode = "folder"
else:
self._mode = "zip"
if self._mode == "zip":
try:
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip:
self.backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
with zipfile.ZipFile(f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r") as fg_zip:
self.foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
except FileNotFoundError as e:
logger.error(
f"RecombinationNet: {e}. Make sure to have the background and foreground zips in the root"
f" directory: found {os.listdir(root)}"
)
raise e
classes = set([f.split("/")[-2] for f in self.foregrounds])
else:
logger.info("ForNet folder mode: loading classes")
classes = set(os.listdir(os.path.join(root, "train" if train else "val", "foregrounds")))
foregrounds = []
backgrounds = []
for cls in tqdm(classes, desc="Loading files"):
foregrounds.extend(
[
f"{cls}/{f}"
for f in os.listdir(os.path.join(root, "train" if train else "val", "foregrounds", cls))
]
)
backgrounds.extend(
[
f"{cls}/{f}"
for f in os.listdir(os.path.join(root, "train" if train else "val", "backgrounds", cls))
]
)
self.foregrounds = foregrounds
self.backgrounds = backgrounds
self.classes = sorted(list(classes), key=lambda x: int(x[1:]))
assert os.path.exists(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json"), (
f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json not found, provide the information or set"
" pruning_ratio=1.0"
)
with open(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json", "r") as f:
self.fg_bg_ratios = json.load(f)
if self._mode == "folder":
self.fg_bg_ratios = {"/".join(key.split("/")[-2:]): val for key, val in self.fg_bg_ratios.items()}
logger.info(f"Renamed fg_bg_ratios keys to {list(self.fg_bg_ratios.keys())[:3]}...")
if pruning_ratio <= 1.0:
backup_backgrounds = {}
for bg_file in self.backgrounds:
bg_cls = bg_file.split("/")[-2]
backup_backgrounds[bg_cls] = bg_file
self.backgrounds = [
bg for bg in self.backgrounds if bg in self.fg_bg_ratios and self.fg_bg_ratios[bg] < pruning_ratio
]
# logger.info(
# f"RecombinationNet: keep {len(self.backgrounds)} of {len(self.fg_bg_ratios)} backgrounds with pr {pruning_ratio}"
# )
self.root = root
self.train = train
self.background_combination = background_combination
self.fg_scale_jitter = fg_scale_jitter
self.fg_transform = fg_transform
self.return_fg_masks = return_fg_masks
self.paste_pre_transform = paste_pre_transform
self.mask_smoothing_sigma = mask_smoothing_sigma
self.rel_jut_out = rel_jut_out
self.size_fact = size_fact
self.fg_in_nonant = fg_in_nonant
assert fg_in_nonant is None or -1 <= fg_in_nonant < 9, f"fg_in_nonant={fg_in_nonant} not in [0, 8] or None"
self.orig_img_prob = orig_img_prob
if orig_img_prob != 0.0:
assert (isinstance(orig_img_prob, float) and orig_img_prob > 0.0) or orig_img_prob in [
"linear",
"cos",
"revlinear",
]
assert orig_ds is not None, "orig_ds must be provided if orig_img_prob > 0.0"
assert not return_fg_masks, "can't provide fg masks for original images (yet)"
assert os.path.exists(os.path.join(root, f"{'train' if train else 'val'}_indices.json")) or isinstance(
orig_ds, str
), f"{root}/{'train' if train else 'val'}_indices.json must be provided if orig_ds is a dataset"
if not isinstance(orig_ds, str):
with open(os.path.join(root, f"{'train' if train else 'val'}_indices.json"), "r") as f:
self.key_to_orig_idx = json.load(f)
else:
if not (
orig_ds.endswith("train" if train else "val") or orig_ds.endswith("train/" if train else "val/")
):
orig_ds = f"{orig_ds}/{'train' if train else 'val'}"
self.key_to_orig_idx = None
self._orig_ds_file_type = _orig_ds_file_type
self.orig_ds = orig_ds
self.epochs = epochs
self._epoch = 0
assert fg_size_mode in [
"max",
"min",
"mean",
"range",
], f"fg_size_mode={fg_size_mode} not supported; use one of ['max', 'min', 'mean', 'range']"
self.fg_size_mode = fg_size_mode
self.fg_bates_n = fg_bates_n
if not paste_pre_transform:
if isinstance(transform, (T.Compose, Compose)):
transform = transform.transforms
# do cropping and resizing mainly on background; paste foreground on top later
self.bg_transform = []
self.join_transform = []
for tf in transform:
if isinstance(tf, tuple(self._bg_transforms)):
self.bg_transform.append(tf)
else:
self.join_transform.append(tf)
if _album_compose:
from data.album_transf import AlbumTorchCompose
self.bg_transform = AlbumTorchCompose(self.bg_transform)
self.join_transform = AlbumTorchCompose(self.join_transform)
else:
self.bg_transform = T.Compose(self.bg_transform)
self.join_transform = T.Compose(self.join_transform)
else:
if isinstance(transform, list):
if _album_compose:
from data.album_transf import AlbumTorchCompose
self.join_transform = AlbumTorchCompose(transform)
else:
self.join_transform = T.Compose(transform)
else:
self.join_transform = transform
self.bg_transform = None
self.trgt_map = {cls: i for i, cls in enumerate(self.classes)}
self.target_transform = target_transform
self.cls_to_allowed_bg = {}
for bg_file in self.backgrounds:
if background_combination == "same":
bg_cls = bg_file.split("/")[-2]
if bg_cls not in self.cls_to_allowed_bg:
self.cls_to_allowed_bg[bg_cls] = []
self.cls_to_allowed_bg[bg_cls].append(bg_file)
if background_combination == "same":
for cls_code in classes:
if cls_code not in self.cls_to_allowed_bg or len(self.cls_to_allowed_bg[cls_code]) == 0:
self.cls_to_allowed_bg[cls_code] = [backup_backgrounds[cls_code]]
logger.warning(f"No background for class {cls_code}, using {backup_backgrounds[cls_code]}")
self._zf = {}
@property
def epoch(self):
return self._epoch
@epoch.setter
def epoch(self, value):
self._epoch = value
def __len__(self):
"""Size of the dataset.
Returns:
int: number of foregrounds
"""
return len(self.foregrounds)
def num_classes(self):
return len(self.classes)
def _get_fg(self, idx):
worker_id = self._wrkr_info()
fg_file = self.foregrounds[idx]
with self._zf[worker_id]["fg"].open(fg_file) as f:
fg_data = BytesIO(f.read())
return Image.open(fg_data)
def _wrkr_info(self):
worker_id = get_worker_info().id if get_worker_info() else 0
if worker_id not in self._zf and self._mode == "zip":
self._zf[worker_id] = {
"bg": zipfile.ZipFile(f"{self.root}/backgrounds_{'train' if self.train else 'val'}.zip", "r"),
"fg": zipfile.ZipFile(f"{self.root}/foregrounds_{'train' if self.train else 'val'}.zip", "r"),
}
return worker_id
def __getitem__(self, idx):
"""Get the foreground at index idx and combine it with a (random) background.
Args:
idx (int): foreground index
Returns:
torch.Tensor, torch.Tensor: image, target
"""
worker_id = self._wrkr_info()
fg_file = self.foregrounds[idx]
trgt_cls = fg_file.split("/")[-2]
if (
(self.orig_img_prob == "linear" and np.random.rand() < self._epoch / self.epochs)
or (self.orig_img_prob == "revlinear" and np.random.rand() < (self._epoch - self.epochs) / self.epochs)
or (self.orig_img_prob == "cos" and np.random.rand() > np.cos(np.pi * self._epoch / (2 * self.epochs)))
or (
isinstance(self.orig_img_prob, float)
and self.orig_img_prob > 0.0
and np.random.rand() < self.orig_img_prob
)
):
data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
if isinstance(self.orig_ds, str):
image_file = os.path.join(self.orig_ds, f"{data_key}.{self._orig_ds_file_type}")
orig_img = Image.open(image_file).convert("RGB")
else:
orig_data = self.orig_ds[self.key_to_orig_idx[data_key]]
orig_img = orig_data["image"] if isinstance(orig_data, dict) else orig_data[0]
if self.bg_transform:
orig_img = self.bg_transform(orig_img)
if self.join_transform:
orig_img = self.join_transform(orig_img)
trgt = self.trgt_map[trgt_cls]
if self.target_transform:
trgt = self.target_transform(trgt)
return orig_img, trgt
if self._mode == "zip":
with self._zf[worker_id]["fg"].open(fg_file) as f:
fg_data = BytesIO(f.read())
try:
fg_img = Image.open(fg_data).convert("RGBA")
except PIL.UnidentifiedImageError as e:
logger.error(f"Error with idx={idx}, file={self.foregrounds[idx]}")
raise e
else:
# data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
fg_img = Image.open(
os.path.join(self.root, "train" if self.train else "val", "foregrounds", fg_file)
).convert("RGBA")
if self.fg_transform:
fg_img = self.fg_transform(fg_img)
fg_size_factor = T.ToTensor()(fg_img.split()[-1]).mean().item()
if self.background_combination == "all":
bg_idx = np.random.randint(len(self.backgrounds))
bg_file = self.backgrounds[bg_idx]
elif self.background_combination == "original":
bg_file = fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")
else:
bg_idx = np.random.randint(len(self.cls_to_allowed_bg[trgt_cls]))
bg_file = self.cls_to_allowed_bg[trgt_cls][bg_idx]
if self._mode == "zip":
with self._zf[worker_id]["bg"].open(bg_file) as f:
bg_data = BytesIO(f.read())
bg_img = Image.open(bg_data).convert("RGB")
else:
bg_img = Image.open(
os.path.join(self.root, "train" if self.train else "val", "backgrounds", bg_file)
).convert("RGB")
if not self.paste_pre_transform:
bg_img = self.bg_transform(bg_img)
bg_size = bg_img.size
# choose scale factor, such that relative area is in fg_scale
bg_area = bg_size[0] * bg_size[1]
if self.fg_in_nonant is not None:
bg_area = bg_area / 9
# logger.info(f"background: size={bg_size} area={bg_area}")
# logger.info(f"fg_file={fg_file}, fg_bg_ratio_keys={list(self.fg_bg_ratios.keys())[:3]}...")
orig_fg_ratio = self.fg_bg_ratios[fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")]
bg_fg_ratio = self.fg_bg_ratios[bg_file]
if self.fg_size_mode == "max":
goal_fg_ratio_lower = goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
elif self.fg_size_mode == "min":
goal_fg_ratio_lower = goal_fg_ratio_upper = min(orig_fg_ratio, bg_fg_ratio)
elif self.fg_size_mode == "mean":
goal_fg_ratio_lower = goal_fg_ratio_upper = (orig_fg_ratio + bg_fg_ratio) / 2
else:
# range
goal_fg_ratio_lower = min(orig_fg_ratio, bg_fg_ratio)
goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
# logger.info(f"fg_bg_ratio={orig_fg_ratio}")
fg_scale = (
np.random.uniform(
goal_fg_ratio_lower * (1 - self.fg_scale_jitter), goal_fg_ratio_upper * (1 + self.fg_scale_jitter)
)
/ fg_size_factor
* self.size_fact
)
goal_shape_y = round(np.sqrt(bg_area * fg_scale * fg_img.size[1] / fg_img.size[0]))
goal_shape_x = round(np.sqrt(bg_area * fg_scale * fg_img.size[0] / fg_img.size[1]))
fg_img = fg_img.resize((goal_shape_x, goal_shape_y))
if fg_img.size[0] > bg_size[0] or fg_img.size[1] > bg_size[1]:
# random crop to fit
goal_w, goal_h = (min(fg_img.size[0], bg_size[0]), min(fg_img.size[1], bg_size[1]))
fg_img = T.RandomCrop((goal_h, goal_w))(fg_img) if self.train else T.CenterCrop((goal_h, goal_w))(fg_img)
# paste fg on bg
z1, z2 = (
(
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(), # bates distribution n=1 => uniform
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(),
)
if self.fg_bates_n != 0
else (0.5, 0.5)
)
if self.fg_bates_n < 0:
z1 = z1 + 0.5 - floor(z1 + 0.5)
z2 = z2 + 0.5 - floor(z2 + 0.5)
x_min = -self.rel_jut_out * fg_img.size[0]
x_max = bg_size[0] - fg_img.size[0] * (1 - self.rel_jut_out)
y_min = -self.rel_jut_out * fg_img.size[1]
y_max = bg_size[1] - fg_img.size[1] * (1 - self.rel_jut_out)
if self.fg_in_nonant is not None and self.fg_in_nonant >= 0:
x_min = (self.fg_in_nonant % 3) * bg_size[0] / 3
x_max = ((self.fg_in_nonant % 3) + 1) * bg_size[0] / 3 - fg_img.size[0]
y_min = (self.fg_in_nonant // 3) * bg_size[1] / 3
y_max = ((self.fg_in_nonant // 3) + 1) * bg_size[1] / 3 - fg_img.size[1]
if x_min > x_max:
x_min = x_max = (x_min + x_max) / 2
if y_min > y_max:
y_min = y_max = (y_min + y_max) / 2
offs_x = round(z1 * (x_max - x_min) + x_min)
offs_y = round(z2 * (y_max - y_min) + y_min)
paste_mask = fg_img.split()[-1]
if self.mask_smoothing_sigma > 0.0:
sigma = (np.random.rand() * 0.9 + 0.1) * self.mask_smoothing_sigma
paste_mask = paste_mask.filter(ImageFilter.GaussianBlur(radius=sigma))
paste_mask = paste_mask.point(lambda p: 2 * p - 255 if p > 128 else 0)
bg_img.paste(fg_img.convert("RGB"), (offs_x, offs_y), paste_mask)
bg_img = bg_img.convert("RGB")
if self.return_fg_masks:
fg_mask = Image.new("L", bg_size, 0)
fg_mask.paste(paste_mask, (offs_x, offs_y))
fg_mask = T.ToTensor()(fg_mask)[0]
bg_img = T.ToTensor()(bg_img)
if self.join_transform:
# img_mask_stack = torch.cat([bg_img, fg_mask.unsqueeze(0)], dim=0)
# img_mask_stack = self.join_transform(img_mask_stack)
# bg_img, fg_mask = img_mask_stack[:-1], img_mask_stack[-1]
bg_img, fg_mask = apply_dense_transforms(bg_img, fg_mask, self.join_transform)
else:
bg_img = self.join_transform(bg_img)
if trgt_cls not in self.trgt_map:
raise ValueError(f"trgt_cls={trgt_cls} not in trgt_map: {self.trgt_map}")
trgt = self.trgt_map[trgt_cls]
if self.target_transform:
trgt = self.target_transform(trgt)
if self.return_fg_masks:
return bg_img, trgt, fg_mask
return bg_img, trgt

View File

@@ -0,0 +1,151 @@
import argparse
import shutil
import zipfile
from os import listdir, makedirs, path
from random import choice
from datadings.reader import MsgpackReader
from datadings.writer import FileWriter
from tqdm.auto import tqdm
parser = argparse.ArgumentParser()
parser.add_argument("-tiny_imagenet_zip", type=str, required=True, help="Path to the Tiny ImageNet zip file")
parser.add_argument("-output_dir", type=str, required=True, help="Directory to extract the image names to")
parser.add_argument("-in_segment_dir", type=str, required=True, help="Directory that holds the segmented ImageNet")
parser.add_argument(
"-imagenet_path", type=str, nargs="?", required=True, help="Path to the original ImageNet dataset (datadings)"
)
args = parser.parse_args()
images = {"train": set(), "val": set()}
with zipfile.ZipFile(args.tiny_imagenet_zip, "r") as zip_ref:
for info in tqdm(zip_ref.infolist(), desc="Gathering Images"):
if info.filename.endswith(".JPEG"):
if "/val/" in info.filename:
images["val"].add(info.filename.split("/")[-1])
elif "/train/" in info.filename:
images["train"].add(info.filename.split("/")[-1])
with open(path.join(args.output_dir, "tiny_imagenet_train_images.txt"), "w+") as f:
f.write("\n".join(images["train"]))
with open(path.join(args.output_dir, "tiny_imagenet_val_images.txt"), "w+") as f:
f.write("\n".join(images["val"]))
print(f"Found {len(images['train'])} training images and {len(images['val'])} validation images")
classes = {img_name.split("_")[0] for img_name in images["train"]}
classes = sorted(list(classes), key=lambda x: int(x[1:]))
assert len(classes) == 200, f"Expected 200 classes, found {len(classes)}"
assert len(images["train"]) == len(classes) * 500, f"Expected 100000 training images, found {len(images['train'])}"
assert len(images["val"]) == len(classes) * 50, f"Expected 10000 validation images, found {len(images['val'])}"
with open(path.join(args.output_dir, "tiny_imagenet_classes.txt"), "w+") as f:
f.write("\n".join(classes))
# copy over the relevant images
for split in ["train", "val"]:
ipc = 500 if split == "train" else 50
part = "foregrounds_WEBP"
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
for synset in classes:
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
tqdm.write(
f"skip {split} > {part} > {synset} with"
f" {len(listdir(path.join(args.output_dir, split, part, synset)))} ims"
)
pbar.update(ipc)
continue
for img in listdir(path.join(args.in_segment_dir, split, part, synset)):
orig_name = (
img.split(".")[0] + ".JPEG"
if split == "train"
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
)
if orig_name in images[split]:
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, img),
path.join(args.output_dir, split, part, synset, img),
)
pbar.update(1)
while len(listdir(path.join(args.output_dir, split, part, synset))) < min(
ipc, len(listdir(path.join(args.in_segment_dir, split, part, synset)))
):
# copy over more random images
image_names = [
(
img,
(
img.split(".")[0] + ".JPEG"
if split == "train"
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
),
)
for img in listdir(path.join(args.in_segment_dir, split, part, synset))
]
image_names = [
img for img in image_names if img[1] not in listdir(path.join(args.output_dir, split, part, synset))
]
img = choice(image_names)[0]
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, img),
path.join(args.output_dir, split, part, synset, img),
)
pbar.update(1)
tqdm.write(f"Extra image: {orig_name} to {split}/{part}/{synset}")
# copy over the background images corresponding to those foregrounds
part = "backgrounds_JPEG"
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
for synset in classes:
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
tqdm.write(f"skip {split} > {part} > {synset}")
pbar.update(ipc)
continue
for img in listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset)):
bg_name = img.replace(".WEBP", ".JPEG")
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, bg_name),
path.join(args.output_dir, split, part, synset, bg_name),
)
pbar.update(1)
assert len(listdir(path.join(args.output_dir, split, part, synset))) == len(
listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset))
), (
f"Expected {len(listdir(path.join(args.output_dir, split, 'foregrounds_WEBP', synset)))} background"
f" images, found {len(listdir(path.join(args.output_dir, split, part, synset)))}"
)
# write the original dataset to datadings
for part in ["train", "val"]:
reader = MsgpackReader(path.join(args.imagenet_path, f"{part}.msgpack"))
with FileWriter(path.join(args.output_dir, f"TinyIN_{part}.msgpack")) as writer:
for data in tqdm(reader, desc=f"Writing {part} to datadings"):
key = data["key"].split("/")[-1]
allowed_synsets = [key.split("_")[0]] if part == "train" else classes
if part == "train" and allowed_synsets[0] not in classes:
continue
keep_image = False
label_synset = None
for synset in allowed_synsets:
for img in listdir(path.join(args.output_dir, part, "foregrounds_WEBP", synset)):
if img.split(".")[0] == key.split(".")[0]:
keep_image = True
label_synset = synset
break
if not keep_image:
continue
data["label"] = classes.index(label_synset)
writer.write(data)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,48 @@
import argparse
import os
import json
import torch
parser = argparse.ArgumentParser("Script to convert ImageNet trained models to ImageNet-9")
parser.add_argument("-m", "--model", type=str, required=True, help="Model weights (.pt file).")
parser.add_argument(
"--in_to_in9", type=str, default="/ds-sds/images/ImageNet-9/in_to_in9.json", help="Path to in_to_in9.json"
)
args = parser.parse_args()
checkpoint = torch.load(args.model, map_location="cpu")
model_state = checkpoint["model_state"]
head_keys = [k for k in model_state.keys() if ".head." in k or ".fc." in k]
print("weights that will be modified:", head_keys)
assert len(head_keys) > 0, "no head keys found :("
with open(args.in_to_in9, "r") as f:
in_to_in9_classes = json.load(f)
print(f"{len([k for k, v in in_to_in9_classes.items() if v == -1])} classes get mapped to -1")
print("map", len(in_to_in9_classes), " classes to", set(in_to_in9_classes.values()))
print("Building conversion matrix")
conversion_matrix = torch.zeros((9, 1000))
for in_idx, in9_idx in in_to_in9_classes.items():
if in9_idx == -1:
continue
in_idx = int(in_idx)
conversion_matrix[in9_idx, in_idx] = 1
print(f"Conversion matrix ({conversion_matrix.shape}) has {int(conversion_matrix.sum().item())} non-zero values")
for head_key in head_keys:
print(f"converting {head_key} of shape {model_state[head_key].shape}", end=" ")
model_state[head_key] = conversion_matrix @ model_state[head_key]
print(f"\tto {model_state[head_key].shape}")
checkpoint["model_state"] = model_state
checkpoint["args"]["n_classes"] = 9
save_folder = os.path.dirname(args.model)
orig_model_name = args.model.split(os.sep)[-1]
new_model_name = ".".join(orig_model_name.split(".")[:-1]) + "_to_in9." + orig_model_name.split(".")[-1]
print(f"saving model as {new_model_name} in {save_folder}")
torch.save(checkpoint, os.path.join(save_folder, new_model_name))

View File

@@ -0,0 +1,200 @@
n07695742: pretzel
n03902125: pay-phone, pay-station
n03980874: poncho
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
n02730930: apron
n02699494: altar
n03201208: dining table, board
n02056570: king penguin, Aptenodytes patagonica
n04099969: rocking chair, rocker
n04366367: suspension bridge
n04067472: reel
n02808440: bathtub, bathing tub, bath, tub
n04540053: volleyball
n02403003: ox
n03100240: convertible
n04562935: water tower
n02788148: bannister, banister, balustrade, balusters, handrail
n02988304: CD player
n02423022: gazelle
n03637318: lampshade, lamp shade
n01774384: black widow, Latrodectus mactans
n01768244: trilobite
n07614500: ice cream, icecream
n04254777: sock
n02085620: Chihuahua
n01443537: goldfish, Carassius auratus
n01629819: European fire salamander, Salamandra salamandra
n02099601: golden retriever
n02321529: sea cucumber, holothurian
n03837869: obelisk
n02002724: black stork, Ciconia nigra
n02841315: binoculars, field glasses, opera glasses
n04560804: water jug
n02364673: guinea pig, Cavia cobaya
n03706229: magnetic compass
n09256479: coral reef
n09332890: lakeside, lakeshore
n03544143: hourglass
n02124075: Egyptian cat
n02948072: candle, taper, wax light
n01950731: sea slug, nudibranch
n02791270: barbershop
n03179701: desk
n02190166: fly
n04275548: spider web, spider's web
n04417672: thatch, thatched roof
n03930313: picket fence, paling
n02236044: mantis, mantid
n03976657: pole
n01774750: tarantula
n04376876: syringe
n04133789: sandal
n02099712: Labrador retriever
n04532670: viaduct
n04487081: trolleybus, trolley coach, trackless trolley
n09428293: seashore, coast, seacoast, sea-coast
n03160309: dam, dike, dyke
n03250847: drumstick
n02843684: birdhouse
n07768694: pomegranate
n03670208: limousine, limo
n03085013: computer keyboard, keypad
n02892201: brass, memorial tablet, plaque
n02233338: cockroach, roach
n03649909: lawn mower, mower
n03388043: fountain
n02917067: bullet train, bullet
n02486410: baboon
n04596742: wok
n03255030: dumbbell
n03937543: pill bottle
n02113799: standard poodle
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
n02906734: broom
n07920052: espresso
n01698640: American alligator, Alligator mississipiensis
n02123394: Persian cat
n03424325: gasmask, respirator, gas helmet
n02129165: lion, king of beasts, Panthera leo
n04008634: projectile, missile
n03042490: cliff dwelling
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
n02815834: beaker
n02395406: hog, pig, grunter, squealer, Sus scrofa
n01784675: centipede
n03126707: crane
n04399382: teddy, teddy bear
n07875152: potpie
n03733131: maypole
n02802426: basketball
n03891332: parking meter
n01910747: jellyfish
n03838899: oboe, hautboy, hautbois
n03770439: miniskirt, mini
n02281406: sulphur butterfly, sulfur butterfly
n03970156: plunger, plumber's helper
n09246464: cliff, drop, drop-off
n02206856: bee
n02074367: dugong, Dugong dugon
n03584254: iPod
n04179913: sewing machine
n04328186: stopwatch, stop watch
n07583066: guacamole
n01917289: brain coral
n03447447: gondola
n02823428: beer bottle
n03854065: organ, pipe organ
n02793495: barn
n04285008: sports car, sport car
n02231487: walking stick, walkingstick, stick insect
n04465501: tractor
n02814860: beacon, lighthouse, beacon light, pharos
n02883205: bow tie, bow-tie, bowtie
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
n04149813: scoreboard
n04023962: punching bag, punch bag, punching ball, punchball
n02226429: grasshopper, hopper
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
n02669723: academic gown, academic robe, judge's robe
n04486054: triumphal arch
n04070727: refrigerator, icebox
n03444034: go-kart
n02666196: abacus
n01945685: slug
n04251144: snorkel
n03617480: kimono
n03599486: jinrikisha, ricksha, rickshaw
n02437312: Arabian camel, dromedary, Camelus dromedarius
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
n04118538: rugby ball
n01770393: scorpion
n04356056: sunglasses, dark glasses, shades
n03804744: nail
n02132136: brown bear, bruin, Ursus arctos
n03400231: frying pan, frypan, skillet
n03983396: pop bottle, soda bottle
n07734744: mushroom
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
n02410509: bison
n03404251: fur coat
n04456115: torch
n02123045: tabby, tabby cat
n03026506: Christmas stocking
n07715103: cauliflower
n04398044: teapot
n02927161: butcher shop, meat market
n07749582: lemon
n07615774: ice lolly, lolly, lollipop, popsicle
n02795169: barrel, cask
n04532106: vestment
n02837789: bikini, two-piece
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
n04265275: space heater
n02481823: chimpanzee, chimp, Pan troglodytes
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
n06596364: comic book
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
n02504458: African elephant, Loxodonta africana
n03014705: chest
n01944390: snail
n04146614: school bus
n01641577: bullfrog, Rana catesbeiana
n07720875: bell pepper
n02999410: chain
n01855672: goose
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
n07753592: banana
n07871810: meat loaf, meatloaf
n04501370: turnstile
n04311004: steel arch bridge
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
n04074963: remote control, remote
n03662601: lifeboat
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
n03089624: confectionery, confectionary, candy store
n04259630: sombrero
n03393912: freight car
n04597913: wooden spoon
n07711569: mashed potato
n03355925: flagpole, flagstaff
n02963159: cardigan
n07579787: plate
n02950826: cannon
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
n02094433: Yorkshire terrier
n02909870: bucket, pail
n02058221: albatross, mollymawk
n01742172: boa constrictor, Constrictor constrictor
n09193705: alp
n04371430: swimming trunks, bathing trunks
n07747607: orange
n03814639: neck brace
n04507155: umbrella
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
n03763968: military uniform
n07873807: pizza, pizza pie
n03992509: potter's wheel
n03796401: moving van
n12267677: acorn

View File

@@ -0,0 +1,73 @@
# Repeat Augment sampler taken from DeiT: https://github.com/facebookresearch/deit/blob/main/samplers.py
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import math
import torch
import torch.distributed as dist
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU)
Heavily based on torch.utils.data.DistributedSampler
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if num_repeats < 1:
raise ValueError("num_repeats should be greater than 0")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_repeats = num_repeats
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g)
else:
indices = torch.arange(start=0, end=len(self.dataset))
# add extra samples to make it evenly divisible
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
padding_size: int = self.total_size - len(indices)
if padding_size > 0:
indices += indices[:padding_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[: self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch
def __str__(self) -> str:
return (
f"{type(self).__name__}(num_replicas: {self.num_replicas}, rank: {self.rank}, num_repeats:"
f" {self.num_repeats}, epoch: {self.epoch}, num_samples: {self.num_samples}, total_size: {self.total_size},"
f" num_selected_samples: {self.num_selected_samples}, shuffle: {self.shuffle})"
)

View File

@@ -0,0 +1,381 @@
import argparse
import json
import os
from copy import copy
from loguru import logger
from nltk.corpus import wordnet as wn
class bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def _lemmas_str(synset):
return ", ".join([lemma.name() for lemma in synset.lemmas()])
class WNEntry:
def __init__(
self,
name: str,
id: int,
lemmas: str,
parent_id: int,
depth: int = None,
in_image_net: bool = False,
child_ids: list = None,
in_main_tree: bool = True,
_n_images: int = 0,
_description: str = None,
_name: str = None,
_pruned: bool = False,
):
self.name = name
self.id = id
self.lemmas = lemmas
self.parent_id = parent_id
self.depth = depth
self.in_image_net = in_image_net
self.child_ids = child_ids
self.in_main_tree = in_main_tree
self._n_images = _n_images
self._description = _description
self._name = _name
self._pruned = _pruned
def __str__(self, tree=None, accumulate=True):
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}" if self.in_image_net else f"{bcolors.FAIL}-{bcolors.ENDC}"
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
if self.child_ids is None or tree is None:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
else:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
["\n ".join(tree.nodes[child_id].__str__(tree).split("\n")) for child_id in self.child_ids]
)
def tree_diff(self, tree_1, tree_2):
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
else:
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
n_ims = (
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
)
if self.child_ids is None:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
)
def prune(self, tree):
if self._pruned:
return
if self.child_ids is not None:
for child_id in self.child_ids:
tree[child_id].prune(tree)
self._pruned = True
parent_node = tree.nodes[self.parent_id]
try:
parent_node.child_ids.remove(self.id)
except ValueError as e:
print(
f"Error removing {self.name} from"
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
)
while parent_node._pruned:
parent_node = tree.nodes[parent_node.parent_id]
parent_node._n_images += self._n_images
self._n_images = 0
@property
def description(self):
if not self._description:
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
return self._description
@property
def print_name(self):
return self.name.split(".")[0]
def get_branch(self, tree=None):
if self.parent_id is None or tree is None:
return self.print_name
parent = tree.nodes[self.parent_id]
return parent.get_branch(tree) + " > " + self.print_name
def get_branch_list(self, tree):
if self.parent_id is None:
return [self]
parent = tree.nodes[self.parent_id]
return parent.get_branch_list(tree) + [self]
def to_dict(self):
return {
"name": self.name,
"id": self.id,
"lemmas": self.lemmas,
"parent_id": self.parent_id,
"depth": self.depth,
"in_image_net": self.in_image_net,
"child_ids": self.child_ids,
"in_main_tree": self.in_main_tree,
"_n_images": self._n_images,
"_description": self._description,
"_name": self._name,
"_pruned": self._pruned,
}
@classmethod
def from_dict(cls, d):
return cls(**d)
def n_images(self, tree=None):
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
return self._n_images
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
def n_children(self, tree=None):
if self.child_ids is None:
return 0
if tree is None or len(self.child_ids) == 0:
return len(self.child_ids)
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
def get_examples(self, tree, n_examples=3):
if self.child_ids is None or len(self.child_ids) == 0:
return ""
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
max_images = max(child_images.values())
if max_images == 0:
# go on number of child nodes
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
# sorted childids by number of images
top_children = [
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
]
top_children = top_children[: min(n_examples, len(top_children))]
return ", ".join(
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
)
class WNTree:
def __init__(self, root=1740, nodes=None):
if isinstance(root, int):
root_id = root
root_synset = wn.synset_from_pos_and_offset("n", root)
root_node = WNEntry(
root_synset.name(),
root_id,
_lemmas_str(root_synset),
parent_id=None,
depth=0,
)
else:
assert isinstance(root, WNEntry)
root_id = root.id
root_node = root
self.root = root_node
self.nodes = {root_id: self.root} if nodes is None else nodes
self.parentless = []
self.label_index = None
self.pruned = set()
def to_dict(self):
return {
"root": self.root.to_dict(),
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
"parentless": self.parentless,
"pruned": list(self.pruned),
}
def prune(self, min_images):
pruned_nodes = set()
# prune all nodes that have fewer than min_images below them
for node_id, node in self.nodes.items():
if node.n_images(self) < min_images:
pruned_nodes.add(node_id)
node.prune(self)
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
node_stack = [self.root]
node_idx = 0
while node_idx < len(node_stack):
node = node_stack[node_idx]
if node.child_ids is not None:
for child_id in node.child_ids:
child = self.nodes[child_id]
node_stack.append(child)
node_idx += 1
# now prune the stack from the bottom up
for node in node_stack[::-1]:
# only look at images of that class, not of additional children
if node.n_images() < min_images:
pruned_nodes.add(node.id)
node.prune(self)
self.pruned = pruned_nodes
return pruned_nodes
@classmethod
def from_dict(cls, d):
tree = cls()
tree.root = WNEntry.from_dict(d["root"])
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
tree.parentless = d["parentless"]
if "pruned" in d:
tree.pruned = set(d["pruned"])
return tree
def add_node(self, node_id, in_in=True):
if node_id in self.nodes:
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
return
synset = wn.synset_from_pos_and_offset("n", node_id)
# print(f"adding node {synset.name()} with id {node_id}")
hypernyms = synset.hypernyms()
if len(hypernyms) == 0:
parent_id = None
self.parentless.append(node_id)
main_tree = False
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
else:
parent_id = synset.hypernyms()[0].offset()
if parent_id not in self.nodes:
self.add_node(parent_id, in_in=False)
parent = self.nodes[parent_id]
if parent.child_ids is None:
parent.child_ids = []
parent.child_ids.append(node_id)
main_tree = parent.in_main_tree
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
node = WNEntry(
synset.name(),
node_id,
_lemmas_str(synset),
parent_id=parent_id,
in_image_net=in_in,
depth=depth,
in_main_tree=main_tree,
)
self.nodes[node_id] = node
def __len__(self):
return len(self.nodes)
def image_net_len(self, only_main_tree=False):
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def max_depth(self, only_main_tree=False):
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def __str__(self):
return (
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self)}\nParentless:\n"
+ "\n".join([self.nodes[node_id].__str__(tree=self) for node_id in self.parentless])
)
def save(self, path):
with open(path, "w") as f:
json.dump(self.to_dict(), f)
@classmethod
def load(cls, path):
with open(path, "r") as f:
tree_dict = json.load(f)
return cls.from_dict(tree_dict)
def subtree(self, node_id):
if node_id not in self.nodes:
return None
node_queue = [self.nodes[node_id]]
subtree_ids = set()
while len(node_queue) > 0:
node = node_queue.pop(0)
subtree_ids.add(node.id)
if node.child_ids is not None:
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
subtree_root = subtree_nodes[node_id]
subtree_root.parent_id = None
depth_diff = subtree_root.depth
for node in subtree_nodes.values():
node.depth -= depth_diff
return WNTree(root=subtree_root, nodes=subtree_nodes)
def _make_label_index(self, include_merged=False):
self.label_index = sorted(
[
node_id
for node_id, node in self.nodes.items()
if node.n_images(self if include_merged else None) > 0 and not node._pruned
]
)
def get_label(self, node_id):
if self.label_index is None:
self._make_label_index()
while self.nodes[node_id]._pruned:
node_id = self.nodes[node_id].parent_id
return self.label_index.index(node_id)
def n_labels(self):
if self.label_index is None:
self._make_label_index()
return len(self.label_index)
def __contains__(self, item):
if isinstance(item, str):
if item[0] == "n":
item = int(item[1:])
else:
return False
if isinstance(item, int):
return item in self.nodes
if isinstance(item, WNEntry):
return item.id in self.nodes
return False
def __getitem__(self, item):
if isinstance(item, str) and item[0].startswith("n"):
try:
item = int(item[1:])
except ValueError:
pass
if isinstance(item, str) and ".n." in item:
for node in self.nodes.values():
if item == node.name:
return node
raise KeyError(f"Item {item} not found in tree")
if isinstance(item, int):
return self.nodes[item]
if isinstance(item, WNEntry):
return self.nodes[item.id]
raise KeyError(f"Item {item} not found in tree")

View File

@@ -0,0 +1,834 @@
import json
import os
import sys
from datetime import datetime
from functools import partial
from math import isfinite
from time import time
import torch
import torch.nn as nn
import torch.optim as optim
from loguru import logger
from timm.data import Mixup
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm.auto import tqdm
from metrics import calculate_metrics, per_class_counts
from utils import (
NoScaler,
ScalerGradNormReturn,
SchedulerArgs,
log_formatter,
save_model_state,
)
try:
from apex.optimizers import FusedLAMB # noqa: F401
apex_available = True
except ImportError:
logger.error("Nvidia apex not available")
apex_available = False
try:
from lion_pytorch import Lion
lion_available = True
except ImportError:
logger.error("Lion not available")
lion_available = False
WANDB_AVAILABLE = False
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
logger.error("wandb not available")
def wandb_available(turn_off=False):
"""If wandb is available.
Args:
turn_off (bool, optional): set wandb to be unavailble manually.
Returns:
bool: wandb is available
"""
global WANDB_AVAILABLE
if turn_off:
WANDB_AVAILABLE = False
return WANDB_AVAILABLE
tqdm = partial(tqdm, leave=True, position=0) # noqa: F405
def setup_tracking_and_logging(args, rank, append_model_path=None, log_wandb=True):
"""Set up logging and tracking for an experiment.
Args:
args (DotDict): Parsed command-line arguments
rank (int): The rank of the current process
append_model_path (str, optional): Path of an existing model, by default None
log_wandb (bool, optional): Whether to log to wandb, by default True
Returns:
str: folder, where all the run data is saved.
Raises:
AssertionError: If `dataset` or `model` is `None`.
Notes:
This function sets up logger to stdout and file, as well as MLflow tracking for an experiment.
For wandb logger, provide .wandb.apikey in the current directory.
"""
dataset, model, epochs = args.dataset.replace(os.sep, "_").lower(), args.model.replace(os.sep, "_"), args.epochs
_base_folder = (
os.path.join(args.results_folder, args.experiment_name, args.task.replace("-", ""), dataset)
if args.out_dir is None
else args.out_dir
)
run_folder = os.path.join(
_base_folder,
f"{args.run_name.replace(os.sep, '_')}_{model}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}",
)
assert dataset is not None and model is not None
if os.name == "nt":
run_folder = run_folder.replace("@", "_").replace(" ", "_").replace(":", ".")
if append_model_path is not None:
run_folder = os.path.dirname(append_model_path)
if "run_name" not in args or args.run_name is None:
args.run_name = run_folder.split(os.sep)[-1].split("_")[0]
elif args.distributed:
obj_list = [None]
if rank == 0:
obj_list[0] = run_folder
dist.broadcast_object_list(obj_list, src=0)
run_folder = obj_list[0]
if rank == 0:
os.makedirs(run_folder, exist_ok=True)
dist.barrier()
elif rank == 0:
os.makedirs(run_folder, exist_ok=True)
assert "%" not in args.run_name, f"found '%' in run_name '{args.run_name}'. This messes with string formatting..."
if args.debug:
args.log_level = "debug"
# logger to stdout & file
log_name = args.task.replace("-", "")
if args.task not in ["pre-train", "fine-tune", "fine-tune-head"]:
log_name += f"_{dataset}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}"
log_file = os.path.join(run_folder, f"{log_name}.log")
logger.remove()
logger.configure(extra=dict(run_name=args.run_name, rank=rank, world_size=args.world_size))
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.info(f"Run folder '{run_folder}'")
if rank == 0:
logger.info(f"{args.task.replace('-', '').capitalize()} {model} on {dataset} for {epochs} epochs")
global WANDB_AVAILABLE
WANDB_AVAILABLE = WANDB_AVAILABLE and log_wandb and os.path.isfile(".wandb.apikey") and args.wandb
if WANDB_AVAILABLE:
with open(".wandb.apikey", "r") as f:
__wandb_api_key = f.read().strip()
wandb.login(key=__wandb_api_key)
if args.wandb_run_id is not None:
wandb_args = dict(project=args.experiment_name, resume="must", id=args.wandb_run_id)
else:
wandb_args = dict(
project=args.experiment_name,
name=args.run_name.replace("_", "-").replace(" ", "-"),
config={"logfile": log_file, **dict(args)},
job_type=args.task,
tags=[model, dataset],
resume="allow",
id=args.wandb_run_id,
)
wandb.init(**wandb_args)
args["wandb_run_id"] = wandb.run.id
if rank == 0:
logger.info(f"wandb run initialized with id {args['wandb_run_id']}.")
else:
logger.info(
f"Not using wandb. (args.wandb={args.wandb}, .wandb.apikey exists={os.path.isfile('.wandb.apikey')},"
f" function declaration log_wandb={log_wandb})"
)
if args.distributed:
dist.barrier()
if args.debug:
torch.autograd.set_detect_anomaly(True)
logger.warning("torch.autograd anomaly detection enabled. Will slow down model.")
return run_folder
def setup_model_optim_sched_scaler(model, device, epochs, args, head_only=False):
"""Set up model, optimizer, and scheduler with automatic mixed precision (amp) and distributed data parallel (DDP).
Args:
model (nn.Module): the loaded model
device (torch.device): the current device
epochs (int): total number of epochs to learn for (for scheduler)
args: further arguments
head_only (bool, optional): train only the linear head. Default: False
Returns:
tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, ScalerGradNormReturn]: model, optimizer, scheduler, scaler
"""
model = model.to(device)
if head_only:
for param in model.parameters():
param.requires_grad = False
for param in model.head.parameters():
param.requires_grad = True
for name, param in model.named_parameters():
if "head" in name:
param.requires_grad = True
else:
param.requires_grad = False
params = model.head.parameters()
else:
params = model # model.named_parameters() use model itself for now and let timm do the work...
if args.opt == "lion" and not lion_available:
args.opt = "fusedlamb"
logger.warning("Falling back from lion to fusedlamb")
if args.opt == "fusedlamb" and not apex_available:
args.opt = "adamw"
logger.warning("Falling back from fusedlamb to adamw")
if args.opt == "lion":
optimizer = Lion(params, lr=args["lr"], weight_decay=args["weight_decay"])
else:
optimizer = create_optimizer(args, params)
scaler = ScalerGradNormReturn() if args.amp else NoScaler()
# if args.model_ema:
# # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
# ema_model = ModelEma(model, decay=args.model_ema_decay, resume='')
if args.distributed:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[device])
if args.compile_model:
model = torch.compile(model)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer,
# lr_lambda=scheduler_function_factory(**args))
sched_args = SchedulerArgs(args.sched, args.epochs, args.min_lr, args.warmup_lr, args.warmup_epochs)
scheduler, _ = create_scheduler(sched_args, optimizer)
return model, optimizer, scheduler, scaler
def setup_criteria_mixup(args, dataset=None, **criterion_kwargs):
"""Set up further objects that are needed for training.
Args:
args: arguments
dataset (torch.data.Dataset, optional): dataset that implements images_per_class, for class weights (Default value = None)
criterion_kwargs: further arguments for the criterion
**criterion_kwargs:
Returns:
tuple[nn.Module, nn.Module, Mixup]: criterion, val_criterion, mixup
"""
weight = None
if args.loss_weight != "none":
if dataset is not None and hasattr(dataset, "images_per_class"):
ipc = dataset.images_per_class
total_ims = sum(ipc)
if args.loss_weight == "linear":
weight = torch.tensor([total_ims / (ims * args.n_classes) for ims in ipc])
elif args.loss_weight == "log":
p_c = torch.tensor([ims / total_ims for ims in ipc])
log_p_c = torch.where(p_c > 0, p_c.log(), torch.zeros_like(p_c))
entr = -(p_c * log_p_c).sum()
weight = -log_p_c / entr
elif args.loss_weight == "sqrt":
p_c = torch.tensor([ims / total_ims for ims in ipc])
weight = 1 / (p_c.sqrt() * p_c.sqrt().sum())
else:
logger.warning("Could not find images_per_class in dataset. Using uniform weights.")
if args.aug_cutmix or args.multi_label:
# criterion = SoftTargetCrossEntropy()
if args.ignore_index >= 0:
if weight is None:
weight = torch.ones(args.n_classes)
weight[args.ignore_index] = 0
if args.multi_label:
if args.loss == "ce":
criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
val_criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
else:
raise NotImplementedError(
f"Only BCEWithLogitsLoss (ce) is implemented for multi-label classification, not {args.loss}."
)
else:
if args.loss == "ce":
loss_cls = nn.CrossEntropyLoss
elif args.loss == "baikal":
loss_cls = BaikalLoss
else:
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
criterion = loss_cls(weight=weight, **criterion_kwargs)
val_criterion = loss_cls(weight=weight, **criterion_kwargs)
else:
if args.loss == "ce":
loss_cls = nn.CrossEntropyLoss
elif args.loss == "baikal":
loss_cls = BaikalLoss
else:
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
criterion = loss_cls(
label_smoothing=args.label_smoothing,
ignore_index=args.ignore_index if weight is None else -100,
weight=weight,
**criterion_kwargs,
) # LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
# criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
val_criterion = loss_cls(
label_smoothing=args.label_smoothing,
ignore_index=args.ignore_index if weight is None else -100,
weight=weight,
**criterion_kwargs,
) # LabelSmoothingCrossEntropy(smoothing=0.)
mixup_kwargs = dict(
mixup_alpha=args.aug_mixup_alpha,
cutmix_alpha=args.aug_cutmix_alpha,
label_smoothing=args.label_smoothing,
num_classes=args.n_classes,
)
mixup = Mixup(**mixup_kwargs) if abs(args.aug_cutmix_alpha) + abs(args.aug_mixup_alpha) > 0.0 else None
return criterion, val_criterion, mixup
def _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
model_folder,
scaler,
do_metrics_calculation=True,
start_epoch=0,
show_tqdm=True,
topk=(1, 5),
acc_dict_key=None,
train_dali_server=None,
val_dali_server=None,
):
"""Train the model.
Args:
model:
train_loader:
optimizer:
rank:
epochs:
device:
mixup:
criterion:
world_size:
scheduler:
args:
val_loader:
val_criterion:
model_folder:
scaler:
do_metrics_calculation: (Default value = True)
start_epoch: (Default value = 0)
show_tqdm: (Default value = True)
topk: (Default value = (1)
5):
acc_dict_key: (Default value = None)
Returns:
dict: evaluation metrics at the end of training
"""
if acc_dict_key is None:
acc_dict_key = "acc{}"
training_start = time()
topk = tuple(k for k in topk if k <= args.n_classes)
time_spend_training = time_spend_validating = 0
current_best_acc = 0.0
if rank == 0:
logger.info(f"Dataloader has {len(train_loader)} batches")
logger.debug("Starting training with the following settings:")
logger.debug(f"criterion: {criterion}")
logger.debug(f"train_loader: {train_loader}, sampler: {train_loader.sampler}")
logger.debug(f"dataset: {train_loader.dataset}")
logger.debug(f"optimizer: {optimizer}")
logger.debug(f"device: {device}")
logger.debug(f"start epoch: {start_epoch}, epochs: {epochs}")
logger.debug(f"scaler: {scaler}")
logger.debug(f"max_grad_norm: {args.max_grad_norm}")
# logger.debug(f"model_ema:\n{model_ema}\n{model_ema.decay}\n{model_ema.device}")
if mixup:
logger.debug(
f"mixup: {mixup}; mixup_alpha: {mixup.mixup_alpha}, cutmix_alpha: {mixup.cutmix_alpha},"
f" cutmix_minmax: {mixup.cutmix_minmax}, prob: {mixup.mix_prob}, switch_prob: {mixup.switch_prob},"
f" label_smoothing: {mixup.label_smoothing}, num_classes: {mixup.num_classes}, correct_lam:"
f" {mixup.correct_lam}, mixup_enabled: {mixup.mixup_enabled}"
)
else:
logger.debug(f"mixup: {mixup}")
for epoch in range(start_epoch, epochs):
with logger.contextualize(epoch=str(epoch + 1)):
if args.distributed:
train_loader.sampler.set_epoch(epoch)
set_ep_func = getattr(train_loader.dataset, "set_epoch", None)
if callable(set_ep_func):
train_loader.dataset.set_epoch(epoch)
val_loader.dataset.set_epoch(epoch)
if train_dali_server:
train_dali_server.start_thread()
logger.info("started train dali server")
epoch_time, epoch_stats = _train_one_epoch(
model,
train_loader,
optimizer,
rank,
epoch,
device,
mixup,
criterion,
scheduler,
scaler,
args,
topk,
"train/" + acc_dict_key,
show_tqdm,
)
time_spend_training += epoch_time
if train_dali_server:
train_dali_server.stop_thread()
val_time, val_stats = _evaluate(
model,
val_loader,
epoch,
rank,
device,
val_criterion,
args,
topk,
"val/" + acc_dict_key,
dali_server=val_dali_server,
)
time_spend_validating += val_time
if rank == 0:
logger.info(f"total_time={time() - training_start}s")
if rank == 0:
top1_val_acc = val_stats["val/" + acc_dict_key.format(1)]
# print metadata for grafana
metadata = {
"epoch": epoch + 1,
"progress": (epoch + 1) / args.epochs,
**val_stats,
**epoch_stats,
}
# filter out Nan and infinity values
metadata = {k: v for k, v in metadata.items() if isfinite(v)}
print(json.dumps(metadata), flush=True)
logger.debug(f"printed metadata: {json.dumps(metadata)}")
if WANDB_AVAILABLE:
wandb.log(metadata, step=epoch + 1)
# saving current state
if top1_val_acc > current_best_acc or (epoch + 1) % args.save_epochs == 0:
reason = "top" if top1_val_acc > current_best_acc else "" # min(...) will be the top-1 accuracy
if reason == "top":
current_best_acc = top1_val_acc
logger.info(f"found a new best model with acc: {current_best_acc}")
kwargs = dict(
model_state=model.state_dict(),
stats=metadata,
optimizer_state=optimizer.state_dict(),
additional_reason=reason,
regular_save=(epoch + 1) % args.save_epochs == 0,
)
if scheduler:
kwargs["scheduler_state"] = scheduler.state_dict()
save_model_state(
model_folder, epoch + 1, args, **kwargs, max_interm_ep_states=args.keep_interm_states
)
if rank == 0:
end_time = time()
logger.info(
f"training done: total time={end_time - training_start}, "
f"time spend training={time_spend_training}, "
f"time spend validating={time_spend_validating}"
)
results = {**val_stats, **epoch_stats, f"val/best_{acc_dict_key.format(1)}": current_best_acc}
if rank == 0:
save_model_state(
model_folder,
epoch + 1,
args,
model_state=model.state_dict(),
stats=results,
additional_reason="final",
regular_save=False,
max_interm_ep_states=args.keep_interm_states,
)
if do_metrics_calculation:
# Calculate efficiency metrics
inp = next(iter(train_loader))[0].to(device)
metrics = calculate_metrics(
args,
model,
rank=rank,
input=inp,
device=device,
did_training=True,
all_metrics=False,
world_size=world_size,
key_start="eval/",
)
if rank == 0:
logger.info(f"Efficiency metrics: {json.dumps(metrics)}")
return results
def _mask_preds(preds, cls_masks, mask_val=-100):
"""Mask the predictions by the mask.
Args:
preds: model predictions
cls_masks: class masks
mask_val: (Default value = -100)
Returns:
torch.Tensor: masked predictions
"""
if cls_masks is None:
return preds
return torch.where(cls_masks.bool(), mask_val, preds)
def _evaluate(
model,
val_loader,
epoch,
rank,
device,
val_criterion,
args,
topk=(1, 5),
acc_dict_key=None,
dali_server=None,
):
"""Evaluate the model.
Args:
model (nn.Module): the model to evaluate
val_loader (DataLoader): loader for evaluation data
epoch (int): the current epoch (for logger & tracking)
rank (int): this processes rank (don't log n times)
device (torch.device): device to evaluate on
val_criterion (nn.Module): validation loss
args (DotDict): further arguments
topk (tuple[int], optional, optional): top-k accuracy, by default (1, 5)
acc_dict_key (str, optional, optional): key for the accuracy dictionary, by default name of the performance metric. 'val_' will be prepended.
Returns:
tuple[float, float, dict, dict]: validation time, validation loss, validation accuracies, additional information
"""
if not acc_dict_key:
acc_dict_key = "acc{}"
if dali_server:
dali_server.start_thread()
topk = tuple(k for k in topk if k <= args.n_classes)
model.eval()
val_loss = 0
val_accs = {acc_dict_key.format(k): 0.0 for k in topk}
val_start = time()
n_iters = 0
iterator = (
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch + 1}")
if rank == 0 and args.tqdm
else val_loader
)
class_counts = torch.zeros(1 if isinstance(topk, int) else len(topk), args.n_classes, 2)
for batch_data in iterator:
xs, ys = batch_data[:2]
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
if args.debug:
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
xs, ys = xs.to(device, non_blocking=True), ys.to(device, non_blocking=True)
with torch.no_grad(), torch.amp.autocast("cuda", enabled=args.eval_amp):
preds = model(xs)
preds = _mask_preds(preds, cls_masks)
if args.multi_label:
# labels are float for BCELoss
ys = ys.float()
val_loss += val_criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys).item()
class_counts += per_class_counts(preds, ys, args.n_classes, topk=topk, ignore_index=args.ignore_index)
n_iters += 1
if args.distributed:
dist.barrier()
if dali_server:
dali_server.stop_thread()
val_end = time()
iterations = n_iters
if args.distributed:
gather_tensor = torch.Tensor([val_loss]).to(device)
dist.barrier()
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
gather_tensor = gather_tensor.tolist()
val_loss = gather_tensor[0]
class_counts = class_counts.to(device)
dist.all_reduce(class_counts, op=dist.ReduceOp.SUM)
class_counts = class_counts.cpu()
for i, k in enumerate(topk):
key = acc_dict_key.format(k)
mkey = key.replace("acc", "m-acc")
val_accs[key] = class_counts[i].sum(dim=0)[0].item() / class_counts[i].sum(dim=0).sum(dim=-1).item()
val_accs[mkey] = (class_counts[i, :, 0] / class_counts[i].sum(dim=-1)).mean().item()
val_accs["val/loss"] = val_loss
if rank == 0:
log_s = f"val/time={val_end - val_start}s"
for key, val in val_accs.items():
log_s += f", {key}={val:.4f}"
logger.info(log_s)
return val_end - val_start, val_accs
def _train_one_epoch(
model,
train_loader,
optimizer,
rank,
epoch,
device,
mixup,
criterion,
scheduler,
scaler,
args,
topk=(1, 5),
acc_dict_key=None,
show_tqdm=True,
):
"""Train the model for one epoch.
Args:
model:
train_loader:
optimizer:
rank:
epoch:
device:
mixup:
criterion:
scheduler:
scaler:
args:
topk: (Default value = (1, 5)
acc_dict_key: (Default value = None)
show_tqdm: (Default value = True)
Returns:
tuple[float, float, dict]: time spend in training, epoch loss, epoch accuracies
"""
if not acc_dict_key:
acc_dict_key = "acc{}"
model.train()
iterator = (
tqdm(train_loader, total=len(train_loader), desc=f"Training epoch {epoch + 1}")
if rank == 0 and show_tqdm
else train_loader
)
if not args.amp:
scaler = NoScaler()
epoch_loss = 0
epoch_accs = {}
epoch_start = time()
grad_norms = []
n_iters = 0
if hasattr(train_loader.dataset, "epoch"):
train_loader.dataset.epoch = epoch
for i, batch_data in enumerate(iterator):
xs, ys = batch_data[:2]
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
optimizer.zero_grad()
n_iters += 1
xs = xs.to(device, non_blocking=True)
ys = ys.to(device, non_blocking=True)
if args.debug and i == 0:
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
if mixup:
if args.multi_label:
xs, ys = mixup(xs, ys, cls_masks)
else:
xs, ys = mixup(xs, ys)
if args.debug and i == 0:
logger.debug(f"input x: {type(xs)}; {xs.shape}, y: {type(ys)}; {ys.shape}")
with torch.amp.autocast("cuda", enabled=args.amp):
preds = model(xs)
preds = _mask_preds(preds, cls_masks)
if args.multi_label:
# labels are float for BCELoss
ys = ys.float()
loss = criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys) + (
model.get_internal_loss() if hasattr(model, "get_internal_loss") else model.module.get_internal_loss()
)
if not isfinite(loss.item()):
logger.error(f"Got loss value {loss.item()}. Stopping training.")
logger.info(f"input has nan: {xs.isnan().any().item()}")
logger.info(f"target has nan: {ys.isnan().any().item()}")
logger.info(f"output has nan: {preds.isnan().any().item()}")
for name, param in model.named_parameters():
if param.isnan().any().item():
logger.error(f"parameter {name} has a nan value")
if len(grad_norms) > 0:
grad_norms = torch.Tensor(grad_norms)
logger.info(
f"Gradient norms until now: min={grad_norms.min().item()}, 20th"
f" %tile={torch.quantile(grad_norms, .2).item()}, mean={torch.mean(grad_norms)}, 80th"
f" %tile={torch.quantile(grad_norms, .8).item()}, max={grad_norms.max()}"
)
sys.exit(1)
iter_grad_norm = scaler(
loss,
optimizer,
parameters=model.parameters(),
clip_grad=args.max_grad_norm if args.max_grad_norm > 0.0 else None,
).cpu()
if args.gather_stats_during_training and isfinite(iter_grad_norm):
grad_norms.append(iter_grad_norm)
# if args.aug_cutmix:
# ys = ys.argmax(dim=-1) # for accuracy with CutMix, just use the argmax for both
#
epoch_loss += loss.item()
# accuracies = accuracy(preds, ys, topk=topk, dict_key=acc_dict_key, ignore_index=args.ignore_index)
# for key in accuracies:
# epoch_accs[key] += accuracies[key]
if args.distributed:
dist.barrier()
epoch_end = time()
iterations = n_iters
# epoch_accs = {key: val / iterations for key, val in epoch_accs.items()}
epoch_loss = epoch_loss / iterations
grad_norm_avrg = -1
inf_grads = iterations - len(grad_norms)
if len(grad_norms) > 0 and args.gather_stats_during_training:
grad_norm_max = max(grad_norms)
grad_norms = torch.Tensor(grad_norms)
grad_norm_20 = torch.quantile(grad_norms, 0.2).item()
grad_norm_80 = torch.quantile(grad_norms, 0.8).item()
grad_norm_avrg = torch.mean(grad_norms)
if args.distributed:
# grad norm is already synchronized
# gather_tensor = torch.Tensor([epoch_loss, *[epoch_accs[acc_dict_key.format(k)] for k in topk]]).to(device)
gather_tensor = torch.Tensor([epoch_loss]).to(device)
dist.barrier()
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
# gather_tensor = (gather_tensor / world_size).tolist()
epoch_loss = gather_tensor.item()
# for i, k in enumerate(topk):
# epoch_accs[acc_dict_key.format(k)] = gather_tensor[i + 1]
lr = optimizer.param_groups[0]["lr"]
epoch_accs["train/lr"] = lr
epoch_accs["train/loss"] = epoch_loss
if rank == 0:
if args.gather_stats_during_training:
print_s = f"train/time={epoch_end - epoch_start}s"
logger.info(print_s)
if len(grad_norms) > 0:
logger.info(
f"grad norm avrg={grad_norm_avrg}, grad norm max={grad_norm_max}, "
f"inf grad norm={inf_grads}, grad norm 20%={grad_norm_20}, grad norm 80%={grad_norm_80}"
)
else:
logger.warning(f"inf grad norm={inf_grads}")
logger.error("100% of update steps with infinite grad norms!")
else:
logger.info(f"train/time={epoch_end - epoch_start}s")
if scheduler:
if isinstance(scheduler, optim.lr_scheduler.LambdaLR):
scheduler.step()
else:
scheduler.step(epoch)
if args.gather_stats_during_training:
return epoch_end - epoch_start, epoch_accs
return epoch_end - epoch_start, {}

View File

@@ -0,0 +1,710 @@
"""Module to evaluate trained models."""
import os
from contextlib import nullcontext
from datetime import datetime
from math import sqrt
from time import time
import torch
from loguru import logger
from timm.loss import LabelSmoothingCrossEntropy
from timm.models.resnet import ResNet as TimmResNet
from torch import distributed as dist
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from tqdm import tqdm
from engine import (
_evaluate,
setup_criteria_mixup,
setup_model_optim_sched_scaler,
setup_tracking_and_logging,
wandb_available,
)
from load_dataset import prepare_dataset
from metrics import calculate_metrics
from models import load_pretrained
from utils import (
RepeatedDataset,
ddp_cleanup,
ddp_setup,
denormalize,
get_cpu_name,
grad_cam_reshape_transform,
prep_kwargs,
set_filter_warnings,
)
def evaluate_metrics(model, dataset, **kwargs):
"""Evaluate efficiency metrics for a given model.
Args:
model (str): path to model state .tar
dataset (str): name of the dataset to evaluate on
**kwargs: further arguments
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
rank = 0
args.compile_model = False
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
setup_tracking_and_logging(args, rank=rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None)
train_loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
old_args["eval_imsize"] = args.imsize
args.model = model_name = old_args.model
args.dataset = dataset
args.epochs = 5
model, optim, _, scaler = setup_model_optim_sched_scaler(model, device, epochs=10, args=args, head_only=False)
if rank == 0:
logger.info(
f"Evaluate metrics for model {model_name} on {dataset}. "
f"It was {old_args.task.replace('-','')}d on {old_args.dataset} for {save_state['epoch']} "
"epochs."
)
# logger.info(f"full set of arguments: {args}")
logger.info(f"full set of training arguments: {old_args}")
logger.info(f"full set of eval-metrics arguments: {args}")
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
metrics = calculate_metrics(
args, model, rank=rank, device=device, optim=optim, scaler=scaler, train_loader=train_loader, key_start="eval/"
)
if rank == 0:
logger.info(f"Metrics: {metrics}")
if wandb_available():
import wandb
wandb.log(metrics)
def evaluate(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model accuracy.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
args.dataset = dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
)
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
val_dataset, args, train=False
)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
if rank == 0:
logger.info(
f"Evaluate model {model_name} on {val_dataset}. "
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
if rank == 0:
logger.info("start evaluation")
logger.info(f"Run info at: '{run_folder}'")
if rank == 0:
val_time, val_stats = _evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
)
log_s = f"Evaluation done in {val_time}s"
for key, val in val_stats.items():
log_s += f", {key}={val:.4f}"
logger.info(log_s)
if wandb_available():
import wandb
wandb.log(val_stats)
else:
_evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
)
ddp_cleanup(args=args, rank=rank)
def evaluate_center_bias(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model accuracy in different nonants.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
if dataset is None:
dataset = val_dataset
assert dataset is not None, "Specify validation dataset (-valds) or dataset (-ds)."
args.dataset = dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
)
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
if rank == 0:
logger.info(
f"Evaluate model {model_name} on {val_dataset}. "
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
if rank == 0:
logger.info("start evaluation")
logger.info(f"Run info at: '{run_folder}'")
if rank == 0:
nonant_accs = []
for nonant in range(-1, 9):
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
val_loader.dataset.fg_in_nonant = nonant
logger.info(f"Evaluate nonant {nonant} for 5 rounds.")
round_accs = []
for _ in range(5):
val_time, val_stats = _evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
)
round_accs.append(val_stats["acc1"])
nonant_accs.append(sum(round_accs) / len(round_accs))
log_s = f"Evaluation done in {val_time}s: "
for nonant, val in enumerate(nonant_accs[1:]):
log_s += f", nonant {nonant}={val}% acc ({val / nonant_accs[0]} rel acc)"
center_bias_val = 1 - (
min([nonant_accs[1], nonant_accs[3], nonant_accs[7], nonant_accs[9]])
+ min([nonant_accs[2], nonant_accs[4], nonant_accs[6], nonant_accs[8]])
) / (2 * nonant_accs[5])
log_s += f", center_bias={center_bias_val:.4f}"
logger.info(log_s)
if wandb_available():
import wandb
wandb.log({f"eval_{args.val_dataset}/center_bias": center_bias_val})
else:
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
ddp_cleanup(args=args, rank=rank)
def evaluate_size_bias(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model accuracy for differently scaled foregrounds.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
if dataset is None:
dataset = val_dataset
assert val_dataset is not None and dataset is not None
args.dataset = dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path, log_wandb=False)
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
if rank == 0:
logger.info(
f"Evaluate model {model_name} on {val_dataset}. "
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
if rank == 0:
logger.info("start evaluation")
logger.info(f"Run info at: '{run_folder}'")
sizes = [0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 1.75, 2.0]
if rank == 0:
size_accs = []
val_times = 0
for size in sizes:
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
val_loader.dataset.size_fact = size
val_loader.dataset.fg_scale_jitter = 0.0
logger.info(f"Evaluate size factor {size} for 5 rounds.")
round_accs = []
for _ in range(5):
val_time, val_stats = _evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
)
round_accs.append(val_stats["acc1"])
val_times += val_time
size_accs.append(sum(round_accs) / len(round_accs))
log_s = f"Evaluation done in {val_times}s: "
for size, val in zip(sizes, size_accs):
log_s += f", rel_size {size}={val}% acc ({val / size_accs[sizes.index(1.0)]} rel acc)"
logger.info(log_s)
else:
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
ddp_cleanup(args=args, rank=rank)
def evaluate_attributions(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model attributions using captum.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
The `captum` package is required.
"""
from captum.attr import IntegratedGradients
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
assert val_dataset is not None, "Please set dataset (-ds) or validation dataset (-valds)"
args.dataset = val_dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = val_dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
)
assert "fornet" in val_dataset.lower(), "Only ForNet supported for attribution evaluation."
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
val_dataset, args, train=False
)
val_loader.dataset.return_fg_masks = True
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
# assert (
# args.imsize == old_args.imsize
# ), f"Model was trained on {old_args.imsize}x{old_args.imsize} images. Not {args.imsize}x{args.imsize}."
epoch = save_state["epoch"]
if rank == 0:
logger.info(
f"Evaluate attributions of model {model_name} on {dataset}. "
f"It was pretrained on {old_args.dataset} for {epoch} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
if args.new_log:
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
else:
logger.info(f"full set of attribution evaluation arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
if rank == 0:
logger.info(f"Run info at: '{run_folder}'")
iterator = (
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch}")
if rank == 0 and args.tqdm
else val_loader
)
if args.debug:
from matplotlib import pyplot as plt
eval_attn_importance = False
if isinstance(model, TimmResNet):
reshape_transform = None
target_layers = [model.layer4[-1]]
elif model_name.lower().startswith("vit-"):
reshape_transform = grad_cam_reshape_transform
target_layers = [model.blocks[-1].norm1]
eval_attn_importance = True
from architectures.vit import _MatrixSaveAttn
model.blocks[-1].attn = _MatrixSaveAttn.cast(model.blocks[-1].attn)
elif model_name.lower().startswith("swin_"):
reshape_transform = grad_cam_reshape_transform
target_layers = [model.layers[-1].blocks[-1].norm1]
else:
raise NotImplementedError(f"Model {model_name} not supported for attribution evaluation.")
model.eval()
val_start = time()
rel_ig_weights = 0.0
rel_attn_weights = 0.0
rel_cam_weights = {"GradCAM": 0.0, "GradCAM++": 0.0}
if rank == 0:
logger.info("Start attribution evaluation")
if dali_server:
dali_server.start_thread()
for batch_data in iterator:
xs, ys, fg_masks = batch_data
xs, ys, fg_masks = (
xs.to(device, non_blocking=True),
ys.to(device, non_blocking=True),
fg_masks.float().to(device, non_blocking=True),
)
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
model.zero_grad()
ig = IntegratedGradients(model)
# we use attention temperature of 10 to make differences more apparent after exp
attr_ig = (
ig.attribute(xs, target=ys, baselines=0.0, internal_batch_size=args.batch_size * 4).sum(dim=1) * 10
) # B x W x H
attr_probs = attr_ig.view(xs.shape[0], -1).softmax(dim=-1).view(xs.shape[0], *xs.shape[2:])
fg_masks = fg_masks.view(attr_probs.shape)
fg_attrs = (attr_probs * fg_masks).sum(dim=(-1, -2))
rel_attr_weight = fg_attrs / fg_masks.mean(dim=(-1, -2))
rel_attr_weight = torch.where(fg_masks.mean(dim=(-1, -2)) > 0, rel_attr_weight, 1.0)
if rel_attr_weight.isnan().any():
logger.error(f"NaNs in rel_attr_weight: {rel_attr_weight}, fg_mask_weights: {fg_masks.mean(dim=(-1, -2))}")
break
rel_ig_weights += rel_attr_weight.mean().item()
cam_targets = [ClassifierOutputTarget(int(trgt)) for trgt in ys.tolist()]
for method, name in zip([GradCAM, GradCAMPlusPlus], ["GradCAM", "GradCAM++"]):
with method(model=model, target_layers=target_layers, reshape_transform=reshape_transform) as cam, (
torch.amp.autocast("cuda") if args.eval_amp else nullcontext()
):
cam_attr = cam(input_tensor=xs, targets=cam_targets)
cam_attr = torch.from_numpy(cam_attr).to(device)
rel_cam_attr = (cam_attr * fg_masks).sum(dim=(-1, -2)) / cam_attr.sum(dim=(-1, -2))
cam_attr_weight = rel_cam_attr / fg_masks.mean(dim=(-1, -2))
cam_attr_weight = torch.where(
(fg_masks.mean(dim=(-1, -2)) > 0) & (cam_attr.sum(dim=(-1, -2)) > 0), cam_attr_weight, 1.0
)
rel_cam_weights[name] += cam_attr_weight.mean().item()
if cam_attr_weight.isnan().any():
logger.error(
f"NaNs in cam_attr_weight ({name}): {cam_attr_weight}, fg_mask_weights:"
f" {fg_masks.mean(dim=(-1, -2))}"
)
break
if eval_attn_importance:
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
pred = model(xs) # noqa: F841
last_attn_mat = model.blocks[-1].attn.attn_mat
cls_tkn_attn = last_attn_mat[:, :, 0, 1:].mean(dim=1).squeeze(dim=1) # B x H x 1(CLS Token) X N -> B x N
B, N = cls_tkn_attn.shape
att_HW = int(sqrt(N))
cls_tkn_attn = cls_tkn_attn.view(B, 1, att_HW, att_HW)
attn_attr = F.interpolate(
cls_tkn_attn, size=(xs.shape[-2], xs.shape[-1]), mode="bilinear", align_corners=False
).view(B, xs.shape[-2], xs.shape[-1])
rel_attn_attr = (attn_attr * fg_masks).sum(dim=(-1, -2)) / attn_attr.sum(dim=(-1, -2))
attn_attr_weight = rel_attn_attr / fg_masks.mean(dim=(-1, -2))
attn_attr_weight = torch.where(
(fg_masks.mean(dim=(-1, -2)) > 0) & (attn_attr.sum(dim=(-1, -2)) > 0), attn_attr_weight, 1.0
)
rel_attn_weights += attn_attr_weight.mean().item()
if args.debug:
logger.debug(f"Attribution scores: IG: {rel_attr_weight[:5]}, GradCAM(++): {cam_attr_weight[:5]}")
num_subplots = 5 if eval_attn_importance else 4
fig, axs = plt.subplots(num_subplots, 4)
for plt_i in range(4):
axs[0][plt_i].imshow(denormalize(xs[plt_i]).permute(1, 2, 0).cpu().numpy())
axs[1][plt_i].imshow(fg_masks[plt_i].cpu().numpy())
axs[2][plt_i].imshow(attr_probs[plt_i].cpu().numpy())
axs[3][plt_i].imshow(cam_attr[plt_i].cpu().numpy())
if eval_attn_importance:
axs[4][plt_i].imshow(attn_attr[plt_i].cpu().numpy())
plt.show()
iterator_desc = (
f"IG weights: {rel_ig_weights / (iterator.n + 1):.4f}, GradCAM weights:"
f" {rel_cam_weights['GradCAM'] / (iterator.n + 1):.4f}, GradCAM++ weights:"
f" {rel_cam_weights['GradCAM++'] / (iterator.n + 1):.4f}"
)
if eval_attn_importance:
iterator_desc += f", Attn weights: {rel_attn_weights / (iterator.n + 1):.4f}"
iterator.set_description(iterator_desc)
if args.distributed:
dist.barrier()
val_end = time()
rel_ig_weights /= len(iterator)
rel_grad_cam = rel_cam_weights["GradCAM"] / len(iterator)
rel_grad_cam_pp = rel_cam_weights["GradCAM++"] / len(iterator)
rel_attn_weights /= len(iterator)
if dali_server:
dali_server.stop_thread()
if args.distributed:
gather_tensor = torch.Tensor([rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights]).to(device)
dist.barrier()
dist.all_reduce(gather_tensor)
gather_tensor = (gather_tensor / world_size).tolist()
rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights = gather_tensor
if rank == 0:
output_text = (
f"epoch {epoch}: eval_{args.val_dataset}/rel_ig_weights={rel_ig_weights},"
f" eval_{args.val_dataset}/rel_grad_cam={rel_grad_cam},"
f" eval_{args.val_dataset}/rel_grad_cam_pp={rel_grad_cam_pp}"
)
if eval_attn_importance:
output_text += f", eval_{args.val_dataset}/rel_attn_weights={rel_attn_weights}"
output_text += f", eval_{args.val_dataset}/attribution_eval_time={val_end - val_start}s"
logger.info(output_text)
if wandb_available():
import wandb
wandb_data = {
f"eval_{args.val_dataset}/importance_ig": rel_ig_weights,
f"eval_{args.val_dataset}/importance_grad_cam": rel_grad_cam,
f"eval_{args.val_dataset}/importance_grad_cam_pp": rel_grad_cam_pp,
}
if eval_attn_importance:
wandb_data[f"eval_{args.val_dataset}/importance_attn"] = rel_attn_weights
wandb.log(wandb_data)
ddp_cleanup(args=args, rank=rank)

View File

@@ -0,0 +1,211 @@
import argparse
import os
from functools import partial
from math import log
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import Compose, RandomCrop, Resize, ToTensor
from tqdm.auto import tqdm
try:
from models import load_pretrained
except ModuleNotFoundError:
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from models import load_pretrained
from utils import prep_kwargs
def score_5( # noqa: D103
idx,
bg_probs,
mean_probs,
fg_ratio,
max_idx,
fg_ratio_max=0.9,
fg_ratio_min=0.002,
fg_ratio_exp=0.4, # learned: 0.487
idx_exp=0.01, # learned: 0.043
bg_probs_exp=0.2, # learned: 0.24
opt_fg_ratio=0.1, # learned: 0.1
mean_probs_exp=0.2, # learned: 0.2
fg_ratio_penalty=1, # learned: -0.446 ???
):
return (
log(mean_probs) * mean_probs_exp
+ log(1 - bg_probs) * bg_probs_exp
+ log(1 - abs(fg_ratio - opt_fg_ratio)) * fg_ratio_exp
+ log(1 - idx / (max_idx + 1)) * idx_exp
+ (fg_ratio_min < fg_ratio < fg_ratio_max) * fg_ratio_penalty
) # TODO: KEEP IT LIKE IT IS
parser = argparse.ArgumentParser(description="Inspect image versions")
parser.add_argument("-f", "--base_folder", type=str, required=True, help="Base folder to inspect")
parser.add_argument(
"-batch_size",
type=int,
default=32,
help="Batch size for model inspection. Will be 1 background and batch_size - 1 foregrounds",
)
parser.add_argument("-imsize", type=int, default=224, help="Image size")
parser.add_argument(
"-score_f_weights", choices=["manual", "automatic"], default="manual", help="Score function hyperparameters"
)
parser.add_argument("-auto_fg_pen_val", type=float, default=-0.446, help="Automatic foreground penalty value")
parser.add_argument("-d", "--dataset", choices=["tinyimagenet", "imagenet"], default="tinyimagenet", help="Dataset")
args = parser.parse_args()
score_f = (
score_5
if args.score_f_weights == "manual"
else partial(
score_5, **dict(fg_ratio_exp=0.487, idx_exp=0.043, bg_probs_exp=0.24, fg_ratio_penalty=args.auto_fg_pen_val)
)
)
bg_folder = os.path.join(args.base_folder, "backgrounds")
fg_folder = os.path.join(args.base_folder, "foregrounds")
classes = os.listdir(fg_folder)
classes = sorted(classes, key=lambda x: int(x[1:]))
assert len(classes) in [200, 1_000], f"Expected 200 or 1_000 classes, got {len(classes)}"
total_images = set()
for in_cls in classes:
cls_images = {
os.path.join(in_cls, "_".join(img.split(".")[0].split("_")[:-1]))
for img in os.listdir(os.path.join(fg_folder, in_cls))
if img.split(".")[0].split("_")[-1].startswith("v")
}
total_images.update(cls_images)
total_images = list(total_images)
# base_folder = os.path.join(*(os.path.dirname(__file__).split("/")[:-1]))
# in_cls to print name/lemma
with open(os.path.join("data", "misc_dataset_files", "tinyimagenet_synset_names.txt"), "r") as f:
in_cls_to_name = {line.split(":")[0].strip(): line.split(":")[1].strip() for line in f.readlines() if len(line) > 2}
if args.dataset == "tinyimagenet":
inspection_model_paths = [] # PATHS TO MODEL WEIGHTS (.pt) PRETRAINED ON TinyImageNet
elif args.dataset == "imagenet":
inspection_model_paths = [] # PATHS TO MODEL WEIGHTS (.pt) PRETRAINED ON ImageNet
else:
raise ValueError(f"Unknown dataset {args.dataset}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inspection_models = [
load_pretrained(path, prep_kwargs({}), new_dataset_params=False)[0].to(device) for path in inspection_model_paths
]
img_transform = Compose([Resize((args.imsize, args.imsize)), RandomCrop(args.imsize), ToTensor()])
total_versions = []
for img_name in tqdm(total_images, desc="Image version computation"):
in_cls, img_name = img_name.split("/")
versions = set()
for img in os.listdir(os.path.join(fg_folder, in_cls)):
if "_".join(img.split("_")[: len(img_name.split("_"))]) == img_name:
versions.add(img)
if len(versions) == 1:
version = list(versions)[0]
if version.split(".")[0].split("_")[-1].startswith("v"):
tqdm.write(f"renaming single version image {version} to {img_name}.WEBP")
os.rename(os.path.join(fg_folder, in_cls, version), os.path.join(fg_folder, in_cls, f"{img_name}.WEBP"))
os.rename(
os.path.join(bg_folder, in_cls, version.replace(".WEBP", ".JPEG")),
os.path.join(bg_folder, in_cls, f"{img_name}.JPEG"),
)
continue
elif len(versions) == 0:
tqdm.write(f"Image {img_name} has no versions")
continue
versions = sorted(list(versions))
assert all(
[version.split(".")[0].split("_")[-1].startswith("v") for version in versions]
), f"Weird Versions: {versions} for image {img_name}"
assert len(versions) <= 3, f"Too many versions for image {img_name}: {versions}"
version_scores = []
for v_idx, version in enumerate(versions):
img = Image.open(os.path.join(fg_folder, in_cls, version))
bg_img = Image.open(os.path.join(bg_folder, in_cls, f"{version.split('.')[0]}.JPEG"))
img_mask = np.array(img.convert("RGBA").split()[-1])
fg_ratio = np.sum(img_mask) / (255 * bg_img.size[0] * bg_img.size[1])
fg_size = img.size
monochrome_backgrounds = [
Image.new(
"RGB",
(max(args.imsize, fg_size[0]), max(args.imsize, fg_size[1])),
(255 * i // (args.batch_size - 2), 255 * i // (args.batch_size - 2), 255 * i // (args.batch_size - 2)),
)
for i in range(args.batch_size - 1)
]
pasting_error = False
for mc_bg in monochrome_backgrounds:
try:
mc_bg.paste(img, ((args.imsize - fg_size[0]) // 2, (args.imsize - fg_size[1]) // 2), img)
except ValueError as e:
tqdm.write(f"Image {img_name} could not be pasted into background: {e}")
pasting_error = True
break
inp_batch = torch.stack(
[img_transform(bg_img)] + [img_transform(mc_bg) for mc_bg in monochrome_backgrounds], dim=0
).to(device)
cls_idx = classes.index(in_cls)
bg_probs = []
mean_probs = []
for model in inspection_models:
model.eval()
with torch.no_grad():
out_probs = model(inp_batch).softmax(dim=-1)[:, cls_idx].cpu().numpy()
bg_probs.append(out_probs[0])
mean_probs.append(np.mean(out_probs[1:]))
# average the lists
bg_probs = np.mean(bg_probs)
mean_probs = np.mean(mean_probs)
version_score = (
score_f(
idx=v_idx,
bg_probs=float(bg_probs),
mean_probs=float(mean_probs),
fg_ratio=float(fg_ratio),
max_idx=len(versions) - 1,
)
if not pasting_error
else -100
)
version_scores.append(version_score)
assert len(versions) == len(version_scores), f"Expected {len(versions)} scores, got {len(version_scores)}"
if max(version_scores) > min(version_scores):
# find best version
best_version_idx = int(np.argmax(version_scores))
best_version = versions[best_version_idx]
# delete all other versions
for version in versions:
if version != best_version:
os.remove(os.path.join(fg_folder, in_cls, version))
os.remove(os.path.join(bg_folder, in_cls, f"{version.split('.')[0]}.JPEG"))
# remove version tag in name
new_version_name = "_".join(best_version.split("_")[:-1]) + "." + best_version.split(".")[-1]
os.rename(os.path.join(fg_folder, in_cls, best_version), os.path.join(fg_folder, in_cls, new_version_name))
os.rename(
os.path.join(bg_folder, in_cls, f"{best_version.split('.')[0]}.JPEG"),
os.path.join(bg_folder, in_cls, f"{new_version_name.split('.')[0]}.JPEG"),
)
else:
tqdm.write(f"All versions have the same score for image {img_name}")

View File

@@ -0,0 +1,16 @@
#!/bin/bash
srun -K \
--container-image=PATH/TO/SLURM/IMAGE \
--container-workdir="$(pwd)" \
--container-mounts=/ALL/IMPORTANT/MOUNTS,"$(pwd)":"$(pwd)" \
--partition=RTXA6000,RTX3090,A100-40GB,A100-80GB,H100,H200 \
--job-name="python" \
--nodes=1 \
--gpus=1 \
--ntasks=1 \
--cpus-per-task=24 \
--mem=64G \
--time=1-0 \
--export="NLTK_DATA=/PATH/TO/NLTK_DATA" \
python3 "$@"

View File

@@ -0,0 +1,346 @@
"""Module to load the datasets, using torch and datadings."""
import contextlib
import os
from functools import partial
import torchvision.transforms as tv_transforms
from datadings.reader import MsgpackReader
from timm.data import create_transform
from torch.utils.data import DataLoader, DistributedSampler, WeightedRandomSampler
from torchvision.datasets import (
CIFAR10,
CIFAR100,
FGVCAircraft,
Flowers102,
Food101,
ImageFolder,
OxfordIIITPet,
StanfordCars,
)
from data.counter_animal import CounterAnimal
from data.data_utils import (
DDDecodeDataset,
ToOneHotSequence,
collate_imnet,
collate_listops,
get_hf_transform,
minimal_augment,
segment_augment,
three_augment,
)
from data.fornet import ForNet
from data.samplers import RASampler
from paths_config import ds_path
def prepare_dataset(dataset_name, args, transform=None, train=True, rank=None):
"""Load a dataset from disk, different formats are used for different datasets.
Supported datasets: CIFAR10, ImageNet, ImageNet21k
Args:
dataset_name (str): name of the dataset
args: further arguments
transform (list[Module] | str, optional): transformations to use on the data; the list gets composed, or give args.augment_strategy (Default value = None)
train (bool, optional): use the training split (or test/validation split) (Default value = True)
rank (int, optional): global rank of this process in distributed training (Default value = None)
Returns:
DataLoader: data loader for the dataset
int: number of classes in the dataset
int: ignore index for the dataset
bool: whether the dataset is multi-label
"""
compose = tv_transforms.Compose
dali_server = None
if transform is None:
if args.augment_engine == "torchvision":
if args.augment_strategy == "3-augment":
transform = three_augment(args, as_list=False, test=not train)
elif args.augment_strategy == "differentiable-transform":
from data.distilled_dataset import differentiable_augment
transform = differentiable_augment(args, as_list=False, test=not train)
elif args.augment_strategy == "none":
transform = []
elif args.augment_strategy == "lm_one_hot":
transform = [
tv_transforms.Grayscale(num_output_channels=1),
tv_transforms.ToTensor(),
ToOneHotSequence(),
]
elif args.augment_strategy == "segment-augment":
transform = segment_augment(args, test=not train)
elif args.augment_strategy == "minimal":
transform = minimal_augment(args, test=not train)
elif args.augment_strategy == "deit":
if train:
transform = create_transform(
input_size=args.imsize,
is_training=True,
color_jitter=args.aug_color_jitter_factor,
auto_augment=args.auto_augment_strategy,
interpolation="bicubic",
re_prob=args.aug_random_erase_prob,
re_mode=args.aug_random_erase_mode,
re_count=args.aug_random_erase_count,
)
else:
transform = three_augment(args, test=True) # only do resize, centercrop, and normalize
else:
raise NotImplementedError(
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
)
elif args.augment_engine == "albumentations":
from data import album_transf as ATf
compose = ATf.AlbumTorchCompose
if args.augment_strategy == "3-augment":
transform = ATf.three_augment(args, as_list=False, test=not train)
elif args.augment_strategy == "minimal":
transform = ATf.minimal_augment(args, test=not train)
else:
raise NotImplementedError(
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
)
elif args.augment_engine == "dali":
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
from data import dali_transf as DTf
dev_id = int(os.environ.get("LOCAL_RANK", 0))
if args.augment_strategy == "3-augment":
pipe = DTf.three_augment(
args,
test=not train,
batch_size=args.batch_size,
num_threads=args.num_workers,
device_id=dev_id,
)
elif args.augment_strategy == "minimal":
pipe = DTf.minimal_augment(
args,
test=not train,
batch_size=args.batch_size,
num_threads=args.num_workers,
device_id=dev_id,
)
else:
raise NotImplementedError(
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
)
dali_server = dali_proxy.DALIServer(pipe)
transform = dali_server.proxy
dataset_name_case_sensitive = dataset_name # keep the original name for AnimalNet folder
dataset_name = dataset_name.lower()
ignore_index = -100
multi_label = False
if isinstance(transform, list):
transform = compose(transform)
if dataset_name == "cifar10":
dataset = CIFAR10(root=ds_path("cifar"), train=train, download=False, transform=transform)
n_classes, collate = 10, None
elif dataset_name == "stanford-cars":
dataset = StanfordCars(
root=ds_path("stanford_cars"),
split="train" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 196, None
elif dataset_name == "oxford-pet":
dataset = OxfordIIITPet(
root=ds_path("oxford_pet"),
split="trainval" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 37, None
elif dataset_name == "flowers102":
dataset = Flowers102(
root=ds_path("flowers102"),
split="train" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 102, None
elif dataset_name == "food-101":
dataset = Food101(
root=ds_path("food101"),
split="train" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 101, None
elif dataset_name == "fgvc-aircraft":
dataset = FGVCAircraft(
root=ds_path("aircraft"),
split="train" if train else "test",
annotation_level="variant",
download=False,
transform=transform,
)
n_classes, collate = 100, None
elif dataset_name == "imagenet":
dataset = ImageFolder(os.path.join(ds_path("imagenet1k"), "train" if train else "val"), transform=transform)
n_classes, collate = 1000, None
elif dataset_name == "tinyimagenet":
dataset = ImageFolder(os.path.join(ds_path("tinyimagenet"), "train" if train else "val"), transform=transform)
n_classes, collate = 200, None
elif dataset_name.startswith("fornet"):
ds_def = dataset_name.split("/")
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
pruning_ratio = 0.8 if len(ds_def) < 3 else float(ds_def[2])
fg_size_mode = ("range" if train else "max") if len(ds_def) < 4 else ds_def[3]
paste_pre_transform = True if len(ds_def) < 5 else ds_def[4] in ["y", "t"]
orig_img_prob = (
0.0 if len(ds_def) < 6 else (ds_def[5] if ds_def[5] in ["linear", "revlinear", "cos"] else float(ds_def[5]))
)
mask_smoothing_sigma = 0.0 if len(ds_def) < 7 else float(ds_def[6])
assert len(ds_def) < 5 or ds_def[4] in [
"y",
"t",
"n",
"f",
], f"Invalid dataset definition: {ds_def[4]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
orig_ds = ds_path("imagenet1k")
dataset = ForNet(
ds_path("fornet"),
train=train,
background_combination=comb_scheme,
pruning_ratio=pruning_ratio,
transform=transform,
fg_transform=(
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
),
fg_size_mode=fg_size_mode,
paste_pre_transform=paste_pre_transform,
orig_img_prob=orig_img_prob,
orig_ds=orig_ds,
mask_smoothing_sigma=mask_smoothing_sigma,
epochs=args.epochs,
_album_compose=args.augment_engine == "albumentations",
)
n_classes, collate = 1000, None
elif dataset_name.startswith("tinyfornet"):
ds_def = dataset_name.split("/")
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
pruning_ratio = 1.1 if len(ds_def) < 3 else float(ds_def[2])
fg_size_mode = "range" if len(ds_def) < 4 else ds_def[3]
fg_bates_n = 1 if len(ds_def) < 5 else int(ds_def[4])
paste_pre_transform = False if len(ds_def) < 6 else ds_def[5] in ["y", "t"]
orig_img_prob = (
0.0 if len(ds_def) < 7 else (ds_def[6] if ds_def[6] in ["linear", "revlinear", "cos"] else float(ds_def[6]))
)
mask_smoothing_sigma = 0.0 if len(ds_def) < 8 else float(ds_def[7])
assert len(ds_def) < 6 or ds_def[5] in [
"y",
"t",
"n",
"f",
], f"Invalid dataset definition: {ds_def[5]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
assert "-" not in ds_def[0] or len(ds_def[0].split("-")) == 3, f"Invalid dataset definition: {ds_def[0]}"
version = "" if "-" not in ds_def[0] else f"_v{ds_def[0].split('-')[1]}_f{ds_def[0].split('-')[2]}"
orig_ds = ds_path("tinyimagenet")
dataset = ForNet(
f"{ds_path('tinyimagenet')}{version}",
train=train,
background_combination=comb_scheme,
pruning_ratio=pruning_ratio,
transform=transform,
fg_transform=(
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
),
fg_size_mode=fg_size_mode,
fg_bates_n=fg_bates_n,
paste_pre_transform=paste_pre_transform,
orig_img_prob=orig_img_prob,
orig_ds=orig_ds,
mask_smoothing_sigma=mask_smoothing_sigma,
epochs=args.epochs,
_album_compose=args.augment_engine == "albumentations",
)
n_classes, collate = 200, None
elif dataset_name.startswith("counteranimal/"):
mode = dataset_name.split("/")[1]
dataset = CounterAnimal(ds_path("counteranimal"), mode=mode, transform=transform, train=train)
n_classes, collate = 1000, None
elif dataset_name.startswith("imagenet9/"):
variant = dataset_name.split("/")[1]
assert variant in [
"next",
"same",
"rand",
], f"ImageNet-9 has possible variants next, same, and rand, but not '{variant}'."
dataset = ImageFolder(os.path.join(ds_path("imagenet9"), f"mixed_{variant}", "val"), transform=transform)
n_classes, collate = 9, None
else:
raise NotImplementedError(f"Dataset {dataset_name} is not implemented (yet).")
if args.aug_repeated_augment_repeats > 1 and train:
# use repeated augment sampler from DeiT
sampler = RASampler(
dataset,
num_replicas=args.world_size,
rank=rank,
shuffle=args.shuffle,
num_repeats=args.aug_repeated_augment_repeats,
)
elif args.weighted_sampler:
assert hasattr(
dataset, "per_sample_weights"
), f"Dataset {type(dataset)} should implement per_sample_weights function, but does not."
sampler = WeightedRandomSampler(dataset.per_sample_weights(), num_samples=len(dataset) // args.world_size)
elif args.distributed:
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank, shuffle=train and args.shuffle)
else:
sampler = None
loader_batch_size = 1 if dataset_name.startswith("listops") else args.batch_size
loader_kwargs = dict(
batch_size=loader_batch_size,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
drop_last=train,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
persistent_workers=False,
collate_fn=collate,
shuffle=None if sampler else train and args.shuffle,
sampler=sampler,
)
if args.augment_engine == "dali":
data_loader = dali_proxy.DataLoader(dali_server, dataset, **loader_kwargs)
else:
data_loader = DataLoader(dataset, **loader_kwargs)
return data_loader, n_classes, ignore_index, multi_label, dali_server

View File

@@ -0,0 +1,679 @@
#!/usr/bin/env python3
"""Parse args and call the correct script inside slurm container.
Outside the container, on the head-node, create and call the correct srun command.
"""
import argparse
import os
import subprocess
from datetime import datetime
from config import default_kwargs, slurm_defaults
from paths_config import results_folder, slurm_output_folder
_EXPNAMES = ["EfficientCVBench", "test", "recombine_imagenet"]
def base_parser():
"""Create the argument parser with all the choices for the training / evaluation scripts."""
parser = argparse.ArgumentParser("Transformer training and evaluation.")
# Main
group = parser.add_argument_group("Main")
group.add_argument(
"-t",
"--task",
nargs="?",
choices=[
"pre-train",
"fine-tune",
"fine-tune-head",
"eval",
"parser-test",
"eval-metrics",
"eval-attr",
"continue",
"eval-center-bias",
"eval-size-bias",
"load-images",
"save-images",
],
required=True,
help="Task to perform.",
)
group.add_argument(
"-m",
"--model",
nargs="?",
type=str,
required=True,
help="Model to use. Either model name for a new model or weights and dicts to load for fine-tuning.",
)
group.add_argument("-ds", "--dataset", nargs="?", type=str, help="Dataset to train on.")
group.add_argument(
"-valds", "--val-dataset", nargs="?", type=str, help="Validation dataset. Defaults to same as training."
)
group.add_argument("-ep", "--epochs", nargs="?", type=int, help="Number of epochs to train.")
group.add_argument(
"-run",
"--run-name",
nargs="?",
type=str,
help="A name for the run. If not give, the model name is used instead.",
)
group.add_argument(
"--defaults", nargs="?", choices=["DeiT", "DeiTIII"], default="DeiTIII", help="Default settings to use."
)
# Further model parameters
group = parser.add_argument_group("Further model parameters")
group.add_argument("--drop-path-rate", nargs="?", type=float, help="Drop path rate for ViT models.")
group.add_argument("--layer-scale-init-values", nargs="?", type=float, help="LayerScale initial values.")
group.add_argument("--layer-scale", action=argparse.BooleanOptionalAction, help="Use layer scale?")
group.add_argument(
"--qkv-bias",
action=argparse.BooleanOptionalAction,
help="Use bias in linear transformation to queries, keys, and values?",
)
group.add_argument("--pre-norm", action=argparse.BooleanOptionalAction, help="Use norm first architecture?")
group.add_argument("--dropout", nargs="?", type=float, help="Model dropout.")
group.add_argument("-heads", "--num-heads", nargs="?", type=int, help="Number of parallel attention heads.")
group.add_argument("--input-dim", nargs="?", type=int, help="Dimensionality of text encoding.")
group.add_argument("--max-seq-len", nargs="?", type=int, help="Maximum sequence length for text data.")
group.add_argument(
"--fused-attn",
action=argparse.BooleanOptionalAction,
help="Use fused attention (for ViT with Timm's attention only)?",
)
# group.add_argument(
# "--perf-metric", nargs="?", choices=["acc", "mIoU"], help="Performance metric to use for evaluation."
# )
# group.add_argument("-no_model_ema", action="store_true",
# help="Don't use an exponential moving average for model parameters")
# group.add_argument("-model_ema_decay", nargs='?', type=float, default=default_kwargs["model_ema_decay"],
# help="Decay rate for exponential moving average of model parameters")
# Experiment management
group = parser.add_argument_group("Experiment management")
group.add_argument("--seed", nargs="?", type=int, help="Manual RNG seed.")
group.add_argument(
"-exp",
"--experiment-name",
nargs="?",
choices=_EXPNAMES,
help="Name for the experiment. Is used for grouping of runs.",
)
group.add_argument(
"--save-epochs", nargs="?", type=int, help="Number of epochs after which to save the full training state."
)
group.add_argument(
"--keep-interm-states",
nargs="?",
type=int,
help="Number of intermediate states to keep. All others (earlier ones) will be deleted automatically.",
)
group.add_argument(
"--custom-dataset-path", nargs="?", type=str, help="Overwrite the path to any dataset to this path."
)
group.add_argument(
"--results-folder",
nargs="?",
default=results_folder,
type=str,
help="Folder to put script results (mlflow data, models, etc.).",
)
group.add_argument(
"--gather-stats-during-training",
action=argparse.BooleanOptionalAction,
help="Gather training statistics from all GPUs?",
)
group.add_argument("--tqdm", action=argparse.BooleanOptionalAction, help="Show tqdm for every epoch?")
group.add_argument(
"--debug", action=argparse.BooleanOptionalAction, help="Debug mode: lots of intermediate prints."
)
group.add_argument("--wandb", action=argparse.BooleanOptionalAction, help="Use external logging via Wandb?")
group.add_argument("--log-level", choices=["info", "debug"], help="Log level", metavar="LEVEL")
group.add_argument("-out", "--out-dir", type=str, help="Output directory for additional outputs.")
# Speedup
group = parser.add_argument_group("Speedup")
group.add_argument("--amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision?")
group.add_argument(
"--eval-amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision during evaluation?"
)
group.add_argument("--compile-model", action=argparse.BooleanOptionalAction, help="Use torch.compile?")
group.add_argument("--cuda", action=argparse.BooleanOptionalAction, help="Use cuda?")
# Data loading
group = parser.add_argument_group("Data loading")
group.add_argument("-bs", "--batch-size", nargs="?", type=int, help="Batch size over all graphics cards (togeter).")
group.add_argument("--num-workers", nargs="?", type=int, help="Number of dataloader worker threads. Should be >0.")
group.add_argument(
"--pin-memory", action=argparse.BooleanOptionalAction, help="Use pin_memory of torch Dataloader?"
)
group.add_argument(
"--prefetch-factor",
nargs="?",
type=int,
help="Prefetch factor for dataloader workers (how many batches to fetch)",
)
group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, help="Shuffle the training data?")
group.add_argument(
"--weighted-sampler",
action=argparse.BooleanOptionalAction,
help="Use a class-weighted sampler to sample evenly from all classes (train and val)?",
)
group.add_argument("--ipc", type=int, help="How many images per class to load and save.")
# Optimizer
group = parser.add_argument_group("Optimizer")
group.add_argument("--opt", nargs="?", type=str, help="Optimizer to use.")
group.add_argument("--weight-decay", nargs="?", type=float, help="Weight decay factor for use in optimizer.")
group.add_argument("-lr", "--lr", nargs="?", type=float, help="Initial learning rate.")
group.add_argument(
"--max-grad-norm", nargs="?", type=float, help="Maximum norm for the gradients (used for cutoff)."
)
group.add_argument("--warmup-epochs", nargs="?", type=int, help="Number of epochs of linear warmup.")
group.add_argument("--label-smoothing", nargs="?", type=float, help="Label smoothing factor.")
group.add_argument("--loss", nargs="?", choices=["ce", "baikal"], type=str, help="Loss function to use.")
group.add_argument(
"--loss-weight", nargs="?", type=str, choices=["none", "linear", "log", "sqrt"], help="Per class loss weight."
)
group.add_argument("--sched", nargs="?", choices=["cosine", "const"], help="Learning rate schedule.")
group.add_argument("--min-lr", nargs="?", type=float, help="Minimum learning rate to be hit by scheduler.")
group.add_argument("--warmup-lr", nargs="?", type=float, help="Warmup learning rate.")
group.add_argument("--warmup-sched", nargs="?", choices=["linear", "const"], help="Schedule for warmup")
group.add_argument(
"--opt-eps", nargs="?", type=float, help="Epsilon value added in the optimizer to stabilize training."
)
group.add_argument("--momentum", nargs="?", type=float, help="Optimizer momentum.")
# Data augmentation
group = parser.add_argument_group("Data augmentation")
group.add_argument("--augment-strategy", nargs="?", type=str, help="Data augmentation strategy.")
group.add_argument("--aug-rand-rot", nargs="?", type=int, help="Random rotation limit.")
group.add_argument(
"--aug-flip", action=argparse.BooleanOptionalAction, help="Use data augmentation: horizontal flip?"
)
group.add_argument(
"--aug-crop",
action=argparse.BooleanOptionalAction,
help="Use data augmentation: cropping. This may break the skript?",
)
group.add_argument("--aug-resize", action=argparse.BooleanOptionalAction, help="Use data augmentation: resize?")
group.add_argument(
"--aug-grayscale", action=argparse.BooleanOptionalAction, help="Use data augmentation: grayscale?"
)
group.add_argument("--aug-solarize", action=argparse.BooleanOptionalAction, help="Use data augmentation: solarize?")
group.add_argument(
"--aug-gauss-blur", action=argparse.BooleanOptionalAction, help="Use data augmentation: gaussian blur?"
)
group.add_argument(
"--aug-cutmix-alpha",
type=float,
help="Alpha value for using CutMix. CutMix is active when aug_cutmix_alpha > 0.",
)
group.add_argument(
"--aug-mixup-alpha", type=float, help="Alpha value for using Mixup. Mixup is active when aug_mixup_alpha > 0."
)
group.add_argument(
"--aug-color-jitter-factor",
nargs="?",
type=float,
help="Factor to use for the data augmentation: color jitter.",
)
group.add_argument(
"--aug-normalize", action=argparse.BooleanOptionalAction, help="Use data augmentation: Normalization?"
)
group.add_argument(
"--aug-repeated-augment-repeats",
type=int,
help="Number of image repeats with repeat-augment from DeiT. 1 is not using repeat-augment.",
)
group.add_argument(
"--aug-random-erase-prob", type=float, help="For DeiT augment: Probabiliy of RandomErase augmentation."
)
group.add_argument("--auto-augment-strategy", type=str, help="For DeiT augment: AutoAugment Policy to use.")
group.add_argument("--imsize", nargs="?", type=int, help="Image size given to the model -> imsize x imsize.")
group.add_argument(
"--augment-engine",
nargs="?",
choices=["torchvision", "albumentations", "dali"],
help="Which data augmentation engine to use.",
)
return parser
def partition_choices():
"""Automatically create a list of all possible slurm partitions."""
potential = list(set([l.split(" ")[0] for l in os.popen("sinfo")])) # noqa: E741
if len(potential) <= 2:
return slurm_defaults["partition"]
return [p[:-1] if "*" in p else p for p in potential if p != "PARTITION"]
def slurm_parser(parser=None):
"""Add srun arguments to the given parser.
Args:
parser (argparse.ArgumentParser, optional): base parser to extend; default is parser from *base_parser*
Returns:
parser (argparse.ArgumentParser): extended parser
"""
if parser is None:
parser = base_parser()
group = parser.add_argument_group("Slurm arguments")
group.add_argument(
"--partition",
nargs="*",
default=slurm_defaults["partition"],
choices=partition_choices(),
help="Slurm partition to use",
)
group.add_argument(
"--container-image",
nargs="?",
default=slurm_defaults["container_image"],
type=str,
help="Path to slurm container image (.sqsh)",
)
group.add_argument(
"--container-workdir",
nargs="?",
default=slurm_defaults["container_workdir"],
type=str,
help="Working directory in container",
)
group.add_argument(
"--container-mounts",
nargs="?",
default=slurm_defaults["container_mounts"],
type=str,
help="All slurm mounts separated by ','.",
)
group.add_argument(
"--job-name",
nargs="?",
default=slurm_defaults["job_name"],
type=str,
help="Slurm job name. Will default to '<model> <task> <dataset>'.",
)
group.add_argument(
"--nodes", nargs="?", default=slurm_defaults["nodes"], type=int, help="Number of cluster nodes to use."
)
group.add_argument(
"--ntasks", nargs="?", default=slurm_defaults["ntasks"], type=int, help="Number of GPUs to use for the job."
)
group.add_argument("--gpus", action=argparse.BooleanOptionalAction, default=True, help="Use gpus for this job?")
group.add_argument(
"-cpus",
"--cpus-per-task",
"--cpus-per-gpu",
nargs="?",
default=slurm_defaults["cpus_per_task"],
type=int,
help="Number of CPUs per task/GPU.",
)
group.add_argument(
"-mem",
"--mem-per-gpu",
"--mem-per-task",
nargs="?",
default=slurm_defaults["mem_per_gpu"],
type=int,
help="Ram per GPU (in Gb) to use. Will be given as total mem in srun command.",
)
group.add_argument(
"--task-prolog",
nargs="?",
default=slurm_defaults["task_prolog"],
type=str,
help="Shell script for task prolog (installing packages, etc.).",
)
group.add_argument("--time", nargs="?", default=slurm_defaults["time"], type=str, help="Slurm time limit.")
group.add_argument(
"--export",
nargs="?",
default=slurm_defaults["export"],
type=str,
help="Additional environment variables to export.",
)
group.add_argument("--exclude", nargs="?", default=slurm_defaults["exclude"], type=str, help="Nodes to exclude.")
group.add_argument(
"--after-job", nargs="?", default=slurm_defaults["after_job"], type=int, help="Job ID to wait for."
)
group.add_argument(
"--interactive",
action="store_true",
help=(
"Run using srun instead of sbatch. This will print the output into the terminal, not the slurm output file."
" The logfile will still be created as usual."
),
default=False,
)
group = parser.add_argument_group("Run locally")
group.add_argument("--local", action="store_true", help="Run locally; not in slurm", default=False)
return parser
def parse_args(args=None, parser=None):
"""Parse args from *base_parser* and insert defaults.
Args:
args: (Default value = None)
parser: (Default value = None)
Returns:
dict: parsed arguments
"""
if args is None:
parser = base_parser()
args = parser.parse_args()
args = dict(vars(args))
check_arg_completeness(args, parser)
return args
def check_arg_completeness(args, parser):
"""Check completeness of arguments.
Args:
args (dict): arguments to check
parser (argparse.ArgumentParser): for raising the parser error
Note:
will raise a parser error if the arguments are not complete.
"""
if args["task"] in ["pre-train", "fine-tune", "fine-tune-head"]:
if "run_name" not in args or args["run_name"] is None or len(args["run_name"]) == 0:
parser.error(f"-run_name is required for task {args['task']}")
if "experiment_name" not in args or args["experiment_name"] is None or len(args["experiment_name"]) == 0:
parser.error(f"-experiment_name is required for task {args['task']}. Choose from {_EXPNAMES}")
if "epochs" not in args or args["epochs"] is None:
parser.error(f"-epochs is required for task {args['task']}")
if ("dataset" not in args or args["dataset"] is None) and args["task"] in [
"pre-train",
"fine-tune",
"fine-tune-head",
"eval-metrics",
]:
parser.error(f"-dataset is required for task {args['task']}")
if (
("val_dataset" not in args or args["val_dataset"] is None)
and ("dataset" not in args or args["dataset"] is None)
and args["task"] in ["eval"]
):
parser.error(f"-dataset or -val_dataset is required for task {args['task']}")
if args["aug_repeated_augment_repeats"] is not None and args["aug_repeated_augment_repeats"] < 1:
parser.error(
"number of repeats for repeated augment has to be >= 1, but got -aug_repeated_augment_repeats ="
f" {args['aug_repeated_augment_repeats']}"
)
if args["task"] == "save-images" and ("out_dir" not in args or args["out_dir"] is None):
parser.error("Need to set save directory (--out-dir) to save the images in.")
def inside_slurm():
"""Test for being inside a slurm container.
Works by testing for environment variable 'RANK'.
"""
return "RANK" in os.environ
# TODO: fix ./runscript.tmp: 18: Syntax error: Unterminated quoted string
def create_runscript(args, file_name=None):
"""Create a run script for a distributed training job using SLURM.
Args:
args (dict): A dictionary containing various arguments for the job, including parameters for SLURM and for training.
file_name (str, optional, optional): The name of the file to create. Defaults to "runscript.tmp".
Returns:
str: The name of the created file.
str: Additional command line arguments for sbatch.
Example:
>>> args = {"model": "vit_large_patch16_384", "task": "pre-train", "batch_size": 256, ...}
>>> file_name = "my_run_script.sh"
>>> create_runscript(args, file_name)
"""
for key, val in slurm_defaults.items():
if key not in args and val is not None:
args[key] = val
if "run_name" not in args or args["run_name"] is None:
model_str = args["model"]
if model_str.endswith(".pt"):
model_str = os.path.dirname(model_str)
run_name = args["task"] + " " + model_str.split(os.sep)[-1].split("_")[0]
else:
run_name = args["run_name"]
job_name = run_name.replace(" ", "_").replace("/", "_").replace(">", "_").replace("<", "_")
if file_name is None:
file_name = (
f"experiments/sbatch/run_{args['task']}_{job_name}_at_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.sbatch"
)
task_args = ""
# slurm_command = "echo run distributed:\necho python3 main.py {0}\n\nsrun -K \\\n" # " --gpus-per-task=1 \\\n --gpu-bind=none \\\n"
srun_command = "\nsrun -K \\\n"
sbatch_commands = ( # outfile name is job name, date, job id, node name
"#!/bin/bash\n\n#SBATCH"
f" --output={slurm_output_folder}/%x-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-%j-%N.out\n"
)
sbatch_cmd_args = "" # additional command line arguments for sbatch
python_command = " python3 main.py {0}\n"
for key, val in args.items():
if key == "local":
continue
if key == "interactive":
continue
if key == "gpus":
continue
if key in slurm_defaults:
# it's a parameter for srun
# slurm has - instead of _
key = key.replace("_", "-")
if key == "mem-per-gpu":
# convert mem-per-gpu to mem
# slurm_command += f" --mem={val * args['ntasks'] // args['nodes']}G \\\n" # that amount of memory is assigned on each node
key = "mem"
val = f"{val * args['ntasks'] // args['nodes']}G" # that amount of memory is assigned on each node
# continue
if key == "job-name" and val is None:
# # default jobname is '<task> <model> <dataset>'
# model_str = args["model"]
# task = args["task"]
# if task == "pre-train":
# # it's just the model name...
# model = model_str.split("_")[0]
# else:
# # it's a path to the tar file
# if not model_str.startswith(res_folder):
# model = "<vit model>"
# else:
# model = model_str[len(res_folder) :].split("_")[1].split(" ")[0]
# if "dataset" in args and args["dataset"] is not None:
# dataset = args["dataset"]
# else:
# dataset = ""
val = run_name
if key == "job-name" and not val.startswith('"'):
val = f'"{val}"'
if key in ["task-prolog", "nodes", "exclude", "after-job"] and val is None:
continue
if key == "task-prolog":
srun_command += f' --{key}="{val}" \\\n'
continue
if key == "after-job":
sbatch_cmd_args += f"--dependency=afterany:{val} "
continue
if key == "partition" and isinstance(val, list):
val = ",".join(val)
if key == "ntasks":
if args["nodes"] == 1:
gpus = val if args["gpus"] else 0
# slurm_command += f" --gpus={val} \\\n"
sbatch_commands += f"#SBATCH --gpus={gpus}\n"
else:
assert (
val % args["nodes"] == 0
), f"Number of tasks ({val}) must be a multiple of the number of nodes ({args['nodes']})."
# slurm_command += f" --gpus-per-node={val // args['nodes']} \\\n"
sbatch_commands += f"#SBATCH --gpus-per-node={val // args['nodes']}\n"
sbatch_commands += "#SBATCH --ntasks-per-node=8\n"
if "container" in key:
srun_command += f" --{key}={val} \\\n"
else:
sbatch_commands += f"#SBATCH --{key}={val}\n"
# slurm_command += f" --{key}={val} \\\n"
else:
# it's a parameter for the training
if val is None:
continue
if key in ["results_folder"] and val == globals()[key]:
continue
key = key.replace("_", "-")
if isinstance(val, bool):
if val:
task_args += f"--{key} "
else:
task_args += f"--no-{key} "
continue
if isinstance(val, str):
task_args += f'--{key} "{val}" '
else:
task_args += f"--{key} {val} "
# slurm_command += "python3 main.py {0}\n"
# os.umask(0) # make it possible to create an executable file
# with open(file_name, "w+", opener=lambda pth, flgs: os.open(pth, flgs, 0o777)) as f:
# f.write(slurm_command.format(task_args))
with open(file_name, "w+") as f:
f.write(sbatch_commands + srun_command + python_command.format(task_args))
# delete all runscripts older than a month
n_old_files = int(os.popen("find experiments/sbatch/ -type f -mtime +30 | wc -l").read())
if n_old_files > 0:
print(f"Deleting {n_old_files} old runscripts.")
os.system("find experiments/sbatch/ -type f -mtime +30 -delete")
return file_name, sbatch_cmd_args
if __name__ == "__main__":
if not inside_slurm():
# Make execution script and execute it
parser = slurm_parser()
args = vars(parser.parse_args())
if not args["local"]:
script_name, cmd_args = create_runscript(args)
# os.system("./" + script_name) # run srun to execute this script in slurm cluster
# -> the following lines will be executed there
if args["interactive"]:
os.system(f"python3 srun-sbatch.py {script_name}")
else:
os.system(f"sbatch {cmd_args} {script_name}") # sbatch to queue the job on the cluster
exit(0)
# local execution is wanted
for key in list(args.keys()):
if key.replace("_", "-") in slurm_defaults:
args.pop(key)
args = parse_args(args, parser)
else:
args = parse_args()
args["branch"] = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode("utf-8")
args["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
if args["task"] == "pre-train":
from train import pretrain
pretrain(**args)
elif args["task"] == "fine-tune":
from train import finetune
finetune(**args)
elif args["task"] == "fine-tune-head":
from train import finetune
finetune(**args, head_only=True)
elif args["task"] == "parser-test":
from copy import copy
from utils import prep_kwargs, log_args
kwargs = prep_kwargs(copy(args))
log_args(kwargs)
# keys = sorted(list(args.keys()))
# fill_len = max(len(k) for k in keys)
# for key in keys:
# print(f"{key + ' ' * (fill_len - len(key))} = {args[key]} -> {kwargs[key]}")
elif args["task"] == "eval-metrics":
from evaluate import evaluate_metrics
evaluate_metrics(**args)
elif args["task"] == "eval":
from evaluate import evaluate
evaluate(**args)
elif args["task"] == "eval-attr":
from evaluate import evaluate_attributions
evaluate_attributions(**args)
elif args["task"] == "continue":
from recover import continue_training
continue_training(**args)
elif args["task"] == "eval-center-bias":
from evaluate import evaluate_center_bias
evaluate_center_bias(**args)
elif args["task"] == "eval-size-bias":
from evaluate import evaluate_size_bias
evaluate_size_bias(**args)
elif args["task"] == "load-images":
from test import load_images
load_images(**args)
elif args["task"] == "save-images":
from test import save_images
save_images(**args)
else:
raise NotImplementedError(f"Task {args['task']} is not implemented.")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,154 @@
"""Model loading and preparation."""
import importlib
import os
import warnings
from functools import partial
import timm
import torch
import torch.nn as nn
from loguru import logger
import utils
from architectures.vit import TimmViT
from resizing_interface import vit_sizes
_ARCHITECTURES_IMPORTED = False
def _import_architectures():
global _ARCHITECTURES_IMPORTED
if not _ARCHITECTURES_IMPORTED:
model_file_path = os.path.dirname(os.path.abspath(__file__))
for file in os.listdir(os.path.join(model_file_path, "architectures")):
if not file.endswith(".py"):
continue
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
importlib.import_module(f"architectures.{file[:-3]}")
logger.debug(f"Imported architectures.{file[:-3]}")
except Exception as e:
logger.error(f"\033[93mCould not import \033[0m\033[91m{file}\033[0m")
logger.error(e)
_ARCHITECTURES_IMPORTED = True
def prepare_model(model_str, args):
"""Prepare a new model.
If the name is of the format ViT-<size>/<patch_size>, use a *TimmViT*, else fall back to timm model loading.
Args:
model_str (str): model name
args (utils.DotDict): further arguments, needs to have keys n_classes, drop_path_rate; key imsize or '_<imsize>' at the end of ViT specification
Returns:
torch.nn.Module: model
"""
_import_architectures()
kwargs = dict(args)
for key in list([key for key, val in kwargs.items() if val is None]):
kwargs.pop(key)
if args.layer_scale_init_values:
kwargs["init_values"] = kwargs["init_scale"] = args.layer_scale_init_values
if args.dropout and args.dropout > 0.0:
kwargs["drop"] = kwargs["drop_rate"] = args.dropout
if args.drop_path_rate and args.drop_path_rate > 0.0:
kwargs["drop_block_rate"] = args.drop_path_rate
kwargs["num_classes"] = args.n_classes
kwargs["img_size"] = args.imsize
if model_str.startswith("ViT"):
# Format: ViT-{Ti,S,B,L}/<patch_size>[_<image_res>]
h1, h2 = model_str.split("/")
_, model_size = h1.split("-")
if "_" in h2:
patch_size, image_res = h2.split("_")
assert args.imsize is None or args.imsize == int(
image_res
), f"Got two different image sizes: {args.imsize} vs {image_res}"
else:
patch_size = h2
kwargs = {**vit_sizes[model_size], **kwargs}
model = TimmViT(patch_size=int(patch_size), in_chans=3, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
else:
logger.debug(f"Loading model via timm api {model_str} with args {kwargs}")
model = timm.create_model(model_str, pretrained=False, **kwargs)
return model
def load_pretrained(model_path, args, new_dataset_params=False):
"""Load a pretrained model from .tar file.
Args:
new_dataset_params (bool, optional): change model parameters (imsize, n_classes) to the ones from args. (Default value = False)
model_path (str): path to .tar file
args: new model parameters
Returns:
tuple: model, args, old_args, save_state
"""
_import_architectures()
save_state = torch.load(model_path, map_location="cpu")
old_args = utils.prep_kwargs(save_state["args"])
args.model = old_args.model
old_args.cuda = args.cuda
if old_args.model.startswith("flash_vit"):
args.pop("layer_scale_init_values", None)
old_args.pop("layer_scale_init_values", None)
# load the model (the old one first)
model = prepare_model(old_args.model, old_args)
logger.debug(f"loading model {old_args.model} from {model_path} with args {old_args}")
file_save_state = utils.remove_prefix(save_state["model_state"], prefix="_orig_mod.")
file_save_state = utils.remove_prefix(file_save_state)
try:
model.load_state_dict(file_save_state)
except (UnboundLocalError, RuntimeError) as e:
model_keys = set(model.state_dict().keys())
file_keys = set(file_save_state.keys())
logger.warning(f"Error loading state dict: {e}")
model_minus_file = model_keys.difference(file_keys)
file_minus_model = file_keys.difference(model_keys)
logger.warning(f"model-file: {model_minus_file}\nfile-model: {file_minus_model}")
if len(file_minus_model) == 0 and all([".ls" in key and key.endswith(".gamma") for key in model_minus_file]):
logger.info("Old model was without LayerScale -> replicating")
try:
args.pop("layer_scale_init_values")
old_args.pop("layer_scale_init_values")
model = prepare_model(old_args.model, old_args)
model.load_state_dict(file_save_state)
except (UnboundLocalError, RuntimeError) as e:
logger.error("Could not resolve conflict")
logger.error(f"Still got error {e}")
exit(-1)
elif any("head.0." in key for key in file_minus_model):
logger.info("Old model used nn.Seqeuntial for head. Trying to fix -> nn.Linear")
file_save_state = {key.replace("head.0.", "head."): val for key, val in file_save_state.items()}
try:
model.load_state_dict(file_save_state)
except (UnboundLocalError, RuntimeError) as e:
logger.error("Could not resolve conflict")
logger.error(f"Still got error {e}")
exit(-1)
else:
exit(-1)
if new_dataset_params:
# setup for finetuning parameters
model.set_image_res(args.imsize)
model.set_num_classes(args.n_classes)
if args.max_seq_len is not None:
model.set_max_seq_len(args.max_seq_len)
return model, args, old_args, save_state

View File

@@ -0,0 +1,36 @@
import os
user = os.environ.get("USER")
results_folder = os.path.join("/BASE/FOLDER/TO/STORE/WEIGHTS/AND/LOGS", "EfficientCVBench")
# PATH: /netscratch/<user>/slurm
slurm_output_folder = os.path.join("/FOLDER/FOR/SLURM/TO/WRITE/LOGS/TO", "slurm")
_ds_paths = {
"cifar": "/PATH/TO/CIFAT",
"tinyimagenet": "/PATH/TO/TINYIMAGENET",
"stanford_cars": "/PATH/TO/CARS/SUPERFOLDER",
"oxford_pet": "/PATH/TO/PET/SUPERFOLDER",
"flowers102": "/PATH/TO/FLOWERS/SUPERFOLDER",
"food101": "/PATH/TO/FOOD/SUPERFOLDER",
"aircraft": "/PATH/TO/AIRCRAFT/SUPERFOLDER",
"fornet": "/PATH/TO/FORNET",
"counteranimal": "/PATH/TO/CounterAnimal/LAION-final",
}
def ds_path(dataset, args=None):
"""Get the (base) path for any dataset.
Args:
-----
dataset (str): The dataset I'm looking for.
args (DotDict, optional): Run args. If args.custom_dataset_path is set, this one is always returned.
Returns:
--------
str: Path to the dataset root folder.
"""
if args is not None and "custom_dataset_path" in args and args.custom_dataset_path is not None:
return args.custom_dataset_path
return _ds_paths[dataset]

View File

@@ -0,0 +1,136 @@
"""Continue pretraining / finetuning after something went wrong."""
import torch
from loguru import logger
from engine import _train, setup_criteria_mixup, setup_model_optim_sched_scaler, setup_tracking_and_logging
from load_dataset import prepare_dataset
from models import load_pretrained
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs
def continue_training(model, **kwargs):
"""Continue training a model from a saved state.
Args:
model (str): path to saved state.
**kwargs: additional keyword arguments.
"""
model_path = model
save_state = torch.load(model, map_location="cpu")
# state is of the form
#
# state = {'epoch': epochs,
# 'model_state': model.state_dict(),
# 'optimizer_state': optimizer.state_dict(),
# 'scheduler_state': scheduler.state_dict(),
# 'args': dict(args),
# 'run_name': run_name,
# 'stats': metrics}
args = prep_kwargs(save_state["args"])
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
torch.cuda.set_device(device)
if "world_size" in args and args.world_size is not None:
global_bs = args.batch_size * args.world_size
else:
# assume global bs is given in kwargs
global_bs = kwargs["batch_size"]
args.batch_size = int(global_bs / world_size)
args.world_size = world_size
if "dataset" in args and args.dataset is not None:
dataset = args.dataset
else:
# get default dataset for the task
dataset = "ImageNet21k" if args.task == "pre-train" else "ImageNet"
args.dataset = dataset
if "val_dataset" in args and args.val_dataset is not None:
val_dataset = args.val_dataset
else:
val_dataset = dataset
args.val_dataset = val_dataset
start_epoch = save_state["epoch"]
if "epochs" in args and args.epochs is not None and args.epochs != start_epoch:
epochs = args.epochs
else:
epochs = kwargs["epochs"]
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path)
logger.info(f"Logging run information to '{run_folder}'")
# get the datasets & dataloaders
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
dataset, args, rank=rank
)
val_loader, _, __, ___, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
# model_name = args.model
model, args, _, __ = load_pretrained(model_path, args)
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
try:
optimizer.load_state_dict(save_state["optimizer_state"])
except ValueError as e:
logger.error(f"Could not load optimizer state: {e}")
logger.error(
f"optimizer state: {optimizer.state_dict().keys()}, param groups: {optimizer.state_dict()['param_groups']}"
)
logger.error(
f"saved state: {save_state['optimizer_state'].keys()}, param groups:"
f" {save_state['optimizer_state']['param_groups']}"
)
raise e
scheduler.load_state_dict(save_state["scheduler_state"])
# log all devices
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
if rank == 0:
logger.info(f"torch version {torch.__version__}")
log_args(args)
if args.seed:
torch.manual_seed(args.seed)
criterion, val_criterion, mixup = setup_criteria_mixup(args)
if rank == 0:
logger.info(f"start training at epoch {start_epoch}")
logger.info(f"Run info at: '{run_folder}'")
res = _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
run_folder,
scaler=scaler,
do_metrics_calculation=True,
start_epoch=start_epoch,
show_tqdm=args.tqdm,
train_dali_server=train_dali_server,
val_dali_server=val_dali_server,
)
if rank == 0:
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
logger.info(f"Run '{args.run_name}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%")
ddp_cleanup(args=args, rank=rank)

View File

@@ -0,0 +1,25 @@
captum==0.8.0
# datadings==3.4.6
einops==0.8.1
fvcore==0.1.5.post20221221
grad-cam==1.5.4
halonet-pytorch==0.0.4
numpy==1.26.4
opencv-python==4.11.0.86
pytorch_wavelets==1.3.0
pywavelets==1.8.0
reformer-pytorch==1.4.4
routing-transformer==1.6.1
sinkhorn-transformer==0.11.4
timm==1.0.15
torch==2.6.0
torcheval==0.0.7
torchprofile==0.0.4
torchvision==0.21.0
tqdm==4.67.1
nltk==3.9.1
numpy==1.26.4
Pillow==11.1.0
psutils==3.3.9
wandb==0.19.9
psutil==7.0.0

View File

@@ -0,0 +1,108 @@
from copy import copy
import torch
from loguru import logger
from torch import nn
vit_sizes = {
"na": dict(embed_dim=96, depth=2, num_heads=2),
"mu": dict(embed_dim=144, depth=6, num_heads=3),
"Ti": dict(embed_dim=192, depth=12, num_heads=3),
"S": dict(embed_dim=384, depth=12, num_heads=6),
"B": dict(embed_dim=768, depth=12, num_heads=12),
"L": dict(embed_dim=1024, depth=24, num_heads=16),
"LRA_CIFAR": dict(embed_dim=256, depth=1, num_heads=4, mlp_ratio=1.0),
"LRA_IMDB": dict(embed_dim=256, depth=4, num_heads=4, mlp_ratio=4.0),
"LRA_ListOps": dict(embed_dim=512, depth=4, num_heads=8, mlp_ratio=4.0),
}
class ResizingInterface:
"""Interface for resizing parts of a Vision Transformer model."""
def get_internal_loss(self):
"""Add a term to the loss."""
return 0.0
def set_image_res(self, res):
"""Set a new image resolution.
Resets the (learned) patch embedding.
Args:
res (int): new image resolution
"""
self._set_input_strand(res=res)
def _set_input_strand(self, res=None, patch_size=None):
"""Set a new image resolution and patch size.
Args:
res (int): (Default value = None)
patch_size (int): (Default value = None)
"""
if res is None:
res = self.img_size
if patch_size is None:
patch_size = self.patch_size
else:
# TODO: implement interpolation of patch_embed weights to new patch size/input shape
raise NotImplementedError("Interpolation of patch_embed weights to new patch size not implemented yet.")
if res == self.img_size and patch_size == self.patch_size:
return # nothing to do here
logger.info(f"Resizing input from {self.img_size} to {res} with patch size {self.patch_size} to {patch_size}.")
old_patch_embed_state = copy(self.patch_embed.state_dict())
self.patch_embed = self.embed_layer(
img_size=res,
patch_size=patch_size,
in_chans=self.in_chans,
embed_dim=self.embed_dim,
bias=not self.pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
)
self.patch_embed.load_state_dict(old_patch_embed_state)
num_extra_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
orig_size = int((self.pos_embed.shape[-2] - num_extra_tokens) ** 0.5)
new_size = int(self.patch_embed.num_patches**0.5)
extra_tokens = self.pos_embed[:, :num_extra_tokens]
pos_tokens = self.pos_embed[:, num_extra_tokens:]
# make it shape rest x embed_dim x orig_size x orig_size
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
pos_tokens = nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
# make it shape rest x new_size^2 x embed_dim
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
if num_extra_tokens > 0:
pos_tokens = torch.cat((extra_tokens, pos_tokens), dim=1)
self.pos_embed = nn.Parameter(pos_tokens.contiguous())
self.img_size = res
self.patch_size = patch_size
def set_num_classes(self, n_classes):
"""Reset the classification head with a new number of classes.
Args:
n_classes (int): new number of classes
"""
if n_classes == self.num_classes:
return
logger.info(f"Resizing classification head from {self.num_classes} to {n_classes}.")
self.head = nn.Linear(self.embed_dim, n_classes) if n_classes > 0 else nn.Identity()
self.num_classes = n_classes
# init weight + bias
# nn.init.zeros_(self.head.weight)
# nn.init.constant_(self.head.bias, -log(self.num_classes))
nn.init.trunc_normal_(self.head.weight, std=0.02)
nn.init.constant_(self.head.bias, 0)

View File

@@ -0,0 +1,46 @@
import os
import sys
assert len(sys.argv) >= 2, f"Add a sbatch script as the first argument"
assert os.path.isfile(
sys.argv[1]
), f"First argument has to be an executable script (a file that exists), but got '{sys.argv[1]}'"
with open(sys.argv[1], "r") as f:
script = f.readlines()
script = [l.strip() for l in script if len(l.strip()) > 0]
# join lines ending with \
joined_script = []
current_line = ""
for line in script:
current_line += line
if current_line.endswith("\\"):
current_line = current_line[:-1]
else:
joined_script.append(current_line)
current_line = ""
script = joined_script
additional_srun_params = []
srun_lines = []
for line in script:
if line.upper().startswith("#SBATCH "):
param = line[len("#SBATCH ") :]
if param.startswith("--output="):
continue
additional_srun_params.append(param)
elif line.startswith("srun "):
srun_lines.append(line)
further_args = " ".join(sys.argv[2:])
sruns = [
line.replace("srun ", "srun " + " ".join(additional_srun_params) + " ").replace('"$@"', further_args)
for line in srun_lines
]
for srun_line in sruns:
print(f"I will run:\n{srun_line}", flush=True)
os.system(srun_line)

View File

@@ -0,0 +1,97 @@
import math
import os
import sys
import numpy as np
from loguru import logger
from matplotlib import pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from load_dataset import prepare_dataset
from utils import log_args, log_formatter, prep_kwargs
def load_images(dataset, **kwargs):
args = prep_kwargs(kwargs)
args.dataset = dataset
args.aug_normalize = False
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
images = next(iter(loader))[0]
images = images.permute(0, 2, 3, 1).numpy()
images = [images[i] for i in range(images.shape[0])]
rows = math.ceil(math.sqrt(len(images) / 2))
ims_per_row = len(images) // rows
fig, axs = plt.subplots(rows, ims_per_row)
axs = [ax for row in axs for ax in row]
for img, ax in zip(images, axs):
ax.imshow(img)
fig.suptitle(f"Examples from {dataset}")
fig.tight_layout(pad=0)
plt.show()
def save_images(dataset, out_dir, ipc=None, **kwargs):
args = prep_kwargs(kwargs)
args.dataset = dataset
args.out_dir = out_dir
args.ipc = ipc
args.aug_normalize = False
log_file = os.path.join(out_dir, "save_images.log")
logger.remove()
logger.configure(extra=dict(run_name=f"Save images of {dataset}", rank=0, world_size=-1))
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.info(f"Out dir '{out_dir}'")
log_args(args)
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
n_ims = [0 for i in range(args.n_classes)]
if args.n_classes == 1000:
# assume its ImageNet classes
logger.info("1000 classes => assuming ImageNet class names")
with open("data/misc_dataset_files/imagenet_labels.txt", "r") as f:
lines = f.readlines()
labels = [l.strip().split(" ")[0] for l in lines]
lbl_to_cls_name = sorted(labels)
else:
lbl_to_cls_name = [i for i in range(args.n_classes)]
for cls_name in lbl_to_cls_name:
os.makedirs(os.path.join(args.out_dir, cls_name), exist_ok=True)
skipped_ims = 0
tqdm_is_disabled = (not args.tqdm) or os.environ.get("TQDM_DISABLE", 0) != 0
for i, (images, labels) in (
pbar := tqdm(enumerate(loader), desc="Loading and saving images", disable=tqdm_is_disabled, total=len(loader))
):
images = (images.permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
images = [images[i] for i in range(images.shape[0])]
labels = labels.tolist()
for img, lbl in zip(images, labels):
if ipc is not None and n_ims[lbl] >= ipc:
skipped_ims += 1
continue
img = Image.fromarray(img).save(
os.path.join(args.out_dir, lbl_to_cls_name[lbl], f"{lbl_to_cls_name[lbl]}_{n_ims[lbl]}.JPEG")
)
n_ims[lbl] += 1
if ipc is not None and sum(n_ims) >= args.n_classes * ipc:
break
if tqdm_is_disabled:
if i % 1000 == 0:
logger.info(f"Batch [{i+1}/{len(loader)}]: Saved {sum(n_ims)}, skipped {skipped_ims}")
else:
pbar.set_description(f"Loading and saving (saved {sum(n_ims)}, skipped {skipped_ims})")
logger.success(f"Extracted {sum(n_ims)} images.")

View File

@@ -0,0 +1,303 @@
import os
import random
import sys
import numpy as np
import timm
import torch
from loguru import logger
from engine import (
_train,
setup_criteria_mixup,
setup_model_optim_sched_scaler,
setup_tracking_and_logging,
wandb_available,
)
from load_dataset import prepare_dataset
from models import load_pretrained, prepare_model
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs, set_filter_warnings
def finetune(model, dataset, epochs, val_dataset=None, head_only=False, **kwargs):
"""Finetune a pretrained model on a given dataset for a specified number of epochs.
Args:
model (str): Path to the pretrained model state file (in .tar format).
dataset (str): Name of the dataset to finetune on.
val_dataset (str, optional): Name of the validation dataset. (Default value = None)
epochs (int): Number of epochs to train for.
head_only (bool, optional): Whether to train only the head of the model. Default: False.
**kwargs (dict): Further arguments for model setup, training, evaluation,...
Notes:
This function assumes that the model was pretrained on a different dataset using a different set of hyperparameters.
It fine-tunes the model on a new dataset by loading the pretrained weights and training for the specified number of
epochs. The function supports distributed training using the PyTorch DistributedDataParallel module.
"""
set_filter_warnings()
# Add defaults & make keys properties
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
args.val_dataset = val_dataset
args.dataset = dataset
args.epochs = epochs
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
args.world_size = world_size
try:
torch.cuda.set_device(device)
except RuntimeError as e:
logger.error(
f"Could not set device {device} as current device; "
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
raise e
args.batch_size = int(args.batch_size / world_size)
if args.seed is not None:
# fix the seed for reproducibility
seed = args.seed + rank
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# get the datasets & dataloaders
# transform only contains resize & crop here; everything else is handled on the GPU / in the training loop
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
dataset, args, rank=rank
)
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
assert (
args.n_classes == _val_classes
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
save_state = torch.load(model, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
parent_folder = os.path.dirname(model)
args.model = old_args.model
run_folder = setup_tracking_and_logging(args, rank)
if rank == 0:
if not os.path.exists(os.path.join(run_folder, "parent_run")):
os.symlink(parent_folder, os.path.join(run_folder, "parent_run"), target_is_directory=True)
logger.info(
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
if args.seed:
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
# model_name = old_args.model
if rank == 0:
logger.info(f"The model was pretrained on {old_args.dataset} for {save_state['epoch']} epochs.")
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(
model, device, epochs, args, head_only=head_only
)
# log all devices
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"timm version {timm.__version__}")
logger.info(f"full set of old arguments: {old_args}")
log_args(args)
if args.seed:
torch.manual_seed(seed)
criterion, val_criterion, mixup = setup_criteria_mixup(args)
if rank == 0:
logger.info(f"Run info at: '{run_folder}'")
res = _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
run_folder,
scaler,
do_metrics_calculation=True,
show_tqdm=args.tqdm,
train_dali_server=train_dali_server,
val_dali_server=val_dali_server,
)
if rank == 0:
best_acc_key = sorted([key for key in res.keys() if key.startswith("val/best_")])[0]
logger.info(
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
)
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)
def pretrain(model, dataset, epochs, val_dataset=None, **kwargs):
"""Train or pretrain a model.
Args:
model (str): Name of the model to train.
dataset (str): Name of the dataset to train the model on.
epochs (int): Number of training epochs.
val_dataset (str, optional, optional): Name of the validation dataset, by default None
**kwargs (dict): Additional keyword arguments.
Notes:
This function sets up logger, prepares the model, and trains the model on the given dataset.
"""
set_filter_warnings()
# Add defaults & make args properties
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
args.val_dataset = val_dataset
args.dataset = dataset
args.model = model
args.epochs = epochs
args.distributed, device, world_size, rank, gpu_id = ddp_setup(args.cuda)
args.world_size = world_size
# sleep(rank * 5)
# logger.debug(f'running environment commands for rank {rank}')
# os.system('env')
# os.system('nvidia-smi')
# sleep((world_size - rank) * 5)
logger.debug(
f"rank params: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; gpu params: "
f"SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}"
)
if args.cuda:
try:
torch.cuda.set_device(device)
except RuntimeError as e:
logger.error(f"Could not set device {device} as current device: {e}")
logger.error(
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
raise e
args.batch_size = int(args.batch_size / world_size)
run_folder = setup_tracking_and_logging(args, rank)
if rank % world_size == 0:
logger.info(
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
if args.seed is not None:
# fix the seed for reproducibility
seed = args.seed + rank
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
# get the datasets & dataloaders
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
dataset, args, rank=rank
)
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
assert (
args.n_classes == _val_classes
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
# setup model with amp & DDP
if isinstance(model, str):
if model.startswith("ViT") and "_" not in model:
model += f"_{args.imsize}"
model_name = model
model = prepare_model(model, args)
if not model_name:
model_name = type(model).__name__
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
# log all devices
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if device != 'cpu' else ''}")
if rank == 0:
logger.info(f"python version {sys.version}")
logger.info(f"torch version {torch.__version__}")
logger.info(f"timm version {timm.__version__}")
log_args(args)
if args.seed:
torch.manual_seed(seed)
criterion, val_criterion, mixup = setup_criteria_mixup(args)
if rank == 0:
logger.info(f"Run info at: '{run_folder}'")
res = _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
run_folder,
scaler,
do_metrics_calculation=True,
show_tqdm=args.tqdm,
train_dali_server=train_dali_server,
val_dali_server=val_dali_server,
)
if rank == 0:
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
logger.info(
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
)
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)

View File

@@ -0,0 +1,594 @@
"""Utils and small helper functions."""
import collections.abc
import json
import os
import shutil
import sys
import warnings
from dataclasses import dataclass
from itertools import repeat
from math import cos, pi, sqrt
import numpy as np
import torch
import torch.distributed as dist
from loguru import logger
from timm.data import Mixup
from timm.utils import NativeScaler, dispatch_clip_grad
from torch.nn.modules.loss import _WeightedLoss
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import paths_config
from config import default_kwargs, get_default_kwargs # noqa: F401 # is used in prep_kwargs
class RepeatedDataset(Dataset):
"""Dataset that repeats the given dataset a number of times."""
def __init__(self, dataset, num_repeats):
"""Create repeated dataset.
Args:
dataset (Dataset): dataset to repeat.
num_repeats (int): number of repeats.
"""
self.dataset = dataset
self.num_repeats = num_repeats
def __getitem__(self, idx):
return self.dataset[idx // self.num_repeats]
def __len__(self):
return len(self.dataset) * self.num_repeats
@dataclass
class SchedulerArgs:
"""Class for scheduler arguments."""
sched: str
epochs: int
min_lr: float
warmup_lr: float
warmup_epochs: int
cooldown_epochs: int = 0
def scheduler_function_factory(
epochs, sched, warmup_epochs=0, lr=None, min_lr=0.0, warmup_sched=None, warmup_lr=None, offset=0, **kwargs
):
"""Create a scheduler factor function.
Args:
sched (str): the learning rate schedule type
epochs (int): length of the full schedule
warmup_epochs (int, optional): number of epochs reserved for warmup (Default value = 0)
lr (float, optional): learning rate (has to be given, when warmup or min_lr are set) (Default value = None)
min_lr (float, optional): minimum learning rate (Default value = 0.0)
warmup_sched (str, optional): the type of schedule during warmup (Default value = None)
warmup_lr (float, optional): (starting) learning rate during warmup (Default value = None)
offset (int, optional): offset for the schedule (to be the same as the timm scheduler) (Default value = 0)
**kwargs: unused
Returns:
function: scheduler function
"""
sched = sched.lower()
def warmup_f(ep):
return 1.0
if warmup_epochs > 0:
assert warmup_lr is not None, "Need warmup_lr, but got None"
warmup_lr_factor = warmup_lr / lr
if warmup_sched == "linear":
def warmup_f(ep):
return warmup_lr_factor + (1 - warmup_lr_factor) * max(ep, 0.0) / warmup_epochs
elif warmup_sched == "const":
def warmup_f(ep):
return warmup_lr_factor
else:
raise NotImplementedError(f"Warmup schedule {warmup_sched} not implemented")
epochs = epochs - warmup_epochs + offset
if sched == "cosine":
# cos from 0 to pi
def main_f(ep):
return cos(pi * ep / epochs) / 2 + 0.5
elif sched == "const":
def main_f(ep):
return 1.0
else:
raise NotImplementedError(f"Schedule {sched} is not implemented.")
# rescale and add min_lr
min_lr_fact = min_lr / lr
def main_f_with_min_lr(ep):
return (1 - min_lr_fact) * main_f(ep) + min_lr_fact
return lambda ep: (
warmup_f(ep + offset) if ep + offset < warmup_epochs else main_f_with_min_lr(ep + offset - warmup_epochs)
)
class DotDict(dict):
"""Extension of a Python dictionary to access its keys using dot notation."""
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __getattr__(self, item, default=None):
"""Get item from.
Args:
item: key
default (optional): default value. Defaults to None.
Returns:
value
"""
if item not in self:
return default
return self.get(item)
def prep_kwargs(kwargs):
"""Prepare the arguments and add defaults.
Args:
kwargs (dict[str, Any]): dict of kwargs
Returns:
DotDict: prepared kwargs
"""
if "defaults" not in kwargs:
kwargs["defaults"] = "DeiTIII"
defaults = get_default_kwargs(kwargs["defaults"])
for k, v in defaults.items():
if k not in kwargs or kwargs[k] is None:
kwargs[k] = v
if "results_folder" not in kwargs:
kwargs[var_name] = paths_config.results_folder # globals()[var_name]
if kwargs["results_folder"].endswith("/"):
kwargs["results_folder"] = kwargs["results_folder"][:-1]
if "val_dataset" not in kwargs and "dataset" in kwargs:
kwargs["val_dataset"] = kwargs["dataset"]
return DotDict(kwargs)
def denormalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
"""Invert the normlize operation.
Args:
x (torch.Tensor): images to de-normalize
mean (tuple, optional): normalization mean. Defaults to (0.485, 0.456, 0.406).
std (tuple, optional): normalization std. Defaults to (0.229, 0.224, 0.225).
Returns:
torch.Tensor: de-normalized images
"""
operation = transforms.Normalize(
mean=[-mu / sigma for mu, sigma in zip(mean, std)], std=[1 / sigma for sigma in std]
)
return operation(x)
def log_formatter(record):
if "run_name" not in record["extra"]:
return (
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <y>name TBD</y> > <y>?</y>/<y>?</y> <c>|</c> <level>{level:"
" <8}</level> <c>|</c> {message}\n"
)
epoch_str = "@ epoch <y>{extra[epoch]: >3}</y> " if "epoch" in record["extra"] else ""
code_loc_str = "<r>{name}</r>.<r>{function}</r>:<r>{line}</r> - " if record["level"].no >= 30 else ""
return (
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <m>{extra[run_name]}</m> >"
" <m>{extra[rank]}</m>/<m>{extra[world_size]}</m> "
+ epoch_str
+ "<c>|</c> <level>{level: <8}</level> <c>|</c> "
+ code_loc_str
+ "{message}\n"
)
def ddp_setup(use_cuda=True):
"""Set up the distributed environment.
Args:
use_cuda: (Default value = True)
Returns:
tuple: A tuple containing the following elements:
* bool: Whether the training is distributed.
* torch.device: The device to use for distributed training.
* int: The total number of processes in the distributed setup.
* int: The global rank of the current process in the distributed setup.
* int: The local rank of the current process on its node.
Notes:
The 'nccl' backend is used.
"""
logger.remove()
rank = int(os.getenv("RANK", 0))
local_rank = int(os.getenv("LOCAL_RANK", 0))
num_gpus = int(os.getenv("WORLD_SIZE", 1))
distributed = "RANK" in os.environ and num_gpus > 1
logger.add(sys.stderr, format=log_formatter, colorize=True, enqueue=True)
if distributed:
assert use_cuda, "Only use distributed mode with cuda."
try:
dist.init_process_group("nccl")
except ValueError as e:
logger.critical(f"Value error while setting up nccl process group: {e}")
logger.info(
f" CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}. Shutting down now."
)
raise e
assert torch.cuda.is_available() or not use_cuda, "CUDA is not available"
assert (
len(str(os.environ.get("SLURM_STEP_GPUS")).split(","))
== len(str(os.environ.get("CUDA_VISIBLE_DEVICES")).split(","))
== len(str(os.environ.get("GPU_DEVICE_ORDINAL")).split(","))
== num_gpus
) or not use_cuda, (
f"SLURM GPU setup is incorrect: CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}"
)
return distributed, torch.device(f"cuda:{local_rank}") if use_cuda else "cpu", num_gpus, rank, local_rank
def ddp_cleanup(args, sync_old_wandb=False, rank=0):
"""Clean the distributed setup after use.
Args:
args (DotDict): arguments
sync_old_wandb (bool, optional): Whether to sync and remove wandb runs older than 100 hours (>3 days). Defaults to False.
rank (int, optional): The rank of the current process, so only one process syncs wandb. Defaults to 0.
"""
if sync_old_wandb and rank == 0:
os.system("wandb sync --clean --clean-old-hours 100 --clean-force")
if args.distributed:
logger.info("waiting for all processes to finish")
dist.barrier()
dist.destroy_process_group()
logger.info("exiting now")
exit(0)
def set_filter_warnings():
"""Filter out some warnings to reduce spam."""
# filter DataLoader number of workers warning
warnings.filterwarnings(
"ignore", ".*worker processes in total. Our suggested max number of worker in current system is.*"
)
# Filter datadings only varargs warning
warnings.filterwarnings("ignore", ".*only accepts varargs so.*")
# Filter warnings from calculation of MACs & FLOPs
# warnings.filterwarnings("ignore", ".*No handlers found:.*")
# Filter warnings from gather
warnings.filterwarnings("ignore", ".*is_namedtuple is deprecated, please use the python checks instead.*")
# Filter warnings from meshgrid
warnings.filterwarnings("ignore", ".*in an upcoming release, it will be required to pass the indexing.*")
# Filter warnings from timm when overwriting models
warnings.filterwarnings("ignore", ".*UserWarning: Overwriting .*")
def remove_prefix(state_dict, prefix="module."):
"""Remove a prefix from the keys in a state dictionary.
Args:
state_dict (dict[str, Any]): The state dictionary to remove the prefix from.
prefix (str, optional): The prefix to remove from the keys. Default is 'module.'.
Returns:
dict[str, Any]: A new dictionary with the prefix removed from the keys.
Examples:
>>> state_dict = {'module.layer1.weight': 1, 'module.layer1.bias': 2}
>>> remove_prefix(state_dict)
{'layer1.weight': 1, 'layer1.bias': 2}
"""
return {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
def prime_factors(n):
"""Calculate the prime factors of a given integer.
Args:
n (int): The integer to find the prime factors of.
Returns:
list[int]: The prime factors of n.
"""
i = 2
factors = []
while i * i <= n:
if n % i:
i += 1
else:
n //= i
factors.append(i)
if n > 1:
factors.append(n)
return factors
def linear_regession(points):
"""Calculate a linear interpolation of the points.
Args:
points (dict[float, float]): points to interpolate in the format points[x] = y
Returns:
function: A function that interpolates the points.
"""
N = len(points)
x = []
y = []
for x_i, y_i in points.items():
x.append(x_i)
y.append(y_i)
x = np.array(x)
y = np.array(y)
a = (N * (x * y).sum() - x.sum() * y.sum()) / (N * (x * x).sum() - x.sum() ** 2)
b = (y.sum() - a * x.sum()) / N
return lambda z: a * z + b
def save_model_state(
model_folder,
epoch,
args,
model_state,
regular_save=True,
stats=None,
val_accs=None,
epoch_accs=None,
additional_reason="",
max_interm_ep_states=2,
**kwargs,
):
"""Save the model state.
Args:
model_folder: Folder to save model in
epoch: current epoch
args: arguments to guide and save
model_state: state of the model
regular_save: Is this a regular or a special save? (Default value = True)
stats: model stats to save (Default value = None)
val_accs: model accuracy to save (Default value = None)
epoch_accs: training accuracy to save (Default value = None)
additional_reason: save reason; in case it's not just a regular save interval. Would be "top" or "final", for example. (Default value = "")
max_interm_ep_states: Number of regular epoch states to keep (Default value = 2)
**kwargs: Further arguments to save
"""
# make args dict, not DotDict to be able to save it
state = {"epoch": epoch, "model_state": model_state, "run_name": args.run_name, "args": dict(args)}
if stats is None:
stats = {}
if val_accs is not None:
stats = {**stats, **val_accs}
if epoch_accs is not None:
stats = {**stats, **epoch_accs}
state["stats"] = stats
state = {**state, **kwargs}
logger.info(f"saving model state at epoch {epoch} ({additional_reason})")
regular_file_name = f"ep_{epoch}.pt"
save_name = additional_reason + ".pt" if len(additional_reason) > 0 else regular_file_name
outfile = os.path.join(model_folder, save_name)
torch.save(state, outfile)
if len(additional_reason) > 0 and regular_save:
shutil.copyfile(outfile, os.path.join(model_folder, regular_file_name))
# remove intermediate epoch states (all but the last max_interm_ep_states)
if max_interm_ep_states > 0:
epoch_states = [f for f in os.listdir(model_folder) if f.startswith("ep_") and f.endswith(".pt")]
epoch_states = sorted(epoch_states, key=lambda x: int(x.split("_")[1].split(".")[0]))
if len(epoch_states) > max_interm_ep_states:
for f in epoch_states[:-max_interm_ep_states]:
os.remove(os.path.join(model_folder, f))
logger.debug(f"removed intermediate epoch state {f}")
def log_args(args, rank=0):
if rank == 0:
logger.info("full set of arguments: " + json.dumps(dict(args), sort_keys=True))
# keys = sorted(list(args.keys()))
# for key in keys:
# logger.info(f"arg: {key} = {args[key]}")
class ScalerGradNormReturn(NativeScaler):
"""A wrapper around PyTorch's NativeScaler that returns the gradient norm."""
def __str__(self) -> str:
return f"{type(self).__name__}(_scaler: {self._scaler})"
def __call__(self, loss, optimizer, clip_grad=None, clip_mode="norm", parameters=None, create_graph=False):
"""Scale and backpropagate through the loss tensor, and return the gradient norm of the selected parameters.
Does an optimizer step.
Args:
loss (torch.Tensor): The loss tensor to scale and backpropagate through.
optimizer (torch.optim.Optimizer): The optimizer to use for the optimization step.
clip_grad (float, optional): The maximum allowed norm of the gradients. If None, no clipping is performed.
clip_mode (str, optional): The mode used for clipping the gradients. Only used if `clip_grad` is not None. Possible values are 'norm'
(clipping the norm of the gradients) and 'value' (clipping the value of the gradients). (default='norm')
parameters (iterable[torch.nn.Parameter], optional): The parameters to compute the gradient norm for. If None, the gradient norm is not computed.
create_graph (bool, optional): Whether to create a computation graph for computing second-order gradients. (default=False)
Returns:
float: The gradient norm of the selected parameters.
"""
self._scaler.scale(loss).backward(create_graph=create_graph)
# always unscale the gradients, since it's being done anyway
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
if parameters is not None:
grads = [p.grad for p in parameters if p.grad is not None]
device = grads[0].device
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
else:
grad_norm = -1
if clip_grad is not None:
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()
return grad_norm
class NoScaler:
"""Dummy gradient scaler that doesn't scale gradients.
This scaler performs a simple backward pass with the given loss, and then updates the model's parameters
with the given optimizer. The resulting gradient norm is computed and returned.
"""
def __str__(self) -> str:
return f"{type(self).__name__}()"
def __call__(self, loss, optimizer, parameters=None, **kwargs):
"""Perform backward pass with the given loss, updates the model's parameters with the given optimizer, and computes the resulting gradient norm.
Args:
loss (torch.Tensor): The loss tensor that the gradients will be computed from.
optimizer (torch.optim.Optimizer): The optimizer that will be used to update the model's parameters.
parameters (iterable[torch.Tensor], optional): An iterable of model parameters to compute gradients. If None, returns -1.
**kwargs: Additional keyword arguments; nothing will be done with these.
Returns:
float: The gradient norm computed after the optimizer step, if parameters is not None.
"""
loss.backward()
if parameters is not None:
grads = [p.grad for p in parameters if p.grad is not None]
device = grads[0].device
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
else:
grad_norm = -1
optimizer.step()
return grad_norm
def get_cpu_name():
"""Get the name of the CPU."""
with open("/proc/cpuinfo", "r") as f:
for line in f:
if line.startswith("model name"):
return line.split(":")[1].strip()
return "unknown"
# From PyTorch internals
def _ntuple(n):
"""Make a function to create n-tuples.
Args:
n (int): tuple length
Returns:
function: function to create n-tuples
"""
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
"""Calculate the smallest number >= v that is divisible by divisor.
This function is primarily used to ensure that the output of a layer
is divisible by a certain number, typically to align with hardware
optimizations or memory layouts.
Args:
v (int): The input value.
divisor (int, optional): The divisor. Defaults to 8.
min_value (int, optional): The minimum value to return. If None, defaults to the divisor.
round_limit (float, optional): A threshold for rounding down. If the result of rounding down is less than round_limit * v, the next multiple of the divisor is returned instead. Defaults to 0.9.
Returns:
int: The smallest number >= v that is divisible by divisor.
"""
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < round_limit * v:
new_v += divisor
return new_v
def grad_cam_reshape_transform(tensor):
"""Transform the tensor for Grad-CAM calculation.
Args:
tensor (torch.Tensor): input tensor
Returns:
torch.Tensor: reshaped tensor without [CLS] token.
"""
n_squ = tensor.shape[1]
result = tensor[:, 1:] if int(sqrt(n_squ)) ** 2 != n_squ else tensor
bs, n, dim = result.shape
result = result.reshape(bs, int(sqrt(n)), int(sqrt(n)), dim)
# Bring the channels to the first dimension,
# like in CNNs.
return result.transpose(2, 3).transpose(1, 2)

BIN
ForAug-supplementary.pdf Normal file

Binary file not shown.

BIN
ForAug.pdf Normal file

Binary file not shown.

111
aaai2026.bib Normal file
View File

@@ -0,0 +1,111 @@
@book{em:86,
editor = "Engelmore, Robert and Morgan, Anthony",
title = "Blackboard Systems",
year = 1986,
address = "Reading, Mass.",
publisher = "Addison-Wesley",
}
@inproceedings{c:83,
author = "Clancey, William J.",
year = 1983,
title = "{Communication, Simulation, and Intelligent
Agents: Implications of Personal Intelligent Machines
for Medical Education}",
booktitle="Proceedings of the Eighth International Joint Conference on Artificial Intelligence {(IJCAI-83)}",
pages = "556-560",
address = "Menlo Park, Calif",
publisher = "{IJCAI Organization}",
}
@inproceedings{c:84,
author = "Clancey, William J.",
year = 1984,
title = "{Classification Problem Solving}",
booktitle = "Proceedings of the Fourth National
Conference on Artificial Intelligence",
pages = "45-54",
address = "Menlo Park, Calif.",
publisher="AAAI Press",
}
@article{r:80,
author = {Robinson, Arthur L.},
title = {New Ways to Make Microcircuits Smaller},
volume = {208},
number = {4447},
pages = {1019--1022},
year = {1980},
doi = {10.1126/science.208.4447.1019},
publisher = {American Association for the Advancement of Science},
issn = {0036-8075},
URL = {https://science.sciencemag.org/content/208/4447/1019},
eprint = {https://science.sciencemag.org/content/208/4447/1019.full.pdf},
journal = {Science},
}
@article{r:80x,
author = "Robinson, Arthur L.",
year = 1980,
title = "{New Ways to Make Microcircuits Smaller---Duplicate Entry}",
journal = "Science",
volume = 208,
pages = "1019-1026",
}
@article{hcr:83,
title = {Strategic explanations for a diagnostic consultation system},
journal = {International Journal of Man-Machine Studies},
volume = {20},
number = {1},
pages = {3-19},
year = {1984},
issn = {0020-7373},
doi = {https://doi.org/10.1016/S0020-7373(84)80003-6},
url = {https://www.sciencedirect.com/science/article/pii/S0020737384800036},
author = {Diane Warner Hasling and William J. Clancey and Glenn Rennels},
abstract = {This article examines the problem of automatte explanation of reasoning, especially as it relates to expert systems. By explanation we mean the ability of a program to discuss what it is doing in some understandable way. We first present a general framework in which to view explanation and review some of the research done in this area. We then focus on the explanation system for NEOMYCIN, a medical consultation program. A consultation program interactively helps a user to solve a problem. Our goal is to have NEOMYCIN explain its problem-solving strategies. An explanation of strategy describes the plan the program is using to reach a solution. Such an explanation is usually concrete, referring to aspects of the current problem situation. Abstract explanations articulate a general principle, which can be applied in different situations; such explanations are useful in teaching and in explaining by analogy. We describe the aspects of NEOMYCIN that make abstract strategic explanations possible—the representation of strategic knowledge explicitly and separately from domain knowledge— and demonstrate how this representation can be used to generate explanations.}
}
@article{hcrt:83,
author = "Hasling, Diane Warner and Clancey, William J. and Rennels, Glenn R. and Test, Thomas",
year = 1983,
title = "{Strategic Explanations in Consultation---Duplicate}",
journal = "The International Journal of Man-Machine Studies",
volume = 20,
number = 1,
pages = "3-19",
}
@techreport{r:86,
author = "Rice, James",
year = 1986,
title = "{Poligon: A System for Parallel Problem Solving}",
type = "Technical Report",
number = "KSL-86-19",
institution = "Dept.\ of Computer Science, Stanford Univ.",
}
@phdthesis{c:79,
author = "Clancey, William J.",
year = 1979,
title = "{Transfer of Rule-Based Expertise
through a Tutorial Dialogue}",
type = "{Ph.D.} diss.",
school = "Dept.\ of Computer Science, Stanford Univ.",
address = "Stanford, Calif.",
}
@unpublished{c:21,
author = "Clancey, William J.",
title = "{The Engineering of Qualitative Models}",
year = 2021,
note = "Forthcoming",
}
@misc{c:22,
title={Attention Is All You Need},
author={Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
year={2017},
eprint={1706.03762},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{c:23,
title = "Pluto: The 'Other' Red Planet",
author = "{NASA}",
howpublished = "\url{https://www.nasa.gov/nh/pluto-the-other-red-planet}",
year = 2015,
note = "Accessed: 2018-12-06"
}

File diff suppressed because it is too large Load Diff

315
aaai2026.sty Normal file
View File

@@ -0,0 +1,315 @@
\NeedsTeXFormat{LaTeX2e}%
\ProvidesPackage{aaai2026}[2026/04/29 AAAI 2026 Submission format]%
\def\year{2026}%
\typeout{Conference Style for AAAI for LaTeX 2e -- version for submission}%
%
\def\copyright@on{T}
\def\showauthors@on{T}
\def\nocopyright{\gdef\copyright@on{}} % Copyright notice is required for camera-ready only.
\DeclareOption{submission}{%
\gdef\copyright@on{}%
\gdef\showauthors@on{}%
\long\gdef\pdfinfo #1{\relax}%
}%
\DeclareOption{draft}{%
\gdef\copyright@on{}%
}%
\ProcessOptions\relax%
% WARNING: IF YOU ARE USING THIS STYLE SHEET FOR AN AAAI PUBLICATION, YOU
% MAY NOT MODIFY IT FOR ANY REASON. MODIFICATIONS (IN YOUR SOURCE
% OR IN THIS STYLE SHEET WILL RESULT IN REJECTION OF YOUR PAPER).
%
% WARNING: This style is NOT guaranteed to work. It is provided in the
% hope that it might make the preparation of papers easier, but this style
% file is provided "as is" without warranty of any kind, either express or
% implied, including but not limited to the implied warranties of
% merchantability, fitness for a particular purpose, or noninfringement.
% You use this style file at your own risk. Standard disclaimers apply.
% There are undoubtably bugs in this style. If you would like to submit
% bug fixes, improvements, etc. please let us know. Please use the contact form
% at www.aaai.org.
%
% Do not use this file unless you are an experienced LaTeX user.
%
% PHYSICAL PAGE LAYOUT
\setlength\topmargin{-0.25in} \setlength\oddsidemargin{-0.25in}
\setlength\textheight{9.0in} \setlength\textwidth{7.0in}
\setlength\columnsep{0.375in} \newlength\titlebox \setlength\titlebox{2.25in}
\setlength\headheight{0pt} \setlength\headsep{0pt}
%\setlength\footheight{0pt} \setlength\footskip{0pt}
\thispagestyle{empty} \pagestyle{empty}
\flushbottom \twocolumn \sloppy
% We're never going to need a table of contents, so just flush it to
% save space --- suggested by drstrip@sandia-2
\def\addcontentsline#1#2#3{}
% gf: PRINT COPYRIGHT NOTICE
\def\copyright@year{\number\year}
\def\copyright@text{Copyright \copyright\space \copyright@year,
Association for the Advancement of Artificial Intelligence (www.aaai.org).
All rights reserved.}
\def\copyrighttext#1{\gdef\copyright@on{T}\gdef\copyright@text{#1}}
\def\copyrightyear#1{\gdef\copyright@on{T}\gdef\copyright@year{#1}}
% gf: End changes for copyright notice (used in \maketitle, below)
% Title stuff, taken from deproc.
%
\def\maketitle{%
\par%
\begingroup % to make the footnote style local to the title
\def\thefootnote{\fnsymbol{footnote}}
\twocolumn[\@maketitle] \@thanks%
\endgroup%
% Insert copyright slug unless turned off
\if T\copyright@on\insert\footins{\noindent\footnotesize\copyright@text}\fi%
%
\setcounter{footnote}{0}%
\let\maketitle\relax%
\let\@maketitle\relax%
\gdef\@thanks{}%
\gdef\@author{}%
\gdef\@title{}%
\let\thanks\relax%
}%
\long\gdef\affiliations #1{ \def \affiliations_{\if T\showauthors@on#1\fi}}%
%
\def\@maketitle{%
\def\theauthors{\if T\showauthors@on\@author\else Anonymous submission\fi}
\newcounter{eqfn}\setcounter{eqfn}{0}%
\newsavebox{\titlearea}
\sbox{\titlearea}{
\let\footnote\relax\let\thanks\relax%
\setcounter{footnote}{0}%
\def\equalcontrib{%
\ifnum\value{eqfn}=0%
\footnote{These authors contributed equally.}%
\setcounter{eqfn}{\value{footnote}}%
\else%
\footnotemark[\value{eqfn}]%
\fi%
}%
\vbox{%
\hsize\textwidth%
\linewidth\hsize%
\vskip 0.625in minus 0.125in%
\centering%
{\LARGE\bf \@title \par}%
\vskip 0.1in plus 0.5fil minus 0.05in%
{\Large{\textbf{\theauthors\ifhmode\\\fi}}}%
\vskip .2em plus 0.25fil%
{\normalsize \affiliations_\ifhmode\\\fi}%
\vskip 1em plus 2fil%
}%
}%
%
\newlength\actualheight%
\settoheight{\actualheight}{\usebox{\titlearea}}%
\ifdim\actualheight>\titlebox%
\setlength{\titlebox}{\actualheight}%
\fi%
%
\vbox to \titlebox {%
\let\footnote\thanks\relax%
\setcounter{footnote}{0}%
\def\equalcontrib{%
\ifnum\value{eqfn}=0%
\footnote{These authors contributed equally.}%
\setcounter{eqfn}{\value{footnote}}%
\else%
\footnotemark[\value{eqfn}]%
\fi%
}%
\hsize\textwidth%
\linewidth\hsize%
\vskip 0.625in minus 0.125in%
\centering%
{\LARGE\bf \@title \par}%
\vskip 0.1in plus 0.5fil minus 0.05in%
{\Large{\textbf{\theauthors\ifhmode\\\fi}}}%
\vskip .2em plus 0.25fil%
{\normalsize \affiliations_\ifhmode\\\fi}%
\vskip 1em plus 2fil%
}%
}%
%
\renewenvironment{abstract}{%
\centerline{\bf Abstract}%
\vspace{0.5ex}%
\setlength{\leftmargini}{10pt}%
\begin{quote}%
\small%
}{%
\par%
\end{quote}%
\vskip 1ex%
}%
\newenvironment{links}{%
\newcommand{\link}[2]{\par\textbf{##1} --- \url{##2}}%
\setlength{\hangindent}{10pt}%
\setlength{\parskip}{2pt}%
\begin{flushleft}%
}{%
\end{flushleft}%
\vskip 1ex%
}%
% jsp added:
\def\pubnote#1{
\thispagestyle{myheadings}%
\pagestyle{myheadings}%
\markboth{#1}{#1}%
\setlength\headheight{10pt}%
\setlength\headsep{10pt}%
}%
%
% SECTIONS with less space
\def\section{\@startsection {section}{1}{\z@}{-2.0ex plus
-0.5ex minus -.2ex}{3pt plus 2pt minus 1pt}{\Large\bf\centering}}
\def\subsection{\@startsection{subsection}{2}{\z@}{-2.0ex plus
-0.5ex minus -.2ex}{3pt plus 2pt minus 1pt}{\large\bf\raggedright}}
\def\subsubsection{\@startsection{subparagraph}{3}{\z@}{-6pt plus
%%% DIEGO changed: 29/11/2009
%% 2pt minus 1pt}{-1em}{\normalsize\bf}}
-2pt minus -1pt}{-1em}{\normalsize\bf}}
%%% END changed
\renewcommand\paragraph{\@startsection{paragraph}{4}{\z@}{-6pt plus -2pt minus -1pt}{-1em}{\normalsize\bf}}%
\setcounter{secnumdepth}{0}
% add period to section (but not subsection) numbers, reduce space after
%\renewcommand{\thesection}
% {\arabic{section}.\hskip-0.6em}
%\renewcommand{\thesubsection}
% {\arabic{section}.\arabic{subsection}\hskip-0.6em}
% FOOTNOTES
\footnotesep 6.65pt %
\skip\footins 9pt plus 4pt minus 2pt
\def\footnoterule{\kern-3pt \hrule width 5pc \kern 2.6pt }
\setcounter{footnote}{0}
% LISTS AND PARAGRAPHS
\parindent 10pt
\topsep 4pt plus 1pt minus 2pt
\partopsep 1pt plus 0.5pt minus 0.5pt
\itemsep 0.5pt plus 1pt minus 0.5pt
\parsep 2pt plus 1pt minus 0.5pt
\leftmargin 10pt \leftmargini 13pt \leftmarginii 10pt \leftmarginiii 5pt \leftmarginiv 5pt \leftmarginv 5pt \leftmarginvi 5pt
\labelwidth\leftmargini\advance\labelwidth-\labelsep \labelsep 5pt
\def\@listi{\leftmargin\leftmargini}
\def\@listii{\leftmargin\leftmarginii
\labelwidth\leftmarginii\advance\labelwidth-\labelsep
\topsep 2pt plus 1pt minus 0.5pt
\parsep 1pt plus 0.5pt minus 0.5pt
\itemsep \parsep}
\def\@listiii{\leftmargin\leftmarginiii
\labelwidth\leftmarginiii\advance\labelwidth-\labelsep
\topsep 1pt plus 0.5pt minus 0.5pt
\parsep \z@
\partopsep 0.5pt plus 0pt minus 0.5pt
\itemsep \topsep}
\def\@listiv{\leftmargin\leftmarginiv
\labelwidth\leftmarginiv\advance\labelwidth-\labelsep}
\def\@listv{\leftmargin\leftmarginv
\labelwidth\leftmarginv\advance\labelwidth-\labelsep}
\def\@listvi{\leftmargin\leftmarginvi
\labelwidth\leftmarginvi\advance\labelwidth-\labelsep}
\abovedisplayskip 7pt plus2pt minus5pt%
\belowdisplayskip \abovedisplayskip
\abovedisplayshortskip 0pt plus3pt%
\belowdisplayshortskip 4pt plus3pt minus3pt%
% Less leading in most fonts (due to the narrow columns)
% The choices were between 1-pt and 1.5-pt leading
\def\normalsize{\@setfontsize\normalsize\@xpt{11}} % 10 point on 11
\def\small{\@setfontsize\small\@ixpt{10}} % 9 point on 10
\def\footnotesize{\@setfontsize\footnotesize\@ixpt{10}} % 9 point on 10
\def\scriptsize{\@setfontsize\scriptsize\@viipt{10}} % 7 point on 8
\def\tiny{\@setfontsize\tiny\@vipt{7}} % 6 point on 7
\def\large{\@setfontsize\large\@xipt{12}} % 11 point on 12
\def\Large{\@setfontsize\Large\@xiipt{14}} % 12 point on 14
\def\LARGE{\@setfontsize\LARGE\@xivpt{16}} % 14 point on 16
\def\huge{\@setfontsize\huge\@xviipt{20}} % 17 point on 20
\def\Huge{\@setfontsize\Huge\@xxpt{23}} % 20 point on 23
\AtBeginDocument{%
\@ifpackageloaded{natbib}%
{%
% When natbib is in use, set the proper style and fix a few things
\let\cite\citep
\let\shortcite\citeyearpar
\setcitestyle{aysep={}}
\setlength\bibhang{0pt}
\bibliographystyle{aaai2026}
}{}%
\@ifpackageloaded{hyperref}%
{%
\PackageError{aaai}{You must not use hyperref in AAAI papers.}{You (or one of the packages you imported) are importing the hyperref package, which is forbidden in AAAI papers. You must remove it from the paper to proceed.}
}{}%
\@ifpackageloaded{bbm}%
{%
\PackageError{aaai}{You must not use bbm package in AAAI papers because it introduces Type 3 fonts which are forbidden.}{See https://tex.stackexchange.com/questions/479160/a-replacement-to-mathbbm1-with-type-1-fonts for possible alternatives.}
}{}%
\@ifpackageloaded{authblk}%
{%
\PackageError{aaai}{Package authblk is forbbidden.}{Package authblk is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{balance}%
{%
\PackageError{aaai}{Package balance is forbbidden.}{Package balance is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{CJK}%
{%
\PackageError{aaai}{Package CJK is forbbidden.}{Package CJK is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{flushend}%
{%
\PackageError{aaai}{Package flushend is forbbidden.}{Package flushend is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{fontenc}%
{%
\PackageError{aaai}{Package fontenc is forbbidden.}{Package fontenc is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{fullpage}%
{%
\PackageError{aaai}{Package fullpage is forbbidden.}{Package fullpage is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{geometry}%
{%
\PackageError{aaai}{Package geometry is forbbidden.}{Package geometry is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{grffile}%
{%
\PackageError{aaai}{Package grffile is forbbidden.}{Package grffile is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{navigator}%
{%
\PackageError{aaai}{Package navigator is forbbidden.}{Package navigator is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{savetrees}%
{%
\PackageError{aaai}{Package savetrees is forbbidden.}{Package savetrees is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{setspace}%
{%
\PackageError{aaai}{Package setspace is forbbidden.}{Package setspace is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{stfloats}%
{%
\PackageError{aaai}{Package stfloats is forbbidden.}{Package stfloats is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{tabu}%
{%
\PackageError{aaai}{Package tabu is forbbidden.}{Package tabu is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{titlesec}%
{%
\PackageError{aaai}{Package titlesec is forbbidden.}{Package titlesec is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{tocbibind}%
{%
\PackageError{aaai}{Package tocbibind is forbbidden.}{Package tocbibind is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{ulem}%
{%
\PackageError{aaai}{Package ulem is forbbidden.}{Package ulem is forbbiden. You must find an alternative.}
}{}%
\@ifpackageloaded{wrapfig}%
{%
\PackageError{aaai}{Package wrapfig is forbbidden.}{Package wrapfig is forbbiden. You must find an alternative.}
}{}%
}
\let\endthebibliography=\endlist

View File

@@ -1,79 +0,0 @@
% ALGORITHM STYLE -- Released 8 April 1996
% for LaTeX-2e
% Copyright -- 1994 Peter Williams
% E-mail Peter.Williams@dsto.defence.gov.au
\NeedsTeXFormat{LaTeX2e}
\ProvidesPackage{algorithm}
\typeout{Document Style `algorithm' - floating environment}
\RequirePackage{float}
\RequirePackage{ifthen}
\newcommand{\ALG@within}{nothing}
\newboolean{ALG@within}
\setboolean{ALG@within}{false}
\newcommand{\ALG@floatstyle}{ruled}
\newcommand{\ALG@name}{Algorithm}
\newcommand{\listalgorithmname}{List of \ALG@name s}
% Declare Options
% first appearance
\DeclareOption{plain}{
\renewcommand{\ALG@floatstyle}{plain}
}
\DeclareOption{ruled}{
\renewcommand{\ALG@floatstyle}{ruled}
}
\DeclareOption{boxed}{
\renewcommand{\ALG@floatstyle}{boxed}
}
% then numbering convention
\DeclareOption{part}{
\renewcommand{\ALG@within}{part}
\setboolean{ALG@within}{true}
}
\DeclareOption{chapter}{
\renewcommand{\ALG@within}{chapter}
\setboolean{ALG@within}{true}
}
\DeclareOption{section}{
\renewcommand{\ALG@within}{section}
\setboolean{ALG@within}{true}
}
\DeclareOption{subsection}{
\renewcommand{\ALG@within}{subsection}
\setboolean{ALG@within}{true}
}
\DeclareOption{subsubsection}{
\renewcommand{\ALG@within}{subsubsection}
\setboolean{ALG@within}{true}
}
\DeclareOption{nothing}{
\renewcommand{\ALG@within}{nothing}
\setboolean{ALG@within}{true}
}
\DeclareOption*{\edef\ALG@name{\CurrentOption}}
% ALGORITHM
%
\ProcessOptions
\floatstyle{\ALG@floatstyle}
\ifthenelse{\boolean{ALG@within}}{
\ifthenelse{\equal{\ALG@within}{part}}
{\newfloat{algorithm}{htbp}{loa}[part]}{}
\ifthenelse{\equal{\ALG@within}{chapter}}
{\newfloat{algorithm}{htbp}{loa}[chapter]}{}
\ifthenelse{\equal{\ALG@within}{section}}
{\newfloat{algorithm}{htbp}{loa}[section]}{}
\ifthenelse{\equal{\ALG@within}{subsection}}
{\newfloat{algorithm}{htbp}{loa}[subsection]}{}
\ifthenelse{\equal{\ALG@within}{subsubsection}}
{\newfloat{algorithm}{htbp}{loa}[subsubsection]}{}
\ifthenelse{\equal{\ALG@within}{nothing}}
{\newfloat{algorithm}{htbp}{loa}}{}
}{
\newfloat{algorithm}{htbp}{loa}
}
\floatname{algorithm}{\ALG@name}
\newcommand{\listofalgorithms}{\listof{algorithm}{\listalgorithmname}}

View File

@@ -1,201 +0,0 @@
% ALGORITHMIC STYLE -- Released 8 APRIL 1996
% for LaTeX version 2e
% Copyright -- 1994 Peter Williams
% E-mail PeterWilliams@dsto.defence.gov.au
%
% Modified by Alex Smola (08/2000)
% E-mail Alex.Smola@anu.edu.au
%
\NeedsTeXFormat{LaTeX2e}
\ProvidesPackage{algorithmic}
\typeout{Document Style `algorithmic' - environment}
%
\RequirePackage{ifthen}
\RequirePackage{calc}
\newboolean{ALC@noend}
\setboolean{ALC@noend}{false}
\newcounter{ALC@line}
\newcounter{ALC@rem}
\newlength{\ALC@tlm}
%
\DeclareOption{noend}{\setboolean{ALC@noend}{true}}
%
\ProcessOptions
%
% ALGORITHMIC
\newcommand{\algorithmicrequire}{\textbf{Require:}}
\newcommand{\algorithmicensure}{\textbf{Ensure:}}
\newcommand{\algorithmiccomment}[1]{\{#1\}}
\newcommand{\algorithmicend}{\textbf{end}}
\newcommand{\algorithmicif}{\textbf{if}}
\newcommand{\algorithmicthen}{\textbf{then}}
\newcommand{\algorithmicelse}{\textbf{else}}
\newcommand{\algorithmicelsif}{\algorithmicelse\ \algorithmicif}
\newcommand{\algorithmicendif}{\algorithmicend\ \algorithmicif}
\newcommand{\algorithmicfor}{\textbf{for}}
\newcommand{\algorithmicforall}{\textbf{for all}}
\newcommand{\algorithmicdo}{\textbf{do}}
\newcommand{\algorithmicendfor}{\algorithmicend\ \algorithmicfor}
\newcommand{\algorithmicwhile}{\textbf{while}}
\newcommand{\algorithmicendwhile}{\algorithmicend\ \algorithmicwhile}
\newcommand{\algorithmicloop}{\textbf{loop}}
\newcommand{\algorithmicendloop}{\algorithmicend\ \algorithmicloop}
\newcommand{\algorithmicrepeat}{\textbf{repeat}}
\newcommand{\algorithmicuntil}{\textbf{until}}
%changed by alex smola
\newcommand{\algorithmicinput}{\textbf{input}}
\newcommand{\algorithmicoutput}{\textbf{output}}
\newcommand{\algorithmicset}{\textbf{set}}
\newcommand{\algorithmictrue}{\textbf{true}}
\newcommand{\algorithmicfalse}{\textbf{false}}
\newcommand{\algorithmicand}{\textbf{and\ }}
\newcommand{\algorithmicor}{\textbf{or\ }}
\newcommand{\algorithmicfunction}{\textbf{function}}
\newcommand{\algorithmicendfunction}{\algorithmicend\ \algorithmicfunction}
\newcommand{\algorithmicmain}{\textbf{main}}
\newcommand{\algorithmicendmain}{\algorithmicend\ \algorithmicmain}
%end changed by alex smola
\def\ALC@item[#1]{%
\if@noparitem \@donoparitem
\else \if@inlabel \indent \par \fi
\ifhmode \unskip\unskip \par \fi
\if@newlist \if@nobreak \@nbitem \else
\addpenalty\@beginparpenalty
\addvspace\@topsep \addvspace{-\parskip}\fi
\else \addpenalty\@itempenalty \addvspace\itemsep
\fi
\global\@inlabeltrue
\fi
\everypar{\global\@minipagefalse\global\@newlistfalse
\if@inlabel\global\@inlabelfalse \hskip -\parindent \box\@labels
\penalty\z@ \fi
\everypar{}}\global\@nobreakfalse
\if@noitemarg \@noitemargfalse \if@nmbrlist \refstepcounter{\@listctr}\fi \fi
\sbox\@tempboxa{\makelabel{#1}}%
\global\setbox\@labels
\hbox{\unhbox\@labels \hskip \itemindent
\hskip -\labelwidth \hskip -\ALC@tlm
\ifdim \wd\@tempboxa >\labelwidth
\box\@tempboxa
\else \hbox to\labelwidth {\unhbox\@tempboxa}\fi
\hskip \ALC@tlm}\ignorespaces}
%
\newenvironment{algorithmic}[1][0]{
\let\@item\ALC@item
\newcommand{\ALC@lno}{%
\ifthenelse{\equal{\arabic{ALC@rem}}{0}}
{{\footnotesize \arabic{ALC@line}:}}{}%
}
\let\@listii\@listi
\let\@listiii\@listi
\let\@listiv\@listi
\let\@listv\@listi
\let\@listvi\@listi
\let\@listvii\@listi
\newenvironment{ALC@g}{
\begin{list}{\ALC@lno}{ \itemsep\z@ \itemindent\z@
\listparindent\z@ \rightmargin\z@
\topsep\z@ \partopsep\z@ \parskip\z@\parsep\z@
\leftmargin 1em
\addtolength{\ALC@tlm}{\leftmargin}
}
}
{\end{list}}
\newcommand{\ALC@it}{\addtocounter{ALC@line}{1}\addtocounter{ALC@rem}{1}\ifthenelse{\equal{\arabic{ALC@rem}}{#1}}{\setcounter{ALC@rem}{0}}{}\item}
\newcommand{\ALC@com}[1]{\ifthenelse{\equal{##1}{default}}%
{}{\ \algorithmiccomment{##1}}}
\newcommand{\REQUIRE}{\item[\algorithmicrequire]}
\newcommand{\ENSURE}{\item[\algorithmicensure]}
\newcommand{\STATE}{\ALC@it}
\newcommand{\COMMENT}[1]{\algorithmiccomment{##1}}
%changes by alex smola
\newcommand{\INPUT}{\item[\algorithmicinput]}
\newcommand{\OUTPUT}{\item[\algorithmicoutput]}
\newcommand{\SET}{\item[\algorithmicset]}
% \newcommand{\TRUE}{\algorithmictrue}
% \newcommand{\FALSE}{\algorithmicfalse}
\newcommand{\AND}{\algorithmicand}
\newcommand{\OR}{\algorithmicor}
\newenvironment{ALC@func}{\begin{ALC@g}}{\end{ALC@g}}
\newenvironment{ALC@main}{\begin{ALC@g}}{\end{ALC@g}}
%end changes by alex smola
\newenvironment{ALC@if}{\begin{ALC@g}}{\end{ALC@g}}
\newenvironment{ALC@for}{\begin{ALC@g}}{\end{ALC@g}}
\newenvironment{ALC@whl}{\begin{ALC@g}}{\end{ALC@g}}
\newenvironment{ALC@loop}{\begin{ALC@g}}{\end{ALC@g}}
\newenvironment{ALC@rpt}{\begin{ALC@g}}{\end{ALC@g}}
\renewcommand{\\}{\@centercr}
\newcommand{\IF}[2][default]{\ALC@it\algorithmicif\ ##2\ \algorithmicthen%
\ALC@com{##1}\begin{ALC@if}}
\newcommand{\SHORTIF}[2]{\ALC@it\algorithmicif\ ##1\
\algorithmicthen\ {##2}}
\newcommand{\ELSE}[1][default]{\end{ALC@if}\ALC@it\algorithmicelse%
\ALC@com{##1}\begin{ALC@if}}
\newcommand{\ELSIF}[2][default]%
{\end{ALC@if}\ALC@it\algorithmicelsif\ ##2\ \algorithmicthen%
\ALC@com{##1}\begin{ALC@if}}
\newcommand{\FOR}[2][default]{\ALC@it\algorithmicfor\ ##2\ \algorithmicdo%
\ALC@com{##1}\begin{ALC@for}}
\newcommand{\FORALL}[2][default]{\ALC@it\algorithmicforall\ ##2\ %
\algorithmicdo%
\ALC@com{##1}\begin{ALC@for}}
\newcommand{\SHORTFORALL}[2]{\ALC@it\algorithmicforall\ ##1\ %
\algorithmicdo\ {##2}}
\newcommand{\WHILE}[2][default]{\ALC@it\algorithmicwhile\ ##2\ %
\algorithmicdo%
\ALC@com{##1}\begin{ALC@whl}}
\newcommand{\LOOP}[1][default]{\ALC@it\algorithmicloop%
\ALC@com{##1}\begin{ALC@loop}}
%changed by alex smola
\newcommand{\FUNCTION}[2][default]{\ALC@it\algorithmicfunction\ ##2\ %
\ALC@com{##1}\begin{ALC@func}}
\newcommand{\MAIN}[2][default]{\ALC@it\algorithmicmain\ ##2\ %
\ALC@com{##1}\begin{ALC@main}}
%end changed by alex smola
\newcommand{\REPEAT}[1][default]{\ALC@it\algorithmicrepeat%
\ALC@com{##1}\begin{ALC@rpt}}
\newcommand{\UNTIL}[1]{\end{ALC@rpt}\ALC@it\algorithmicuntil\ ##1}
\ifthenelse{\boolean{ALC@noend}}{
\newcommand{\ENDIF}{\end{ALC@if}}
\newcommand{\ENDFOR}{\end{ALC@for}}
\newcommand{\ENDWHILE}{\end{ALC@whl}}
\newcommand{\ENDLOOP}{\end{ALC@loop}}
\newcommand{\ENDFUNCTION}{\end{ALC@func}}
\newcommand{\ENDMAIN}{\end{ALC@main}}
}{
\newcommand{\ENDIF}{\end{ALC@if}\ALC@it\algorithmicendif}
\newcommand{\ENDFOR}{\end{ALC@for}\ALC@it\algorithmicendfor}
\newcommand{\ENDWHILE}{\end{ALC@whl}\ALC@it\algorithmicendwhile}
\newcommand{\ENDLOOP}{\end{ALC@loop}\ALC@it\algorithmicendloop}
\newcommand{\ENDFUNCTION}{\end{ALC@func}\ALC@it\algorithmicendfunction}
\newcommand{\ENDMAIN}{\end{ALC@main}\ALC@it\algorithmicendmain}
}
\renewcommand{\@toodeep}{}
\begin{list}{\ALC@lno}{\setcounter{ALC@line}{0}\setcounter{ALC@rem}{0}%
\itemsep\z@ \itemindent\z@ \listparindent\z@%
\partopsep\z@ \parskip\z@ \parsep\z@%
\labelsep 0.5em \topsep 0.2em%
\ifthenelse{\equal{#1}{0}}
{\labelwidth 0.5em }
{\labelwidth 1.2em }
\leftmargin\labelwidth \addtolength{\leftmargin}{\labelsep}
\ALC@tlm\labelsep
}
}
{\end{list}}

View File

@@ -0,0 +1,817 @@
%File: anonymous-submission-latex-2026.tex
\documentclass[letterpaper]{article} % DO NOT CHANGE THIS
\usepackage[submission]{aaai2026} % DO NOT CHANGE THIS
\usepackage{times} % DO NOT CHANGE THIS
\usepackage{helvet} % DO NOT CHANGE THIS
\usepackage{courier} % DO NOT CHANGE THIS
\usepackage[hyphens]{url} % DO NOT CHANGE THIS
\usepackage{graphicx} % DO NOT CHANGE THIS
\urlstyle{rm} % DO NOT CHANGE THIS
\def\UrlFont{\rm} % DO NOT CHANGE THIS
\usepackage{natbib} % DO NOT CHANGE THIS AND DO NOT ADD ANY OPTIONS TO IT
\usepackage{caption} % DO NOT CHANGE THIS AND DO NOT ADD ANY OPTIONS TO IT
\frenchspacing % DO NOT CHANGE THIS
\setlength{\pdfpagewidth}{8.5in} % DO NOT CHANGE THIS
\setlength{\pdfpageheight}{11in} % DO NOT CHANGE THIS
%
% These are recommended to typeset algorithms but not required. See the subsubsection on algorithms. Remove them if you don't have algorithms in your paper.
\usepackage{algorithm}
\usepackage{algorithmic}
%
% These are are recommended to typeset listings but not required. See the subsubsection on listing. Remove this block if you don't have listings in your paper.
\usepackage{newfloat}
\usepackage{listings}
\DeclareCaptionStyle{ruled}{labelfont=normalfont,labelsep=colon,strut=off} % DO NOT CHANGE THIS
\lstset{%
basicstyle={\footnotesize\ttfamily},% footnotesize acceptable for monospace
numbers=left,numberstyle=\footnotesize,xleftmargin=2em,% show line numbers, remove this entire line if you don't want the numbers.
aboveskip=0pt,belowskip=0pt,%
showstringspaces=false,tabsize=2,breaklines=true}
\floatstyle{ruled}
\newfloat{listing}{tb}{lst}{}
\floatname{listing}{Listing}
%
% Keep the \pdfinfo as shown here. There's no need
% for you to add the /Title and /Author tags.
\pdfinfo{
/TemplateVersion (2026.1)
}
% DISALLOWED PACKAGES
% \usepackage{authblk} -- This package is specifically forbidden
% \usepackage{balance} -- This package is specifically forbidden
% \usepackage{color (if used in text)
% \usepackage{CJK} -- This package is specifically forbidden
% \usepackage{float} -- This package is specifically forbidden
% \usepackage{flushend} -- This package is specifically forbidden
% \usepackage{fontenc} -- This package is specifically forbidden
% \usepackage{fullpage} -- This package is specifically forbidden
% \usepackage{geometry} -- This package is specifically forbidden
% \usepackage{grffile} -- This package is specifically forbidden
% \usepackage{hyperref} -- This package is specifically forbidden
% \usepackage{navigator} -- This package is specifically forbidden
% (or any other package that embeds links such as navigator or hyperref)
% \indentfirst} -- This package is specifically forbidden
% \layout} -- This package is specifically forbidden
% \multicol} -- This package is specifically forbidden
% \nameref} -- This package is specifically forbidden
% \usepackage{savetrees} -- This package is specifically forbidden
% \usepackage{setspace} -- This package is specifically forbidden
% \usepackage{stfloats} -- This package is specifically forbidden
% \usepackage{tabu} -- This package is specifically forbidden
% \usepackage{titlesec} -- This package is specifically forbidden
% \usepackage{tocbibind} -- This package is specifically forbidden
% \usepackage{ulem} -- This package is specifically forbidden
% \usepackage{wrapfig} -- This package is specifically forbidden
% DISALLOWED COMMANDS
% \nocopyright -- Your paper will not be published if you use this command
% \addtolength -- This command may not be used
% \balance -- This command may not be used
% \baselinestretch -- Your paper will not be published if you use this command
% \clearpage -- No page breaks of any kind may be used for the final version of your paper
% \columnsep -- This command may not be used
% \newpage -- No page breaks of any kind may be used for the final version of your paper
% \pagebreak -- No page breaks of any kind may be used for the final version of your paperr
% \pagestyle -- This command may not be used
% \tiny -- This is not an acceptable font size.
% \vspace{- -- No negative value may be used in proximity of a caption, figure, table, section, subsection, subsubsection, or reference
% \vskip{- -- No negative value may be used to alter spacing above or below a caption, figure, table, section, subsection, subsubsection, or reference
\setcounter{secnumdepth}{0} %May be changed to 1 or 2 if section numbers are desired.
% The file aaai2026.sty is the style file for AAAI Press
% proceedings, working notes, and technical reports.
%
% Title
% Your title must be in mixed case, not sentence case.
% That means all verbs (including short verbs like be, is, using,and go),
% nouns, adverbs, adjectives should be capitalized, including both words in hyphenated terms, while
% articles, conjunctions, and prepositions are lower case unless they
% directly follow a colon or long dash
\title{AAAI Press Anonymous Submission\\Instructions for Authors Using \LaTeX{}}
\author{
%Authors
% All authors must be in the same font size and format.
Written by AAAI Press Staff\textsuperscript{\rm 1}\thanks{With help from the AAAI Publications Committee.}\\
AAAI Style Contributions by Pater Patel Schneider,
Sunil Issar,\\
J. Scott Penberthy,
George Ferguson,
Hans Guesgen,
Francisco Cruz\equalcontrib,
Marc Pujol-Gonzalez\equalcontrib
}
\affiliations{
%Afiliations
\textsuperscript{\rm 1}Association for the Advancement of Artificial Intelligence\\
% If you have multiple authors and multiple affiliations
% use superscripts in text and roman font to identify them.
% For example,
% Sunil Issar\textsuperscript{\rm 2},
% J. Scott Penberthy\textsuperscript{\rm 3},
% George Ferguson\textsuperscript{\rm 4},
% Hans Guesgen\textsuperscript{\rm 5}
% Note that the comma should be placed after the superscript
1101 Pennsylvania Ave, NW Suite 300\\
Washington, DC 20004 USA\\
% email address must be in roman text type, not monospace or sans serif
proceedings-questions@aaai.org
%
% See more examples next
}
%Example, Single Author, ->> remove \iffalse,\fi and place them surrounding AAAI title to use it
\iffalse
\title{My Publication Title --- Single Author}
\author {
Author Name
}
\affiliations{
Affiliation\\
Affiliation Line 2\\
name@example.com
}
\fi
\iffalse
%Example, Multiple Authors, ->> remove \iffalse,\fi and place them surrounding AAAI title to use it
\title{My Publication Title --- Multiple Authors}
\author {
% Authors
First Author Name\textsuperscript{\rm 1},
Second Author Name\textsuperscript{\rm 2},
Third Author Name\textsuperscript{\rm 1}
}
\affiliations {
% Affiliations
\textsuperscript{\rm 1}Affiliation 1\\
\textsuperscript{\rm 2}Affiliation 2\\
firstAuthor@affiliation1.com, secondAuthor@affilation2.com, thirdAuthor@affiliation1.com
}
\fi
% REMOVE THIS: bibentry
% This is only needed to show inline citations in the guidelines document. You should not need it and can safely delete it.
\usepackage{bibentry}
% END REMOVE bibentry
\begin{document}
\maketitle
\begin{abstract}
AAAI creates proceedings, working notes, and technical reports directly from electronic source furnished by the authors. To ensure that all papers in the publication have a uniform appearance, authors must adhere to the following instructions.
\end{abstract}
% Uncomment the following to link to your code, datasets, an extended version or similar.
% You must keep this block between (not within) the abstract and the main body of the paper.
% \begin{links}
% \link{Code}{https://aaai.org/example/code}
% \link{Datasets}{https://aaai.org/example/datasets}
% \link{Extended version}{https://aaai.org/example/extended-version}
% \end{links}
\section{Preparing an Anonymous Submission}
This document details the formatting requirements for anonymous submissions. The requirements are the same as for camera ready papers but with a few notable differences:
\begin{itemize}
\item Anonymous submissions must not include the author names and affiliations. Write ``Anonymous Submission'' as the ``sole author'' and leave the affiliations empty.
\item The PDF document's metadata should be cleared with a metadata-cleaning tool before submitting it. This is to prevent leaked information from revealing your identity.
\item References must be anonymized whenever the reader can infer that they are to the authors' previous work.
\item AAAI's copyright notice should not be included as a footer in the first page.
\item Only the PDF version is required at this stage. No source versions will be requested, nor any copyright transfer form.
\end{itemize}
You can remove the copyright notice and ensure that your names aren't shown by including \texttt{submission} option when loading the \texttt{aaai2026} package:
\begin{quote}\begin{scriptsize}\begin{verbatim}
\documentclass[letterpaper]{article}
\usepackage[submission]{aaai2026}
\end{verbatim}\end{scriptsize}\end{quote}
The remainder of this document are the original camera-
ready instructions. Any contradiction of the above points
ought to be ignored while preparing anonymous submis-
sions.
\section{Camera-Ready Guidelines}
Congratulations on having a paper selected for inclusion in an AAAI Press proceedings or technical report! This document details the requirements necessary to get your accepted paper published using PDF\LaTeX{}. If you are using Microsoft Word, instructions are provided in a different document. AAAI Press does not support any other formatting software.
The instructions herein are provided as a general guide for experienced \LaTeX{} users. If you do not know how to use \LaTeX{}, please obtain assistance locally. AAAI cannot provide you with support and the accompanying style files are \textbf{not} guaranteed to work. If the results you obtain are not in accordance with the specifications you received, you must correct your source file to achieve the correct result.
These instructions are generic. Consequently, they do not include specific dates, page charges, and so forth. Please consult your specific written conference instructions for details regarding your submission. Please review the entire document for specific instructions that might apply to your particular situation. All authors must comply with the following:
\begin{itemize}
\item You must use the 2026 AAAI Press \LaTeX{} style file and the aaai2026.bst bibliography style files, which are located in the 2026 AAAI Author Kit (aaai2026.sty, aaai2026.bst).
\item You must complete, sign, and return by the deadline the AAAI copyright form (unless directed by AAAI Press to use the AAAI Distribution License instead).
\item You must read and format your paper source and PDF according to the formatting instructions for authors.
\item You must submit your electronic files and abstract using our electronic submission form \textbf{on time.}
\item You must pay any required page or formatting charges to AAAI Press so that they are received by the deadline.
\item You must check your paper before submitting it, ensuring that it compiles without error, and complies with the guidelines found in the AAAI Author Kit.
\end{itemize}
\section{Copyright}
All papers submitted for publication by AAAI Press must be accompanied by a valid signed copyright form. They must also contain the AAAI copyright notice at the bottom of the first page of the paper. There are no exceptions to these requirements. If you fail to provide us with a signed copyright form or disable the copyright notice, we will be unable to publish your paper. There are \textbf{no exceptions} to this policy. You will find a PDF version of the AAAI copyright form in the AAAI AuthorKit. Please see the specific instructions for your conference for submission details.
\section{Formatting Requirements in Brief}
We need source and PDF files that can be used in a variety of ways and can be output on a variety of devices. The design and appearance of the paper is strictly governed by the aaai style file (aaai2026.sty).
\textbf{You must not make any changes to the aaai style file, nor use any commands, packages, style files, or macros within your own paper that alter that design, including, but not limited to spacing, floats, margins, fonts, font size, and appearance.} AAAI imposes requirements on your source and PDF files that must be followed. Most of these requirements are based on our efforts to standardize conference manuscript properties and layout. All papers submitted to AAAI for publication will be recompiled for standardization purposes. Consequently, every paper submission must comply with the following requirements:
\begin{itemize}
\item Your .tex file must compile in PDF\LaTeX{} --- (you may not include .ps or .eps figure files.)
\item All fonts must be embedded in the PDF file --- including your figures.
\item Modifications to the style file, whether directly or via commands in your document may not ever be made, most especially when made in an effort to avoid extra page charges or make your paper fit in a specific number of pages.
\item No type 3 fonts may be used (even in illustrations).
\item You may not alter the spacing above and below captions, figures, headings, and subheadings.
\item You may not alter the font sizes of text elements, footnotes, heading elements, captions, or title information (for references and mathematics, please see the limited exceptions provided herein).
\item You may not alter the line spacing of text.
\item Your title must follow Title Case capitalization rules (not sentence case).
\item \LaTeX{} documents must use the Times or Nimbus font package (you may not use Computer Modern for the text of your paper).
\item No \LaTeX{} 209 documents may be used or submitted.
\item Your source must not require use of fonts for non-Roman alphabets within the text itself. If your paper includes symbols in other languages (such as, but not limited to, Arabic, Chinese, Hebrew, Japanese, Thai, Russian and other Cyrillic languages), you must restrict their use to bit-mapped figures. Fonts that require non-English language support (CID and Identity-H) must be converted to outlines or 300 dpi bitmap or removed from the document (even if they are in a graphics file embedded in the document).
\item Two-column format in AAAI style is required for all papers.
\item The paper size for final submission must be US letter without exception.
\item The source file must exactly match the PDF.
\item The document margins may not be exceeded (no overfull boxes).
\item The number of pages and the file size must be as specified for your event.
\item No document may be password protected.
\item Neither the PDFs nor the source may contain any embedded links or bookmarks (no hyperref or navigator packages).
\item Your source and PDF must not have any page numbers, footers, or headers (no pagestyle commands).
\item Your PDF must be compatible with Acrobat 5 or higher.
\item Your \LaTeX{} source file (excluding references) must consist of a \textbf{single} file (use of the ``input" command is not allowed.
\item Your graphics must be sized appropriately outside of \LaTeX{} (do not use the ``clip" or ``trim'' command) .
\end{itemize}
If you do not follow these requirements, your paper will be returned to you to correct the deficiencies.
\section{What Files to Submit}
You must submit the following items to ensure that your paper is published:
\begin{itemize}
\item A fully-compliant PDF file.
\item Your \LaTeX{} source file submitted as a \textbf{single} .tex file (do not use the ``input" command to include sections of your paper --- every section must be in the single source file). (The only allowable exception is .bib file, which should be included separately).
\item The bibliography (.bib) file(s).
\item Your source must compile on our system, which includes only standard \LaTeX{} 2020 TeXLive support files.
\item Only the graphics files used in compiling paper.
\item The \LaTeX{}-generated files (e.g. .aux, .bbl file, PDF, etc.).
\end{itemize}
Your \LaTeX{} source will be reviewed and recompiled on our system (if it does not compile, your paper will be returned to you. \textbf{Do not submit your source in multiple text files.} Your single \LaTeX{} source file must include all your text, your bibliography (formatted using aaai2026.bst), and any custom macros.
Your files should work without any supporting files (other than the program itself) on any computer with a standard \LaTeX{} distribution.
\textbf{Do not send files that are not actually used in the paper.} Avoid including any files not needed for compiling your paper, including, for example, this instructions file, unused graphics files, style files, additional material sent for the purpose of the paper review, intermediate build files and so forth.
\textbf{Obsolete style files.} The commands for some common packages (such as some used for algorithms), may have changed. Please be certain that you are not compiling your paper using old or obsolete style files.
\textbf{Final Archive.} Place your source files in a single archive which should be compressed using .zip. The final file size may not exceed 10 MB.
Name your source file with the last (family) name of the first author, even if that is not you.
\section{Using \LaTeX{} to Format Your Paper}
The latest version of the AAAI style file is available on AAAI's website. Download this file and place it in the \TeX\ search path. Placing it in the same directory as the paper should also work. You must download the latest version of the complete AAAI Author Kit so that you will have the latest instruction set and style file.
\subsection{Document Preamble}
In the \LaTeX{} source for your paper, you \textbf{must} place the following lines as shown in the example in this subsection. This command set-up is for three authors. Add or subtract author and address lines as necessary, and uncomment the portions that apply to you. In most instances, this is all you need to do to format your paper in the Times font. The helvet package will cause Helvetica to be used for sans serif. These files are part of the PSNFSS2e package, which is freely available from many Internet sites (and is often part of a standard installation).
Leave the setcounter for section number depth commented out and set at 0 unless you want to add section numbers to your paper. If you do add section numbers, you must uncomment this line and change the number to 1 (for section numbers), or 2 (for section and subsection numbers). The style file will not work properly with numbering of subsubsections, so do not use a number higher than 2.
\subsubsection{The Following Must Appear in Your Preamble}
\begin{quote}
\begin{scriptsize}\begin{verbatim}
\documentclass[letterpaper]{article}
% DO NOT CHANGE THIS
\usepackage[submission]{aaai2026} % DO NOT CHANGE THIS
\usepackage{times} % DO NOT CHANGE THIS
\usepackage{helvet} % DO NOT CHANGE THIS
\usepackage{courier} % DO NOT CHANGE THIS
\usepackage[hyphens]{url} % DO NOT CHANGE THIS
\usepackage{graphicx} % DO NOT CHANGE THIS
\urlstyle{rm} % DO NOT CHANGE THIS
\def\UrlFont{\rm} % DO NOT CHANGE THIS
\usepackage{graphicx} % DO NOT CHANGE THIS
\usepackage{natbib} % DO NOT CHANGE THIS
\usepackage{caption} % DO NOT CHANGE THIS
\frenchspacing % DO NOT CHANGE THIS
\setlength{\pdfpagewidth}{8.5in} % DO NOT CHANGE THIS
\setlength{\pdfpageheight}{11in} % DO NOT CHANGE THIS
%
% Keep the \pdfinfo as shown here. There's no need
% for you to add the /Title and /Author tags.
\pdfinfo{
/TemplateVersion (2026.1)
}
\end{verbatim}\end{scriptsize}
\end{quote}
\subsection{Preparing Your Paper}
After the preamble above, you should prepare your paper as follows:
\begin{quote}
\begin{scriptsize}\begin{verbatim}
\begin{document}
\maketitle
\begin{abstract}
%...
\end{abstract}\end{verbatim}\end{scriptsize}
\end{quote}
\noindent If you want to add links to the paper's code, dataset(s), and extended version or similar this is the place to add them, within a \emph{links} environment:
\begin{quote}%
\begin{scriptsize}\begin{verbatim}
\begin{links}
\link{Code}{https://aaai.org/example/guidelines}
\link{Datasets}{https://aaai.org/example/datasets}
\link{Extended version}{https://aaai.org/example}
\end{links}\end{verbatim}\end{scriptsize}
\end{quote}
\noindent Make sure that you do not de-anonymize yourself with these links.
\noindent You should then continue with the body of your paper. Your paper must conclude with the references, which should be inserted as follows:
\begin{quote}
\begin{scriptsize}\begin{verbatim}
% References and End of Paper
% These lines must be placed at the end of your paper
\bibliography{Bibliography-File}
\end{document}
\end{verbatim}\end{scriptsize}
\end{quote}
\begin{quote}
\begin{scriptsize}\begin{verbatim}
\begin{document}\\
\maketitle\\
...\\
\bibliography{Bibliography-File}\\
\end{document}\\
\end{verbatim}\end{scriptsize}
\end{quote}
\subsection{Commands and Packages That May Not Be Used}
\begin{table*}[t]
\centering
\begin{tabular}{l|l|l|l}
\textbackslash abovecaption &
\textbackslash abovedisplay &
\textbackslash addevensidemargin &
\textbackslash addsidemargin \\
\textbackslash addtolength &
\textbackslash baselinestretch &
\textbackslash belowcaption &
\textbackslash belowdisplay \\
\textbackslash break &
\textbackslash clearpage &
\textbackslash clip &
\textbackslash columnsep \\
\textbackslash float &
\textbackslash input &
\textbackslash input &
\textbackslash linespread \\
\textbackslash newpage &
\textbackslash pagebreak &
\textbackslash renewcommand &
\textbackslash setlength \\
\textbackslash text height &
\textbackslash tiny &
\textbackslash top margin &
\textbackslash trim \\
\textbackslash vskip\{- &
\textbackslash vspace\{- \\
\end{tabular}
%}
\caption{Commands that must not be used}
\label{table1}
\end{table*}
\begin{table}[t]
\centering
%\resizebox{.95\columnwidth}{!}{
\begin{tabular}{l|l|l|l}
authblk & babel & cjk & dvips \\
epsf & epsfig & euler & float \\
fullpage & geometry & graphics & hyperref \\
layout & linespread & lmodern & maltepaper \\
navigator & pdfcomment & pgfplots & psfig \\
pstricks & t1enc & titlesec & tocbind \\
ulem
\end{tabular}
\caption{LaTeX style packages that must not be used.}
\label{table2}
\end{table}
There are a number of packages, commands, scripts, and macros that are incompatable with aaai2026.sty. The common ones are listed in tables \ref{table1} and \ref{table2}. Generally, if a command, package, script, or macro alters floats, margins, fonts, sizing, linespacing, or the presentation of the references and citations, it is unacceptable. Note that negative vskip and vspace may not be used except in certain rare occurances, and may never be used around tables, figures, captions, sections, subsections, subsubsections, or references.
\subsection{Page Breaks}
For your final camera ready copy, you must not use any page break commands. References must flow directly after the text without breaks. Note that some conferences require references to be on a separate page during the review process. AAAI Press, however, does not require this condition for the final paper.
\subsection{Paper Size, Margins, and Column Width}
Papers must be formatted to print in two-column format on 8.5 x 11 inch US letter-sized paper. The margins must be exactly as follows:
\begin{itemize}
\item Top margin: 1.25 inches (first page), .75 inches (others)
\item Left margin: .75 inches
\item Right margin: .75 inches
\item Bottom margin: 1.25 inches
\end{itemize}
The default paper size in most installations of \LaTeX{} is A4. However, because we require that your electronic paper be formatted in US letter size, the preamble we have provided includes commands that alter the default to US letter size. Please note that using any other package to alter page size (such as, but not limited to the Geometry package) will result in your final paper being returned to you for correction.
\subsubsection{Column Width and Margins.}
To ensure maximum readability, your paper must include two columns. Each column should be 3.3 inches wide (slightly more than 3.25 inches), with a .375 inch (.952 cm) gutter of white space between the two columns. The aaai2026.sty file will automatically create these columns for you.
\subsection{Overlength Papers}
If your paper is too long and you resort to formatting tricks to make it fit, it is quite likely that it will be returned to you. The best way to retain readability if the paper is overlength is to cut text, figures, or tables. There are a few acceptable ways to reduce paper size that don't affect readability. First, turn on \textbackslash frenchspacing, which will reduce the space after periods. Next, move all your figures and tables to the top of the page. Consider removing less important portions of a figure. If you use \textbackslash centering instead of \textbackslash begin\{center\} in your figure environment, you can also buy some space. For mathematical environments, you may reduce fontsize {\bf but not below 6.5 point}.
Commands that alter page layout are forbidden. These include \textbackslash columnsep, \textbackslash float, \textbackslash topmargin, \textbackslash topskip, \textbackslash textheight, \textbackslash textwidth, \textbackslash oddsidemargin, and \textbackslash evensizemargin (this list is not exhaustive). If you alter page layout, you will be required to pay the page fee. Other commands that are questionable and may cause your paper to be rejected include \textbackslash parindent, and \textbackslash parskip. Commands that alter the space between sections are forbidden. The title sec package is not allowed. Regardless of the above, if your paper is obviously ``squeezed" it is not going to to be accepted. Options for reducing the length of a paper include reducing the size of your graphics, cutting text, or paying the extra page charge (if it is offered).
\subsection{Type Font and Size}
Your paper must be formatted in Times Roman or Nimbus. We will not accept papers formatted using Computer Modern or Palatino or some other font as the text or heading typeface. Sans serif, when used, should be Courier. Use Symbol or Lucida or Computer Modern for \textit{mathematics only. }
Do not use type 3 fonts for any portion of your paper, including graphics. Type 3 bitmapped fonts are designed for fixed resolution printers. Most print at 300 dpi even if the printer resolution is 1200 dpi or higher. They also often cause high resolution imagesetter devices to crash. Consequently, AAAI will not accept electronic files containing obsolete type 3 fonts. Files containing those fonts (even in graphics) will be rejected. (Authors using blackboard symbols must avoid packages that use type 3 fonts.)
Fortunately, there are effective workarounds that will prevent your file from embedding type 3 bitmapped fonts. The easiest workaround is to use the required times, helvet, and courier packages with \LaTeX{}2e. (Note that papers formatted in this way will still use Computer Modern for the mathematics. To make the math look good, you'll either have to use Symbol or Lucida, or you will need to install type 1 Computer Modern fonts --- for more on these fonts, see the section ``Obtaining Type 1 Computer Modern.")
If you are unsure if your paper contains type 3 fonts, view the PDF in Acrobat Reader. The Properties/Fonts window will display the font name, font type, and encoding properties of all the fonts in the document. If you are unsure if your graphics contain type 3 fonts (and they are PostScript or encapsulated PostScript documents), create PDF versions of them, and consult the properties window in Acrobat Reader.
The default size for your type must be ten-point with twelve-point leading (line spacing). Start all pages (except the first) directly under the top margin. (See the next section for instructions on formatting the title page.) Indent ten points when beginning a new paragraph, unless the paragraph begins directly below a heading or subheading.
\subsubsection{Obtaining Type 1 Computer Modern for \LaTeX{}.}
If you use Computer Modern for the mathematics in your paper (you cannot use it for the text) you may need to download type 1 Computer fonts. They are available without charge from the American Mathematical Society:
http://www.ams.org/tex/type1-fonts.html.
\subsubsection{Nonroman Fonts.}
If your paper includes symbols in other languages (such as, but not limited to, Arabic, Chinese, Hebrew, Japanese, Thai, Russian and other Cyrillic languages), you must restrict their use to bit-mapped figures.
\subsection{Title and Authors}
Your title must appear centered over both text columns in sixteen-point bold type (twenty-four point leading). The title must be written in Title Case according to the Chicago Manual of Style rules. The rules are a bit involved, but in general verbs (including short verbs like be, is, using, and go), nouns, adverbs, adjectives, and pronouns should be capitalized, (including both words in hyphenated terms), while articles, conjunctions, and prepositions are lower case unless they directly follow a colon or long dash. You can use the online tool \url{https://titlecaseconverter.com/} to double-check the proper capitalization (select the "Chicago" style and mark the "Show explanations" checkbox).
Author's names should appear below the title of the paper, centered in twelve-point type (with fifteen point leading), along with affiliation(s) and complete address(es) (including electronic mail address if available) in nine-point roman type (the twelve point leading). You should begin the two-column format when you come to the abstract.
\subsubsection{Formatting Author Information.}
Author information has to be set according to the following specification depending if you have one or more than one affiliation. You may not use a table nor may you employ the \textbackslash authorblk.sty package. For one or several authors from the same institution, please separate them with commas and write all affiliation directly below (one affiliation per line) using the macros \textbackslash author and \textbackslash affiliations:
\begin{quote}\begin{scriptsize}\begin{verbatim}
\author{
Author 1, ..., Author n\\
}
\affiliations {
Address line\\
... \\
Address line\\
}
\end{verbatim}\end{scriptsize}\end{quote}
\noindent For authors from different institutions, use \textbackslash textsuperscript \{\textbackslash rm x \} to match authors and affiliations. Notice that there should not be any spaces between the author name (or comma following it) and the superscript.
\begin{quote}\begin{scriptsize}\begin{verbatim}
\author{
AuthorOne\equalcontrib\textsuperscript{\rm 1,\rm 2},
AuthorTwo\equalcontrib\textsuperscript{\rm 2},
AuthorThree\textsuperscript{\rm 3},\\
AuthorFour\textsuperscript{\rm 4},
AuthorFive \textsuperscript{\rm 5}}
}
\affiliations {
\textsuperscript{\rm 1}AffiliationOne,\\
\textsuperscript{\rm 2}AffiliationTwo,\\
\textsuperscript{\rm 3}AffiliationThree,\\
\textsuperscript{\rm 4}AffiliationFour,\\
\textsuperscript{\rm 5}AffiliationFive\\
\{email, email\}@affiliation.com,
email@affiliation.com,
email@affiliation.com,
email@affiliation.com
}
\end{verbatim}\end{scriptsize}\end{quote}
You can indicate that some authors contributed equally using the \textbackslash equalcontrib command. This will add a marker after the author names and a footnote on the first page.
Note that you may want to break the author list for better visualization. You can achieve this using a simple line break (\textbackslash \textbackslash).
\subsection{\LaTeX{} Copyright Notice}
The copyright notice automatically appears if you use aaai2026.sty. It has been hardcoded and may not be disabled.
\subsection{Credits}
Any credits to a sponsoring agency should appear in the acknowledgments section, unless the agency requires different placement. If it is necessary to include this information on the front page, use
\textbackslash thanks in either the \textbackslash author or \textbackslash title commands.
For example:
\begin{quote}
\begin{small}
\textbackslash title\{Very Important Results in AI\textbackslash thanks\{This work is
supported by everybody.\}\}
\end{small}
\end{quote}
Multiple \textbackslash thanks commands can be given. Each will result in a separate footnote indication in the author or title with the corresponding text at the botton of the first column of the document. Note that the \textbackslash thanks command is fragile. You will need to use \textbackslash protect.
Please do not include \textbackslash pubnote commands in your document.
\subsection{Abstract}
Follow the example commands in this document for creation of your abstract. The command \textbackslash begin\{abstract\} will automatically indent the text block. Please do not indent it further. {Do not include references in your abstract!}
\subsection{Page Numbers}
Do not print any page numbers on your paper. The use of \textbackslash pagestyle is forbidden.
\subsection{Text}
The main body of the paper must be formatted in black, ten-point Times Roman with twelve-point leading (line spacing). You may not reduce font size or the linespacing. Commands that alter font size or line spacing (including, but not limited to baselinestretch, baselineshift, linespread, and others) are expressly forbidden. In addition, you may not use color in the text.
\subsection{Citations}
Citations within the text should include the author's last name and year, for example (Newell 1980). Append lower-case letters to the year in cases of ambiguity. Multiple authors should be treated as follows: (Feigenbaum and Engelmore 1988) or (Ford, Hayes, and Glymour 1992). In the case of four or more authors, list only the first author, followed by et al. (Ford et al. 1997).
\subsection{Extracts}
Long quotations and extracts should be indented ten points from the left and right margins.
\begin{quote}
This is an example of an extract or quotation. Note the indent on both sides. Quotation marks are not necessary if you offset the text in a block like this, and properly identify and cite the quotation in the text.
\end{quote}
\subsection{Footnotes}
Use footnotes judiciously, taking into account that they interrupt the reading of the text. When required, they should be consecutively numbered throughout with superscript Arabic numbers. Footnotes should appear at the bottom of the page, separated from the text by a blank line space and a thin, half-point rule.
\subsection{Headings and Sections}
When necessary, headings should be used to separate major sections of your paper. Remember, you are writing a short paper, not a lengthy book! An overabundance of headings will tend to make your paper look more like an outline than a paper. The aaai2026.sty package will create headings for you. Do not alter their size nor their spacing above or below.
\subsubsection{Section Numbers.}
The use of section numbers in AAAI Press papers is optional. To use section numbers in \LaTeX{}, uncomment the setcounter line in your document preamble and change the 0 to a 1. Section numbers should not be used in short poster papers and/or extended abstracts.
\subsubsection{Section Headings.}
Sections should be arranged and headed as follows:
\begin{enumerate}
\item Main content sections
\item Appendices (optional)
\item Ethical Statement (optional, unnumbered)
\item Acknowledgements (optional, unnumbered)
\item References (unnumbered)
\end{enumerate}
\subsubsection{Appendices.}
Any appendices must appear after the main content. If your main sections are numbered, appendix sections must use letters instead of arabic numerals. In \LaTeX{} you can use the \texttt{\textbackslash appendix} command to achieve this effect and then use \texttt{\textbackslash section\{Heading\}} normally for your appendix sections.
\subsubsection{Ethical Statement.}
You can write a statement about the potential ethical impact of your work, including its broad societal implications, both positive and negative. If included, such statement must be written in an unnumbered section titled \emph{Ethical Statement}.
\subsubsection{Acknowledgments.}
The acknowledgments section, if included, appears right before the references and is headed ``Acknowledgments". It must not be numbered even if other sections are (use \texttt{\textbackslash section*\{Acknowledgements\}} in \LaTeX{}). This section includes acknowledgments of help from associates and colleagues, credits to sponsoring agencies, financial support, and permission to publish. Please acknowledge other contributors, grant support, and so forth, in this section. Do not put acknowledgments in a footnote on the first page. If your grant agency requires acknowledgment of the grant on page 1, limit the footnote to the required statement, and put the remaining acknowledgments at the back. Please try to limit acknowledgments to no more than three sentences.
\subsubsection{References.}
The references section should be labeled ``References" and must appear at the very end of the paper (don't end the paper with references, and then put a figure by itself on the last page). A sample list of references is given later on in these instructions. Please use a consistent format for references. Poorly prepared or sloppy references reflect badly on the quality of your paper and your research. Please prepare complete and accurate citations.
\subsection{Illustrations and Figures}
\begin{figure}[t]
\centering
\includegraphics[width=0.9\columnwidth]{figure1} % Reduce the figure size so that it is slightly narrower than the column. Don't use precise values for figure width.This setup will avoid overfull boxes.
\caption{Using the trim and clip commands produces fragile layers that can result in disasters (like this one from an actual paper) when the color space is corrected or the PDF combined with others for the final proceedings. Crop your figures properly in a graphics program -- not in LaTeX.}
\label{fig1}
\end{figure}
\begin{figure*}[t]
\centering
\includegraphics[width=0.8\textwidth]{figure2} % Reduce the figure size so that it is slightly narrower than the column.
\caption{Adjusting the bounding box instead of actually removing the unwanted data resulted multiple layers in this paper. It also needlessly increased the PDF size. In this case, the size of the unwanted layer doubled the paper's size, and produced the following surprising results in final production. Crop your figures properly in a graphics program. Don't just alter the bounding box.}
\label{fig2}
\end{figure*}
% Using the \centering command instead of \begin{center} ... \end{center} will save space
% Positioning your figure at the top of the page will save space and make the paper more readable
% Using 0.95\columnwidth in conjunction with the
Your paper must compile in PDF\LaTeX{}. Consequently, all your figures must be .jpg, .png, or .pdf. You may not use the .gif (the resolution is too low), .ps, or .eps file format for your figures.
Figures, drawings, tables, and photographs should be placed throughout the paper on the page (or the subsequent page) where they are first discussed. Do not group them together at the end of the paper. If placed at the top of the paper, illustrations may run across both columns. Figures must not invade the top, bottom, or side margin areas. Figures must be inserted using the \textbackslash usepackage\{graphicx\}. Number figures sequentially, for example, figure 1, and so on. Do not use minipage to group figures.
If you normally create your figures using pgfplots, please create the figures first, and then import them as pdfs with proper bounding boxes, as the bounding and trim boxes created by pfgplots are fragile and not valid.
When you include your figures, you must crop them \textbf{outside} of \LaTeX{}. The command \textbackslash includegraphics*[clip=true, viewport 0 0 10 10]{...} might result in a PDF that looks great, but the image is \textbf{not really cropped.} The full image can reappear (and obscure whatever it is overlapping) when page numbers are applied or color space is standardized. Figures \ref{fig1}, and \ref{fig2} display some unwanted results that often occur.
If your paper includes illustrations that are not compatible with PDF\TeX{} (such as .eps or .ps documents), you will need to convert them. The epstopdf package will usually work for eps files. You will need to convert your ps files to PDF in either case.
\subsubsection {Figure Captions.}The illustration number and caption must appear \textit{under} the illustration. Labels and other text with the actual illustration must be at least nine-point type. However, the font and size of figure captions must be 10 point roman. Do not make them smaller, bold, or italic. (Individual words may be italicized if the context requires differentiation.)
\subsection{Tables}
\subsection{Tables}
Tables should be presented in 10 point roman type. If necessary, they may be altered to 9 point type. You must not use \texttt{\textbackslash resizebox} or other commands that resize the entire table to make it smaller, because you can't control the final font size this way.
If your table is too large you can use \texttt{\textbackslash setlength\{\textbackslash tabcolsep\}\{1mm\}} to compress the columns a bit or you can adapt the content (e.g.: reduce the decimal precision when presenting numbers, use shortened column titles, make some column duble-line to get it narrower).
Tables that do not fit in a single column must be placed across double columns. If your table won't fit within the margins even when spanning both columns and using the above techniques, you must split it in two separate tables.
\subsubsection {Table Captions.} The number and caption for your table must appear \textit{under} (not above) the table. Additionally, the font and size of table captions must be 10 point roman and must be placed beneath the figure. Do not make them smaller, bold, or italic. (Individual words may be italicized if the context requires differentiation.)
\subsubsection{Low-Resolution Bitmaps.}
You may not use low-resolution (such as 72 dpi) screen-dumps and GIF files---these files contain so few pixels that they are always blurry, and illegible when printed. If they are color, they will become an indecipherable mess when converted to black and white. This is always the case with gif files, which should never be used. The resolution of screen dumps can be increased by reducing the print size of the original file while retaining the same number of pixels. You can also enlarge files by manipulating them in software such as PhotoShop. Your figures should be 300 dpi when incorporated into your document.
\subsubsection{\LaTeX{} Overflow.}
\LaTeX{} users please beware: \LaTeX{} will sometimes put portions of the figure or table or an equation in the margin. If this happens, you need to make the figure or table span both columns. If absolutely necessary, you may reduce the figure, or reformat the equation, or reconfigure the table.{ \bf Check your log file!} You must fix any overflow into the margin (that means no overfull boxes in \LaTeX{}). \textbf{Nothing is permitted to intrude into the margin or gutter.}
\subsubsection{Using Color.}
Use of color is restricted to figures only. It must be WACG 2.0 compliant. (That is, the contrast ratio must be greater than 4.5:1 no matter the font size.) It must be CMYK, NOT RGB. It may never be used for any portion of the text of your paper. The archival version of your paper will be printed in black and white and grayscale. The web version must be readable by persons with disabilities. Consequently, because conversion to grayscale can cause undesirable effects (red changes to black, yellow can disappear, and so forth), we strongly suggest you avoid placing color figures in your document. If you do include color figures, you must (1) use the CMYK (not RGB) colorspace and (2) be mindful of readers who may happen to have trouble distinguishing colors. Your paper must be decipherable without using color for distinction.
\subsubsection{Drawings.}
We suggest you use computer drawing software (such as Adobe Illustrator or, (if unavoidable), the drawing tools in Microsoft Word) to create your illustrations. Do not use Microsoft Publisher. These illustrations will look best if all line widths are uniform (half- to two-point in size), and you do not create labels over shaded areas. Shading should be 133 lines per inch if possible. Use Times Roman or Helvetica for all figure call-outs. \textbf{Do not use hairline width lines} --- be sure that the stroke width of all lines is at least .5 pt. Zero point lines will print on a laser printer, but will completely disappear on the high-resolution devices used by our printers.
\subsubsection{Photographs and Images.}
Photographs and other images should be in grayscale (color photographs will not reproduce well; for example, red tones will reproduce as black, yellow may turn to white, and so forth) and set to a minimum of 300 dpi. Do not prescreen images.
\subsubsection{Resizing Graphics.}
Resize your graphics \textbf{before} you include them with LaTeX. You may \textbf{not} use trim or clip options as part of your \textbackslash includegraphics command. Resize the media box of your PDF using a graphics program instead.
\subsubsection{Fonts in Your Illustrations.}
You must embed all fonts in your graphics before including them in your LaTeX document.
\subsubsection{Algorithms.}
Algorithms and/or programs are a special kind of figures. Like all illustrations, they should appear floated to the top (preferably) or bottom of the page. However, their caption should appear in the header, left-justified and enclosed between horizontal lines, as shown in Algorithm~\ref{alg:algorithm}. The algorithm body should be terminated with another horizontal line. It is up to the authors to decide whether to show line numbers or not, how to format comments, etc.
In \LaTeX{} algorithms may be typeset using the {\tt algorithm} and {\tt algorithmic} packages, but you can also use one of the many other packages for the task.
\begin{algorithm}[tb]
\caption{Example algorithm}
\label{alg:algorithm}
\textbf{Input}: Your algorithm's input\\
\textbf{Parameter}: Optional list of parameters\\
\textbf{Output}: Your algorithm's output
\begin{algorithmic}[1] %[1] enables line numbers
\STATE Let $t=0$.
\WHILE{condition}
\STATE Do some action.
\IF {conditional}
\STATE Perform task A.
\ELSE
\STATE Perform task B.
\ENDIF
\ENDWHILE
\STATE \textbf{return} solution
\end{algorithmic}
\end{algorithm}
\subsubsection{Listings.}
Listings are much like algorithms and programs. They should also appear floated to the top (preferably) or bottom of the page. Listing captions should appear in the header, left-justified and enclosed between horizontal lines as shown in Listing~\ref{lst:listing}. Terminate the body with another horizontal line and avoid any background color. Line numbers, if included, must appear within the text column.
\begin{listing}[tb]%
\caption{Example listing {\tt quicksort.hs}}%
\label{lst:listing}%
\begin{lstlisting}[language=Haskell]
quicksort :: Ord a => [a] -> [a]
quicksort [] = []
quicksort (p:xs) = (quicksort lesser) ++ [p] ++ (quicksort greater)
where
lesser = filter (< p) xs
greater = filter (>= p) xs
\end{lstlisting}
\end{listing}
\subsection{References}
The AAAI style includes a set of definitions for use in formatting references with BibTeX. These definitions make the bibliography style fairly close to the ones specified in the Reference Examples appendix below. To use these definitions, you also need the BibTeX style file ``aaai2026.bst," available in the AAAI Author Kit on the AAAI web site. Then, at the end of your paper but before \textbackslash end{document}, you need to put the following lines:
\begin{quote}
\begin{small}
\textbackslash bibliography\{bibfile1,bibfile2,...\}
\end{small}
\end{quote}
Please note that the aaai2026.sty class already sets the bibliographystyle for you, so you do not have to place any \textbackslash bibliographystyle command in the document yourselves. The aaai2026.sty file is incompatible with the hyperref and navigator packages. If you use either, your references will be garbled and your paper will be returned to you.
References may be the same size as surrounding text.
However, in this section (only), you may reduce the size to {\em \textbackslash small} (9pt) if your paper exceeds the allowable number of pages. Making it any smaller than 9 point with 10 point linespacing, however, is not allowed.
The list of files in the \textbackslash bibliography command should be the names of your BibTeX source files (that is, the .bib files referenced in your paper).
The following commands are available for your use in citing references:
\begin{quote}
{\em \textbackslash cite:} Cites the given reference(s) with a full citation. This appears as ``(Author Year)'' for one reference, or ``(Author Year; Author Year)'' for multiple references.\smallskip\\
{\em \textbackslash shortcite:} Cites the given reference(s) with just the year. This appears as ``(Year)'' for one reference, or ``(Year; Year)'' for multiple references.\smallskip\\
{\em \textbackslash citeauthor:} Cites the given reference(s) with just the author name(s) and no parentheses.\smallskip\\
{\em \textbackslash citeyear:} Cites the given reference(s) with just the date(s) and no parentheses.
\end{quote}
You may also use any of the \emph{natbib} citation commands.
\section{Proofreading Your PDF}
Please check all the pages of your PDF file. The most commonly forgotten element is the acknowledgements --- especially the correct grant number. Authors also commonly forget to add the metadata to the source, use the wrong reference style file, or don't follow the capitalization rules or comma placement for their author-title information properly. A final common problem is text (expecially equations) that runs into the margin. You will need to fix these common errors before submitting your file.
\section{Improperly Formatted Files }
In the past, AAAI has corrected improperly formatted files submitted by the authors. Unfortunately, this has become an increasingly burdensome expense that we can no longer absorb). Consequently, if your file is improperly formatted, it will be returned to you for correction.
\section{Naming Your Electronic File}
We require that you name your \LaTeX{} source file with the last name (family name) of the first author so that it can easily be differentiated from other submissions. Complete file-naming instructions will be provided to you in the submission instructions.
\section{Submitting Your Electronic Files to AAAI}
Instructions on paper submittal will be provided to you in your acceptance letter.
\section{Inquiries}
If you have any questions about the preparation or submission of your paper as instructed in this document, please contact AAAI Press at the address given below. If you have technical questions about implementation of the aaai style file, please contact an expert at your site. We do not provide technical support for \LaTeX{} or any other software package. To avoid problems, please keep your paper simple, and do not incorporate complicated macros and style files.
\begin{quote}
\noindent AAAI Press\\
1101 Pennsylvania Ave, NW Suite 300\\
Washington, DC 20004 USA\\
\textit{Telephone:} 1-202-360-4062\\
\textit{E-mail:} See the submission instructions for your particular conference or event.
\end{quote}
\section{Additional Resources}
\LaTeX{} is a difficult program to master. If you've used that software, and this document didn't help or some items were not explained clearly, we recommend you read Michael Shell's excellent document (testflow doc.txt V1.0a 2002/08/13) about obtaining correct PS/PDF output on \LaTeX{} systems. (It was written for another purpose, but it has general application as well). It is available at www.ctan.org in the tex-archive.
\appendix
\section{Reference Examples}
\label{sec:reference_examples}
\nobibliography*
Formatted bibliographies should look like the following examples. You should use BibTeX to generate the references. Missing fields are unacceptable when compiling references, and usually indicate that you are using the wrong type of entry (BibTeX class).
\paragraph{Book with multiple authors~\nocite{em:86}} Use the \texttt{@book} class.\\[.2em]
\bibentry{em:86}.
\paragraph{Journal and magazine articles~\nocite{r:80, hcr:83}} Use the \texttt{@article} class.\\[.2em]
\bibentry{r:80}.\\[.2em]
\bibentry{hcr:83}.
\paragraph{Proceedings paper published by a society, press or publisher~\nocite{c:83, c:84}} Use the \texttt{@inproceedings} class. You may abbreviate the \emph{booktitle} field, but make sure that the conference edition is clear.\\[.2em]
\bibentry{c:84}.\\[.2em]
\bibentry{c:83}.
\paragraph{University technical report~\nocite{r:86}} Use the \texttt{@techreport} class.\\[.2em]
\bibentry{r:86}.
\paragraph{Dissertation or thesis~\nocite{c:79}} Use the \texttt{@phdthesis} class.\\[.2em]
\bibentry{c:79}.
\paragraph{Forthcoming publication~\nocite{c:21}} Use the \texttt{@misc} class with a \texttt{note="Forthcoming"} annotation.
\begin{quote}
\begin{footnotesize}
\begin{verbatim}
@misc(key,
[...]
note="Forthcoming",
)
\end{verbatim}
\end{footnotesize}
\end{quote}
\bibentry{c:21}.
\paragraph{ArXiv paper~\nocite{c:22}} Fetch the BibTeX entry from the "Export Bibtex Citation" link in the arXiv website. Notice it uses the \texttt{@misc} class instead of the \texttt{@article} one, and that it includes the \texttt{eprint} and \texttt{archivePrefix} keys.
\begin{quote}
\begin{footnotesize}
\begin{verbatim}
@misc(key,
[...]
eprint="xxxx.yyyy",
archivePrefix="arXiv",
)
\end{verbatim}
\end{footnotesize}
\end{quote}
\bibentry{c:22}.
\paragraph{Website or online resource~\nocite{c:23}} Use the \texttt{@misc} class. Add the url in the \texttt{howpublished} field and the date of access in the \texttt{note} field:
\begin{quote}
\begin{footnotesize}
\begin{verbatim}
@misc(key,
[...]
howpublished="\url{http://...}",
note="Accessed: YYYY-mm-dd",
)
\end{verbatim}
\end{footnotesize}
\end{quote}
\bibentry{c:23}.
\vspace{.2em}
For the most up to date version of the AAAI reference style, please consult the \textit{AI Magazine} Author Guidelines at \url{https://aaai.org/ojs/index.php/aimagazine/about/submissions#authorGuidelines}
\section{Acknowledgments}
AAAI is especially grateful to Peter Patel Schneider for his work in implementing the original aaai.sty file, liberally using the ideas of other style hackers, including Barbara Beeton. We also acknowledge with thanks the work of George Ferguson for his guide to using the style and BibTeX files --- which has been incorporated into this document --- and Hans Guesgen, who provided several timely modifications, as well as the many others who have, from time to time, sent in suggestions on improvements to the AAAI style. We are especially grateful to Francisco Cruz, Marc Pujol-Gonzalez, and Mico Loretan for the improvements to the Bib\TeX{} and \LaTeX{} files made in 2020.
The preparation of the \LaTeX{} and Bib\TeX{} files that implement these instructions was supported by Schlumberger Palo Alto Research, AT\&T Bell Laboratories, Morgan Kaufmann Publishers, The Live Oak Press, LLC, and AAAI Press. Bibliography style changes were added by Sunil Issar. \verb+\+pubnote was added by J. Scott Penberthy. George Ferguson added support for printing the AAAI copyright slug. Additional changes to aaai2026.sty and aaai2026.bst have been made by Francisco Cruz and Marc Pujol-Gonzalez.
\bigskip
\noindent Thank you for reading these instructions carefully. We look forward to receiving your electronic files!
\bibliography{aaai2026}
\end{document}

View File

@@ -1,485 +0,0 @@
% fancyhdr.sty version 3.2
% Fancy headers and footers for LaTeX.
% Piet van Oostrum,
% Dept of Computer and Information Sciences, University of Utrecht,
% Padualaan 14, P.O. Box 80.089, 3508 TB Utrecht, The Netherlands
% Telephone: +31 30 2532180. Email: piet@cs.uu.nl
% ========================================================================
% LICENCE:
% This file may be distributed under the terms of the LaTeX Project Public
% License, as described in lppl.txt in the base LaTeX distribution.
% Either version 1 or, at your option, any later version.
% ========================================================================
% MODIFICATION HISTORY:
% Sep 16, 1994
% version 1.4: Correction for use with \reversemargin
% Sep 29, 1994:
% version 1.5: Added the \iftopfloat, \ifbotfloat and \iffloatpage commands
% Oct 4, 1994:
% version 1.6: Reset single spacing in headers/footers for use with
% setspace.sty or doublespace.sty
% Oct 4, 1994:
% version 1.7: changed \let\@mkboth\markboth to
% \def\@mkboth{\protect\markboth} to make it more robust
% Dec 5, 1994:
% version 1.8: corrections for amsbook/amsart: define \@chapapp and (more
% importantly) use the \chapter/sectionmark definitions from ps@headings if
% they exist (which should be true for all standard classes).
% May 31, 1995:
% version 1.9: The proposed \renewcommand{\headrulewidth}{\iffloatpage...
% construction in the doc did not work properly with the fancyplain style.
% June 1, 1995:
% version 1.91: The definition of \@mkboth wasn't restored on subsequent
% \pagestyle{fancy}'s.
% June 1, 1995:
% version 1.92: The sequence \pagestyle{fancyplain} \pagestyle{plain}
% \pagestyle{fancy} would erroneously select the plain version.
% June 1, 1995:
% version 1.93: \fancypagestyle command added.
% Dec 11, 1995:
% version 1.94: suggested by Conrad Hughes <chughes@maths.tcd.ie>
% CJCH, Dec 11, 1995: added \footruleskip to allow control over footrule
% position (old hardcoded value of .3\normalbaselineskip is far too high
% when used with very small footer fonts).
% Jan 31, 1996:
% version 1.95: call \@normalsize in the reset code if that is defined,
% otherwise \normalsize.
% this is to solve a problem with ucthesis.cls, as this doesn't
% define \@currsize. Unfortunately for latex209 calling \normalsize doesn't
% work as this is optimized to do very little, so there \@normalsize should
% be called. Hopefully this code works for all versions of LaTeX known to
% mankind.
% April 25, 1996:
% version 1.96: initialize \headwidth to a magic (negative) value to catch
% most common cases that people change it before calling \pagestyle{fancy}.
% Note it can't be initialized when reading in this file, because
% \textwidth could be changed afterwards. This is quite probable.
% We also switch to \MakeUppercase rather than \uppercase and introduce a
% \nouppercase command for use in headers. and footers.
% May 3, 1996:
% version 1.97: Two changes:
% 1. Undo the change in version 1.8 (using the pagestyle{headings} defaults
% for the chapter and section marks. The current version of amsbook and
% amsart classes don't seem to need them anymore. Moreover the standard
% latex classes don't use \markboth if twoside isn't selected, and this is
% confusing as \leftmark doesn't work as expected.
% 2. include a call to \ps@empty in ps@@fancy. This is to solve a problem
% in the amsbook and amsart classes, that make global changes to \topskip,
% which are reset in \ps@empty. Hopefully this doesn't break other things.
% May 7, 1996:
% version 1.98:
% Added % after the line \def\nouppercase
% May 7, 1996:
% version 1.99: This is the alpha version of fancyhdr 2.0
% Introduced the new commands \fancyhead, \fancyfoot, and \fancyhf.
% Changed \headrulewidth, \footrulewidth, \footruleskip to
% macros rather than length parameters, In this way they can be
% conditionalized and they don't consume length registers. There is no need
% to have them as length registers unless you want to do calculations with
% them, which is unlikely. Note that this may make some uses of them
% incompatible (i.e. if you have a file that uses \setlength or \xxxx=)
% May 10, 1996:
% version 1.99a:
% Added a few more % signs
% May 10, 1996:
% version 1.99b:
% Changed the syntax of \f@nfor to be resistent to catcode changes of :=
% Removed the [1] from the defs of \lhead etc. because the parameter is
% consumed by the \@[xy]lhead etc. macros.
% June 24, 1997:
% version 1.99c:
% corrected \nouppercase to also include the protected form of \MakeUppercase
% \global added to manipulation of \headwidth.
% \iffootnote command added.
% Some comments added about \@fancyhead and \@fancyfoot.
% Aug 24, 1998
% version 1.99d
% Changed the default \ps@empty to \ps@@empty in order to allow
% \fancypagestyle{empty} redefinition.
% Oct 11, 2000
% version 2.0
% Added LPPL license clause.
%
% A check for \headheight is added. An errormessage is given (once) if the
% header is too large. Empty headers don't generate the error even if
% \headheight is very small or even 0pt.
% Warning added for the use of 'E' option when twoside option is not used.
% In this case the 'E' fields will never be used.
%
% Mar 10, 2002
% version 2.1beta
% New command: \fancyhfoffset[place]{length}
% defines offsets to be applied to the header/footer to let it stick into
% the margins (if length > 0).
% place is like in fancyhead, except that only E,O,L,R can be used.
% This replaces the old calculation based on \headwidth and the marginpar
% area.
% \headwidth will be dynamically calculated in the headers/footers when
% this is used.
%
% Mar 26, 2002
% version 2.1beta2
% \fancyhfoffset now also takes h,f as possible letters in the argument to
% allow the header and footer widths to be different.
% New commands \fancyheadoffset and \fancyfootoffset added comparable to
% \fancyhead and \fancyfoot.
% Errormessages and warnings have been made more informative.
%
% Dec 9, 2002
% version 2.1
% The defaults for \footrulewidth, \plainheadrulewidth and
% \plainfootrulewidth are changed from \z@skip to 0pt. In this way when
% someone inadvertantly uses \setlength to change any of these, the value
% of \z@skip will not be changed, rather an errormessage will be given.
% March 3, 2004
% Release of version 3.0
% Oct 7, 2004
% version 3.1
% Added '\endlinechar=13' to \fancy@reset to prevent problems with
% includegraphics in header when verbatiminput is active.
% March 22, 2005
% version 3.2
% reset \everypar (the real one) in \fancy@reset because spanish.ldf does
% strange things with \everypar between << and >>.
\def\ifancy@mpty#1{\def\temp@a{#1}\ifx\temp@a\@empty}
\def\fancy@def#1#2{\ifancy@mpty{#2}\fancy@gbl\def#1{\leavevmode}\else
\fancy@gbl\def#1{#2\strut}\fi}
\let\fancy@gbl\global
\def\@fancyerrmsg#1{%
\ifx\PackageError\undefined
\errmessage{#1}\else
\PackageError{Fancyhdr}{#1}{}\fi}
\def\@fancywarning#1{%
\ifx\PackageWarning\undefined
\errmessage{#1}\else
\PackageWarning{Fancyhdr}{#1}{}\fi}
% Usage: \@forc \var{charstring}{command to be executed for each char}
% This is similar to LaTeX's \@tfor, but expands the charstring.
\def\@forc#1#2#3{\expandafter\f@rc\expandafter#1\expandafter{#2}{#3}}
\def\f@rc#1#2#3{\def\temp@ty{#2}\ifx\@empty\temp@ty\else
\f@@rc#1#2\f@@rc{#3}\fi}
\def\f@@rc#1#2#3\f@@rc#4{\def#1{#2}#4\f@rc#1{#3}{#4}}
% Usage: \f@nfor\name:=list\do{body}
% Like LaTeX's \@for but an empty list is treated as a list with an empty
% element
\newcommand{\f@nfor}[3]{\edef\@fortmp{#2}%
\expandafter\@forloop#2,\@nil,\@nil\@@#1{#3}}
% Usage: \def@ult \cs{defaults}{argument}
% sets \cs to the characters from defaults appearing in argument
% or defaults if it would be empty. All characters are lowercased.
\newcommand\def@ult[3]{%
\edef\temp@a{\lowercase{\edef\noexpand\temp@a{#3}}}\temp@a
\def#1{}%
\@forc\tmpf@ra{#2}%
{\expandafter\if@in\tmpf@ra\temp@a{\edef#1{#1\tmpf@ra}}{}}%
\ifx\@empty#1\def#1{#2}\fi}
%
% \if@in <char><set><truecase><falsecase>
%
\newcommand{\if@in}[4]{%
\edef\temp@a{#2}\def\temp@b##1#1##2\temp@b{\def\temp@b{##1}}%
\expandafter\temp@b#2#1\temp@b\ifx\temp@a\temp@b #4\else #3\fi}
\newcommand{\fancyhead}{\@ifnextchar[{\f@ncyhf\fancyhead h}%
{\f@ncyhf\fancyhead h[]}}
\newcommand{\fancyfoot}{\@ifnextchar[{\f@ncyhf\fancyfoot f}%
{\f@ncyhf\fancyfoot f[]}}
\newcommand{\fancyhf}{\@ifnextchar[{\f@ncyhf\fancyhf{}}%
{\f@ncyhf\fancyhf{}[]}}
% New commands for offsets added
\newcommand{\fancyheadoffset}{\@ifnextchar[{\f@ncyhfoffs\fancyheadoffset h}%
{\f@ncyhfoffs\fancyheadoffset h[]}}
\newcommand{\fancyfootoffset}{\@ifnextchar[{\f@ncyhfoffs\fancyfootoffset f}%
{\f@ncyhfoffs\fancyfootoffset f[]}}
\newcommand{\fancyhfoffset}{\@ifnextchar[{\f@ncyhfoffs\fancyhfoffset{}}%
{\f@ncyhfoffs\fancyhfoffset{}[]}}
% The header and footer fields are stored in command sequences with
% names of the form: \f@ncy<x><y><z> with <x> for [eo], <y> from [lcr]
% and <z> from [hf].
\def\f@ncyhf#1#2[#3]#4{%
\def\temp@c{}%
\@forc\tmpf@ra{#3}%
{\expandafter\if@in\tmpf@ra{eolcrhf,EOLCRHF}%
{}{\edef\temp@c{\temp@c\tmpf@ra}}}%
\ifx\@empty\temp@c\else
\@fancyerrmsg{Illegal char `\temp@c' in \string#1 argument:
[#3]}%
\fi
\f@nfor\temp@c{#3}%
{\def@ult\f@@@eo{eo}\temp@c
\if@twoside\else
\if\f@@@eo e\@fancywarning
{\string#1's `E' option without twoside option is useless}\fi\fi
\def@ult\f@@@lcr{lcr}\temp@c
\def@ult\f@@@hf{hf}{#2\temp@c}%
\@forc\f@@eo\f@@@eo
{\@forc\f@@lcr\f@@@lcr
{\@forc\f@@hf\f@@@hf
{\expandafter\fancy@def\csname
f@ncy\f@@eo\f@@lcr\f@@hf\endcsname
{#4}}}}}}
\def\f@ncyhfoffs#1#2[#3]#4{%
\def\temp@c{}%
\@forc\tmpf@ra{#3}%
{\expandafter\if@in\tmpf@ra{eolrhf,EOLRHF}%
{}{\edef\temp@c{\temp@c\tmpf@ra}}}%
\ifx\@empty\temp@c\else
\@fancyerrmsg{Illegal char `\temp@c' in \string#1 argument:
[#3]}%
\fi
\f@nfor\temp@c{#3}%
{\def@ult\f@@@eo{eo}\temp@c
\if@twoside\else
\if\f@@@eo e\@fancywarning
{\string#1's `E' option without twoside option is useless}\fi\fi
\def@ult\f@@@lcr{lr}\temp@c
\def@ult\f@@@hf{hf}{#2\temp@c}%
\@forc\f@@eo\f@@@eo
{\@forc\f@@lcr\f@@@lcr
{\@forc\f@@hf\f@@@hf
{\expandafter\setlength\csname
f@ncyO@\f@@eo\f@@lcr\f@@hf\endcsname
{#4}}}}}%
\fancy@setoffs}
% Fancyheadings version 1 commands. These are more or less deprecated,
% but they continue to work.
\newcommand{\lhead}{\@ifnextchar[{\@xlhead}{\@ylhead}}
\def\@xlhead[#1]#2{\fancy@def\f@ncyelh{#1}\fancy@def\f@ncyolh{#2}}
\def\@ylhead#1{\fancy@def\f@ncyelh{#1}\fancy@def\f@ncyolh{#1}}
\newcommand{\chead}{\@ifnextchar[{\@xchead}{\@ychead}}
\def\@xchead[#1]#2{\fancy@def\f@ncyech{#1}\fancy@def\f@ncyoch{#2}}
\def\@ychead#1{\fancy@def\f@ncyech{#1}\fancy@def\f@ncyoch{#1}}
\newcommand{\rhead}{\@ifnextchar[{\@xrhead}{\@yrhead}}
\def\@xrhead[#1]#2{\fancy@def\f@ncyerh{#1}\fancy@def\f@ncyorh{#2}}
\def\@yrhead#1{\fancy@def\f@ncyerh{#1}\fancy@def\f@ncyorh{#1}}
\newcommand{\lfoot}{\@ifnextchar[{\@xlfoot}{\@ylfoot}}
\def\@xlfoot[#1]#2{\fancy@def\f@ncyelf{#1}\fancy@def\f@ncyolf{#2}}
\def\@ylfoot#1{\fancy@def\f@ncyelf{#1}\fancy@def\f@ncyolf{#1}}
\newcommand{\cfoot}{\@ifnextchar[{\@xcfoot}{\@ycfoot}}
\def\@xcfoot[#1]#2{\fancy@def\f@ncyecf{#1}\fancy@def\f@ncyocf{#2}}
\def\@ycfoot#1{\fancy@def\f@ncyecf{#1}\fancy@def\f@ncyocf{#1}}
\newcommand{\rfoot}{\@ifnextchar[{\@xrfoot}{\@yrfoot}}
\def\@xrfoot[#1]#2{\fancy@def\f@ncyerf{#1}\fancy@def\f@ncyorf{#2}}
\def\@yrfoot#1{\fancy@def\f@ncyerf{#1}\fancy@def\f@ncyorf{#1}}
\newlength{\fancy@headwidth}
\let\headwidth\fancy@headwidth
\newlength{\f@ncyO@elh}
\newlength{\f@ncyO@erh}
\newlength{\f@ncyO@olh}
\newlength{\f@ncyO@orh}
\newlength{\f@ncyO@elf}
\newlength{\f@ncyO@erf}
\newlength{\f@ncyO@olf}
\newlength{\f@ncyO@orf}
\newcommand{\headrulewidth}{0.4pt}
\newcommand{\footrulewidth}{0pt}
\newcommand{\footruleskip}{.3\normalbaselineskip}
% Fancyplain stuff shouldn't be used anymore (rather
% \fancypagestyle{plain} should be used), but it must be present for
% compatibility reasons.
\newcommand{\plainheadrulewidth}{0pt}
\newcommand{\plainfootrulewidth}{0pt}
\newif\if@fancyplain \@fancyplainfalse
\def\fancyplain#1#2{\if@fancyplain#1\else#2\fi}
\headwidth=-123456789sp %magic constant
% Command to reset various things in the headers:
% a.o. single spacing (taken from setspace.sty)
% and the catcode of ^^M (so that epsf files in the header work if a
% verbatim crosses a page boundary)
% It also defines a \nouppercase command that disables \uppercase and
% \Makeuppercase. It can only be used in the headers and footers.
\let\fnch@everypar\everypar% save real \everypar because of spanish.ldf
\def\fancy@reset{\fnch@everypar{}\restorecr\endlinechar=13
\def\baselinestretch{1}%
\def\nouppercase##1{{\let\uppercase\relax\let\MakeUppercase\relax
\expandafter\let\csname MakeUppercase \endcsname\relax##1}}%
\ifx\undefined\@newbaseline% NFSS not present; 2.09 or 2e
\ifx\@normalsize\undefined \normalsize % for ucthesis.cls
\else \@normalsize \fi
\else% NFSS (2.09) present
\@newbaseline%
\fi}
% Initialization of the head and foot text.
% The default values still contain \fancyplain for compatibility.
\fancyhf{} % clear all
% lefthead empty on ``plain'' pages, \rightmark on even, \leftmark on odd pages
% evenhead empty on ``plain'' pages, \leftmark on even, \rightmark on odd pages
\if@twoside
\fancyhead[el,or]{\fancyplain{}{\sl\rightmark}}
\fancyhead[er,ol]{\fancyplain{}{\sl\leftmark}}
\else
\fancyhead[l]{\fancyplain{}{\sl\rightmark}}
\fancyhead[r]{\fancyplain{}{\sl\leftmark}}
\fi
\fancyfoot[c]{\rm\thepage} % page number
% Use box 0 as a temp box and dimen 0 as temp dimen.
% This can be done, because this code will always
% be used inside another box, and therefore the changes are local.
\def\@fancyvbox#1#2{\setbox0\vbox{#2}\ifdim\ht0>#1\@fancywarning
{\string#1 is too small (\the#1): ^^J Make it at least \the\ht0.^^J
We now make it that large for the rest of the document.^^J
This may cause the page layout to be inconsistent, however\@gobble}%
\dimen0=#1\global\setlength{#1}{\ht0}\ht0=\dimen0\fi
\box0}
% Put together a header or footer given the left, center and
% right text, fillers at left and right and a rule.
% The \lap commands put the text into an hbox of zero size,
% so overlapping text does not generate an errormessage.
% These macros have 5 parameters:
% 1. LEFTSIDE BEARING % This determines at which side the header will stick
% out. When \fancyhfoffset is used this calculates \headwidth, otherwise
% it is \hss or \relax (after expansion).
% 2. \f@ncyolh, \f@ncyelh, \f@ncyolf or \f@ncyelf. This is the left component.
% 3. \f@ncyoch, \f@ncyech, \f@ncyocf or \f@ncyecf. This is the middle comp.
% 4. \f@ncyorh, \f@ncyerh, \f@ncyorf or \f@ncyerf. This is the right component.
% 5. RIGHTSIDE BEARING. This is always \relax or \hss (after expansion).
\def\@fancyhead#1#2#3#4#5{#1\hbox to\headwidth{\fancy@reset
\@fancyvbox\headheight{\hbox
{\rlap{\parbox[b]{\headwidth}{\raggedright#2}}\hfill
\parbox[b]{\headwidth}{\centering#3}\hfill
\llap{\parbox[b]{\headwidth}{\raggedleft#4}}}\headrule}}#5}
\def\@fancyfoot#1#2#3#4#5{#1\hbox to\headwidth{\fancy@reset
\@fancyvbox\footskip{\footrule
\hbox{\rlap{\parbox[t]{\headwidth}{\raggedright#2}}\hfill
\parbox[t]{\headwidth}{\centering#3}\hfill
\llap{\parbox[t]{\headwidth}{\raggedleft#4}}}}}#5}
\def\headrule{{\if@fancyplain\let\headrulewidth\plainheadrulewidth\fi
\hrule\@height\headrulewidth\@width\headwidth \vskip-\headrulewidth}}
\def\footrule{{\if@fancyplain\let\footrulewidth\plainfootrulewidth\fi
\vskip-\footruleskip\vskip-\footrulewidth
\hrule\@width\headwidth\@height\footrulewidth\vskip\footruleskip}}
\def\ps@fancy{%
\@ifundefined{@chapapp}{\let\@chapapp\chaptername}{}%for amsbook
%
% Define \MakeUppercase for old LaTeXen.
% Note: we used \def rather than \let, so that \let\uppercase\relax (from
% the version 1 documentation) will still work.
%
\@ifundefined{MakeUppercase}{\def\MakeUppercase{\uppercase}}{}%
\@ifundefined{chapter}{\def\sectionmark##1{\markboth
{\MakeUppercase{\ifnum \c@secnumdepth>\z@
\thesection\hskip 1em\relax \fi ##1}}{}}%
\def\subsectionmark##1{\markright {\ifnum \c@secnumdepth >\@ne
\thesubsection\hskip 1em\relax \fi ##1}}}%
{\def\chaptermark##1{\markboth {\MakeUppercase{\ifnum \c@secnumdepth>\m@ne
\@chapapp\ \thechapter. \ \fi ##1}}{}}%
\def\sectionmark##1{\markright{\MakeUppercase{\ifnum \c@secnumdepth >\z@
\thesection. \ \fi ##1}}}}%
%\csname ps@headings\endcsname % use \ps@headings defaults if they exist
\ps@@fancy
\gdef\ps@fancy{\@fancyplainfalse\ps@@fancy}%
% Initialize \headwidth if the user didn't
%
\ifdim\headwidth<0sp
%
% This catches the case that \headwidth hasn't been initialized and the
% case that the user added something to \headwidth in the expectation that
% it was initialized to \textwidth. We compensate this now. This loses if
% the user intended to multiply it by a factor. But that case is more
% likely done by saying something like \headwidth=1.2\textwidth.
% The doc says you have to change \headwidth after the first call to
% \pagestyle{fancy}. This code is just to catch the most common cases were
% that requirement is violated.
%
\global\advance\headwidth123456789sp\global\advance\headwidth\textwidth
\fi}
\def\ps@fancyplain{\ps@fancy \let\ps@plain\ps@plain@fancy}
\def\ps@plain@fancy{\@fancyplaintrue\ps@@fancy}
\let\ps@@empty\ps@empty
\def\ps@@fancy{%
\ps@@empty % This is for amsbook/amsart, which do strange things with \topskip
\def\@mkboth{\protect\markboth}%
\def\@oddhead{\@fancyhead\fancy@Oolh\f@ncyolh\f@ncyoch\f@ncyorh\fancy@Oorh}%
\def\@oddfoot{\@fancyfoot\fancy@Oolf\f@ncyolf\f@ncyocf\f@ncyorf\fancy@Oorf}%
\def\@evenhead{\@fancyhead\fancy@Oelh\f@ncyelh\f@ncyech\f@ncyerh\fancy@Oerh}%
\def\@evenfoot{\@fancyfoot\fancy@Oelf\f@ncyelf\f@ncyecf\f@ncyerf\fancy@Oerf}%
}
% Default definitions for compatibility mode:
% These cause the header/footer to take the defined \headwidth as width
% And to shift in the direction of the marginpar area
\def\fancy@Oolh{\if@reversemargin\hss\else\relax\fi}
\def\fancy@Oorh{\if@reversemargin\relax\else\hss\fi}
\let\fancy@Oelh\fancy@Oorh
\let\fancy@Oerh\fancy@Oolh
\let\fancy@Oolf\fancy@Oolh
\let\fancy@Oorf\fancy@Oorh
\let\fancy@Oelf\fancy@Oelh
\let\fancy@Oerf\fancy@Oerh
% New definitions for the use of \fancyhfoffset
% These calculate the \headwidth from \textwidth and the specified offsets.
\def\fancy@offsolh{\headwidth=\textwidth\advance\headwidth\f@ncyO@olh
\advance\headwidth\f@ncyO@orh\hskip-\f@ncyO@olh}
\def\fancy@offselh{\headwidth=\textwidth\advance\headwidth\f@ncyO@elh
\advance\headwidth\f@ncyO@erh\hskip-\f@ncyO@elh}
\def\fancy@offsolf{\headwidth=\textwidth\advance\headwidth\f@ncyO@olf
\advance\headwidth\f@ncyO@orf\hskip-\f@ncyO@olf}
\def\fancy@offself{\headwidth=\textwidth\advance\headwidth\f@ncyO@elf
\advance\headwidth\f@ncyO@erf\hskip-\f@ncyO@elf}
\def\fancy@setoffs{%
% Just in case \let\headwidth\textwidth was used
\fancy@gbl\let\headwidth\fancy@headwidth
\fancy@gbl\let\fancy@Oolh\fancy@offsolh
\fancy@gbl\let\fancy@Oelh\fancy@offselh
\fancy@gbl\let\fancy@Oorh\hss
\fancy@gbl\let\fancy@Oerh\hss
\fancy@gbl\let\fancy@Oolf\fancy@offsolf
\fancy@gbl\let\fancy@Oelf\fancy@offself
\fancy@gbl\let\fancy@Oorf\hss
\fancy@gbl\let\fancy@Oerf\hss}
\newif\iffootnote
\let\latex@makecol\@makecol
\def\@makecol{\ifvoid\footins\footnotetrue\else\footnotefalse\fi
\let\topfloat\@toplist\let\botfloat\@botlist\latex@makecol}
\def\iftopfloat#1#2{\ifx\topfloat\empty #2\else #1\fi}
\def\ifbotfloat#1#2{\ifx\botfloat\empty #2\else #1\fi}
\def\iffloatpage#1#2{\if@fcolmade #1\else #2\fi}
\newcommand{\fancypagestyle}[2]{%
\@namedef{ps@#1}{\let\fancy@gbl\relax#2\relax\ps@fancy}}

View File

@@ -1,805 +0,0 @@
% File: icml2024.sty (LaTeX style file for ICML-2024, version of 2023-11-23)
% This file contains the LaTeX formatting parameters for a two-column
% conference proceedings that is 8.5 inches wide by 11 inches high.
%
% Modified by Jonathan Scarlett 2024: changed years, volume, location
%
% Modified by Sivan Sabato 2023: changed years and volume number.
% Modified by Jonathan Scarlett 2023: added page numbers to every page
%
% Modified by Csaba Szepesvari 2022: changed years, PMLR ref. Turned off checking marginparwidth
% as marginparwidth only controls the space available for margin notes and margin notes
% will NEVER be used anyways in submitted versions, so there is no reason one should
% check whether marginparwidth has been tampered with.
% Also removed pdfview=FitH from hypersetup as it did not do its job; the default choice is a bit better
% but of course the double-column format is not supported by this hyperlink preview functionality
% in a completely satisfactory fashion.
% Modified by Gang Niu 2022: Changed color to xcolor
%
% Modified by Iain Murray 2018: changed years, location. Remove affiliation notes when anonymous.
% Move times dependency from .tex to .sty so fewer people delete it.
%
% Modified by Daniel Roy 2017: changed byline to use footnotes for affiliations, and removed emails
%
% Modified by Percy Liang 12/2/2013: changed the year, location from the previous template for ICML 2014
% Modified by Fei Sha 9/2/2013: changed the year, location form the previous template for ICML 2013
%
% Modified by Fei Sha 4/24/2013: (1) remove the extra whitespace after the first author's email address (in %the camera-ready version) (2) change the Proceeding ... of ICML 2010 to 2014 so PDF's metadata will show up % correctly
%
% Modified by Sanjoy Dasgupta, 2013: changed years, location
%
% Modified by Francesco Figari, 2012: changed years, location
%
% Modified by Christoph Sawade and Tobias Scheffer, 2011: added line
% numbers, changed years
%
% Modified by Hal Daume III, 2010: changed years, added hyperlinks
%
% Modified by Kiri Wagstaff, 2009: changed years
%
% Modified by Sam Roweis, 2008: changed years
%
% Modified by Ricardo Silva, 2007: update of the ifpdf verification
%
% Modified by Prasad Tadepalli and Andrew Moore, merely changing years.
%
% Modified by Kristian Kersting, 2005, based on Jennifer Dy's 2004 version
% - running title. If the original title is to long or is breaking a line,
% use \icmltitlerunning{...} in the preamble to supply a shorter form.
% Added fancyhdr package to get a running head.
% - Updated to store the page size because pdflatex does compile the
% page size into the pdf.
%
% Hacked by Terran Lane, 2003:
% - Updated to use LaTeX2e style file conventions (ProvidesPackage,
% etc.)
% - Added an ``appearing in'' block at the base of the first column
% (thus keeping the ``appearing in'' note out of the bottom margin
% where the printer should strip in the page numbers).
% - Added a package option [accepted] that selects between the ``Under
% review'' notice (default, when no option is specified) and the
% ``Appearing in'' notice (for use when the paper has been accepted
% and will appear).
%
% Originally created as: ml2k.sty (LaTeX style file for ICML-2000)
% by P. Langley (12/23/99)
%%%%%%%%%%%%%%%%%%%%
%% This version of the style file supports both a ``review'' version
%% and a ``final/accepted'' version. The difference is only in the
%% text that appears in the note at the bottom of the first column of
%% the first page. The default behavior is to print a note to the
%% effect that the paper is under review and don't distribute it. The
%% final/accepted version prints an ``Appearing in'' note. To get the
%% latter behavior, in the calling file change the ``usepackage'' line
%% from:
%% \usepackage{icml2024}
%% to
%% \usepackage[accepted]{icml2024}
%%%%%%%%%%%%%%%%%%%%
\NeedsTeXFormat{LaTeX2e}
\ProvidesPackage{icml2024}[2023/11/23 v2.0 ICML Conference Style File]
% Before 2018, \usepackage{times} was in the example TeX, but inevitably
% not everybody did it.
\RequirePackage{times}
% Use fancyhdr package
\RequirePackage{fancyhdr}
\RequirePackage{xcolor} % changed from color to xcolor (2021/11/24)
\RequirePackage{algorithm}
\RequirePackage{algorithmic}
\RequirePackage{natbib}
\RequirePackage{eso-pic} % used by \AddToShipoutPicture
\RequirePackage{forloop}
\RequirePackage{url}
%%%%%%%% Options
\DeclareOption{accepted}{%
\renewcommand{\Notice@String}{\ICML@appearing}
\gdef\isaccepted{1}
}
\DeclareOption{nohyperref}{%
\gdef\nohyperref{1}
}
%%%%%%%%%%%%%%%%%%%%
% This string is printed at the bottom of the page for the
% final/accepted version of the ``appearing in'' note. Modify it to
% change that text.
%%%%%%%%%%%%%%%%%%%%
\newcommand{\ICML@appearing}{\textit{Proceedings of the
$\mathit{41}^{st}$ International Conference on Machine Learning},
Vienna, Austria. PMLR 235, 2024.
Copyright 2024 by the author(s).}
%%%%%%%%%%%%%%%%%%%%
% This string is printed at the bottom of the page for the draft/under
% review version of the ``appearing in'' note. Modify it to change
% that text.
%%%%%%%%%%%%%%%%%%%%
\newcommand{\Notice@String}{Preliminary work. Under review by the
International Conference on Machine Learning (ICML)\@. Do not distribute.}
% Cause the declared options to actually be parsed and activated
\ProcessOptions\relax
\ifdefined\isaccepted\else\ifdefined\hypersetup
\hypersetup{pdfauthor={Anonymous Authors}}
\fi
\fi
\ifdefined\nohyperref\else\ifdefined\hypersetup
\definecolor{mydarkblue}{rgb}{0,0.08,0.45}
\hypersetup{ %
pdftitle={},
pdfsubject={Proceedings of the International Conference on Machine Learning 2024},
pdfkeywords={},
pdfborder=0 0 0,
pdfpagemode=UseNone,
colorlinks=true,
linkcolor=mydarkblue,
citecolor=mydarkblue,
filecolor=mydarkblue,
urlcolor=mydarkblue,
}
\fi
\fi
% Uncomment the following for debugging. It will cause LaTeX to dump
% the version of the ``appearing in'' string that will actually appear
% in the document.
%\typeout{>> Notice string='\Notice@String'}
% Change citation commands to be more like old ICML styles
\newcommand{\yrcite}[1]{\citeyearpar{#1}}
\renewcommand{\cite}[1]{\citep{#1}}
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% to ensure the letter format is used. pdflatex does compile the
% page size into the pdf. This is done using \pdfpagewidth and
% \pdfpageheight. As Latex does not know this directives, we first
% check whether pdflatex or latex is used.
%
% Kristian Kersting 2005
%
% in order to account for the more recent use of pdfetex as the default
% compiler, I have changed the pdf verification.
%
% Ricardo Silva 2007
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
\paperwidth=8.5in
\paperheight=11in
% old PDFLaTex verification, circa 2005
%
%\newif\ifpdf\ifx\pdfoutput\undefined
% \pdffalse % we are not running PDFLaTeX
%\else
% \pdfoutput=1 % we are running PDFLaTeX
% \pdftrue
%\fi
\newif\ifpdf %adapted from ifpdf.sty
\ifx\pdfoutput\undefined
\else
\ifx\pdfoutput\relax
\else
\ifcase\pdfoutput
\else
\pdftrue
\fi
\fi
\fi
\ifpdf
% \pdfpagewidth=\paperwidth
% \pdfpageheight=\paperheight
\setlength{\pdfpagewidth}{8.5in}
\setlength{\pdfpageheight}{11in}
\fi
% Physical page layout
\evensidemargin -0.23in
\oddsidemargin -0.23in
\setlength\textheight{9.0in}
\setlength\textwidth{6.75in}
\setlength\columnsep{0.25in}
\setlength\headheight{10pt}
\setlength\headsep{10pt}
\addtolength{\topmargin}{-20pt}
\addtolength{\topmargin}{-0.29in}
% Historically many authors tried to include packages like geometry or fullpage,
% which change the page layout. It either makes the proceedings inconsistent, or
% wastes organizers' time chasing authors. So let's nip these problems in the
% bud here. -- Iain Murray 2018.
%\RequirePackage{printlen}
\AtBeginDocument{%
% To get the numbers below, include printlen package above and see lengths like this:
%\printlength\oddsidemargin\\
%\printlength\headheight\\
%\printlength\textheight\\
%\printlength\marginparsep\\
%\printlength\footskip\\
%\printlength\hoffset\\
%\printlength\paperwidth\\
%\printlength\topmargin\\
%\printlength\headsep\\
%\printlength\textwidth\\
%\printlength\marginparwidth\\
%\printlength\marginparpush\\
%\printlength\voffset\\
%\printlength\paperheight\\
%
\newif\ifmarginsmessedwith
\marginsmessedwithfalse
\ifdim\oddsidemargin=-16.62178pt \else oddsidemargin has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\headheight=10.0pt \else headheight has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\textheight=650.43pt \else textheight has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\marginparsep=11.0pt \else marginparsep has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\footskip=25.0pt \else footskip has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\hoffset=0.0pt \else hoffset has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\paperwidth=614.295pt \else paperwidth has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\topmargin=-24.95781pt \else topmargin has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\headsep=10.0pt \else headsep has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\textwidth=487.8225pt \else textwidth has been altered.\\ \marginsmessedwithtrue\fi
%\ifdim\marginparwidth=65.0pt \else marginparwidth has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\marginparpush=5.0pt \else marginparpush has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\voffset=0.0pt \else voffset has been altered.\\ \marginsmessedwithtrue\fi
\ifdim\paperheight=794.96999pt \else paperheight has been altered.\\ \marginsmessedwithtrue\fi
\ifmarginsmessedwith
\textbf{\large \em The page layout violates the ICML style.}
Please do not change the page layout, or include packages like geometry,
savetrees, or fullpage, which change it for you.
We're not able to reliably undo arbitrary changes to the style. Please remove
the offending package(s), or layout-changing commands and try again.
\fi}
%% The following is adapted from code in the acmconf.sty conference
%% style file. The constants in it are somewhat magical, and appear
%% to work well with the two-column format on US letter paper that
%% ICML uses, but will break if you change that layout, or if you use
%% a longer block of text for the copyright notice string. Fiddle with
%% them if necessary to get the block to fit/look right.
%%
%% -- Terran Lane, 2003
%%
%% The following comments are included verbatim from acmconf.sty:
%%
%%% This section (written by KBT) handles the 1" box in the lower left
%%% corner of the left column of the first page by creating a picture,
%%% and inserting the predefined string at the bottom (with a negative
%%% displacement to offset the space allocated for a non-existent
%%% caption).
%%%
\def\ftype@copyrightbox{8}
\def\@copyrightspace{
% Create a float object positioned at the bottom of the column. Note
% that because of the mystical nature of floats, this has to be called
% before the first column is populated with text (e.g., from the title
% or abstract blocks). Otherwise, the text will force the float to
% the next column. -- TDRL.
\@float{copyrightbox}[b]
\begin{center}
\setlength{\unitlength}{1pc}
\begin{picture}(20,1.5)
% Create a line separating the main text from the note block.
% 4.818pc==0.8in.
\put(0,2.5){\line(1,0){4.818}}
% Insert the text string itself. Note that the string has to be
% enclosed in a parbox -- the \put call needs a box object to
% position. Without the parbox, the text gets splattered across the
% bottom of the page semi-randomly. The 19.75pc distance seems to be
% the width of the column, though I can't find an appropriate distance
% variable to substitute here. -- TDRL.
\put(0,0){\parbox[b]{19.75pc}{\small \Notice@String}}
\end{picture}
\end{center}
\end@float}
% Note: A few Latex versions need the next line instead of the former.
% \addtolength{\topmargin}{0.3in}
% \setlength\footheight{0pt}
\setlength\footskip{25.0pt}
%\pagestyle{empty}
\flushbottom \twocolumn
\sloppy
% Clear out the addcontentsline command
\def\addcontentsline#1#2#3{}
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% commands for formatting paper title, author names, and addresses.
%%start%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%% title as running head -- Kristian Kersting 2005 %%%%%%%%%%%%%
%\makeatletter
%\newtoks\mytoksa
%\newtoks\mytoksb
%\newcommand\addtomylist[2]{%
% \mytoksa\expandafter{#1}%
% \mytoksb{#2}%
% \edef#1{\the\mytoksa\the\mytoksb}%
%}
%\makeatother
% box to check the size of the running head
\newbox\titrun
% general page style
\pagestyle{fancy}
\fancyhf{}
\fancyhead{}
\fancyfoot{}
\cfoot{\thepage}
% set the width of the head rule to 1 point
\renewcommand{\headrulewidth}{1pt}
% definition to set the head as running head in the preamble
\def\icmltitlerunning#1{\gdef\@icmltitlerunning{#1}}
% main definition adapting \icmltitle from 2004
\long\def\icmltitle#1{%
%check whether @icmltitlerunning exists
% if not \icmltitle is used as running head
\ifx\undefined\@icmltitlerunning%
\gdef\@icmltitlerunning{#1}
\fi
%add it to pdf information
\ifdefined\nohyperref\else\ifdefined\hypersetup
\hypersetup{pdftitle={#1}}
\fi\fi
%get the dimension of the running title
\global\setbox\titrun=\vbox{\small\bf\@icmltitlerunning}
% error flag
\gdef\@runningtitleerror{0}
% running title too long
\ifdim\wd\titrun>\textwidth%
{\gdef\@runningtitleerror{1}}%
% running title breaks a line
\else\ifdim\ht\titrun>6.25pt
{\gdef\@runningtitleerror{2}}%
\fi
\fi
% if there is somthing wrong with the running title
\ifnum\@runningtitleerror>0
\typeout{}%
\typeout{}%
\typeout{*******************************************************}%
\typeout{Title exceeds size limitations for running head.}%
\typeout{Please supply a shorter form for the running head}
\typeout{with \string\icmltitlerunning{...}\space prior to \string\begin{document}}%
\typeout{*******************************************************}%
\typeout{}%
\typeout{}%
% set default running title
\chead{\small\bf Title Suppressed Due to Excessive Size}%
\else
% 'everything' fine, set provided running title
\chead{\small\bf\@icmltitlerunning}%
\fi
% no running title on the first page of the paper
\thispagestyle{plain}
%%%%%%%%%%%%%%%%%%%% Kristian Kersting %%%%%%%%%%%%%%%%%%%%%%%%%
%end%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
{\center\baselineskip 18pt
\toptitlebar{\Large\bf #1}\bottomtitlebar}
}
\gdef\icmlfullauthorlist{}
\newcommand\addstringtofullauthorlist{\g@addto@macro\icmlfullauthorlist}
\newcommand\addtofullauthorlist[1]{%
\ifdefined\icmlanyauthors%
\addstringtofullauthorlist{, #1}%
\else%
\addstringtofullauthorlist{#1}%
\gdef\icmlanyauthors{1}%
\fi%
% \ifdefined\nohyperref\else
\ifdefined\hypersetup%
\hypersetup{pdfauthor=\icmlfullauthorlist}%
\fi%\fi
}
\def\toptitlebar{\hrule height1pt \vskip .25in}
\def\bottomtitlebar{\vskip .22in \hrule height1pt \vskip .3in}
\newenvironment{icmlauthorlist}{%
\setlength\topsep{0pt}
\setlength\parskip{0pt}
\begin{center}
}{%
\end{center}
}
\newcounter{@affiliationcounter}
\newcommand{\@pa}[1]{%
% ``#1''
\ifcsname the@affil#1\endcsname
% do nothing
\else
\ifcsname @icmlsymbol#1\endcsname
% nothing
\else
\stepcounter{@affiliationcounter}%
\newcounter{@affil#1}%
\setcounter{@affil#1}{\value{@affiliationcounter}}%
\fi
\fi%
\ifcsname @icmlsymbol#1\endcsname
\textsuperscript{\csname @icmlsymbol#1\endcsname\,}%
\else
%\expandafter\footnotemark[\arabic{@affil#1}\,]%
\textsuperscript{\arabic{@affil#1}\,}%
\fi
}
%\newcommand{\icmlauthor}[2]{%
%\addtofullauthorlist{#1}%
%#1\@for\theaffil:=#2\do{\pa{\theaffil}}%
%}
\newcommand{\icmlauthor}[2]{%
\ifdefined\isaccepted
\mbox{\bf #1}\,\@for\theaffil:=#2\do{\@pa{\theaffil}} \addtofullauthorlist{#1}%
\else
\ifdefined\@icmlfirsttime
\else
\gdef\@icmlfirsttime{1}
\mbox{\bf Anonymous Authors}\@pa{@anon} \addtofullauthorlist{Anonymous Authors}
\fi
\fi
}
\newcommand{\icmlsetsymbol}[2]{%
\expandafter\gdef\csname @icmlsymbol#1\endcsname{#2}
}
\newcommand{\icmlaffiliation}[2]{%
\ifdefined\isaccepted
\ifcsname the@affil#1\endcsname
\expandafter\gdef\csname @affilname\csname the@affil#1\endcsname\endcsname{#2}%
\else
{\bf AUTHORERR: Error in use of \textbackslash{}icmlaffiliation command. Label ``#1'' not mentioned in some \textbackslash{}icmlauthor\{author name\}\{labels here\} command beforehand. }
\typeout{}%
\typeout{}%
\typeout{*******************************************************}%
\typeout{Affiliation label undefined. }%
\typeout{Make sure \string\icmlaffiliation\space follows }
\typeout{all of \string\icmlauthor\space commands}%
\typeout{*******************************************************}%
\typeout{}%
\typeout{}%
\fi
\else % \isaccepted
% can be called multiple times... it's idempotent
\expandafter\gdef\csname @affilname1\endcsname{Anonymous Institution, Anonymous City, Anonymous Region, Anonymous Country}
\fi
}
\newcommand{\icmlcorrespondingauthor}[2]{
\ifdefined\isaccepted
\ifdefined\icmlcorrespondingauthor@text
\g@addto@macro\icmlcorrespondingauthor@text{, #1 \textless{}#2\textgreater{}}
\else
\gdef\icmlcorrespondingauthor@text{#1 \textless{}#2\textgreater{}}
\fi
\else
\gdef\icmlcorrespondingauthor@text{Anonymous Author \textless{}anon.email@domain.com\textgreater{}}
\fi
}
\newcommand{\icmlEqualContribution}{\textsuperscript{*}Equal contribution }
\newcounter{@affilnum}
\newcommand{\printAffiliationsAndNotice}[1]{%
\stepcounter{@affiliationcounter}%
{\let\thefootnote\relax\footnotetext{\hspace*{-\footnotesep}\ifdefined\isaccepted #1\fi%
\forloop{@affilnum}{1}{\value{@affilnum} < \value{@affiliationcounter}}{
\textsuperscript{\arabic{@affilnum}}\ifcsname @affilname\the@affilnum\endcsname%
\csname @affilname\the@affilnum\endcsname%
\else
{\bf AUTHORERR: Missing \textbackslash{}icmlaffiliation.}
\fi
}.
\ifdefined\icmlcorrespondingauthor@text
Correspondence to: \icmlcorrespondingauthor@text.
\else
{\bf AUTHORERR: Missing \textbackslash{}icmlcorrespondingauthor.}
\fi
\ \\
\Notice@String
}
}
}
%\makeatother
\long\def\icmladdress#1{%
{\bf The \textbackslash{}icmladdress command is no longer used. See the example\_paper PDF .tex for usage of \textbackslash{}icmlauther and \textbackslash{}icmlaffiliation.}
}
%% keywords as first class citizens
\def\icmlkeywords#1{%
% \ifdefined\isaccepted \else
% \par {\bf Keywords:} #1%
% \fi
% \ifdefined\nohyperref\else\ifdefined\hypersetup
% \hypersetup{pdfkeywords={#1}}
% \fi\fi
% \ifdefined\isaccepted \else
% \par {\bf Keywords:} #1%
% \fi
\ifdefined\nohyperref\else\ifdefined\hypersetup
\hypersetup{pdfkeywords={#1}}
\fi\fi
}
% modification to natbib citations
\setcitestyle{authoryear,round,citesep={;},aysep={,},yysep={;}}
% Redefinition of the abstract environment.
\renewenvironment{abstract}
{%
% Insert the ``appearing in'' copyright notice.
%\@copyrightspace
\centerline{\large\bf Abstract}
\vspace{-0.12in}\begin{quote}}
{\par\end{quote}\vskip 0.12in}
% numbered section headings with different treatment of numbers
\def\@startsection#1#2#3#4#5#6{\if@noskipsec \leavevmode \fi
\par \@tempskipa #4\relax
\@afterindenttrue
% Altered the following line to indent a section's first paragraph.
% \ifdim \@tempskipa <\z@ \@tempskipa -\@tempskipa \@afterindentfalse\fi
\ifdim \@tempskipa <\z@ \@tempskipa -\@tempskipa \fi
\if@nobreak \everypar{}\else
\addpenalty{\@secpenalty}\addvspace{\@tempskipa}\fi \@ifstar
{\@ssect{#3}{#4}{#5}{#6}}{\@dblarg{\@sict{#1}{#2}{#3}{#4}{#5}{#6}}}}
\def\@sict#1#2#3#4#5#6[#7]#8{\ifnum #2>\c@secnumdepth
\def\@svsec{}\else
\refstepcounter{#1}\edef\@svsec{\csname the#1\endcsname}\fi
\@tempskipa #5\relax
\ifdim \@tempskipa>\z@
\begingroup #6\relax
\@hangfrom{\hskip #3\relax\@svsec.~}{\interlinepenalty \@M #8\par}
\endgroup
\csname #1mark\endcsname{#7}\addcontentsline
{toc}{#1}{\ifnum #2>\c@secnumdepth \else
\protect\numberline{\csname the#1\endcsname}\fi
#7}\else
\def\@svsechd{#6\hskip #3\@svsec #8\csname #1mark\endcsname
{#7}\addcontentsline
{toc}{#1}{\ifnum #2>\c@secnumdepth \else
\protect\numberline{\csname the#1\endcsname}\fi
#7}}\fi
\@xsect{#5}}
\def\@sect#1#2#3#4#5#6[#7]#8{\ifnum #2>\c@secnumdepth
\def\@svsec{}\else
\refstepcounter{#1}\edef\@svsec{\csname the#1\endcsname\hskip 0.4em }\fi
\@tempskipa #5\relax
\ifdim \@tempskipa>\z@
\begingroup #6\relax
\@hangfrom{\hskip #3\relax\@svsec}{\interlinepenalty \@M #8\par}
\endgroup
\csname #1mark\endcsname{#7}\addcontentsline
{toc}{#1}{\ifnum #2>\c@secnumdepth \else
\protect\numberline{\csname the#1\endcsname}\fi
#7}\else
\def\@svsechd{#6\hskip #3\@svsec #8\csname #1mark\endcsname
{#7}\addcontentsline
{toc}{#1}{\ifnum #2>\c@secnumdepth \else
\protect\numberline{\csname the#1\endcsname}\fi
#7}}\fi
\@xsect{#5}}
% section headings with less space above and below them
\def\thesection {\arabic{section}}
\def\thesubsection {\thesection.\arabic{subsection}}
\def\section{\@startsection{section}{1}{\z@}{-0.12in}{0.02in}
{\large\bf\raggedright}}
\def\subsection{\@startsection{subsection}{2}{\z@}{-0.10in}{0.01in}
{\normalsize\bf\raggedright}}
\def\subsubsection{\@startsection{subsubsection}{3}{\z@}{-0.08in}{0.01in}
{\normalsize\sc\raggedright}}
\def\paragraph{\@startsection{paragraph}{4}{\z@}{1.5ex plus
0.5ex minus .2ex}{-1em}{\normalsize\bf}}
\def\subparagraph{\@startsection{subparagraph}{5}{\z@}{1.5ex plus
0.5ex minus .2ex}{-1em}{\normalsize\bf}}
% Footnotes
\footnotesep 6.65pt %
\skip\footins 9pt
\def\footnoterule{\kern-3pt \hrule width 0.8in \kern 2.6pt }
\setcounter{footnote}{0}
% Lists and paragraphs
\parindent 0pt
\topsep 4pt plus 1pt minus 2pt
\partopsep 1pt plus 0.5pt minus 0.5pt
\itemsep 2pt plus 1pt minus 0.5pt
\parsep 2pt plus 1pt minus 0.5pt
\parskip 6pt
\leftmargin 2em \leftmargini\leftmargin \leftmarginii 2em
\leftmarginiii 1.5em \leftmarginiv 1.0em \leftmarginv .5em
\leftmarginvi .5em
\labelwidth\leftmargini\advance\labelwidth-\labelsep \labelsep 5pt
\def\@listi{\leftmargin\leftmargini}
\def\@listii{\leftmargin\leftmarginii
\labelwidth\leftmarginii\advance\labelwidth-\labelsep
\topsep 2pt plus 1pt minus 0.5pt
\parsep 1pt plus 0.5pt minus 0.5pt
\itemsep \parsep}
\def\@listiii{\leftmargin\leftmarginiii
\labelwidth\leftmarginiii\advance\labelwidth-\labelsep
\topsep 1pt plus 0.5pt minus 0.5pt
\parsep \z@ \partopsep 0.5pt plus 0pt minus 0.5pt
\itemsep \topsep}
\def\@listiv{\leftmargin\leftmarginiv
\labelwidth\leftmarginiv\advance\labelwidth-\labelsep}
\def\@listv{\leftmargin\leftmarginv
\labelwidth\leftmarginv\advance\labelwidth-\labelsep}
\def\@listvi{\leftmargin\leftmarginvi
\labelwidth\leftmarginvi\advance\labelwidth-\labelsep}
\abovedisplayskip 7pt plus2pt minus5pt%
\belowdisplayskip \abovedisplayskip
\abovedisplayshortskip 0pt plus3pt%
\belowdisplayshortskip 4pt plus3pt minus3pt%
% Less leading in most fonts (due to the narrow columns)
% The choices were between 1-pt and 1.5-pt leading
\def\@normalsize{\@setsize\normalsize{11pt}\xpt\@xpt}
\def\small{\@setsize\small{10pt}\ixpt\@ixpt}
\def\footnotesize{\@setsize\footnotesize{10pt}\ixpt\@ixpt}
\def\scriptsize{\@setsize\scriptsize{8pt}\viipt\@viipt}
\def\tiny{\@setsize\tiny{7pt}\vipt\@vipt}
\def\large{\@setsize\large{14pt}\xiipt\@xiipt}
\def\Large{\@setsize\Large{16pt}\xivpt\@xivpt}
\def\LARGE{\@setsize\LARGE{20pt}\xviipt\@xviipt}
\def\huge{\@setsize\huge{23pt}\xxpt\@xxpt}
\def\Huge{\@setsize\Huge{28pt}\xxvpt\@xxvpt}
% Revised formatting for figure captions and table titles.
\newsavebox\newcaptionbox\newdimen\newcaptionboxwid
\long\def\@makecaption#1#2{
\vskip 10pt
\baselineskip 11pt
\setbox\@tempboxa\hbox{#1. #2}
\ifdim \wd\@tempboxa >\hsize
\sbox{\newcaptionbox}{\small\sl #1.~}
\newcaptionboxwid=\wd\newcaptionbox
\usebox\newcaptionbox {\footnotesize #2}
% \usebox\newcaptionbox {\small #2}
\else
\centerline{{\small\sl #1.} {\small #2}}
\fi}
\def\fnum@figure{Figure \thefigure}
\def\fnum@table{Table \thetable}
% Strut macros for skipping spaces above and below text in tables.
\def\abovestrut#1{\rule[0in]{0in}{#1}\ignorespaces}
\def\belowstrut#1{\rule[-#1]{0in}{#1}\ignorespaces}
\def\abovespace{\abovestrut{0.20in}}
\def\aroundspace{\abovestrut{0.20in}\belowstrut{0.10in}}
\def\belowspace{\belowstrut{0.10in}}
% Various personal itemization commands.
\def\texitem#1{\par\noindent\hangindent 12pt
\hbox to 12pt {\hss #1 ~}\ignorespaces}
\def\icmlitem{\texitem{$\bullet$}}
% To comment out multiple lines of text.
\long\def\comment#1{}
%% Line counter (not in final version). Adapted from NIPS style file by Christoph Sawade
% Vertical Ruler
% This code is, largely, from the CVPR 2010 conference style file
% ----- define vruler
\makeatletter
\newbox\icmlrulerbox
\newcount\icmlrulercount
\newdimen\icmlruleroffset
\newdimen\cv@lineheight
\newdimen\cv@boxheight
\newbox\cv@tmpbox
\newcount\cv@refno
\newcount\cv@tot
% NUMBER with left flushed zeros \fillzeros[<WIDTH>]<NUMBER>
\newcount\cv@tmpc@ \newcount\cv@tmpc
\def\fillzeros[#1]#2{\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi
\cv@tmpc=1 %
\loop\ifnum\cv@tmpc@<10 \else \divide\cv@tmpc@ by 10 \advance\cv@tmpc by 1 \fi
\ifnum\cv@tmpc@=10\relax\cv@tmpc@=11\relax\fi \ifnum\cv@tmpc@>10 \repeat
\ifnum#2<0\advance\cv@tmpc1\relax-\fi
\loop\ifnum\cv@tmpc<#1\relax0\advance\cv@tmpc1\relax\fi \ifnum\cv@tmpc<#1 \repeat
\cv@tmpc@=#2\relax\ifnum\cv@tmpc@<0\cv@tmpc@=-\cv@tmpc@\fi \relax\the\cv@tmpc@}%
% \makevruler[<SCALE>][<INITIAL_COUNT>][<STEP>][<DIGITS>][<HEIGHT>]
\def\makevruler[#1][#2][#3][#4][#5]{
\begingroup\offinterlineskip
\textheight=#5\vbadness=10000\vfuzz=120ex\overfullrule=0pt%
\global\setbox\icmlrulerbox=\vbox to \textheight{%
{
\parskip=0pt\hfuzz=150em\cv@boxheight=\textheight
\cv@lineheight=#1\global\icmlrulercount=#2%
\cv@tot\cv@boxheight\divide\cv@tot\cv@lineheight\advance\cv@tot2%
\cv@refno1\vskip-\cv@lineheight\vskip1ex%
\loop\setbox\cv@tmpbox=\hbox to0cm{ % side margin
\hfil {\hfil\fillzeros[#4]\icmlrulercount}
}%
\ht\cv@tmpbox\cv@lineheight\dp\cv@tmpbox0pt\box\cv@tmpbox\break
\advance\cv@refno1\global\advance\icmlrulercount#3\relax
\ifnum\cv@refno<\cv@tot\repeat
}
}
\endgroup
}%
\makeatother
% ----- end of vruler
% \makevruler[<SCALE>][<INITIAL_COUNT>][<STEP>][<DIGITS>][<HEIGHT>]
\def\icmlruler#1{\makevruler[12pt][#1][1][3][\textheight]\usebox{\icmlrulerbox}}
\AddToShipoutPicture{%
\icmlruleroffset=\textheight
\advance\icmlruleroffset by 5.2pt % top margin
\color[rgb]{.7,.7,.7}
\ifdefined\isaccepted \else
\AtTextUpperLeft{%
\put(\LenToUnit{-35pt},\LenToUnit{-\icmlruleroffset}){%left ruler
\icmlruler{\icmlrulercount}}
% \put(\LenToUnit{1.04\textwidth},\LenToUnit{-\icmlruleroffset}){%right ruler
% \icmlruler{\icmlrulercount}}
}
\fi
}
\endinput

Some files were not shown because too many files have changed in this diff Show More