AAAI Version
This commit is contained in:
@@ -0,0 +1,407 @@
|
||||
from random import uniform
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import torchvision
|
||||
from datadings.torch import CompressedToPIL
|
||||
from datadings.torch import Dataset as DDDataset
|
||||
from PIL import ImageFilter
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
ColorJitter,
|
||||
Compose,
|
||||
GaussianBlur,
|
||||
Grayscale,
|
||||
InterpolationMode,
|
||||
Normalize,
|
||||
RandomChoice,
|
||||
RandomCrop,
|
||||
RandomHorizontalFlip,
|
||||
RandomResizedCrop,
|
||||
RandomSolarize,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
_image_and_target_transforms = [
|
||||
torchvision.transforms.RandomCrop,
|
||||
torchvision.transforms.RandomHorizontalFlip,
|
||||
torchvision.transforms.CenterCrop,
|
||||
torchvision.transforms.RandomRotation,
|
||||
torchvision.transforms.RandomAffine,
|
||||
torchvision.transforms.RandomResizedCrop,
|
||||
torchvision.transforms.RandomRotation,
|
||||
]
|
||||
|
||||
|
||||
def apply_dense_transforms(x, y, transforms: torchvision.transforms.transforms.Compose):
|
||||
"""Apply some transfomations to both image and target.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): image
|
||||
y (torch.Tensor): target (image)
|
||||
transforms (torchvision.transforms.transforms.Compose): transformations to apply
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: (x, y) with applyed transformations
|
||||
|
||||
"""
|
||||
for trans in transforms.transforms:
|
||||
if isinstance(trans, torchvision.transforms.RandomResizedCrop):
|
||||
params = trans.get_params(x, trans.scale, trans.ratio)
|
||||
x = F.resized_crop(x, *params, trans.size, trans.interpolation, antialias=trans.antialias)
|
||||
y = F.resized_crop(y.unsqueeze(0), *params, trans.size, 0).squeeze(0) # nearest neighbor interpolation
|
||||
elif isinstance(trans, Resize):
|
||||
pre_shape = x.shape
|
||||
x = trans(x)
|
||||
if x.shape != pre_shape:
|
||||
y = F.resize(y.unsqueeze(0), trans.size, 0, trans.max_size, trans.antialias).squeeze(
|
||||
0
|
||||
) # nearest neighbor interpolation
|
||||
elif any(isinstance(trans, simul_transform) for simul_transform in _image_and_target_transforms):
|
||||
xy = torch.cat([x, y.unsqueeze(0).float()], dim=0)
|
||||
xy = trans(xy)
|
||||
x, y = xy[:-1], xy[-1].long()
|
||||
elif isinstance(trans, torchvision.transforms.ToTensor):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = trans(x)
|
||||
else:
|
||||
x = trans(x)
|
||||
|
||||
return x, y
|
||||
|
||||
|
||||
def get_hf_transform(transform_f, trgt_transform_f=None, image_key="image"):
|
||||
"""Convert the transform function to a huggingface compatible transform function.
|
||||
|
||||
Args:
|
||||
transform_f (callable): Image transform.
|
||||
trgt_transform (callable, optional): Target transform. Defaults to None.
|
||||
image_key (str, optional): Key for the image in the hf ds return dict. Defaults to "image".
|
||||
"""
|
||||
|
||||
def _transform(samples):
|
||||
try:
|
||||
samples[image_key] = [transform_f(im) for im in samples[image_key]]
|
||||
if trgt_transform_f is not None:
|
||||
samples["label"] = [trgt_transform_f(tgt) for tgt in samples["label"]]
|
||||
except TypeError as e:
|
||||
print(f"Type error when transforming samples: {samples}")
|
||||
raise e
|
||||
return samples
|
||||
|
||||
return _transform
|
||||
|
||||
|
||||
class DDDecodeDataset(DDDataset):
|
||||
"""Datadings dataset with image decoding before transform."""
|
||||
|
||||
def __init__(self, *args, transform=None, target_transform=None, transforms=None, **kwargs):
|
||||
"""Create datadings dataset.
|
||||
|
||||
Args:
|
||||
transform (callable, optional): Image transform. Overrides transforms['image']. Defaults to None.
|
||||
target_transform (callable, optional): Label transform. Overrides transforms['label']. Defaults to None.
|
||||
transforms (dict[str, callable], optional): Dict of transforms for each key. Defaults to None.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
if transforms is None:
|
||||
transforms = {}
|
||||
self._decode_transform = transform if transform is not None else transforms.get("image", None)
|
||||
self._decode_target_transform = (
|
||||
target_transform if target_transform is not None else transforms.get("label", None)
|
||||
)
|
||||
|
||||
self.ctp = CompressedToPIL()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = super().__getitem__(idx)
|
||||
img, lbl = sample["image"], sample["label"]
|
||||
if isinstance(img, bytes):
|
||||
img = self.ctp(img)
|
||||
if self._decode_transform is not None:
|
||||
img = self._decode_transform(img)
|
||||
if self._decode_target_transform is not None:
|
||||
lbl = self._decode_target_transform(lbl)
|
||||
return img, lbl
|
||||
|
||||
|
||||
def minimal_augment(args, test=False):
|
||||
"""Minimal Augmentation set for images.
|
||||
|
||||
Contains only resize, crop, flip, to tensor and normalize.
|
||||
|
||||
Args:
|
||||
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
|
||||
test (bool, optional): On the test set? Defaults to False.
|
||||
|
||||
Returns:
|
||||
List: Augmentation list
|
||||
|
||||
"""
|
||||
augs = []
|
||||
augs.append(ToTensor())
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
def three_augment(args, as_list=False, test=False):
|
||||
"""Create the data augmentation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Args:
|
||||
arguments
|
||||
as_list : bool
|
||||
return list of transformations, not composed transformation
|
||||
test : bool
|
||||
In eval mode? If False => train mode
|
||||
|
||||
Returns:
|
||||
-------
|
||||
torch.nn.Module | list[torch.nn.Module]
|
||||
composed transformation of list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
augs.append(ToTensor())
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
|
||||
|
||||
if not test:
|
||||
if args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
augs_choice = []
|
||||
if args.aug_grayscale:
|
||||
augs_choice.append(Grayscale(num_output_channels=3))
|
||||
if args.aug_solarize:
|
||||
augs_choice.append(RandomSolarize(threshold=0.5, p=1.0))
|
||||
if args.aug_gauss_blur:
|
||||
# TODO: check kernel size?
|
||||
augs_choice.append(GaussianBlur(kernel_size=7, sigma=(0.2, 2.0)))
|
||||
# augs_choice.append(QuickGaussBlur())
|
||||
|
||||
if len(augs_choice) > 0:
|
||||
augs.append(RandomChoice(augs_choice))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
ColorJitter(
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
if as_list:
|
||||
return augs
|
||||
return Compose(augs)
|
||||
|
||||
|
||||
def segment_augment(args, test=False):
|
||||
"""Create the data augmentation for segmentation.
|
||||
|
||||
No cropping in this part, as cropping has to be done for the image and labels simultaneously.
|
||||
|
||||
Args:
|
||||
args (DotDict): arguments
|
||||
test (bool, optional): In eval mode? If False => train mode. Defaults to False.
|
||||
|
||||
Returns:
|
||||
list[torch.nn.Module]: list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
|
||||
if test:
|
||||
augs.append(ResizeUp(args.imsize))
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
else:
|
||||
augs.append(RandomResizedCrop(args.imsize, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
ColorJitter(
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
class QuickGaussBlur:
|
||||
"""Gaussian blur transformation using PIL ImageFilter."""
|
||||
|
||||
def __init__(self, sigma=(0.2, 2.0)):
|
||||
"""Create Gaussian blur operator.
|
||||
|
||||
Args:
|
||||
-----
|
||||
sigma : tuple[float, float]
|
||||
range of sigma for blur
|
||||
|
||||
"""
|
||||
self.sigma_min, self.sigma_max = sigma
|
||||
|
||||
def __call__(self, img):
|
||||
return img.filter(ImageFilter.GaussianBlur(radius=uniform(self.sigma_min, self.sigma_max)))
|
||||
|
||||
|
||||
class RemoveTransform:
|
||||
"""Remove data from transformation.
|
||||
|
||||
To use with default collate function.
|
||||
"""
|
||||
|
||||
def __call__(self, x, y=None):
|
||||
if y is None:
|
||||
return [1]
|
||||
return [1], y
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
def collate_imnet(data, image_key="image"):
|
||||
"""Collate function for imagenet(1k / 21k) with datadings.
|
||||
|
||||
Args:
|
||||
----
|
||||
data : list[dict[str, Any]]
|
||||
images for a batch
|
||||
|
||||
Returns:
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
images, labels
|
||||
|
||||
"""
|
||||
if isinstance(data[0][image_key], torch.Tensor):
|
||||
ims = torch.stack([d[image_key] for d in data], dim=0)
|
||||
else:
|
||||
ims = [d[image_key] for d in data]
|
||||
labels = torch.tensor([d["label"] for d in data])
|
||||
# keys = [d['key'] for d in data]
|
||||
return ims, labels # , keys
|
||||
|
||||
|
||||
def collate_listops(data):
|
||||
"""Collate function for ListOps with datadings.
|
||||
|
||||
Args:
|
||||
----
|
||||
data : list[tuple[torch.Tensor, torch.Tensor]]
|
||||
list of samples
|
||||
|
||||
Returns:
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
images, labels
|
||||
|
||||
"""
|
||||
return data[0][0], data[0][1]
|
||||
|
||||
|
||||
def no_param_transf(self, sample):
|
||||
"""Call transformation without extra parameter.
|
||||
|
||||
To use with datadings QuasiShuffler.
|
||||
|
||||
Args:
|
||||
----
|
||||
self : object
|
||||
use this as a method ( <obj>.<method_name> = MethodType(no_param_transf, <obj>) )
|
||||
sample : Any
|
||||
sample to transform
|
||||
|
||||
Returns:
|
||||
-------
|
||||
Any
|
||||
transformed sample
|
||||
|
||||
"""
|
||||
if isinstance(sample, tuple):
|
||||
# sample of type (name (str), data (bytes encoded))
|
||||
sample = sample[1]
|
||||
if isinstance(sample, bytes):
|
||||
# decode msgpack bytes
|
||||
sample = msgpack.loads(sample)
|
||||
params = self._rng(sample)
|
||||
for k, f in self._transforms.items():
|
||||
sample[k] = f(sample[k], params)
|
||||
return sample
|
||||
|
||||
|
||||
class ToOneHotSequence:
|
||||
"""Convert a sequence of grayscale values (range 0 to 1) to a one-hot encoded sequence based on 8-bit rounded values."""
|
||||
|
||||
def __call__(self, x, y=None):
|
||||
# x is 1 x 32 x 32
|
||||
x = (255 * x).round().to(torch.int64).view(-1)
|
||||
assert x.max() < 256, f"Found max value {x.max()} in {x}."
|
||||
x = torch.nn.functional.one_hot(x, num_classes=256).float()
|
||||
if y is None:
|
||||
return x
|
||||
return x, y
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
class ResizeUp(Resize):
|
||||
"""Resize up if image is smaller than target size."""
|
||||
|
||||
def forward(self, img):
|
||||
w, h = img.shape[-2], img.shape[-1]
|
||||
if w < self.size or h < self.size:
|
||||
img = super().forward(img)
|
||||
return img
|
||||
Reference in New Issue
Block a user