AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,142 @@
import albumentations as A
import cv2
import numpy as np
from albumentations.pytorch import ToTensorV2
from datadings.torch import CompressedToPIL
class AlbumTorchCompose(A.Compose):
"""Compose albumentation augmentations in a way that works with PIL images and datadings."""
def __init__(self, *args, **kwargs):
"""Pass to A.Compose."""
super().__init__(*args, **kwargs)
self.to_pil = CompressedToPIL()
def __call__(self, image, mask=None, **kwargs):
if isinstance(image, bytes):
image = self.to_pil(image)
if mask is not None and len(mask) == 0:
mask = None
if not isinstance(image, np.ndarray):
image = np.array(image)
if mask is not None and not isinstance(mask, np.ndarray):
mask = np.array(mask)
if mask is None:
return super().__call__(image=image, **kwargs)["image"]
return super().__call__(image=image, mask=mask, **kwargs)
class PILToNP(A.DualTransform):
"""Convert PIL image to numpy array."""
def apply(self, image, **params):
return np.array(image)
def apply_to_mask(self, image, **params):
return np.array(image)
def get_transform_init_args_names(self):
return ()
class AlbumCompressedToPIL(A.DualTransform):
"""Convert compressed image to PIL image."""
def apply(self, img, **params):
return self.to_pil(img)
def apply_to_mask(self, img, **params):
return self.to_pil(img)
def get_transform_init_args_names(self):
return ()
def minimal_augment(args, test=False):
"""Get minimal augmentations for training or testing.
Args:
args (argparse.Namespace): arguments
test (bool, optional): if True, return test augmentations. Defaults to False.
Returns:
List: Augmentation list
"""
augs = []
if args.aug_resize:
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
if test and args.aug_crop:
augs.append(A.CenterCrop(args.imsize, args.imsize))
elif args.aug_crop:
augs.append(A.RandomCrop(args.imsize, args.imsize))
if not test and args.aug_flip:
augs.append(A.HorizontalFlip(p=0.5))
if args.aug_normalize:
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
augs.append(ToTensorV2())
return augs
def three_augment(args, as_list=False, test=False):
"""Create the data augmentation.
Args:
args (Namespace): arguments
as_list (bool): return list of transformations, not composed transformation
test (bool): In eval mode? If False => train mode
Returns:
torch.nn.Module | list[torch.nn.Module]: composed transformation or list of transformations
"""
augs = []
if args.aug_resize:
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
if test and args.aug_crop:
augs.append(A.CenterCrop(args.imsize, args.imsize))
elif args.aug_crop:
augs.append(A.RandomCrop(args.imsize, args.imsize, pad_if_needed=True, border_mode=cv2.BORDER_REFLECT))
if not test:
if args.aug_flip:
augs.append(A.HorizontalFlip(p=0.5))
augs_choice = []
if args.aug_grayscale:
augs_choice.append(A.ToGray(p=1, num_output_channels=3))
if args.aug_solarize:
augs_choice.append(A.Solarize(p=1, threshold_range=(0.5, 0.5)))
if args.aug_gauss_blur:
augs_choice.append(A.GaussianBlur(p=1, sigma_limit=(0.2, 2.0), blur_limit=(7, 7)))
if len(augs_choice) > 0:
augs.append(A.OneOf(augs_choice))
if args.aug_color_jitter_factor > 0.0:
augs.append(
A.ColorJitter(
brightness=args.aug_color_jitter_factor,
contrast=args.aug_color_jitter_factor,
saturation=args.aug_color_jitter_factor,
hue=0.0,
)
)
if args.aug_normalize:
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
augs.append(ToTensorV2())
if as_list:
return augs
return AlbumTorchCompose(augs)

View File

@@ -0,0 +1,63 @@
import os
from loguru import logger
from PIL import Image
from torch.utils.data import Dataset
class CounterAnimal(Dataset):
"""Dataset to load the CounterAnimal dataset with ImageNet labels."""
def __init__(self, base_path, mode, transform=None, target_transform=None, train=False):
"""Create the dataset.
Args:
base_path (str): path to the base folder (the one where the class folders are in)
mode (str): mode/variant of the dataset (common/counter)
transform: Image augmentation
target_transform: label augmentation
train: train or test set. Train set is not supported
"""
super().__init__()
self.base = base_path
assert mode in ["counter", "common"], f"Supported modes are counter and common, but got '{mode}'"
assert not train, "CounterAnimal only consists of test data, not training data."
self.transform = transform
self.target_transform = target_transform
self.index = []
for class_folder in os.listdir(self.base):
if not os.path.isdir(os.path.join(self.base, class_folder)):
continue
# print(f"looking in folder {class_folder}")
class_idx = int(class_folder.split(" ")[0])
for variant_folder in os.listdir(os.path.join(self.base, class_folder)):
# print(f"\tlooking in variant {variant_folder}")
if not variant_folder.startswith(mode):
# print("\t\tskip")
continue
_folder = os.path.join(self.base, class_folder, variant_folder)
# print(f"\t\tadding {len(os.listdir(_folder))} files to index")
for file in os.listdir(_folder):
if file.lower().split(".")[-1] in ["jpg", "jpeg", "png"]:
self.index.append((os.path.join(_folder, file), class_idx))
# print(f"loaded {len(self.index)} images into the index: {self.index[:5]}...")
assert len(self.index) > 0, "did not find any images :("
def __len__(self):
return len(self.index)
def __getitem__(self, idx):
path, label = self.index[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
if self.target_transform:
label = self.target_transform(label)
return img, label

View File

@@ -0,0 +1,145 @@
from nvidia.dali import fn, pipeline_def, types
# see https://docs.nvidia.com/deeplearning/dali/user-guide/docs/plugins/pytorch_dali_proxy.html
@pipeline_def
def minimal_augment(args, test=False):
"""Minimal Augmentation set for images.
Contains only resize, crop, flip, to tensor and normalize.
Args:
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
test (bool, optional): On the test set? Defaults to False.
Returns:
images: augmented images.
"""
images = fn.external_source(name="images", no_copy=True)
if args.aug_resize:
images = fn.resize(images, size=args.imsize, mode="not_smaller")
if test and args.aug_crop:
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
elif args.aug_crop:
images = fn.crop(
images,
crop=(args.imsize, args.imsize),
crop_pos_x=fn.random.uniform(range=(0, 1)),
crop_pos_y=fn.random.uniform(range=(0, 1)),
)
# if not test and args.aug_flip:
# images = fn.flip(images, horizontal=fn.random.coin_flip())
# if args.aug_normalize:
# images = fn.normalize(
# images,
# mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
# std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
# dtype=types.FLOAT,
# )
return fn.crop_mirror_normalize(
images,
dtype=types.FLOAT,
output_layout="CHW",
crop=(args.imsize, args.imsize),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
)
def dali_solarize(images, threshold=128):
"""Solarize implementation for nvidia DALI.
Args:
images (DALI Tensor): Images to solarize.
threshold (int, optional): Threshold for solarization. Defaults to 128.
Returns:
images: solarized images.
"""
inv_images = types.Constant(255).uint8() - images
mask = (images >= threshold) * types.Constant(1).uint8()
return mask * inv_images + (types.Constant(1).uint8() ^ mask) * images
@pipeline_def(enable_conditionals=True)
def three_augment(args, test=False):
"""3-augment data augmentation pipeline for nvidia DALI.
Args:
args (namespace): augmentation arguments.
test (bool, optional): Test (or train) split. Defaults to False.
Returns:
images: augmented images.
"""
images = fn.external_source(name="images", no_copy=True)
if args.aug_resize:
images = fn.resize(images, size=args.imsize, mode="not_smaller")
if test and args.aug_crop:
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
elif args.aug_crop:
images = fn.crop(
images,
crop=(args.imsize, args.imsize),
crop_pos_x=fn.random.uniform(range=(0, 1)),
crop_pos_y=fn.random.uniform(range=(0, 1)),
)
if not test:
choices = []
# choice = fn.random.choice(3)
# print(images.layout())
choice_ps = [1 * args.aug_grayscale, 1 * args.aug_solarize, 1 * args.aug_gauss_blur]
choice_ps = [c / sum(choice_ps) for c in choice_ps]
choice = fn.random.choice(
[0, 1, 2],
p=choice_ps,
)
if choice == 0:
images = fn.color_space_conversion(
fn.color_space_conversion(images, image_type=types.RGB, output_type=types.GRAY),
image_type=types.GRAY,
output_type=types.RGB,
)
elif choice == 1:
images = dali_solarize(images, threshold=128)
elif choice == 2:
images = fn.gaussian_blur(images, window_size=7, sigma=fn.random.uniform(range=(0.2, 2.0)))
if len(choices) > 0:
images = fn.random.choice(choices)
if args.aug_color_jitter_factor > 0.0:
images = fn.color_twist(
images,
brightness=fn.random.uniform(
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
),
contrast=fn.random.uniform(range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)),
saturation=fn.random.uniform(
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
),
)
return fn.crop_mirror_normalize(
images,
dtype=types.FLOAT,
output_layout="CHW",
crop=(args.imsize, args.imsize),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
)

