AAAI Version
This commit is contained in:
@@ -0,0 +1,142 @@
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import numpy as np
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from datadings.torch import CompressedToPIL
|
||||
|
||||
|
||||
class AlbumTorchCompose(A.Compose):
|
||||
"""Compose albumentation augmentations in a way that works with PIL images and datadings."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Pass to A.Compose."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.to_pil = CompressedToPIL()
|
||||
|
||||
def __call__(self, image, mask=None, **kwargs):
|
||||
if isinstance(image, bytes):
|
||||
image = self.to_pil(image)
|
||||
if mask is not None and len(mask) == 0:
|
||||
mask = None
|
||||
if not isinstance(image, np.ndarray):
|
||||
image = np.array(image)
|
||||
if mask is not None and not isinstance(mask, np.ndarray):
|
||||
mask = np.array(mask)
|
||||
if mask is None:
|
||||
return super().__call__(image=image, **kwargs)["image"]
|
||||
return super().__call__(image=image, mask=mask, **kwargs)
|
||||
|
||||
|
||||
class PILToNP(A.DualTransform):
|
||||
"""Convert PIL image to numpy array."""
|
||||
|
||||
def apply(self, image, **params):
|
||||
return np.array(image)
|
||||
|
||||
def apply_to_mask(self, image, **params):
|
||||
return np.array(image)
|
||||
|
||||
def get_transform_init_args_names(self):
|
||||
return ()
|
||||
|
||||
|
||||
class AlbumCompressedToPIL(A.DualTransform):
|
||||
"""Convert compressed image to PIL image."""
|
||||
|
||||
def apply(self, img, **params):
|
||||
return self.to_pil(img)
|
||||
|
||||
def apply_to_mask(self, img, **params):
|
||||
return self.to_pil(img)
|
||||
|
||||
def get_transform_init_args_names(self):
|
||||
return ()
|
||||
|
||||
|
||||
def minimal_augment(args, test=False):
|
||||
"""Get minimal augmentations for training or testing.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): arguments
|
||||
test (bool, optional): if True, return test augmentations. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List: Augmentation list
|
||||
"""
|
||||
augs = []
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(A.CenterCrop(args.imsize, args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(A.RandomCrop(args.imsize, args.imsize))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(A.HorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
|
||||
|
||||
augs.append(ToTensorV2())
|
||||
return augs
|
||||
|
||||
|
||||
def three_augment(args, as_list=False, test=False):
|
||||
"""Create the data augmentation.
|
||||
|
||||
Args:
|
||||
args (Namespace): arguments
|
||||
as_list (bool): return list of transformations, not composed transformation
|
||||
test (bool): In eval mode? If False => train mode
|
||||
|
||||
Returns:
|
||||
torch.nn.Module | list[torch.nn.Module]: composed transformation or list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(A.CenterCrop(args.imsize, args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(A.RandomCrop(args.imsize, args.imsize, pad_if_needed=True, border_mode=cv2.BORDER_REFLECT))
|
||||
|
||||
if not test:
|
||||
if args.aug_flip:
|
||||
augs.append(A.HorizontalFlip(p=0.5))
|
||||
|
||||
augs_choice = []
|
||||
if args.aug_grayscale:
|
||||
augs_choice.append(A.ToGray(p=1, num_output_channels=3))
|
||||
|
||||
if args.aug_solarize:
|
||||
augs_choice.append(A.Solarize(p=1, threshold_range=(0.5, 0.5)))
|
||||
|
||||
if args.aug_gauss_blur:
|
||||
augs_choice.append(A.GaussianBlur(p=1, sigma_limit=(0.2, 2.0), blur_limit=(7, 7)))
|
||||
|
||||
if len(augs_choice) > 0:
|
||||
augs.append(A.OneOf(augs_choice))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
A.ColorJitter(
|
||||
brightness=args.aug_color_jitter_factor,
|
||||
contrast=args.aug_color_jitter_factor,
|
||||
saturation=args.aug_color_jitter_factor,
|
||||
hue=0.0,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
|
||||
|
||||
augs.append(ToTensorV2())
|
||||
|
||||
if as_list:
|
||||
return augs
|
||||
return AlbumTorchCompose(augs)
|
||||
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class CounterAnimal(Dataset):
|
||||
"""Dataset to load the CounterAnimal dataset with ImageNet labels."""
|
||||
|
||||
def __init__(self, base_path, mode, transform=None, target_transform=None, train=False):
|
||||
"""Create the dataset.
|
||||
|
||||
Args:
|
||||
base_path (str): path to the base folder (the one where the class folders are in)
|
||||
mode (str): mode/variant of the dataset (common/counter)
|
||||
transform: Image augmentation
|
||||
target_transform: label augmentation
|
||||
train: train or test set. Train set is not supported
|
||||
"""
|
||||
super().__init__()
|
||||
self.base = base_path
|
||||
assert mode in ["counter", "common"], f"Supported modes are counter and common, but got '{mode}'"
|
||||
assert not train, "CounterAnimal only consists of test data, not training data."
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
self.index = []
|
||||
for class_folder in os.listdir(self.base):
|
||||
if not os.path.isdir(os.path.join(self.base, class_folder)):
|
||||
continue
|
||||
# print(f"looking in folder {class_folder}")
|
||||
class_idx = int(class_folder.split(" ")[0])
|
||||
for variant_folder in os.listdir(os.path.join(self.base, class_folder)):
|
||||
# print(f"\tlooking in variant {variant_folder}")
|
||||
if not variant_folder.startswith(mode):
|
||||
# print("\t\tskip")
|
||||
continue
|
||||
|
||||
_folder = os.path.join(self.base, class_folder, variant_folder)
|
||||
# print(f"\t\tadding {len(os.listdir(_folder))} files to index")
|
||||
for file in os.listdir(_folder):
|
||||
if file.lower().split(".")[-1] in ["jpg", "jpeg", "png"]:
|
||||
self.index.append((os.path.join(_folder, file), class_idx))
|
||||
|
||||
# print(f"loaded {len(self.index)} images into the index: {self.index[:5]}...")
|
||||
assert len(self.index) > 0, "did not find any images :("
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path, label = self.index[idx]
|
||||
|
||||
img = Image.open(path).convert("RGB")
|
||||
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform:
|
||||
label = self.target_transform(label)
|
||||
|
||||
return img, label
|
||||
@@ -0,0 +1,145 @@
|
||||
from nvidia.dali import fn, pipeline_def, types
|
||||
|
||||
# see https://docs.nvidia.com/deeplearning/dali/user-guide/docs/plugins/pytorch_dali_proxy.html
|
||||
|
||||
|
||||
@pipeline_def
|
||||
def minimal_augment(args, test=False):
|
||||
"""Minimal Augmentation set for images.
|
||||
|
||||
Contains only resize, crop, flip, to tensor and normalize.
|
||||
|
||||
Args:
|
||||
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
|
||||
test (bool, optional): On the test set? Defaults to False.
|
||||
|
||||
Returns:
|
||||
images: augmented images.
|
||||
|
||||
"""
|
||||
images = fn.external_source(name="images", no_copy=True)
|
||||
|
||||
if args.aug_resize:
|
||||
images = fn.resize(images, size=args.imsize, mode="not_smaller")
|
||||
|
||||
if test and args.aug_crop:
|
||||
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
|
||||
elif args.aug_crop:
|
||||
images = fn.crop(
|
||||
images,
|
||||
crop=(args.imsize, args.imsize),
|
||||
crop_pos_x=fn.random.uniform(range=(0, 1)),
|
||||
crop_pos_y=fn.random.uniform(range=(0, 1)),
|
||||
)
|
||||
|
||||
# if not test and args.aug_flip:
|
||||
# images = fn.flip(images, horizontal=fn.random.coin_flip())
|
||||
|
||||
# if args.aug_normalize:
|
||||
# images = fn.normalize(
|
||||
# images,
|
||||
# mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
||||
# std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
|
||||
# dtype=types.FLOAT,
|
||||
# )
|
||||
return fn.crop_mirror_normalize(
|
||||
images,
|
||||
dtype=types.FLOAT,
|
||||
output_layout="CHW",
|
||||
crop=(args.imsize, args.imsize),
|
||||
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
|
||||
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
|
||||
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
|
||||
)
|
||||
|
||||
|
||||
def dali_solarize(images, threshold=128):
|
||||
"""Solarize implementation for nvidia DALI.
|
||||
|
||||
Args:
|
||||
images (DALI Tensor): Images to solarize.
|
||||
threshold (int, optional): Threshold for solarization. Defaults to 128.
|
||||
|
||||
Returns:
|
||||
images: solarized images.
|
||||
|
||||
"""
|
||||
inv_images = types.Constant(255).uint8() - images
|
||||
mask = (images >= threshold) * types.Constant(1).uint8()
|
||||
return mask * inv_images + (types.Constant(1).uint8() ^ mask) * images
|
||||
|
||||
|
||||
@pipeline_def(enable_conditionals=True)
|
||||
def three_augment(args, test=False):
|
||||
"""3-augment data augmentation pipeline for nvidia DALI.
|
||||
|
||||
Args:
|
||||
args (namespace): augmentation arguments.
|
||||
test (bool, optional): Test (or train) split. Defaults to False.
|
||||
|
||||
Returns:
|
||||
images: augmented images.
|
||||
|
||||
"""
|
||||
images = fn.external_source(name="images", no_copy=True)
|
||||
|
||||
if args.aug_resize:
|
||||
images = fn.resize(images, size=args.imsize, mode="not_smaller")
|
||||
|
||||
if test and args.aug_crop:
|
||||
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
|
||||
elif args.aug_crop:
|
||||
images = fn.crop(
|
||||
images,
|
||||
crop=(args.imsize, args.imsize),
|
||||
crop_pos_x=fn.random.uniform(range=(0, 1)),
|
||||
crop_pos_y=fn.random.uniform(range=(0, 1)),
|
||||
)
|
||||
|
||||
if not test:
|
||||
choices = []
|
||||
# choice = fn.random.choice(3)
|
||||
# print(images.layout())
|
||||
choice_ps = [1 * args.aug_grayscale, 1 * args.aug_solarize, 1 * args.aug_gauss_blur]
|
||||
choice_ps = [c / sum(choice_ps) for c in choice_ps]
|
||||
choice = fn.random.choice(
|
||||
[0, 1, 2],
|
||||
p=choice_ps,
|
||||
)
|
||||
|
||||
if choice == 0:
|
||||
images = fn.color_space_conversion(
|
||||
fn.color_space_conversion(images, image_type=types.RGB, output_type=types.GRAY),
|
||||
image_type=types.GRAY,
|
||||
output_type=types.RGB,
|
||||
)
|
||||
|
||||
elif choice == 1:
|
||||
images = dali_solarize(images, threshold=128)
|
||||
elif choice == 2:
|
||||
images = fn.gaussian_blur(images, window_size=7, sigma=fn.random.uniform(range=(0.2, 2.0)))
|
||||
|
||||
if len(choices) > 0:
|
||||
images = fn.random.choice(choices)
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
images = fn.color_twist(
|
||||
images,
|
||||
brightness=fn.random.uniform(
|
||||
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
|
||||
),
|
||||
contrast=fn.random.uniform(range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)),
|
||||
saturation=fn.random.uniform(
|
||||
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
|
||||
),
|
||||
)
|
||||
|
||||
return fn.crop_mirror_normalize(
|
||||
images,
|
||||
dtype=types.FLOAT,
|
||||
output_layout="CHW",
|
||||
crop=(args.imsize, args.imsize),
|
||||
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
|
||||
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
|
||||
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
|
||||
)
|
||||
@@ -0,0 +1,407 @@
|
||||
from random import uniform
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import torchvision
|
||||
from datadings.torch import CompressedToPIL
|
||||
from datadings.torch import Dataset as DDDataset
|
||||
from PIL import ImageFilter
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
ColorJitter,
|
||||
Compose,
|
||||
GaussianBlur,
|
||||
Grayscale,
|
||||
InterpolationMode,
|
||||
Normalize,
|
||||
RandomChoice,
|
||||
RandomCrop,
|
||||
RandomHorizontalFlip,
|
||||
RandomResizedCrop,
|
||||
RandomSolarize,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
_image_and_target_transforms = [
|
||||
torchvision.transforms.RandomCrop,
|
||||
torchvision.transforms.RandomHorizontalFlip,
|
||||
torchvision.transforms.CenterCrop,
|
||||
torchvision.transforms.RandomRotation,
|
||||
torchvision.transforms.RandomAffine,
|
||||
torchvision.transforms.RandomResizedCrop,
|
||||
torchvision.transforms.RandomRotation,
|
||||
]
|
||||
|
||||
|
||||
def apply_dense_transforms(x, y, transforms: torchvision.transforms.transforms.Compose):
|
||||
"""Apply some transfomations to both image and target.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): image
|
||||
y (torch.Tensor): target (image)
|
||||
transforms (torchvision.transforms.transforms.Compose): transformations to apply
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: (x, y) with applyed transformations
|
||||
|
||||
"""
|
||||
for trans in transforms.transforms:
|
||||
if isinstance(trans, torchvision.transforms.RandomResizedCrop):
|
||||
params = trans.get_params(x, trans.scale, trans.ratio)
|
||||
x = F.resized_crop(x, *params, trans.size, trans.interpolation, antialias=trans.antialias)
|
||||
y = F.resized_crop(y.unsqueeze(0), *params, trans.size, 0).squeeze(0) # nearest neighbor interpolation
|
||||
elif isinstance(trans, Resize):
|
||||
pre_shape = x.shape
|
||||
x = trans(x)
|
||||
if x.shape != pre_shape:
|
||||
y = F.resize(y.unsqueeze(0), trans.size, 0, trans.max_size, trans.antialias).squeeze(
|
||||
0
|
||||
) # nearest neighbor interpolation
|
||||
elif any(isinstance(trans, simul_transform) for simul_transform in _image_and_target_transforms):
|
||||
xy = torch.cat([x, y.unsqueeze(0).float()], dim=0)
|
||||
xy = trans(xy)
|
||||
x, y = xy[:-1], xy[-1].long()
|
||||
elif isinstance(trans, torchvision.transforms.ToTensor):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = trans(x)
|
||||
else:
|
||||
x = trans(x)
|
||||
|
||||
return x, y
|
||||
|
||||
|
||||
def get_hf_transform(transform_f, trgt_transform_f=None, image_key="image"):
|
||||
"""Convert the transform function to a huggingface compatible transform function.
|
||||
|
||||
Args:
|
||||
transform_f (callable): Image transform.
|
||||
trgt_transform (callable, optional): Target transform. Defaults to None.
|
||||
image_key (str, optional): Key for the image in the hf ds return dict. Defaults to "image".
|
||||
"""
|
||||
|
||||
def _transform(samples):
|
||||
try:
|
||||
samples[image_key] = [transform_f(im) for im in samples[image_key]]
|
||||
if trgt_transform_f is not None:
|
||||
samples["label"] = [trgt_transform_f(tgt) for tgt in samples["label"]]
|
||||
except TypeError as e:
|
||||
print(f"Type error when transforming samples: {samples}")
|
||||
raise e
|
||||
return samples
|
||||
|
||||
return _transform
|
||||
|
||||
|
||||
class DDDecodeDataset(DDDataset):
|
||||
"""Datadings dataset with image decoding before transform."""
|
||||
|
||||
def __init__(self, *args, transform=None, target_transform=None, transforms=None, **kwargs):
|
||||
"""Create datadings dataset.
|
||||
|
||||
Args:
|
||||
transform (callable, optional): Image transform. Overrides transforms['image']. Defaults to None.
|
||||
target_transform (callable, optional): Label transform. Overrides transforms['label']. Defaults to None.
|
||||
transforms (dict[str, callable], optional): Dict of transforms for each key. Defaults to None.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
if transforms is None:
|
||||
transforms = {}
|
||||
self._decode_transform = transform if transform is not None else transforms.get("image", None)
|
||||
self._decode_target_transform = (
|
||||
target_transform if target_transform is not None else transforms.get("label", None)
|
||||
)
|
||||
|
||||
self.ctp = CompressedToPIL()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = super().__getitem__(idx)
|
||||
img, lbl = sample["image"], sample["label"]
|
||||
if isinstance(img, bytes):
|
||||
img = self.ctp(img)
|
||||
if self._decode_transform is not None:
|
||||
img = self._decode_transform(img)
|
||||
if self._decode_target_transform is not None:
|
||||
lbl = self._decode_target_transform(lbl)
|
||||
return img, lbl
|
||||
|
||||
|
||||
def minimal_augment(args, test=False):
|
||||
"""Minimal Augmentation set for images.
|
||||
|
||||
Contains only resize, crop, flip, to tensor and normalize.
|
||||
|
||||
Args:
|
||||
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
|
||||
test (bool, optional): On the test set? Defaults to False.
|
||||
|
||||
Returns:
|
||||
List: Augmentation list
|
||||
|
||||
"""
|
||||
augs = []
|
||||
augs.append(ToTensor())
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
def three_augment(args, as_list=False, test=False):
|
||||
"""Create the data augmentation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Args:
|
||||
arguments
|
||||
as_list : bool
|
||||
return list of transformations, not composed transformation
|
||||
test : bool
|
||||
In eval mode? If False => train mode
|
||||
|
||||
Returns:
|
||||
-------
|
||||
torch.nn.Module | list[torch.nn.Module]
|
||||
composed transformation of list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
augs.append(ToTensor())
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
|
||||
|
||||
if not test:
|
||||
if args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
augs_choice = []
|
||||
if args.aug_grayscale:
|
||||
augs_choice.append(Grayscale(num_output_channels=3))
|
||||
if args.aug_solarize:
|
||||
augs_choice.append(RandomSolarize(threshold=0.5, p=1.0))
|
||||
if args.aug_gauss_blur:
|
||||
# TODO: check kernel size?
|
||||
augs_choice.append(GaussianBlur(kernel_size=7, sigma=(0.2, 2.0)))
|
||||
# augs_choice.append(QuickGaussBlur())
|
||||
|
||||
if len(augs_choice) > 0:
|
||||
augs.append(RandomChoice(augs_choice))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
ColorJitter(
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
if as_list:
|
||||
return augs
|
||||
return Compose(augs)
|
||||
|
||||
|
||||
def segment_augment(args, test=False):
|
||||
"""Create the data augmentation for segmentation.
|
||||
|
||||
No cropping in this part, as cropping has to be done for the image and labels simultaneously.
|
||||
|
||||
Args:
|
||||
args (DotDict): arguments
|
||||
test (bool, optional): In eval mode? If False => train mode. Defaults to False.
|
||||
|
||||
Returns:
|
||||
list[torch.nn.Module]: list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
|
||||
if test:
|
||||
augs.append(ResizeUp(args.imsize))
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
else:
|
||||
augs.append(RandomResizedCrop(args.imsize, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
ColorJitter(
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
class QuickGaussBlur:
|
||||
"""Gaussian blur transformation using PIL ImageFilter."""
|
||||
|
||||
def __init__(self, sigma=(0.2, 2.0)):
|
||||
"""Create Gaussian blur operator.
|
||||
|
||||
Args:
|
||||
-----
|
||||
sigma : tuple[float, float]
|
||||
range of sigma for blur
|
||||
|
||||
"""
|
||||
self.sigma_min, self.sigma_max = sigma
|
||||
|
||||
def __call__(self, img):
|
||||
return img.filter(ImageFilter.GaussianBlur(radius=uniform(self.sigma_min, self.sigma_max)))
|
||||
|
||||
|
||||
class RemoveTransform:
|
||||
"""Remove data from transformation.
|
||||
|
||||
To use with default collate function.
|
||||
"""
|
||||
|
||||
def __call__(self, x, y=None):
|
||||
if y is None:
|
||||
return [1]
|
||||
return [1], y
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
def collate_imnet(data, image_key="image"):
|
||||
"""Collate function for imagenet(1k / 21k) with datadings.
|
||||
|
||||
Args:
|
||||
----
|
||||
data : list[dict[str, Any]]
|
||||
images for a batch
|
||||
|
||||
Returns:
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
images, labels
|
||||
|
||||
"""
|
||||
if isinstance(data[0][image_key], torch.Tensor):
|
||||
ims = torch.stack([d[image_key] for d in data], dim=0)
|
||||
else:
|
||||
ims = [d[image_key] for d in data]
|
||||
labels = torch.tensor([d["label"] for d in data])
|
||||
# keys = [d['key'] for d in data]
|
||||
return ims, labels # , keys
|
||||
|
||||
|
||||
def collate_listops(data):
|
||||
"""Collate function for ListOps with datadings.
|
||||
|
||||
Args:
|
||||
----
|
||||
data : list[tuple[torch.Tensor, torch.Tensor]]
|
||||
list of samples
|
||||
|
||||
Returns:
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
images, labels
|
||||
|
||||
"""
|
||||
return data[0][0], data[0][1]
|
||||
|
||||
|
||||
def no_param_transf(self, sample):
|
||||
"""Call transformation without extra parameter.
|
||||
|
||||
To use with datadings QuasiShuffler.
|
||||
|
||||
Args:
|
||||
----
|
||||
self : object
|
||||
use this as a method ( <obj>.<method_name> = MethodType(no_param_transf, <obj>) )
|
||||
sample : Any
|
||||
sample to transform
|
||||
|
||||
Returns:
|
||||
-------
|
||||
Any
|
||||
transformed sample
|
||||
|
||||
"""
|
||||
if isinstance(sample, tuple):
|
||||
# sample of type (name (str), data (bytes encoded))
|
||||
sample = sample[1]
|
||||
if isinstance(sample, bytes):
|
||||
# decode msgpack bytes
|
||||
sample = msgpack.loads(sample)
|
||||
params = self._rng(sample)
|
||||
for k, f in self._transforms.items():
|
||||
sample[k] = f(sample[k], params)
|
||||
return sample
|
||||
|
||||
|
||||
class ToOneHotSequence:
|
||||
"""Convert a sequence of grayscale values (range 0 to 1) to a one-hot encoded sequence based on 8-bit rounded values."""
|
||||
|
||||
def __call__(self, x, y=None):
|
||||
# x is 1 x 32 x 32
|
||||
x = (255 * x).round().to(torch.int64).view(-1)
|
||||
assert x.max() < 256, f"Found max value {x.max()} in {x}."
|
||||
x = torch.nn.functional.one_hot(x, num_classes=256).float()
|
||||
if y is None:
|
||||
return x
|
||||
return x, y
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
class ResizeUp(Resize):
|
||||
"""Resize up if image is smaller than target size."""
|
||||
|
||||
def forward(self, img):
|
||||
w, h = img.shape[-2], img.shape[-1]
|
||||
if w < self.size or h < self.size:
|
||||
img = super().forward(img)
|
||||
return img
|
||||
484
AAAI Supplementary Material/Model Training Code/data/fornet.py
Normal file
484
AAAI Supplementary Material/Model Training Code/data/fornet.py
Normal file
@@ -0,0 +1,484 @@
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from math import floor
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from datadings.torch import Compose
|
||||
from loguru import logger
|
||||
from PIL import Image, ImageFilter
|
||||
from torch.utils.data import Dataset, get_worker_info
|
||||
from torchvision import transforms as T
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from data.data_utils import apply_dense_transforms
|
||||
|
||||
|
||||
class ForNet(Dataset):
|
||||
"""Recombine ImageNet forgrounds and backgrounds.
|
||||
|
||||
Note:
|
||||
This dataset has exactly the ImageNet classes.
|
||||
|
||||
"""
|
||||
|
||||
_back_combs = ["same", "all", "original"]
|
||||
_bg_transforms = {T.RandomCrop, T.CenterCrop, T.Resize, T.RandomResizedCrop}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
transform=None,
|
||||
train=True,
|
||||
target_transform=None,
|
||||
background_combination="all",
|
||||
fg_scale_jitter=0.3,
|
||||
fg_transform=None,
|
||||
pruning_ratio=0.8,
|
||||
return_fg_masks=False,
|
||||
fg_size_mode="range",
|
||||
fg_bates_n=1,
|
||||
paste_pre_transform=True,
|
||||
mask_smoothing_sigma=4.0,
|
||||
rel_jut_out=0.0,
|
||||
fg_in_nonant=None,
|
||||
size_fact=1.0,
|
||||
orig_img_prob=0.0,
|
||||
orig_ds=None,
|
||||
_orig_ds_file_type="JPEG",
|
||||
epochs=0,
|
||||
_album_compose=False,
|
||||
):
|
||||
"""Create RecombinationNet dataset.
|
||||
|
||||
Args:
|
||||
root (str): Root folder for the dataset.
|
||||
transform (T.Collate | list, optional): Transform to apply to the image. Defaults to None.
|
||||
train (bool, optional): On the train set (False -> val set). Defaults to True.
|
||||
target_transform (T.Collate | list, optional): Transform to apply to the target values. Defaults to None.
|
||||
background_combination (str, optional): Which backgrounds to combine with foregrounds. Defaults to "same".
|
||||
fg_scale_jitter (tuple, optional): How much should the size of the foreground be changed (random ratio). Defaults to (0.1, 0.8).
|
||||
fg_transform (_type_, optional): Transform to apply to the foreground before applying to the background. This is supposed to be a random rotation, mainly. Defaults to None.
|
||||
pruning_ratio (float, optional): For pruning backgrounds, with (foreground size/background size) >= <pruning_ratio>. Backgrounds from images that contain very large foreground objects are mostly computer generated and therefore relatively unnatural. Defaults to full dataset.
|
||||
return_fg_masks (bool, optional): Return the foreground masks. Defaults to False.
|
||||
fg_size_mode (str, optional): How to determine the size of the foreground, based on the foreground sizes of the foreground and background images. Defaults to "max".
|
||||
fg_bates_n (int, optional): Bates parameter for the distribution of the object position in the foreground. Defaults to 1 (uniform distribution). The higher the value, the more likely the object is in the center. For fg_bates_n = 0, the object is always in the center.
|
||||
paste_pre_transform (bool, optional): Paste the foreground onto the background before applying the transform. If false, the background will be cropped and resized before pasting the foreground. Defaults to False.
|
||||
mask_smoothing_sigma (float, optional): Sigma for the Gaussian blur of the mask edge. Defaults to 0.0. Try 2.0 or 4.0?
|
||||
rel_jut_out (float, optional): How much is the foreground allowed to stand/jut out of the background (and then cut off). Defaults to 0.0.
|
||||
fg_in_nonant (int, optional): If not None, the foreground will be placed in a specific nonant (0-8) of the image. Defaults to None.
|
||||
size_fact (float, optional): Factor to multiply the size of the foreground with. Defaults to 1.0.
|
||||
orig_img_prob (float | str, optional): Probability to use the original image, instead of the fg-bg recombinations. Defaults to 0.0.
|
||||
orig_ds (Dataset, optional): Original dataset (without transforms) to use for the original images. Defaults to None.
|
||||
_orig_ds_file_type (str, optional): File type of the original dataset. Defaults to "JPEG".
|
||||
epochs (int, optional): Number of epochs to train on. Used for linear increase of orig_img_prob.
|
||||
|
||||
Note:
|
||||
For more information on the bates distribution, see https://en.wikipedia.org/wiki/Bates_distribution.
|
||||
For fg_bats_n < 0, we take extend the bates dirstribution to focus more and more on the edges. This is done by sampling B ~ Bates(|fg_bates_n|) and then passing through f(x) = x + 0.5 - floor(x + 0.5).
|
||||
|
||||
For the list of transformations that will be applied to the background only (if paste_pre_transform=False), see RecombinationNet._bg_transforms.
|
||||
|
||||
A nonant in this case refers to a square in a 3x3 grid dividing the image.
|
||||
|
||||
"""
|
||||
assert (
|
||||
background_combination in self._back_combs
|
||||
), f"background_combination={background_combination} is not supported. Use one of {self._back_combs}"
|
||||
|
||||
if (not os.path.exists(f"{root}/backgrounds_{'train' if train else 'val'}.zip")) and os.path.exists(
|
||||
os.path.join(root, "train" if train else "val", "backgrounds")
|
||||
):
|
||||
self._mode = "folder"
|
||||
else:
|
||||
self._mode = "zip"
|
||||
|
||||
if self._mode == "zip":
|
||||
try:
|
||||
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip:
|
||||
self.backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
|
||||
with zipfile.ZipFile(f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r") as fg_zip:
|
||||
self.foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
|
||||
except FileNotFoundError as e:
|
||||
logger.error(
|
||||
f"RecombinationNet: {e}. Make sure to have the background and foreground zips in the root"
|
||||
f" directory: found {os.listdir(root)}"
|
||||
)
|
||||
raise e
|
||||
classes = set([f.split("/")[-2] for f in self.foregrounds])
|
||||
else:
|
||||
logger.info("ForNet folder mode: loading classes")
|
||||
classes = set(os.listdir(os.path.join(root, "train" if train else "val", "foregrounds")))
|
||||
foregrounds = []
|
||||
backgrounds = []
|
||||
for cls in tqdm(classes, desc="Loading files"):
|
||||
foregrounds.extend(
|
||||
[
|
||||
f"{cls}/{f}"
|
||||
for f in os.listdir(os.path.join(root, "train" if train else "val", "foregrounds", cls))
|
||||
]
|
||||
)
|
||||
backgrounds.extend(
|
||||
[
|
||||
f"{cls}/{f}"
|
||||
for f in os.listdir(os.path.join(root, "train" if train else "val", "backgrounds", cls))
|
||||
]
|
||||
)
|
||||
self.foregrounds = foregrounds
|
||||
self.backgrounds = backgrounds
|
||||
|
||||
self.classes = sorted(list(classes), key=lambda x: int(x[1:]))
|
||||
|
||||
assert os.path.exists(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json"), (
|
||||
f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json not found, provide the information or set"
|
||||
" pruning_ratio=1.0"
|
||||
)
|
||||
with open(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json", "r") as f:
|
||||
self.fg_bg_ratios = json.load(f)
|
||||
if self._mode == "folder":
|
||||
self.fg_bg_ratios = {"/".join(key.split("/")[-2:]): val for key, val in self.fg_bg_ratios.items()}
|
||||
logger.info(f"Renamed fg_bg_ratios keys to {list(self.fg_bg_ratios.keys())[:3]}...")
|
||||
|
||||
if pruning_ratio <= 1.0:
|
||||
backup_backgrounds = {}
|
||||
for bg_file in self.backgrounds:
|
||||
bg_cls = bg_file.split("/")[-2]
|
||||
backup_backgrounds[bg_cls] = bg_file
|
||||
self.backgrounds = [
|
||||
bg for bg in self.backgrounds if bg in self.fg_bg_ratios and self.fg_bg_ratios[bg] < pruning_ratio
|
||||
]
|
||||
# logger.info(
|
||||
# f"RecombinationNet: keep {len(self.backgrounds)} of {len(self.fg_bg_ratios)} backgrounds with pr {pruning_ratio}"
|
||||
# )
|
||||
|
||||
self.root = root
|
||||
self.train = train
|
||||
self.background_combination = background_combination
|
||||
self.fg_scale_jitter = fg_scale_jitter
|
||||
self.fg_transform = fg_transform
|
||||
self.return_fg_masks = return_fg_masks
|
||||
self.paste_pre_transform = paste_pre_transform
|
||||
self.mask_smoothing_sigma = mask_smoothing_sigma
|
||||
self.rel_jut_out = rel_jut_out
|
||||
self.size_fact = size_fact
|
||||
self.fg_in_nonant = fg_in_nonant
|
||||
assert fg_in_nonant is None or -1 <= fg_in_nonant < 9, f"fg_in_nonant={fg_in_nonant} not in [0, 8] or None"
|
||||
|
||||
self.orig_img_prob = orig_img_prob
|
||||
if orig_img_prob != 0.0:
|
||||
assert (isinstance(orig_img_prob, float) and orig_img_prob > 0.0) or orig_img_prob in [
|
||||
"linear",
|
||||
"cos",
|
||||
"revlinear",
|
||||
]
|
||||
assert orig_ds is not None, "orig_ds must be provided if orig_img_prob > 0.0"
|
||||
assert not return_fg_masks, "can't provide fg masks for original images (yet)"
|
||||
assert os.path.exists(os.path.join(root, f"{'train' if train else 'val'}_indices.json")) or isinstance(
|
||||
orig_ds, str
|
||||
), f"{root}/{'train' if train else 'val'}_indices.json must be provided if orig_ds is a dataset"
|
||||
if not isinstance(orig_ds, str):
|
||||
with open(os.path.join(root, f"{'train' if train else 'val'}_indices.json"), "r") as f:
|
||||
self.key_to_orig_idx = json.load(f)
|
||||
else:
|
||||
if not (
|
||||
orig_ds.endswith("train" if train else "val") or orig_ds.endswith("train/" if train else "val/")
|
||||
):
|
||||
orig_ds = f"{orig_ds}/{'train' if train else 'val'}"
|
||||
self.key_to_orig_idx = None
|
||||
self._orig_ds_file_type = _orig_ds_file_type
|
||||
self.orig_ds = orig_ds
|
||||
self.epochs = epochs
|
||||
self._epoch = 0
|
||||
|
||||
assert fg_size_mode in [
|
||||
"max",
|
||||
"min",
|
||||
"mean",
|
||||
"range",
|
||||
], f"fg_size_mode={fg_size_mode} not supported; use one of ['max', 'min', 'mean', 'range']"
|
||||
self.fg_size_mode = fg_size_mode
|
||||
self.fg_bates_n = fg_bates_n
|
||||
|
||||
if not paste_pre_transform:
|
||||
if isinstance(transform, (T.Compose, Compose)):
|
||||
transform = transform.transforms
|
||||
|
||||
# do cropping and resizing mainly on background; paste foreground on top later
|
||||
self.bg_transform = []
|
||||
self.join_transform = []
|
||||
for tf in transform:
|
||||
if isinstance(tf, tuple(self._bg_transforms)):
|
||||
self.bg_transform.append(tf)
|
||||
else:
|
||||
self.join_transform.append(tf)
|
||||
|
||||
if _album_compose:
|
||||
from data.album_transf import AlbumTorchCompose
|
||||
|
||||
self.bg_transform = AlbumTorchCompose(self.bg_transform)
|
||||
self.join_transform = AlbumTorchCompose(self.join_transform)
|
||||
else:
|
||||
self.bg_transform = T.Compose(self.bg_transform)
|
||||
self.join_transform = T.Compose(self.join_transform)
|
||||
|
||||
else:
|
||||
if isinstance(transform, list):
|
||||
if _album_compose:
|
||||
from data.album_transf import AlbumTorchCompose
|
||||
|
||||
self.join_transform = AlbumTorchCompose(transform)
|
||||
else:
|
||||
self.join_transform = T.Compose(transform)
|
||||
else:
|
||||
self.join_transform = transform
|
||||
self.bg_transform = None
|
||||
|
||||
self.trgt_map = {cls: i for i, cls in enumerate(self.classes)}
|
||||
|
||||
self.target_transform = target_transform
|
||||
|
||||
self.cls_to_allowed_bg = {}
|
||||
for bg_file in self.backgrounds:
|
||||
if background_combination == "same":
|
||||
bg_cls = bg_file.split("/")[-2]
|
||||
if bg_cls not in self.cls_to_allowed_bg:
|
||||
self.cls_to_allowed_bg[bg_cls] = []
|
||||
self.cls_to_allowed_bg[bg_cls].append(bg_file)
|
||||
|
||||
if background_combination == "same":
|
||||
for cls_code in classes:
|
||||
if cls_code not in self.cls_to_allowed_bg or len(self.cls_to_allowed_bg[cls_code]) == 0:
|
||||
self.cls_to_allowed_bg[cls_code] = [backup_backgrounds[cls_code]]
|
||||
logger.warning(f"No background for class {cls_code}, using {backup_backgrounds[cls_code]}")
|
||||
|
||||
self._zf = {}
|
||||
|
||||
@property
|
||||
def epoch(self):
|
||||
return self._epoch
|
||||
|
||||
@epoch.setter
|
||||
def epoch(self, value):
|
||||
self._epoch = value
|
||||
|
||||
def __len__(self):
|
||||
"""Size of the dataset.
|
||||
|
||||
Returns:
|
||||
int: number of foregrounds
|
||||
|
||||
"""
|
||||
return len(self.foregrounds)
|
||||
|
||||
def num_classes(self):
|
||||
return len(self.classes)
|
||||
|
||||
def _get_fg(self, idx):
|
||||
worker_id = self._wrkr_info()
|
||||
|
||||
fg_file = self.foregrounds[idx]
|
||||
with self._zf[worker_id]["fg"].open(fg_file) as f:
|
||||
fg_data = BytesIO(f.read())
|
||||
return Image.open(fg_data)
|
||||
|
||||
def _wrkr_info(self):
|
||||
worker_id = get_worker_info().id if get_worker_info() else 0
|
||||
|
||||
if worker_id not in self._zf and self._mode == "zip":
|
||||
self._zf[worker_id] = {
|
||||
"bg": zipfile.ZipFile(f"{self.root}/backgrounds_{'train' if self.train else 'val'}.zip", "r"),
|
||||
"fg": zipfile.ZipFile(f"{self.root}/foregrounds_{'train' if self.train else 'val'}.zip", "r"),
|
||||
}
|
||||
return worker_id
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get the foreground at index idx and combine it with a (random) background.
|
||||
|
||||
Args:
|
||||
idx (int): foreground index
|
||||
|
||||
Returns:
|
||||
torch.Tensor, torch.Tensor: image, target
|
||||
|
||||
"""
|
||||
worker_id = self._wrkr_info()
|
||||
fg_file = self.foregrounds[idx]
|
||||
trgt_cls = fg_file.split("/")[-2]
|
||||
|
||||
if (
|
||||
(self.orig_img_prob == "linear" and np.random.rand() < self._epoch / self.epochs)
|
||||
or (self.orig_img_prob == "revlinear" and np.random.rand() < (self._epoch - self.epochs) / self.epochs)
|
||||
or (self.orig_img_prob == "cos" and np.random.rand() > np.cos(np.pi * self._epoch / (2 * self.epochs)))
|
||||
or (
|
||||
isinstance(self.orig_img_prob, float)
|
||||
and self.orig_img_prob > 0.0
|
||||
and np.random.rand() < self.orig_img_prob
|
||||
)
|
||||
):
|
||||
data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
|
||||
if isinstance(self.orig_ds, str):
|
||||
image_file = os.path.join(self.orig_ds, f"{data_key}.{self._orig_ds_file_type}")
|
||||
orig_img = Image.open(image_file).convert("RGB")
|
||||
else:
|
||||
orig_data = self.orig_ds[self.key_to_orig_idx[data_key]]
|
||||
orig_img = orig_data["image"] if isinstance(orig_data, dict) else orig_data[0]
|
||||
|
||||
if self.bg_transform:
|
||||
orig_img = self.bg_transform(orig_img)
|
||||
if self.join_transform:
|
||||
orig_img = self.join_transform(orig_img)
|
||||
trgt = self.trgt_map[trgt_cls]
|
||||
if self.target_transform:
|
||||
trgt = self.target_transform(trgt)
|
||||
return orig_img, trgt
|
||||
|
||||
if self._mode == "zip":
|
||||
with self._zf[worker_id]["fg"].open(fg_file) as f:
|
||||
fg_data = BytesIO(f.read())
|
||||
try:
|
||||
fg_img = Image.open(fg_data).convert("RGBA")
|
||||
except PIL.UnidentifiedImageError as e:
|
||||
logger.error(f"Error with idx={idx}, file={self.foregrounds[idx]}")
|
||||
raise e
|
||||
else:
|
||||
# data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
|
||||
fg_img = Image.open(
|
||||
os.path.join(self.root, "train" if self.train else "val", "foregrounds", fg_file)
|
||||
).convert("RGBA")
|
||||
|
||||
if self.fg_transform:
|
||||
fg_img = self.fg_transform(fg_img)
|
||||
fg_size_factor = T.ToTensor()(fg_img.split()[-1]).mean().item()
|
||||
|
||||
if self.background_combination == "all":
|
||||
bg_idx = np.random.randint(len(self.backgrounds))
|
||||
bg_file = self.backgrounds[bg_idx]
|
||||
elif self.background_combination == "original":
|
||||
bg_file = fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")
|
||||
else:
|
||||
bg_idx = np.random.randint(len(self.cls_to_allowed_bg[trgt_cls]))
|
||||
bg_file = self.cls_to_allowed_bg[trgt_cls][bg_idx]
|
||||
|
||||
if self._mode == "zip":
|
||||
with self._zf[worker_id]["bg"].open(bg_file) as f:
|
||||
bg_data = BytesIO(f.read())
|
||||
bg_img = Image.open(bg_data).convert("RGB")
|
||||
else:
|
||||
bg_img = Image.open(
|
||||
os.path.join(self.root, "train" if self.train else "val", "backgrounds", bg_file)
|
||||
).convert("RGB")
|
||||
|
||||
if not self.paste_pre_transform:
|
||||
bg_img = self.bg_transform(bg_img)
|
||||
|
||||
bg_size = bg_img.size
|
||||
|
||||
# choose scale factor, such that relative area is in fg_scale
|
||||
bg_area = bg_size[0] * bg_size[1]
|
||||
if self.fg_in_nonant is not None:
|
||||
bg_area = bg_area / 9
|
||||
|
||||
# logger.info(f"background: size={bg_size} area={bg_area}")
|
||||
# logger.info(f"fg_file={fg_file}, fg_bg_ratio_keys={list(self.fg_bg_ratios.keys())[:3]}...")
|
||||
orig_fg_ratio = self.fg_bg_ratios[fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")]
|
||||
bg_fg_ratio = self.fg_bg_ratios[bg_file]
|
||||
|
||||
if self.fg_size_mode == "max":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
|
||||
elif self.fg_size_mode == "min":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = min(orig_fg_ratio, bg_fg_ratio)
|
||||
elif self.fg_size_mode == "mean":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = (orig_fg_ratio + bg_fg_ratio) / 2
|
||||
else:
|
||||
# range
|
||||
goal_fg_ratio_lower = min(orig_fg_ratio, bg_fg_ratio)
|
||||
goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
|
||||
|
||||
# logger.info(f"fg_bg_ratio={orig_fg_ratio}")
|
||||
fg_scale = (
|
||||
np.random.uniform(
|
||||
goal_fg_ratio_lower * (1 - self.fg_scale_jitter), goal_fg_ratio_upper * (1 + self.fg_scale_jitter)
|
||||
)
|
||||
/ fg_size_factor
|
||||
* self.size_fact
|
||||
)
|
||||
|
||||
goal_shape_y = round(np.sqrt(bg_area * fg_scale * fg_img.size[1] / fg_img.size[0]))
|
||||
goal_shape_x = round(np.sqrt(bg_area * fg_scale * fg_img.size[0] / fg_img.size[1]))
|
||||
|
||||
fg_img = fg_img.resize((goal_shape_x, goal_shape_y))
|
||||
|
||||
if fg_img.size[0] > bg_size[0] or fg_img.size[1] > bg_size[1]:
|
||||
# random crop to fit
|
||||
goal_w, goal_h = (min(fg_img.size[0], bg_size[0]), min(fg_img.size[1], bg_size[1]))
|
||||
fg_img = T.RandomCrop((goal_h, goal_w))(fg_img) if self.train else T.CenterCrop((goal_h, goal_w))(fg_img)
|
||||
|
||||
# paste fg on bg
|
||||
z1, z2 = (
|
||||
(
|
||||
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(), # bates distribution n=1 => uniform
|
||||
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(),
|
||||
)
|
||||
if self.fg_bates_n != 0
|
||||
else (0.5, 0.5)
|
||||
)
|
||||
if self.fg_bates_n < 0:
|
||||
z1 = z1 + 0.5 - floor(z1 + 0.5)
|
||||
z2 = z2 + 0.5 - floor(z2 + 0.5)
|
||||
|
||||
x_min = -self.rel_jut_out * fg_img.size[0]
|
||||
x_max = bg_size[0] - fg_img.size[0] * (1 - self.rel_jut_out)
|
||||
y_min = -self.rel_jut_out * fg_img.size[1]
|
||||
y_max = bg_size[1] - fg_img.size[1] * (1 - self.rel_jut_out)
|
||||
|
||||
if self.fg_in_nonant is not None and self.fg_in_nonant >= 0:
|
||||
x_min = (self.fg_in_nonant % 3) * bg_size[0] / 3
|
||||
x_max = ((self.fg_in_nonant % 3) + 1) * bg_size[0] / 3 - fg_img.size[0]
|
||||
y_min = (self.fg_in_nonant // 3) * bg_size[1] / 3
|
||||
y_max = ((self.fg_in_nonant // 3) + 1) * bg_size[1] / 3 - fg_img.size[1]
|
||||
|
||||
if x_min > x_max:
|
||||
x_min = x_max = (x_min + x_max) / 2
|
||||
if y_min > y_max:
|
||||
y_min = y_max = (y_min + y_max) / 2
|
||||
|
||||
offs_x = round(z1 * (x_max - x_min) + x_min)
|
||||
offs_y = round(z2 * (y_max - y_min) + y_min)
|
||||
|
||||
paste_mask = fg_img.split()[-1]
|
||||
if self.mask_smoothing_sigma > 0.0:
|
||||
sigma = (np.random.rand() * 0.9 + 0.1) * self.mask_smoothing_sigma
|
||||
paste_mask = paste_mask.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
paste_mask = paste_mask.point(lambda p: 2 * p - 255 if p > 128 else 0)
|
||||
|
||||
bg_img.paste(fg_img.convert("RGB"), (offs_x, offs_y), paste_mask)
|
||||
bg_img = bg_img.convert("RGB")
|
||||
|
||||
if self.return_fg_masks:
|
||||
fg_mask = Image.new("L", bg_size, 0)
|
||||
fg_mask.paste(paste_mask, (offs_x, offs_y))
|
||||
|
||||
fg_mask = T.ToTensor()(fg_mask)[0]
|
||||
|
||||
bg_img = T.ToTensor()(bg_img)
|
||||
|
||||
if self.join_transform:
|
||||
# img_mask_stack = torch.cat([bg_img, fg_mask.unsqueeze(0)], dim=0)
|
||||
# img_mask_stack = self.join_transform(img_mask_stack)
|
||||
# bg_img, fg_mask = img_mask_stack[:-1], img_mask_stack[-1]
|
||||
bg_img, fg_mask = apply_dense_transforms(bg_img, fg_mask, self.join_transform)
|
||||
else:
|
||||
bg_img = self.join_transform(bg_img)
|
||||
|
||||
if trgt_cls not in self.trgt_map:
|
||||
raise ValueError(f"trgt_cls={trgt_cls} not in trgt_map: {self.trgt_map}")
|
||||
trgt = self.trgt_map[trgt_cls]
|
||||
if self.target_transform:
|
||||
trgt = self.target_transform(trgt)
|
||||
|
||||
if self.return_fg_masks:
|
||||
return bg_img, trgt, fg_mask
|
||||
|
||||
return bg_img, trgt
|
||||
@@ -0,0 +1,151 @@
|
||||
import argparse
|
||||
import shutil
|
||||
import zipfile
|
||||
from os import listdir, makedirs, path
|
||||
from random import choice
|
||||
|
||||
from datadings.reader import MsgpackReader
|
||||
from datadings.writer import FileWriter
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-tiny_imagenet_zip", type=str, required=True, help="Path to the Tiny ImageNet zip file")
|
||||
parser.add_argument("-output_dir", type=str, required=True, help="Directory to extract the image names to")
|
||||
parser.add_argument("-in_segment_dir", type=str, required=True, help="Directory that holds the segmented ImageNet")
|
||||
parser.add_argument(
|
||||
"-imagenet_path", type=str, nargs="?", required=True, help="Path to the original ImageNet dataset (datadings)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
images = {"train": set(), "val": set()}
|
||||
|
||||
with zipfile.ZipFile(args.tiny_imagenet_zip, "r") as zip_ref:
|
||||
for info in tqdm(zip_ref.infolist(), desc="Gathering Images"):
|
||||
if info.filename.endswith(".JPEG"):
|
||||
if "/val/" in info.filename:
|
||||
images["val"].add(info.filename.split("/")[-1])
|
||||
elif "/train/" in info.filename:
|
||||
images["train"].add(info.filename.split("/")[-1])
|
||||
|
||||
with open(path.join(args.output_dir, "tiny_imagenet_train_images.txt"), "w+") as f:
|
||||
f.write("\n".join(images["train"]))
|
||||
with open(path.join(args.output_dir, "tiny_imagenet_val_images.txt"), "w+") as f:
|
||||
f.write("\n".join(images["val"]))
|
||||
|
||||
print(f"Found {len(images['train'])} training images and {len(images['val'])} validation images")
|
||||
classes = {img_name.split("_")[0] for img_name in images["train"]}
|
||||
|
||||
classes = sorted(list(classes), key=lambda x: int(x[1:]))
|
||||
assert len(classes) == 200, f"Expected 200 classes, found {len(classes)}"
|
||||
assert len(images["train"]) == len(classes) * 500, f"Expected 100000 training images, found {len(images['train'])}"
|
||||
assert len(images["val"]) == len(classes) * 50, f"Expected 10000 validation images, found {len(images['val'])}"
|
||||
with open(path.join(args.output_dir, "tiny_imagenet_classes.txt"), "w+") as f:
|
||||
f.write("\n".join(classes))
|
||||
|
||||
# copy over the relevant images
|
||||
for split in ["train", "val"]:
|
||||
ipc = 500 if split == "train" else 50
|
||||
part = "foregrounds_WEBP"
|
||||
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
|
||||
for synset in classes:
|
||||
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
|
||||
tqdm.write(
|
||||
f"skip {split} > {part} > {synset} with"
|
||||
f" {len(listdir(path.join(args.output_dir, split, part, synset)))} ims"
|
||||
)
|
||||
pbar.update(ipc)
|
||||
continue
|
||||
for img in listdir(path.join(args.in_segment_dir, split, part, synset)):
|
||||
orig_name = (
|
||||
img.split(".")[0] + ".JPEG"
|
||||
if split == "train"
|
||||
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
|
||||
)
|
||||
if orig_name in images[split]:
|
||||
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
shutil.copy(
|
||||
path.join(args.in_segment_dir, split, part, synset, img),
|
||||
path.join(args.output_dir, split, part, synset, img),
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
while len(listdir(path.join(args.output_dir, split, part, synset))) < min(
|
||||
ipc, len(listdir(path.join(args.in_segment_dir, split, part, synset)))
|
||||
):
|
||||
# copy over more random images
|
||||
image_names = [
|
||||
(
|
||||
img,
|
||||
(
|
||||
img.split(".")[0] + ".JPEG"
|
||||
if split == "train"
|
||||
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
|
||||
),
|
||||
)
|
||||
for img in listdir(path.join(args.in_segment_dir, split, part, synset))
|
||||
]
|
||||
image_names = [
|
||||
img for img in image_names if img[1] not in listdir(path.join(args.output_dir, split, part, synset))
|
||||
]
|
||||
img = choice(image_names)[0]
|
||||
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
shutil.copy(
|
||||
path.join(args.in_segment_dir, split, part, synset, img),
|
||||
path.join(args.output_dir, split, part, synset, img),
|
||||
)
|
||||
pbar.update(1)
|
||||
tqdm.write(f"Extra image: {orig_name} to {split}/{part}/{synset}")
|
||||
|
||||
# copy over the background images corresponding to those foregrounds
|
||||
part = "backgrounds_JPEG"
|
||||
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
|
||||
for synset in classes:
|
||||
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
|
||||
tqdm.write(f"skip {split} > {part} > {synset}")
|
||||
pbar.update(ipc)
|
||||
continue
|
||||
for img in listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset)):
|
||||
bg_name = img.replace(".WEBP", ".JPEG")
|
||||
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
shutil.copy(
|
||||
path.join(args.in_segment_dir, split, part, synset, bg_name),
|
||||
path.join(args.output_dir, split, part, synset, bg_name),
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
assert len(listdir(path.join(args.output_dir, split, part, synset))) == len(
|
||||
listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset))
|
||||
), (
|
||||
f"Expected {len(listdir(path.join(args.output_dir, split, 'foregrounds_WEBP', synset)))} background"
|
||||
f" images, found {len(listdir(path.join(args.output_dir, split, part, synset)))}"
|
||||
)
|
||||
|
||||
# write the original dataset to datadings
|
||||
for part in ["train", "val"]:
|
||||
reader = MsgpackReader(path.join(args.imagenet_path, f"{part}.msgpack"))
|
||||
with FileWriter(path.join(args.output_dir, f"TinyIN_{part}.msgpack")) as writer:
|
||||
for data in tqdm(reader, desc=f"Writing {part} to datadings"):
|
||||
key = data["key"].split("/")[-1]
|
||||
allowed_synsets = [key.split("_")[0]] if part == "train" else classes
|
||||
|
||||
if part == "train" and allowed_synsets[0] not in classes:
|
||||
continue
|
||||
|
||||
keep_image = False
|
||||
label_synset = None
|
||||
for synset in allowed_synsets:
|
||||
for img in listdir(path.join(args.output_dir, part, "foregrounds_WEBP", synset)):
|
||||
if img.split(".")[0] == key.split(".")[0]:
|
||||
keep_image = True
|
||||
label_synset = synset
|
||||
break
|
||||
|
||||
if not keep_image:
|
||||
continue
|
||||
|
||||
data["label"] = classes.index(label_synset)
|
||||
|
||||
writer.write(data)
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,48 @@
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
parser = argparse.ArgumentParser("Script to convert ImageNet trained models to ImageNet-9")
|
||||
parser.add_argument("-m", "--model", type=str, required=True, help="Model weights (.pt file).")
|
||||
parser.add_argument(
|
||||
"--in_to_in9", type=str, default="/ds-sds/images/ImageNet-9/in_to_in9.json", help="Path to in_to_in9.json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
checkpoint = torch.load(args.model, map_location="cpu")
|
||||
|
||||
model_state = checkpoint["model_state"]
|
||||
head_keys = [k for k in model_state.keys() if ".head." in k or ".fc." in k]
|
||||
print("weights that will be modified:", head_keys)
|
||||
assert len(head_keys) > 0, "no head keys found :("
|
||||
|
||||
with open(args.in_to_in9, "r") as f:
|
||||
in_to_in9_classes = json.load(f)
|
||||
print(f"{len([k for k, v in in_to_in9_classes.items() if v == -1])} classes get mapped to -1")
|
||||
|
||||
print("map", len(in_to_in9_classes), " classes to", set(in_to_in9_classes.values()))
|
||||
|
||||
print("Building conversion matrix")
|
||||
conversion_matrix = torch.zeros((9, 1000))
|
||||
for in_idx, in9_idx in in_to_in9_classes.items():
|
||||
if in9_idx == -1:
|
||||
continue
|
||||
in_idx = int(in_idx)
|
||||
conversion_matrix[in9_idx, in_idx] = 1
|
||||
print(f"Conversion matrix ({conversion_matrix.shape}) has {int(conversion_matrix.sum().item())} non-zero values")
|
||||
|
||||
for head_key in head_keys:
|
||||
print(f"converting {head_key} of shape {model_state[head_key].shape}", end=" ")
|
||||
model_state[head_key] = conversion_matrix @ model_state[head_key]
|
||||
print(f"\tto {model_state[head_key].shape}")
|
||||
|
||||
checkpoint["model_state"] = model_state
|
||||
checkpoint["args"]["n_classes"] = 9
|
||||
save_folder = os.path.dirname(args.model)
|
||||
orig_model_name = args.model.split(os.sep)[-1]
|
||||
new_model_name = ".".join(orig_model_name.split(".")[:-1]) + "_to_in9." + orig_model_name.split(".")[-1]
|
||||
print(f"saving model as {new_model_name} in {save_folder}")
|
||||
torch.save(checkpoint, os.path.join(save_folder, new_model_name))
|
||||
@@ -0,0 +1,200 @@
|
||||
n07695742: pretzel
|
||||
n03902125: pay-phone, pay-station
|
||||
n03980874: poncho
|
||||
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
|
||||
n02730930: apron
|
||||
n02699494: altar
|
||||
n03201208: dining table, board
|
||||
n02056570: king penguin, Aptenodytes patagonica
|
||||
n04099969: rocking chair, rocker
|
||||
n04366367: suspension bridge
|
||||
n04067472: reel
|
||||
n02808440: bathtub, bathing tub, bath, tub
|
||||
n04540053: volleyball
|
||||
n02403003: ox
|
||||
n03100240: convertible
|
||||
n04562935: water tower
|
||||
n02788148: bannister, banister, balustrade, balusters, handrail
|
||||
n02988304: CD player
|
||||
n02423022: gazelle
|
||||
n03637318: lampshade, lamp shade
|
||||
n01774384: black widow, Latrodectus mactans
|
||||
n01768244: trilobite
|
||||
n07614500: ice cream, icecream
|
||||
n04254777: sock
|
||||
n02085620: Chihuahua
|
||||
n01443537: goldfish, Carassius auratus
|
||||
n01629819: European fire salamander, Salamandra salamandra
|
||||
n02099601: golden retriever
|
||||
n02321529: sea cucumber, holothurian
|
||||
n03837869: obelisk
|
||||
n02002724: black stork, Ciconia nigra
|
||||
n02841315: binoculars, field glasses, opera glasses
|
||||
n04560804: water jug
|
||||
n02364673: guinea pig, Cavia cobaya
|
||||
n03706229: magnetic compass
|
||||
n09256479: coral reef
|
||||
n09332890: lakeside, lakeshore
|
||||
n03544143: hourglass
|
||||
n02124075: Egyptian cat
|
||||
n02948072: candle, taper, wax light
|
||||
n01950731: sea slug, nudibranch
|
||||
n02791270: barbershop
|
||||
n03179701: desk
|
||||
n02190166: fly
|
||||
n04275548: spider web, spider's web
|
||||
n04417672: thatch, thatched roof
|
||||
n03930313: picket fence, paling
|
||||
n02236044: mantis, mantid
|
||||
n03976657: pole
|
||||
n01774750: tarantula
|
||||
n04376876: syringe
|
||||
n04133789: sandal
|
||||
n02099712: Labrador retriever
|
||||
n04532670: viaduct
|
||||
n04487081: trolleybus, trolley coach, trackless trolley
|
||||
n09428293: seashore, coast, seacoast, sea-coast
|
||||
n03160309: dam, dike, dyke
|
||||
n03250847: drumstick
|
||||
n02843684: birdhouse
|
||||
n07768694: pomegranate
|
||||
n03670208: limousine, limo
|
||||
n03085013: computer keyboard, keypad
|
||||
n02892201: brass, memorial tablet, plaque
|
||||
n02233338: cockroach, roach
|
||||
n03649909: lawn mower, mower
|
||||
n03388043: fountain
|
||||
n02917067: bullet train, bullet
|
||||
n02486410: baboon
|
||||
n04596742: wok
|
||||
n03255030: dumbbell
|
||||
n03937543: pill bottle
|
||||
n02113799: standard poodle
|
||||
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
|
||||
n02906734: broom
|
||||
n07920052: espresso
|
||||
n01698640: American alligator, Alligator mississipiensis
|
||||
n02123394: Persian cat
|
||||
n03424325: gasmask, respirator, gas helmet
|
||||
n02129165: lion, king of beasts, Panthera leo
|
||||
n04008634: projectile, missile
|
||||
n03042490: cliff dwelling
|
||||
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
|
||||
n02815834: beaker
|
||||
n02395406: hog, pig, grunter, squealer, Sus scrofa
|
||||
n01784675: centipede
|
||||
n03126707: crane
|
||||
n04399382: teddy, teddy bear
|
||||
n07875152: potpie
|
||||
n03733131: maypole
|
||||
n02802426: basketball
|
||||
n03891332: parking meter
|
||||
n01910747: jellyfish
|
||||
n03838899: oboe, hautboy, hautbois
|
||||
n03770439: miniskirt, mini
|
||||
n02281406: sulphur butterfly, sulfur butterfly
|
||||
n03970156: plunger, plumber's helper
|
||||
n09246464: cliff, drop, drop-off
|
||||
n02206856: bee
|
||||
n02074367: dugong, Dugong dugon
|
||||
n03584254: iPod
|
||||
n04179913: sewing machine
|
||||
n04328186: stopwatch, stop watch
|
||||
n07583066: guacamole
|
||||
n01917289: brain coral
|
||||
n03447447: gondola
|
||||
n02823428: beer bottle
|
||||
n03854065: organ, pipe organ
|
||||
n02793495: barn
|
||||
n04285008: sports car, sport car
|
||||
n02231487: walking stick, walkingstick, stick insect
|
||||
n04465501: tractor
|
||||
n02814860: beacon, lighthouse, beacon light, pharos
|
||||
n02883205: bow tie, bow-tie, bowtie
|
||||
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
|
||||
n04149813: scoreboard
|
||||
n04023962: punching bag, punch bag, punching ball, punchball
|
||||
n02226429: grasshopper, hopper
|
||||
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
|
||||
n02669723: academic gown, academic robe, judge's robe
|
||||
n04486054: triumphal arch
|
||||
n04070727: refrigerator, icebox
|
||||
n03444034: go-kart
|
||||
n02666196: abacus
|
||||
n01945685: slug
|
||||
n04251144: snorkel
|
||||
n03617480: kimono
|
||||
n03599486: jinrikisha, ricksha, rickshaw
|
||||
n02437312: Arabian camel, dromedary, Camelus dromedarius
|
||||
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
|
||||
n04118538: rugby ball
|
||||
n01770393: scorpion
|
||||
n04356056: sunglasses, dark glasses, shades
|
||||
n03804744: nail
|
||||
n02132136: brown bear, bruin, Ursus arctos
|
||||
n03400231: frying pan, frypan, skillet
|
||||
n03983396: pop bottle, soda bottle
|
||||
n07734744: mushroom
|
||||
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
|
||||
n02410509: bison
|
||||
n03404251: fur coat
|
||||
n04456115: torch
|
||||
n02123045: tabby, tabby cat
|
||||
n03026506: Christmas stocking
|
||||
n07715103: cauliflower
|
||||
n04398044: teapot
|
||||
n02927161: butcher shop, meat market
|
||||
n07749582: lemon
|
||||
n07615774: ice lolly, lolly, lollipop, popsicle
|
||||
n02795169: barrel, cask
|
||||
n04532106: vestment
|
||||
n02837789: bikini, two-piece
|
||||
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
|
||||
n04265275: space heater
|
||||
n02481823: chimpanzee, chimp, Pan troglodytes
|
||||
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
|
||||
n06596364: comic book
|
||||
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
|
||||
n02504458: African elephant, Loxodonta africana
|
||||
n03014705: chest
|
||||
n01944390: snail
|
||||
n04146614: school bus
|
||||
n01641577: bullfrog, Rana catesbeiana
|
||||
n07720875: bell pepper
|
||||
n02999410: chain
|
||||
n01855672: goose
|
||||
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
|
||||
n07753592: banana
|
||||
n07871810: meat loaf, meatloaf
|
||||
n04501370: turnstile
|
||||
n04311004: steel arch bridge
|
||||
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
|
||||
n04074963: remote control, remote
|
||||
n03662601: lifeboat
|
||||
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
|
||||
n03089624: confectionery, confectionary, candy store
|
||||
n04259630: sombrero
|
||||
n03393912: freight car
|
||||
n04597913: wooden spoon
|
||||
n07711569: mashed potato
|
||||
n03355925: flagpole, flagstaff
|
||||
n02963159: cardigan
|
||||
n07579787: plate
|
||||
n02950826: cannon
|
||||
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
|
||||
n02094433: Yorkshire terrier
|
||||
n02909870: bucket, pail
|
||||
n02058221: albatross, mollymawk
|
||||
n01742172: boa constrictor, Constrictor constrictor
|
||||
n09193705: alp
|
||||
n04371430: swimming trunks, bathing trunks
|
||||
n07747607: orange
|
||||
n03814639: neck brace
|
||||
n04507155: umbrella
|
||||
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
|
||||
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
|
||||
n03763968: military uniform
|
||||
n07873807: pizza, pizza pie
|
||||
n03992509: potter's wheel
|
||||
n03796401: moving van
|
||||
n12267677: acorn
|
||||
@@ -0,0 +1,73 @@
|
||||
# Repeat Augment sampler taken from DeiT: https://github.com/facebookresearch/deit/blob/main/samplers.py
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class RASampler(torch.utils.data.Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation.
|
||||
|
||||
It ensures that different each augmented version of a sample will be visible to a
|
||||
different process (GPU)
|
||||
Heavily based on torch.utils.data.DistributedSampler
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
if num_repeats < 1:
|
||||
raise ValueError("num_repeats should be greater than 0")
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.num_repeats = num_repeats
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
|
||||
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(self.dataset), generator=g)
|
||||
else:
|
||||
indices = torch.arange(start=0, end=len(self.dataset))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
|
||||
padding_size: int = self.total_size - len(indices)
|
||||
if padding_size > 0:
|
||||
indices += indices[:padding_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices[: self.num_selected_samples])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_selected_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"{type(self).__name__}(num_replicas: {self.num_replicas}, rank: {self.rank}, num_repeats:"
|
||||
f" {self.num_repeats}, epoch: {self.epoch}, num_samples: {self.num_samples}, total_size: {self.total_size},"
|
||||
f" num_selected_samples: {self.num_selected_samples}, shuffle: {self.shuffle})"
|
||||
)
|
||||
@@ -0,0 +1,381 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from copy import copy
|
||||
|
||||
from loguru import logger
|
||||
from nltk.corpus import wordnet as wn
|
||||
|
||||
|
||||
class bcolors:
|
||||
HEADER = "\033[95m"
|
||||
OKBLUE = "\033[94m"
|
||||
OKCYAN = "\033[96m"
|
||||
OKGREEN = "\033[92m"
|
||||
WARNING = "\033[93m"
|
||||
FAIL = "\033[91m"
|
||||
ENDC = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
UNDERLINE = "\033[4m"
|
||||
|
||||
|
||||
def _lemmas_str(synset):
|
||||
return ", ".join([lemma.name() for lemma in synset.lemmas()])
|
||||
|
||||
|
||||
class WNEntry:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
id: int,
|
||||
lemmas: str,
|
||||
parent_id: int,
|
||||
depth: int = None,
|
||||
in_image_net: bool = False,
|
||||
child_ids: list = None,
|
||||
in_main_tree: bool = True,
|
||||
_n_images: int = 0,
|
||||
_description: str = None,
|
||||
_name: str = None,
|
||||
_pruned: bool = False,
|
||||
):
|
||||
self.name = name
|
||||
self.id = id
|
||||
self.lemmas = lemmas
|
||||
self.parent_id = parent_id
|
||||
self.depth = depth
|
||||
self.in_image_net = in_image_net
|
||||
self.child_ids = child_ids
|
||||
self.in_main_tree = in_main_tree
|
||||
self._n_images = _n_images
|
||||
self._description = _description
|
||||
self._name = _name
|
||||
self._pruned = _pruned
|
||||
|
||||
def __str__(self, tree=None, accumulate=True):
|
||||
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}" if self.in_image_net else f"{bcolors.FAIL}-{bcolors.ENDC}"
|
||||
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
|
||||
if self.child_ids is None or tree is None:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
||||
else:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
||||
["\n ".join(tree.nodes[child_id].__str__(tree).split("\n")) for child_id in self.child_ids]
|
||||
)
|
||||
|
||||
def tree_diff(self, tree_1, tree_2):
|
||||
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
|
||||
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
|
||||
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
|
||||
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
|
||||
else:
|
||||
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
|
||||
n_ims = (
|
||||
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
|
||||
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
|
||||
)
|
||||
|
||||
if self.child_ids is None:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
||||
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
||||
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
|
||||
)
|
||||
|
||||
def prune(self, tree):
|
||||
if self._pruned:
|
||||
return
|
||||
|
||||
if self.child_ids is not None:
|
||||
for child_id in self.child_ids:
|
||||
tree[child_id].prune(tree)
|
||||
|
||||
self._pruned = True
|
||||
parent_node = tree.nodes[self.parent_id]
|
||||
try:
|
||||
parent_node.child_ids.remove(self.id)
|
||||
except ValueError as e:
|
||||
print(
|
||||
f"Error removing {self.name} from"
|
||||
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
|
||||
)
|
||||
while parent_node._pruned:
|
||||
parent_node = tree.nodes[parent_node.parent_id]
|
||||
parent_node._n_images += self._n_images
|
||||
self._n_images = 0
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
if not self._description:
|
||||
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def print_name(self):
|
||||
return self.name.split(".")[0]
|
||||
|
||||
def get_branch(self, tree=None):
|
||||
if self.parent_id is None or tree is None:
|
||||
return self.print_name
|
||||
|
||||
parent = tree.nodes[self.parent_id]
|
||||
return parent.get_branch(tree) + " > " + self.print_name
|
||||
|
||||
def get_branch_list(self, tree):
|
||||
if self.parent_id is None:
|
||||
return [self]
|
||||
parent = tree.nodes[self.parent_id]
|
||||
return parent.get_branch_list(tree) + [self]
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"id": self.id,
|
||||
"lemmas": self.lemmas,
|
||||
"parent_id": self.parent_id,
|
||||
"depth": self.depth,
|
||||
"in_image_net": self.in_image_net,
|
||||
"child_ids": self.child_ids,
|
||||
"in_main_tree": self.in_main_tree,
|
||||
"_n_images": self._n_images,
|
||||
"_description": self._description,
|
||||
"_name": self._name,
|
||||
"_pruned": self._pruned,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(**d)
|
||||
|
||||
def n_images(self, tree=None):
|
||||
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
|
||||
return self._n_images
|
||||
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
|
||||
|
||||
def n_children(self, tree=None):
|
||||
if self.child_ids is None:
|
||||
return 0
|
||||
if tree is None or len(self.child_ids) == 0:
|
||||
return len(self.child_ids)
|
||||
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
|
||||
|
||||
def get_examples(self, tree, n_examples=3):
|
||||
if self.child_ids is None or len(self.child_ids) == 0:
|
||||
return ""
|
||||
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
|
||||
max_images = max(child_images.values())
|
||||
if max_images == 0:
|
||||
# go on number of child nodes
|
||||
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
|
||||
# sorted childids by number of images
|
||||
top_children = [
|
||||
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
top_children = top_children[: min(n_examples, len(top_children))]
|
||||
return ", ".join(
|
||||
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
|
||||
)
|
||||
|
||||
|
||||
class WNTree:
|
||||
def __init__(self, root=1740, nodes=None):
|
||||
if isinstance(root, int):
|
||||
root_id = root
|
||||
root_synset = wn.synset_from_pos_and_offset("n", root)
|
||||
root_node = WNEntry(
|
||||
root_synset.name(),
|
||||
root_id,
|
||||
_lemmas_str(root_synset),
|
||||
parent_id=None,
|
||||
depth=0,
|
||||
)
|
||||
else:
|
||||
assert isinstance(root, WNEntry)
|
||||
root_id = root.id
|
||||
root_node = root
|
||||
|
||||
self.root = root_node
|
||||
self.nodes = {root_id: self.root} if nodes is None else nodes
|
||||
self.parentless = []
|
||||
self.label_index = None
|
||||
self.pruned = set()
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"root": self.root.to_dict(),
|
||||
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
||||
"parentless": self.parentless,
|
||||
"pruned": list(self.pruned),
|
||||
}
|
||||
|
||||
def prune(self, min_images):
|
||||
pruned_nodes = set()
|
||||
|
||||
# prune all nodes that have fewer than min_images below them
|
||||
for node_id, node in self.nodes.items():
|
||||
if node.n_images(self) < min_images:
|
||||
pruned_nodes.add(node_id)
|
||||
node.prune(self)
|
||||
|
||||
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
|
||||
node_stack = [self.root]
|
||||
node_idx = 0
|
||||
while node_idx < len(node_stack):
|
||||
node = node_stack[node_idx]
|
||||
if node.child_ids is not None:
|
||||
for child_id in node.child_ids:
|
||||
child = self.nodes[child_id]
|
||||
node_stack.append(child)
|
||||
node_idx += 1
|
||||
|
||||
# now prune the stack from the bottom up
|
||||
for node in node_stack[::-1]:
|
||||
# only look at images of that class, not of additional children
|
||||
if node.n_images() < min_images:
|
||||
pruned_nodes.add(node.id)
|
||||
node.prune(self)
|
||||
|
||||
self.pruned = pruned_nodes
|
||||
return pruned_nodes
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
tree = cls()
|
||||
tree.root = WNEntry.from_dict(d["root"])
|
||||
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
|
||||
tree.parentless = d["parentless"]
|
||||
if "pruned" in d:
|
||||
tree.pruned = set(d["pruned"])
|
||||
return tree
|
||||
|
||||
def add_node(self, node_id, in_in=True):
|
||||
if node_id in self.nodes:
|
||||
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
|
||||
return
|
||||
|
||||
synset = wn.synset_from_pos_and_offset("n", node_id)
|
||||
|
||||
# print(f"adding node {synset.name()} with id {node_id}")
|
||||
|
||||
hypernyms = synset.hypernyms()
|
||||
if len(hypernyms) == 0:
|
||||
parent_id = None
|
||||
self.parentless.append(node_id)
|
||||
main_tree = False
|
||||
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
|
||||
else:
|
||||
parent_id = synset.hypernyms()[0].offset()
|
||||
if parent_id not in self.nodes:
|
||||
self.add_node(parent_id, in_in=False)
|
||||
parent = self.nodes[parent_id]
|
||||
|
||||
if parent.child_ids is None:
|
||||
parent.child_ids = []
|
||||
parent.child_ids.append(node_id)
|
||||
main_tree = parent.in_main_tree
|
||||
|
||||
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
|
||||
node = WNEntry(
|
||||
synset.name(),
|
||||
node_id,
|
||||
_lemmas_str(synset),
|
||||
parent_id=parent_id,
|
||||
in_image_net=in_in,
|
||||
depth=depth,
|
||||
in_main_tree=main_tree,
|
||||
)
|
||||
|
||||
self.nodes[node_id] = node
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes)
|
||||
|
||||
def image_net_len(self, only_main_tree=False):
|
||||
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
||||
|
||||
def max_depth(self, only_main_tree=False):
|
||||
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
|
||||
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self)}\nParentless:\n"
|
||||
+ "\n".join([self.nodes[node_id].__str__(tree=self) for node_id in self.parentless])
|
||||
)
|
||||
|
||||
def save(self, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(self.to_dict(), f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
with open(path, "r") as f:
|
||||
tree_dict = json.load(f)
|
||||
return cls.from_dict(tree_dict)
|
||||
|
||||
def subtree(self, node_id):
|
||||
if node_id not in self.nodes:
|
||||
return None
|
||||
node_queue = [self.nodes[node_id]]
|
||||
subtree_ids = set()
|
||||
while len(node_queue) > 0:
|
||||
node = node_queue.pop(0)
|
||||
subtree_ids.add(node.id)
|
||||
if node.child_ids is not None:
|
||||
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
|
||||
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
|
||||
subtree_root = subtree_nodes[node_id]
|
||||
subtree_root.parent_id = None
|
||||
depth_diff = subtree_root.depth
|
||||
for node in subtree_nodes.values():
|
||||
node.depth -= depth_diff
|
||||
return WNTree(root=subtree_root, nodes=subtree_nodes)
|
||||
|
||||
def _make_label_index(self, include_merged=False):
|
||||
self.label_index = sorted(
|
||||
[
|
||||
node_id
|
||||
for node_id, node in self.nodes.items()
|
||||
if node.n_images(self if include_merged else None) > 0 and not node._pruned
|
||||
]
|
||||
)
|
||||
|
||||
def get_label(self, node_id):
|
||||
if self.label_index is None:
|
||||
self._make_label_index()
|
||||
while self.nodes[node_id]._pruned:
|
||||
node_id = self.nodes[node_id].parent_id
|
||||
return self.label_index.index(node_id)
|
||||
|
||||
def n_labels(self):
|
||||
if self.label_index is None:
|
||||
self._make_label_index()
|
||||
return len(self.label_index)
|
||||
|
||||
def __contains__(self, item):
|
||||
if isinstance(item, str):
|
||||
if item[0] == "n":
|
||||
item = int(item[1:])
|
||||
else:
|
||||
return False
|
||||
if isinstance(item, int):
|
||||
return item in self.nodes
|
||||
if isinstance(item, WNEntry):
|
||||
return item.id in self.nodes
|
||||
return False
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, str) and item[0].startswith("n"):
|
||||
try:
|
||||
item = int(item[1:])
|
||||
except ValueError:
|
||||
pass
|
||||
if isinstance(item, str) and ".n." in item:
|
||||
for node in self.nodes.values():
|
||||
if item == node.name:
|
||||
return node
|
||||
raise KeyError(f"Item {item} not found in tree")
|
||||
if isinstance(item, int):
|
||||
return self.nodes[item]
|
||||
if isinstance(item, WNEntry):
|
||||
return self.nodes[item.id]
|
||||
raise KeyError(f"Item {item} not found in tree")
|
||||
Reference in New Issue
Block a user