View File

@@ -0,0 +1,407 @@
from random import uniform
import msgpack
import torch
import torchvision
from datadings.torch import CompressedToPIL
from datadings.torch import Dataset as DDDataset
from PIL import ImageFilter
from torchvision.transforms import (
CenterCrop,
ColorJitter,
Compose,
GaussianBlur,
Grayscale,
InterpolationMode,
Normalize,
RandomChoice,
RandomCrop,
RandomHorizontalFlip,
RandomResizedCrop,
RandomSolarize,
Resize,
ToTensor,
)
from torchvision.transforms import functional as F
_image_and_target_transforms = [
torchvision.transforms.RandomCrop,
torchvision.transforms.RandomHorizontalFlip,
torchvision.transforms.CenterCrop,
torchvision.transforms.RandomRotation,
torchvision.transforms.RandomAffine,
torchvision.transforms.RandomResizedCrop,
torchvision.transforms.RandomRotation,
]
def apply_dense_transforms(x, y, transforms: torchvision.transforms.transforms.Compose):
"""Apply some transfomations to both image and target.
Args:
x (torch.Tensor): image
y (torch.Tensor): target (image)
transforms (torchvision.transforms.transforms.Compose): transformations to apply
Returns:
tuple[torch.Tensor, torch.Tensor]: (x, y) with applyed transformations
"""
for trans in transforms.transforms:
if isinstance(trans, torchvision.transforms.RandomResizedCrop):
params = trans.get_params(x, trans.scale, trans.ratio)
x = F.resized_crop(x, *params, trans.size, trans.interpolation, antialias=trans.antialias)
y = F.resized_crop(y.unsqueeze(0), *params, trans.size, 0).squeeze(0) # nearest neighbor interpolation
elif isinstance(trans, Resize):
pre_shape = x.shape
x = trans(x)
if x.shape != pre_shape:
y = F.resize(y.unsqueeze(0), trans.size, 0, trans.max_size, trans.antialias).squeeze(
0
) # nearest neighbor interpolation
elif any(isinstance(trans, simul_transform) for simul_transform in _image_and_target_transforms):
xy = torch.cat([x, y.unsqueeze(0).float()], dim=0)
xy = trans(xy)
x, y = xy[:-1], xy[-1].long()
elif isinstance(trans, torchvision.transforms.ToTensor):
if not isinstance(x, torch.Tensor):
x = trans(x)
else:
x = trans(x)
return x, y
def get_hf_transform(transform_f, trgt_transform_f=None, image_key="image"):
"""Convert the transform function to a huggingface compatible transform function.
Args:
transform_f (callable): Image transform.
trgt_transform (callable, optional): Target transform. Defaults to None.
image_key (str, optional): Key for the image in the hf ds return dict. Defaults to "image".
"""
def _transform(samples):
try:
samples[image_key] = [transform_f(im) for im in samples[image_key]]
if trgt_transform_f is not None:
samples["label"] = [trgt_transform_f(tgt) for tgt in samples["label"]]
except TypeError as e:
print(f"Type error when transforming samples: {samples}")
raise e
return samples
return _transform
class DDDecodeDataset(DDDataset):
"""Datadings dataset with image decoding before transform."""
def __init__(self, *args, transform=None, target_transform=None, transforms=None, **kwargs):
"""Create datadings dataset.
Args:
transform (callable, optional): Image transform. Overrides transforms['image']. Defaults to None.
target_transform (callable, optional): Label transform. Overrides transforms['label']. Defaults to None.
transforms (dict[str, callable], optional): Dict of transforms for each key. Defaults to None.
"""
super().__init__(*args, **kwargs)
if transforms is None:
transforms = {}
self._decode_transform = transform if transform is not None else transforms.get("image", None)
self._decode_target_transform = (
target_transform if target_transform is not None else transforms.get("label", None)
)
self.ctp = CompressedToPIL()
def __getitem__(self, idx):
sample = super().__getitem__(idx)
img, lbl = sample["image"], sample["label"]
if isinstance(img, bytes):
img = self.ctp(img)
if self._decode_transform is not None:
img = self._decode_transform(img)
if self._decode_target_transform is not None:
lbl = self._decode_target_transform(lbl)
return img, lbl
def minimal_augment(args, test=False):
"""Minimal Augmentation set for images.
Contains only resize, crop, flip, to tensor and normalize.
Args:
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
test (bool, optional): On the test set? Defaults to False.
Returns:
List: Augmentation list
"""
augs = []
augs.append(ToTensor())
if args.aug_resize:
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
if test and args.aug_crop:
augs.append(CenterCrop(args.imsize))
elif args.aug_crop:
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
if not test and args.aug_flip:
augs.append(RandomHorizontalFlip(p=0.5))
if args.aug_normalize:
augs.append(
Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225]),
)
)
return augs
def three_augment(args, as_list=False, test=False):
"""Create the data augmentation.
Parameters
----------
Args:
arguments
as_list : bool
return list of transformations, not composed transformation
test : bool
In eval mode? If False => train mode
Returns:
-------
torch.nn.Module | list[torch.nn.Module]
composed transformation of list of transformations
"""
augs = []
augs.append(ToTensor())
if args.aug_resize:
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
if test and args.aug_crop:
augs.append(CenterCrop(args.imsize))
elif args.aug_crop:
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
if not test:
if args.aug_flip:
augs.append(RandomHorizontalFlip(p=0.5))
augs_choice = []
if args.aug_grayscale:
augs_choice.append(Grayscale(num_output_channels=3))
if args.aug_solarize:
augs_choice.append(RandomSolarize(threshold=0.5, p=1.0))
if args.aug_gauss_blur:
# TODO: check kernel size?
augs_choice.append(GaussianBlur(kernel_size=7, sigma=(0.2, 2.0)))
# augs_choice.append(QuickGaussBlur())
if len(augs_choice) > 0:
augs.append(RandomChoice(augs_choice))
if args.aug_color_jitter_factor > 0.0:
augs.append(
ColorJitter(
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
)
)
if args.aug_normalize:
augs.append(
Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225]),
)
)
if as_list:
return augs
return Compose(augs)
def segment_augment(args, test=False):
"""Create the data augmentation for segmentation.
No cropping in this part, as cropping has to be done for the image and labels simultaneously.
Args:
args (DotDict): arguments
test (bool, optional): In eval mode? If False => train mode. Defaults to False.
Returns:
list[torch.nn.Module]: list of transformations
"""
augs = []
if test:
augs.append(ResizeUp(args.imsize))
augs.append(CenterCrop(args.imsize))
else:
augs.append(RandomResizedCrop(args.imsize, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)))
if not test and args.aug_flip:
augs.append(RandomHorizontalFlip(p=0.5))
if args.aug_color_jitter_factor > 0.0:
augs.append(
ColorJitter(
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
)
)
if args.aug_normalize:
augs.append(
Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225]),
)
)
return augs
class QuickGaussBlur:
"""Gaussian blur transformation using PIL ImageFilter."""
def __init__(self, sigma=(0.2, 2.0)):
"""Create Gaussian blur operator.
Args:
-----
sigma : tuple[float, float]
range of sigma for blur
"""
self.sigma_min, self.sigma_max = sigma
def __call__(self, img):
return img.filter(ImageFilter.GaussianBlur(radius=uniform(self.sigma_min, self.sigma_max)))
class RemoveTransform:
"""Remove data from transformation.
To use with default collate function.
"""
def __call__(self, x, y=None):
if y is None:
return [1]
return [1], y
def __repr__(self):
return f"{self.__class__.__name__}()"
def collate_imnet(data, image_key="image"):
"""Collate function for imagenet(1k / 21k) with datadings.
Args:
----
data : list[dict[str, Any]]
images for a batch
Returns:
-------
tuple[torch.Tensor, torch.Tensor]
images, labels
"""
if isinstance(data[0][image_key], torch.Tensor):
ims = torch.stack([d[image_key] for d in data], dim=0)
else:
ims = [d[image_key] for d in data]
labels = torch.tensor([d["label"] for d in data])
# keys = [d['key'] for d in data]
return ims, labels # , keys
def collate_listops(data):
"""Collate function for ListOps with datadings.
Args:
----
data : list[tuple[torch.Tensor, torch.Tensor]]
list of samples
Returns:
-------
tuple[torch.Tensor, torch.Tensor]
images, labels
"""
return data[0][0], data[0][1]
def no_param_transf(self, sample):
"""Call transformation without extra parameter.
To use with datadings QuasiShuffler.
Args:
----
self : object
use this as a method ( <obj>.<method_name> = MethodType(no_param_transf, <obj>) )
sample : Any
sample to transform
Returns:
-------
Any
transformed sample
"""
if isinstance(sample, tuple):
# sample of type (name (str), data (bytes encoded))
sample = sample[1]
if isinstance(sample, bytes):
# decode msgpack bytes
sample = msgpack.loads(sample)
params = self._rng(sample)
for k, f in self._transforms.items():
sample[k] = f(sample[k], params)
return sample
class ToOneHotSequence:
"""Convert a sequence of grayscale values (range 0 to 1) to a one-hot encoded sequence based on 8-bit rounded values."""
def __call__(self, x, y=None):
# x is 1 x 32 x 32
x = (255 * x).round().to(torch.int64).view(-1)
assert x.max() < 256, f"Found max value {x.max()} in {x}."
x = torch.nn.functional.one_hot(x, num_classes=256).float()
if y is None:
return x
return x, y
def __repr__(self):
return f"{self.__class__.__name__}()"
class ResizeUp(Resize):
"""Resize up if image is smaller than target size."""
def forward(self, img):
w, h = img.shape[-2], img.shape[-1]
if w < self.size or h < self.size:
img = super().forward(img)
return img

View File

@@ -0,0 +1,484 @@
import json
import os
import zipfile
from io import BytesIO
from math import floor
import numpy as np
import PIL
from datadings.torch import Compose
from loguru import logger
from PIL import Image, ImageFilter
from torch.utils.data import Dataset, get_worker_info
from torchvision import transforms as T
from tqdm.auto import tqdm
from data.data_utils import apply_dense_transforms
class ForNet(Dataset):
"""Recombine ImageNet forgrounds and backgrounds.
Note:
This dataset has exactly the ImageNet classes.
"""
_back_combs = ["same", "all", "original"]
_bg_transforms = {T.RandomCrop, T.CenterCrop, T.Resize, T.RandomResizedCrop}
def __init__(
self,
root,
transform=None,
train=True,
target_transform=None,
background_combination="all",
fg_scale_jitter=0.3,
fg_transform=None,
pruning_ratio=0.8,
return_fg_masks=False,
fg_size_mode="range",
fg_bates_n=1,
paste_pre_transform=True,
mask_smoothing_sigma=4.0,
rel_jut_out=0.0,
fg_in_nonant=None,
size_fact=1.0,
orig_img_prob=0.0,
orig_ds=None,
_orig_ds_file_type="JPEG",
epochs=0,
_album_compose=False,
):
"""Create RecombinationNet dataset.
Args:
root (str): Root folder for the dataset.
transform (T.Collate | list, optional): Transform to apply to the image. Defaults to None.
train (bool, optional): On the train set (False -> val set). Defaults to True.
target_transform (T.Collate | list, optional): Transform to apply to the target values. Defaults to None.
background_combination (str, optional): Which backgrounds to combine with foregrounds. Defaults to "same".
fg_scale_jitter (tuple, optional): How much should the size of the foreground be changed (random ratio). Defaults to (0.1, 0.8).
fg_transform (_type_, optional): Transform to apply to the foreground before applying to the background. This is supposed to be a random rotation, mainly. Defaults to None.
pruning_ratio (float, optional): For pruning backgrounds, with (foreground size/background size) >= <pruning_ratio>. Backgrounds from images that contain very large foreground objects are mostly computer generated and therefore relatively unnatural. Defaults to full dataset.
return_fg_masks (bool, optional): Return the foreground masks. Defaults to False.
fg_size_mode (str, optional): How to determine the size of the foreground, based on the foreground sizes of the foreground and background images. Defaults to "max".
fg_bates_n (int, optional): Bates parameter for the distribution of the object position in the foreground. Defaults to 1 (uniform distribution). The higher the value, the more likely the object is in the center. For fg_bates_n = 0, the object is always in the center.
paste_pre_transform (bool, optional): Paste the foreground onto the background before applying the transform. If false, the background will be cropped and resized before pasting the foreground. Defaults to False.
mask_smoothing_sigma (float, optional): Sigma for the Gaussian blur of the mask edge. Defaults to 0.0. Try 2.0 or 4.0?
rel_jut_out (float, optional): How much is the foreground allowed to stand/jut out of the background (and then cut off). Defaults to 0.0.
fg_in_nonant (int, optional): If not None, the foreground will be placed in a specific nonant (0-8) of the image. Defaults to None.
size_fact (float, optional): Factor to multiply the size of the foreground with. Defaults to 1.0.
orig_img_prob (float | str, optional): Probability to use the original image, instead of the fg-bg recombinations. Defaults to 0.0.
orig_ds (Dataset, optional): Original dataset (without transforms) to use for the original images. Defaults to None.
_orig_ds_file_type (str, optional): File type of the original dataset. Defaults to "JPEG".
epochs (int, optional): Number of epochs to train on. Used for linear increase of orig_img_prob.
Note:
For more information on the bates distribution, see https://en.wikipedia.org/wiki/Bates_distribution.
For fg_bats_n < 0, we take extend the bates dirstribution to focus more and more on the edges. This is done by sampling B ~ Bates(|fg_bates_n|) and then passing through f(x) = x + 0.5 - floor(x + 0.5).
For the list of transformations that will be applied to the background only (if paste_pre_transform=False), see RecombinationNet._bg_transforms.
A nonant in this case refers to a square in a 3x3 grid dividing the image.
"""
assert (
background_combination in self._back_combs
), f"background_combination={background_combination} is not supported. Use one of {self._back_combs}"
if (not os.path.exists(f"{root}/backgrounds_{'train' if train else 'val'}.zip")) and os.path.exists(
os.path.join(root, "train" if train else "val", "backgrounds")
):
self._mode = "folder"
else:
self._mode = "zip"
if self._mode == "zip":
try:
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip:
self.backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
with zipfile.ZipFile(f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r") as fg_zip:
self.foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
except FileNotFoundError as e:
logger.error(
f"RecombinationNet: {e}. Make sure to have the background and foreground zips in the root"
f" directory: found {os.listdir(root)}"
)
raise e
classes = set([f.split("/")[-2] for f in self.foregrounds])
else:
logger.info("ForNet folder mode: loading classes")
classes = set(os.listdir(os.path.join(root, "train" if train else "val", "foregrounds")))
foregrounds = []
backgrounds = []
for cls in tqdm(classes, desc="Loading files"):
foregrounds.extend(
[
f"{cls}/{f}"
for f in os.listdir(os.path.join(root, "train" if train else "val", "foregrounds", cls))
]
)
backgrounds.extend(
[
f"{cls}/{f}"
for f in os.listdir(os.path.join(root, "train" if train else "val", "backgrounds", cls))
]
)
self.foregrounds = foregrounds
self.backgrounds = backgrounds
self.classes = sorted(list(classes), key=lambda x: int(x[1:]))
assert os.path.exists(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json"), (
f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json not found, provide the information or set"
" pruning_ratio=1.0"
)
with open(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json", "r") as f:
self.fg_bg_ratios = json.load(f)
if self._mode == "folder":
self.fg_bg_ratios = {"/".join(key.split("/")[-2:]): val for key, val in self.fg_bg_ratios.items()}
logger.info(f"Renamed fg_bg_ratios keys to {list(self.fg_bg_ratios.keys())[:3]}...")
if pruning_ratio <= 1.0:
backup_backgrounds = {}
for bg_file in self.backgrounds:
bg_cls = bg_file.split("/")[-2]
backup_backgrounds[bg_cls] = bg_file
self.backgrounds = [
bg for bg in self.backgrounds if bg in self.fg_bg_ratios and self.fg_bg_ratios[bg] < pruning_ratio
]
# logger.info(
# f"RecombinationNet: keep {len(self.backgrounds)} of {len(self.fg_bg_ratios)} backgrounds with pr {pruning_ratio}"
# )
self.root = root
self.train = train
self.background_combination = background_combination
self.fg_scale_jitter = fg_scale_jitter
self.fg_transform = fg_transform
self.return_fg_masks = return_fg_masks
self.paste_pre_transform = paste_pre_transform
self.mask_smoothing_sigma = mask_smoothing_sigma
self.rel_jut_out = rel_jut_out
self.size_fact = size_fact
self.fg_in_nonant = fg_in_nonant
assert fg_in_nonant is None or -1 <= fg_in_nonant < 9, f"fg_in_nonant={fg_in_nonant} not in [0, 8] or None"
self.orig_img_prob = orig_img_prob
if orig_img_prob != 0.0:
assert (isinstance(orig_img_prob, float) and orig_img_prob > 0.0) or orig_img_prob in [
"linear",
"cos",
"revlinear",
]
assert orig_ds is not None, "orig_ds must be provided if orig_img_prob > 0.0"
assert not return_fg_masks, "can't provide fg masks for original images (yet)"
assert os.path.exists(os.path.join(root, f"{'train' if train else 'val'}_indices.json")) or isinstance(
orig_ds, str
), f"{root}/{'train' if train else 'val'}_indices.json must be provided if orig_ds is a dataset"
if not isinstance(orig_ds, str):
with open(os.path.join(root, f"{'train' if train else 'val'}_indices.json"), "r") as f:
self.key_to_orig_idx = json.load(f)
else:
if not (
orig_ds.endswith("train" if train else "val") or orig_ds.endswith("train/" if train else "val/")
):
orig_ds = f"{orig_ds}/{'train' if train else 'val'}"
self.key_to_orig_idx = None
self._orig_ds_file_type = _orig_ds_file_type
self.orig_ds = orig_ds
self.epochs = epochs
self._epoch = 0
assert fg_size_mode in [
"max",
"min",
"mean",
"range",
], f"fg_size_mode={fg_size_mode} not supported; use one of ['max', 'min', 'mean', 'range']"
self.fg_size_mode = fg_size_mode
self.fg_bates_n = fg_bates_n
if not paste_pre_transform:
if isinstance(transform, (T.Compose, Compose)):
transform = transform.transforms
# do cropping and resizing mainly on background; paste foreground on top later
self.bg_transform = []
self.join_transform = []
for tf in transform:
if isinstance(tf, tuple(self._bg_transforms)):
self.bg_transform.append(tf)
else:
self.join_transform.append(tf)
if _album_compose:
from data.album_transf import AlbumTorchCompose
self.bg_transform = AlbumTorchCompose(self.bg_transform)
self.join_transform = AlbumTorchCompose(self.join_transform)
else:
self.bg_transform = T.Compose(self.bg_transform)
self.join_transform = T.Compose(self.join_transform)
else:
if isinstance(transform, list):
if _album_compose:
from data.album_transf import AlbumTorchCompose
self.join_transform = AlbumTorchCompose(transform)
else:
self.join_transform = T.Compose(transform)
else:
self.join_transform = transform
self.bg_transform = None
self.trgt_map = {cls: i for i, cls in enumerate(self.classes)}
self.target_transform = target_transform
self.cls_to_allowed_bg = {}
for bg_file in self.backgrounds:
if background_combination == "same":
bg_cls = bg_file.split("/")[-2]
if bg_cls not in self.cls_to_allowed_bg:
self.cls_to_allowed_bg[bg_cls] = []
self.cls_to_allowed_bg[bg_cls].append(bg_file)
if background_combination == "same":
for cls_code in classes:
if cls_code not in self.cls_to_allowed_bg or len(self.cls_to_allowed_bg[cls_code]) == 0:
self.cls_to_allowed_bg[cls_code] = [backup_backgrounds[cls_code]]
logger.warning(f"No background for class {cls_code}, using {backup_backgrounds[cls_code]}")
self._zf = {}
@property
def epoch(self):
return self._epoch
@epoch.setter
def epoch(self, value):
self._epoch = value
def __len__(self):
"""Size of the dataset.
Returns:
int: number of foregrounds
"""
return len(self.foregrounds)
def num_classes(self):
return len(self.classes)
def _get_fg(self, idx):
worker_id = self._wrkr_info()
fg_file = self.foregrounds[idx]
with self._zf[worker_id]["fg"].open(fg_file) as f:
fg_data = BytesIO(f.read())
return Image.open(fg_data)
def _wrkr_info(self):
worker_id = get_worker_info().id if get_worker_info() else 0
if worker_id not in self._zf and self._mode == "zip":
self._zf[worker_id] = {
"bg": zipfile.ZipFile(f"{self.root}/backgrounds_{'train' if self.train else 'val'}.zip", "r"),
"fg": zipfile.ZipFile(f"{self.root}/foregrounds_{'train' if self.train else 'val'}.zip", "r"),
}
return worker_id
def __getitem__(self, idx):
"""Get the foreground at index idx and combine it with a (random) background.
Args:
idx (int): foreground index
Returns:
torch.Tensor, torch.Tensor: image, target
"""
worker_id = self._wrkr_info()
fg_file = self.foregrounds[idx]
trgt_cls = fg_file.split("/")[-2]
if (
(self.orig_img_prob == "linear" and np.random.rand() < self._epoch / self.epochs)
or (self.orig_img_prob == "revlinear" and np.random.rand() < (self._epoch - self.epochs) / self.epochs)
or (self.orig_img_prob == "cos" and np.random.rand() > np.cos(np.pi * self._epoch / (2 * self.epochs)))
or (
isinstance(self.orig_img_prob, float)
and self.orig_img_prob > 0.0
and np.random.rand() < self.orig_img_prob
)
):
data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
if isinstance(self.orig_ds, str):
image_file = os.path.join(self.orig_ds, f"{data_key}.{self._orig_ds_file_type}")
orig_img = Image.open(image_file).convert("RGB")
else:
orig_data = self.orig_ds[self.key_to_orig_idx[data_key]]
orig_img = orig_data["image"] if isinstance(orig_data, dict) else orig_data[0]
if self.bg_transform:
orig_img = self.bg_transform(orig_img)
if self.join_transform:
orig_img = self.join_transform(orig_img)
trgt = self.trgt_map[trgt_cls]
if self.target_transform:
trgt = self.target_transform(trgt)
return orig_img, trgt
if self._mode == "zip":
with self._zf[worker_id]["fg"].open(fg_file) as f:
fg_data = BytesIO(f.read())
try:
fg_img = Image.open(fg_data).convert("RGBA")
except PIL.UnidentifiedImageError as e:
logger.error(f"Error with idx={idx}, file={self.foregrounds[idx]}")
raise e
else:
# data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
fg_img = Image.open(
os.path.join(self.root, "train" if self.train else "val", "foregrounds", fg_file)
).convert("RGBA")
if self.fg_transform:
fg_img = self.fg_transform(fg_img)
fg_size_factor = T.ToTensor()(fg_img.split()[-1]).mean().item()
if self.background_combination == "all":
bg_idx = np.random.randint(len(self.backgrounds))
bg_file = self.backgrounds[bg_idx]
elif self.background_combination == "original":
bg_file = fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")
else:
bg_idx = np.random.randint(len(self.cls_to_allowed_bg[trgt_cls]))
bg_file = self.cls_to_allowed_bg[trgt_cls][bg_idx]
if self._mode == "zip":
with self._zf[worker_id]["bg"].open(bg_file) as f:
bg_data = BytesIO(f.read())
bg_img = Image.open(bg_data).convert("RGB")
else:
bg_img = Image.open(
os.path.join(self.root, "train" if self.train else "val", "backgrounds", bg_file)
).convert("RGB")
if not self.paste_pre_transform:
bg_img = self.bg_transform(bg_img)
bg_size = bg_img.size
# choose scale factor, such that relative area is in fg_scale
bg_area = bg_size[0] * bg_size[1]
if self.fg_in_nonant is not None:
bg_area = bg_area / 9
# logger.info(f"background: size={bg_size} area={bg_area}")
# logger.info(f"fg_file={fg_file}, fg_bg_ratio_keys={list(self.fg_bg_ratios.keys())[:3]}...")
orig_fg_ratio = self.fg_bg_ratios[fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")]
bg_fg_ratio = self.fg_bg_ratios[bg_file]
if self.fg_size_mode == "max":
goal_fg_ratio_lower = goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
elif self.fg_size_mode == "min":
goal_fg_ratio_lower = goal_fg_ratio_upper = min(orig_fg_ratio, bg_fg_ratio)
elif self.fg_size_mode == "mean":
goal_fg_ratio_lower = goal_fg_ratio_upper = (orig_fg_ratio + bg_fg_ratio) / 2
else:
# range
goal_fg_ratio_lower = min(orig_fg_ratio, bg_fg_ratio)
goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
# logger.info(f"fg_bg_ratio={orig_fg_ratio}")
fg_scale = (
np.random.uniform(
goal_fg_ratio_lower * (1 - self.fg_scale_jitter), goal_fg_ratio_upper * (1 + self.fg_scale_jitter)
)
/ fg_size_factor
* self.size_fact
)
goal_shape_y = round(np.sqrt(bg_area * fg_scale * fg_img.size[1] / fg_img.size[0]))
goal_shape_x = round(np.sqrt(bg_area * fg_scale * fg_img.size[0] / fg_img.size[1]))
fg_img = fg_img.resize((goal_shape_x, goal_shape_y))
if fg_img.size[0] > bg_size[0] or fg_img.size[1] > bg_size[1]:
# random crop to fit
goal_w, goal_h = (min(fg_img.size[0], bg_size[0]), min(fg_img.size[1], bg_size[1]))
fg_img = T.RandomCrop((goal_h, goal_w))(fg_img) if self.train else T.CenterCrop((goal_h, goal_w))(fg_img)
# paste fg on bg
z1, z2 = (
(
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(), # bates distribution n=1 => uniform
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(),
)
if self.fg_bates_n != 0
else (0.5, 0.5)
)
if self.fg_bates_n < 0:
z1 = z1 + 0.5 - floor(z1 + 0.5)
z2 = z2 + 0.5 - floor(z2 + 0.5)
x_min = -self.rel_jut_out * fg_img.size[0]
x_max = bg_size[0] - fg_img.size[0] * (1 - self.rel_jut_out)
y_min = -self.rel_jut_out * fg_img.size[1]
y_max = bg_size[1] - fg_img.size[1] * (1 - self.rel_jut_out)
if self.fg_in_nonant is not None and self.fg_in_nonant >= 0:
x_min = (self.fg_in_nonant % 3) * bg_size[0] / 3
x_max = ((self.fg_in_nonant % 3) + 1) * bg_size[0] / 3 - fg_img.size[0]
y_min = (self.fg_in_nonant // 3) * bg_size[1] / 3
y_max = ((self.fg_in_nonant // 3) + 1) * bg_size[1] / 3 - fg_img.size[1]
if x_min > x_max:
x_min = x_max = (x_min + x_max) / 2
if y_min > y_max:
y_min = y_max = (y_min + y_max) / 2
offs_x = round(z1 * (x_max - x_min) + x_min)
offs_y = round(z2 * (y_max - y_min) + y_min)
paste_mask = fg_img.split()[-1]
if self.mask_smoothing_sigma > 0.0:
sigma = (np.random.rand() * 0.9 + 0.1) * self.mask_smoothing_sigma
paste_mask = paste_mask.filter(ImageFilter.GaussianBlur(radius=sigma))
paste_mask = paste_mask.point(lambda p: 2 * p - 255 if p > 128 else 0)
bg_img.paste(fg_img.convert("RGB"), (offs_x, offs_y), paste_mask)
bg_img = bg_img.convert("RGB")
if self.return_fg_masks:
fg_mask = Image.new("L", bg_size, 0)
fg_mask.paste(paste_mask, (offs_x, offs_y))
fg_mask = T.ToTensor()(fg_mask)[0]
bg_img = T.ToTensor()(bg_img)
if self.join_transform:
# img_mask_stack = torch.cat([bg_img, fg_mask.unsqueeze(0)], dim=0)
# img_mask_stack = self.join_transform(img_mask_stack)
# bg_img, fg_mask = img_mask_stack[:-1], img_mask_stack[-1]
bg_img, fg_mask = apply_dense_transforms(bg_img, fg_mask, self.join_transform)
else:
bg_img = self.join_transform(bg_img)
if trgt_cls not in self.trgt_map:
raise ValueError(f"trgt_cls={trgt_cls} not in trgt_map: {self.trgt_map}")
trgt = self.trgt_map[trgt_cls]
if self.target_transform:
trgt = self.target_transform(trgt)
if self.return_fg_masks:
return bg_img, trgt, fg_mask
return bg_img, trgt

View File

@@ -0,0 +1,151 @@
import argparse
import shutil
import zipfile
from os import listdir, makedirs, path
from random import choice
from datadings.reader import MsgpackReader
from datadings.writer import FileWriter
from tqdm.auto import tqdm
parser = argparse.ArgumentParser()
parser.add_argument("-tiny_imagenet_zip", type=str, required=True, help="Path to the Tiny ImageNet zip file")
parser.add_argument("-output_dir", type=str, required=True, help="Directory to extract the image names to")
parser.add_argument("-in_segment_dir", type=str, required=True, help="Directory that holds the segmented ImageNet")
parser.add_argument(
"-imagenet_path", type=str, nargs="?", required=True, help="Path to the original ImageNet dataset (datadings)"
)
args = parser.parse_args()
images = {"train": set(), "val": set()}
with zipfile.ZipFile(args.tiny_imagenet_zip, "r") as zip_ref:
for info in tqdm(zip_ref.infolist(), desc="Gathering Images"):
if info.filename.endswith(".JPEG"):
if "/val/" in info.filename:
images["val"].add(info.filename.split("/")[-1])
elif "/train/" in info.filename:
images["train"].add(info.filename.split("/")[-1])
with open(path.join(args.output_dir, "tiny_imagenet_train_images.txt"), "w+") as f:
f.write("\n".join(images["train"]))
with open(path.join(args.output_dir, "tiny_imagenet_val_images.txt"), "w+") as f:
f.write("\n".join(images["val"]))
print(f"Found {len(images['train'])} training images and {len(images['val'])} validation images")
classes = {img_name.split("_")[0] for img_name in images["train"]}
classes = sorted(list(classes), key=lambda x: int(x[1:]))
assert len(classes) == 200, f"Expected 200 classes, found {len(classes)}"
assert len(images["train"]) == len(classes) * 500, f"Expected 100000 training images, found {len(images['train'])}"
assert len(images["val"]) == len(classes) * 50, f"Expected 10000 validation images, found {len(images['val'])}"
with open(path.join(args.output_dir, "tiny_imagenet_classes.txt"), "w+") as f:
f.write("\n".join(classes))
# copy over the relevant images
for split in ["train", "val"]:
ipc = 500 if split == "train" else 50
part = "foregrounds_WEBP"
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
for synset in classes:
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
tqdm.write(
f"skip {split} > {part} > {synset} with"
f" {len(listdir(path.join(args.output_dir, split, part, synset)))} ims"
)
pbar.update(ipc)
continue
for img in listdir(path.join(args.in_segment_dir, split, part, synset)):
orig_name = (
img.split(".")[0] + ".JPEG"
if split == "train"
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
)
if orig_name in images[split]:
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, img),
path.join(args.output_dir, split, part, synset, img),
)
pbar.update(1)
while len(listdir(path.join(args.output_dir, split, part, synset))) < min(
ipc, len(listdir(path.join(args.in_segment_dir, split, part, synset)))
):
# copy over more random images
image_names = [
(
img,
(
img.split(".")[0] + ".JPEG"
if split == "train"
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
),
)
for img in listdir(path.join(args.in_segment_dir, split, part, synset))
]
image_names = [
img for img in image_names if img[1] not in listdir(path.join(args.output_dir, split, part, synset))
]
img = choice(image_names)[0]
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, img),
path.join(args.output_dir, split, part, synset, img),
)
pbar.update(1)
tqdm.write(f"Extra image: {orig_name} to {split}/{part}/{synset}")
# copy over the background images corresponding to those foregrounds
part = "backgrounds_JPEG"
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
for synset in classes:
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
tqdm.write(f"skip {split} > {part} > {synset}")
pbar.update(ipc)
continue
for img in listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset)):
bg_name = img.replace(".WEBP", ".JPEG")
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, bg_name),
path.join(args.output_dir, split, part, synset, bg_name),
)
pbar.update(1)
assert len(listdir(path.join(args.output_dir, split, part, synset))) == len(
listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset))
), (
f"Expected {len(listdir(path.join(args.output_dir, split, 'foregrounds_WEBP', synset)))} background"
f" images, found {len(listdir(path.join(args.output_dir, split, part, synset)))}"
)
# write the original dataset to datadings
for part in ["train", "val"]:
reader = MsgpackReader(path.join(args.imagenet_path, f"{part}.msgpack"))
with FileWriter(path.join(args.output_dir, f"TinyIN_{part}.msgpack")) as writer:
for data in tqdm(reader, desc=f"Writing {part} to datadings"):
key = data["key"].split("/")[-1]
allowed_synsets = [key.split("_")[0]] if part == "train" else classes
if part == "train" and allowed_synsets[0] not in classes:
continue
keep_image = False
label_synset = None
for synset in allowed_synsets:
for img in listdir(path.join(args.output_dir, part, "foregrounds_WEBP", synset)):
if img.split(".")[0] == key.split(".")[0]:
keep_image = True
label_synset = synset
break
if not keep_image:
continue
data["label"] = classes.index(label_synset)
writer.write(data)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,48 @@
import argparse
import os
import json
import torch
parser = argparse.ArgumentParser("Script to convert ImageNet trained models to ImageNet-9")
parser.add_argument("-m", "--model", type=str, required=True, help="Model weights (.pt file).")
parser.add_argument(
"--in_to_in9", type=str, default="/ds-sds/images/ImageNet-9/in_to_in9.json", help="Path to in_to_in9.json"
)
args = parser.parse_args()
checkpoint = torch.load(args.model, map_location="cpu")
model_state = checkpoint["model_state"]
head_keys = [k for k in model_state.keys() if ".head." in k or ".fc." in k]
print("weights that will be modified:", head_keys)
assert len(head_keys) > 0, "no head keys found :("
with open(args.in_to_in9, "r") as f:
in_to_in9_classes = json.load(f)
print(f"{len([k for k, v in in_to_in9_classes.items() if v == -1])} classes get mapped to -1")
print("map", len(in_to_in9_classes), " classes to", set(in_to_in9_classes.values()))
print("Building conversion matrix")
conversion_matrix = torch.zeros((9, 1000))
for in_idx, in9_idx in in_to_in9_classes.items():
if in9_idx == -1:
continue
in_idx = int(in_idx)
conversion_matrix[in9_idx, in_idx] = 1
print(f"Conversion matrix ({conversion_matrix.shape}) has {int(conversion_matrix.sum().item())} non-zero values")
for head_key in head_keys:
print(f"converting {head_key} of shape {model_state[head_key].shape}", end=" ")
model_state[head_key] = conversion_matrix @ model_state[head_key]
print(f"\tto {model_state[head_key].shape}")
checkpoint["model_state"] = model_state
checkpoint["args"]["n_classes"] = 9
save_folder = os.path.dirname(args.model)
orig_model_name = args.model.split(os.sep)[-1]
new_model_name = ".".join(orig_model_name.split(".")[:-1]) + "_to_in9." + orig_model_name.split(".")[-1]
print(f"saving model as {new_model_name} in {save_folder}")
torch.save(checkpoint, os.path.join(save_folder, new_model_name))

View File

@@ -0,0 +1,200 @@
n07695742: pretzel
n03902125: pay-phone, pay-station
n03980874: poncho
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
n02730930: apron
n02699494: altar
n03201208: dining table, board
n02056570: king penguin, Aptenodytes patagonica
n04099969: rocking chair, rocker
n04366367: suspension bridge
n04067472: reel
n02808440: bathtub, bathing tub, bath, tub
n04540053: volleyball
n02403003: ox
n03100240: convertible
n04562935: water tower
n02788148: bannister, banister, balustrade, balusters, handrail
n02988304: CD player
n02423022: gazelle
n03637318: lampshade, lamp shade
n01774384: black widow, Latrodectus mactans
n01768244: trilobite
n07614500: ice cream, icecream
n04254777: sock
n02085620: Chihuahua
n01443537: goldfish, Carassius auratus
n01629819: European fire salamander, Salamandra salamandra
n02099601: golden retriever
n02321529: sea cucumber, holothurian
n03837869: obelisk
n02002724: black stork, Ciconia nigra
n02841315: binoculars, field glasses, opera glasses
n04560804: water jug
n02364673: guinea pig, Cavia cobaya
n03706229: magnetic compass
n09256479: coral reef
n09332890: lakeside, lakeshore
n03544143: hourglass
n02124075: Egyptian cat
n02948072: candle, taper, wax light
n01950731: sea slug, nudibranch
n02791270: barbershop
n03179701: desk
n02190166: fly
n04275548: spider web, spider's web
n04417672: thatch, thatched roof
n03930313: picket fence, paling
n02236044: mantis, mantid
n03976657: pole
n01774750: tarantula
n04376876: syringe
n04133789: sandal
n02099712: Labrador retriever
n04532670: viaduct
n04487081: trolleybus, trolley coach, trackless trolley
n09428293: seashore, coast, seacoast, sea-coast
n03160309: dam, dike, dyke
n03250847: drumstick
n02843684: birdhouse
n07768694: pomegranate
n03670208: limousine, limo
n03085013: computer keyboard, keypad
n02892201: brass, memorial tablet, plaque
n02233338: cockroach, roach
n03649909: lawn mower, mower
n03388043: fountain
n02917067: bullet train, bullet
n02486410: baboon
n04596742: wok
n03255030: dumbbell
n03937543: pill bottle
n02113799: standard poodle
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
n02906734: broom
n07920052: espresso
n01698640: American alligator, Alligator mississipiensis
n02123394: Persian cat
n03424325: gasmask, respirator, gas helmet
n02129165: lion, king of beasts, Panthera leo
n04008634: projectile, missile
n03042490: cliff dwelling
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
n02815834: beaker
n02395406: hog, pig, grunter, squealer, Sus scrofa
n01784675: centipede
n03126707: crane
n04399382: teddy, teddy bear
n07875152: potpie
n03733131: maypole
n02802426: basketball
n03891332: parking meter
n01910747: jellyfish
n03838899: oboe, hautboy, hautbois
n03770439: miniskirt, mini
n02281406: sulphur butterfly, sulfur butterfly
n03970156: plunger, plumber's helper
n09246464: cliff, drop, drop-off
n02206856: bee
n02074367: dugong, Dugong dugon
n03584254: iPod
n04179913: sewing machine
n04328186: stopwatch, stop watch
n07583066: guacamole
n01917289: brain coral
n03447447: gondola
n02823428: beer bottle
n03854065: organ, pipe organ
n02793495: barn
n04285008: sports car, sport car
n02231487: walking stick, walkingstick, stick insect
n04465501: tractor
n02814860: beacon, lighthouse, beacon light, pharos
n02883205: bow tie, bow-tie, bowtie
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
n04149813: scoreboard
n04023962: punching bag, punch bag, punching ball, punchball
n02226429: grasshopper, hopper
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
n02669723: academic gown, academic robe, judge's robe
n04486054: triumphal arch
n04070727: refrigerator, icebox
n03444034: go-kart
n02666196: abacus
n01945685: slug
n04251144: snorkel
n03617480: kimono
n03599486: jinrikisha, ricksha, rickshaw
n02437312: Arabian camel, dromedary, Camelus dromedarius
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
n04118538: rugby ball
n01770393: scorpion
n04356056: sunglasses, dark glasses, shades
n03804744: nail
n02132136: brown bear, bruin, Ursus arctos
n03400231: frying pan, frypan, skillet
n03983396: pop bottle, soda bottle
n07734744: mushroom
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
n02410509: bison
n03404251: fur coat
n04456115: torch
n02123045: tabby, tabby cat
n03026506: Christmas stocking
n07715103: cauliflower
n04398044: teapot
n02927161: butcher shop, meat market
n07749582: lemon
n07615774: ice lolly, lolly, lollipop, popsicle
n02795169: barrel, cask
n04532106: vestment
n02837789: bikini, two-piece
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
n04265275: space heater
n02481823: chimpanzee, chimp, Pan troglodytes
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
n06596364: comic book
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
n02504458: African elephant, Loxodonta africana
n03014705: chest
n01944390: snail
n04146614: school bus
n01641577: bullfrog, Rana catesbeiana
n07720875: bell pepper
n02999410: chain
n01855672: goose
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
n07753592: banana
n07871810: meat loaf, meatloaf
n04501370: turnstile
n04311004: steel arch bridge
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
n04074963: remote control, remote
n03662601: lifeboat
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
n03089624: confectionery, confectionary, candy store
n04259630: sombrero
n03393912: freight car
n04597913: wooden spoon
n07711569: mashed potato
n03355925: flagpole, flagstaff
n02963159: cardigan
n07579787: plate
n02950826: cannon
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
n02094433: Yorkshire terrier
n02909870: bucket, pail
n02058221: albatross, mollymawk
n01742172: boa constrictor, Constrictor constrictor
n09193705: alp
n04371430: swimming trunks, bathing trunks
n07747607: orange
n03814639: neck brace
n04507155: umbrella
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
n03763968: military uniform
n07873807: pizza, pizza pie
n03992509: potter's wheel
n03796401: moving van
n12267677: acorn

View File

@@ -0,0 +1,73 @@
# Repeat Augment sampler taken from DeiT: https://github.com/facebookresearch/deit/blob/main/samplers.py
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import math
import torch
import torch.distributed as dist
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU)
Heavily based on torch.utils.data.DistributedSampler
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if num_repeats < 1:
raise ValueError("num_repeats should be greater than 0")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_repeats = num_repeats
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g)
else:
indices = torch.arange(start=0, end=len(self.dataset))
# add extra samples to make it evenly divisible
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
padding_size: int = self.total_size - len(indices)
if padding_size > 0:
indices += indices[:padding_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[: self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch
def __str__(self) -> str:
return (
f"{type(self).__name__}(num_replicas: {self.num_replicas}, rank: {self.rank}, num_repeats:"
f" {self.num_repeats}, epoch: {self.epoch}, num_samples: {self.num_samples}, total_size: {self.total_size},"
f" num_selected_samples: {self.num_selected_samples}, shuffle: {self.shuffle})"
)

View File

@@ -0,0 +1,381 @@
import argparse
import json
import os
from copy import copy
from loguru import logger
from nltk.corpus import wordnet as wn
class bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def _lemmas_str(synset):
return ", ".join([lemma.name() for lemma in synset.lemmas()])
class WNEntry:
def __init__(
self,
name: str,
id: int,
lemmas: str,
parent_id: int,
depth: int = None,
in_image_net: bool = False,
child_ids: list = None,
in_main_tree: bool = True,
_n_images: int = 0,
_description: str = None,
_name: str = None,
_pruned: bool = False,
):
self.name = name
self.id = id
self.lemmas = lemmas
self.parent_id = parent_id
self.depth = depth
self.in_image_net = in_image_net
self.child_ids = child_ids
self.in_main_tree = in_main_tree
self._n_images = _n_images
self._description = _description
self._name = _name
self._pruned = _pruned
def __str__(self, tree=None, accumulate=True):
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}" if self.in_image_net else f"{bcolors.FAIL}-{bcolors.ENDC}"
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
if self.child_ids is None or tree is None:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
else:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
["\n ".join(tree.nodes[child_id].__str__(tree).split("\n")) for child_id in self.child_ids]
)
def tree_diff(self, tree_1, tree_2):
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
else:
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
n_ims = (
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
)
if self.child_ids is None:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
)
def prune(self, tree):
if self._pruned:
return
if self.child_ids is not None:
for child_id in self.child_ids:
tree[child_id].prune(tree)
self._pruned = True
parent_node = tree.nodes[self.parent_id]
try:
parent_node.child_ids.remove(self.id)
except ValueError as e:
print(
f"Error removing {self.name} from"
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
)
while parent_node._pruned:
parent_node = tree.nodes[parent_node.parent_id]
parent_node._n_images += self._n_images
self._n_images = 0
@property
def description(self):
if not self._description:
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
return self._description
@property
def print_name(self):
return self.name.split(".")[0]
def get_branch(self, tree=None):
if self.parent_id is None or tree is None:
return self.print_name
parent = tree.nodes[self.parent_id]
return parent.get_branch(tree) + " > " + self.print_name
def get_branch_list(self, tree):
if self.parent_id is None:
return [self]
parent = tree.nodes[self.parent_id]
return parent.get_branch_list(tree) + [self]
def to_dict(self):
return {
"name": self.name,
"id": self.id,
"lemmas": self.lemmas,
"parent_id": self.parent_id,
"depth": self.depth,
"in_image_net": self.in_image_net,
"child_ids": self.child_ids,
"in_main_tree": self.in_main_tree,
"_n_images": self._n_images,
"_description": self._description,
"_name": self._name,
"_pruned": self._pruned,
}
@classmethod
def from_dict(cls, d):
return cls(**d)
def n_images(self, tree=None):
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
return self._n_images
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
def n_children(self, tree=None):
if self.child_ids is None:
return 0
if tree is None or len(self.child_ids) == 0:
return len(self.child_ids)
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
def get_examples(self, tree, n_examples=3):
if self.child_ids is None or len(self.child_ids) == 0:
return ""
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
max_images = max(child_images.values())
if max_images == 0:
# go on number of child nodes
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
# sorted childids by number of images
top_children = [
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
]
top_children = top_children[: min(n_examples, len(top_children))]
return ", ".join(
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
)
class WNTree:
def __init__(self, root=1740, nodes=None):
if isinstance(root, int):
root_id = root
root_synset = wn.synset_from_pos_and_offset("n", root)
root_node = WNEntry(
root_synset.name(),
root_id,
_lemmas_str(root_synset),
parent_id=None,
depth=0,
)
else:
assert isinstance(root, WNEntry)
root_id = root.id
root_node = root
self.root = root_node
self.nodes = {root_id: self.root} if nodes is None else nodes
self.parentless = []
self.label_index = None
self.pruned = set()
def to_dict(self):
return {
"root": self.root.to_dict(),
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
"parentless": self.parentless,
"pruned": list(self.pruned),
}
def prune(self, min_images):
pruned_nodes = set()
# prune all nodes that have fewer than min_images below them
for node_id, node in self.nodes.items():
if node.n_images(self) < min_images:
pruned_nodes.add(node_id)
node.prune(self)
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
node_stack = [self.root]
node_idx = 0
while node_idx < len(node_stack):
node = node_stack[node_idx]
if node.child_ids is not None:
for child_id in node.child_ids:
child = self.nodes[child_id]
node_stack.append(child)
node_idx += 1
# now prune the stack from the bottom up
for node in node_stack[::-1]:
# only look at images of that class, not of additional children
if node.n_images() < min_images:
pruned_nodes.add(node.id)
node.prune(self)
self.pruned = pruned_nodes
return pruned_nodes
@classmethod
def from_dict(cls, d):
tree = cls()
tree.root = WNEntry.from_dict(d["root"])
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
tree.parentless = d["parentless"]
if "pruned" in d:
tree.pruned = set(d["pruned"])
return tree
def add_node(self, node_id, in_in=True):
if node_id in self.nodes:
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
return
synset = wn.synset_from_pos_and_offset("n", node_id)
# print(f"adding node {synset.name()} with id {node_id}")
hypernyms = synset.hypernyms()
if len(hypernyms) == 0:
parent_id = None
self.parentless.append(node_id)
main_tree = False
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
else:
parent_id = synset.hypernyms()[0].offset()
if parent_id not in self.nodes:
self.add_node(parent_id, in_in=False)
parent = self.nodes[parent_id]
if parent.child_ids is None:
parent.child_ids = []
parent.child_ids.append(node_id)
main_tree = parent.in_main_tree
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
node = WNEntry(
synset.name(),
node_id,
_lemmas_str(synset),
parent_id=parent_id,
in_image_net=in_in,
depth=depth,
in_main_tree=main_tree,
)
self.nodes[node_id] = node
def __len__(self):
return len(self.nodes)
def image_net_len(self, only_main_tree=False):
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def max_depth(self, only_main_tree=False):
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def __str__(self):
return (
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self)}\nParentless:\n"
+ "\n".join([self.nodes[node_id].__str__(tree=self) for node_id in self.parentless])
)
def save(self, path):
with open(path, "w") as f:
json.dump(self.to_dict(), f)
@classmethod
def load(cls, path):
with open(path, "r") as f:
tree_dict = json.load(f)
return cls.from_dict(tree_dict)
def subtree(self, node_id):
if node_id not in self.nodes:
return None
node_queue = [self.nodes[node_id]]
subtree_ids = set()
while len(node_queue) > 0:
node = node_queue.pop(0)
subtree_ids.add(node.id)
if node.child_ids is not None:
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
subtree_root = subtree_nodes[node_id]
subtree_root.parent_id = None
depth_diff = subtree_root.depth
for node in subtree_nodes.values():
node.depth -= depth_diff
return WNTree(root=subtree_root, nodes=subtree_nodes)
def _make_label_index(self, include_merged=False):
self.label_index = sorted(
[
node_id
for node_id, node in self.nodes.items()
if node.n_images(self if include_merged else None) > 0 and not node._pruned
]
)
def get_label(self, node_id):
if self.label_index is None:
self._make_label_index()
while self.nodes[node_id]._pruned:
node_id = self.nodes[node_id].parent_id
return self.label_index.index(node_id)
def n_labels(self):
if self.label_index is None:
self._make_label_index()
return len(self.label_index)
def __contains__(self, item):
if isinstance(item, str):
if item[0] == "n":
item = int(item[1:])
else:
return False
if isinstance(item, int):
return item in self.nodes
if isinstance(item, WNEntry):
return item.id in self.nodes
return False
def __getitem__(self, item):
if isinstance(item, str) and item[0].startswith("n"):
try:
item = int(item[1:])
except ValueError:
pass
if isinstance(item, str) and ".n." in item:
for node in self.nodes.values():
if item == node.name:
return node
raise KeyError(f"Item {item} not found in tree")
if isinstance(item, int):
return self.nodes[item]
if isinstance(item, WNEntry):
return self.nodes[item.id]
raise KeyError(f"Item {item} not found in tree")