"""Module to load the datasets, using torch and datadings.""" import contextlib import os from functools import partial import torchvision.transforms as tv_transforms from datadings.reader import MsgpackReader from timm.data import create_transform from torch.utils.data import DataLoader, DistributedSampler, WeightedRandomSampler from torchvision.datasets import ( CIFAR10, CIFAR100, FGVCAircraft, Flowers102, Food101, ImageFolder, OxfordIIITPet, StanfordCars, ) from data.counter_animal import CounterAnimal from data.data_utils import ( DDDecodeDataset, ToOneHotSequence, collate_imnet, collate_listops, get_hf_transform, minimal_augment, segment_augment, three_augment, ) from data.fornet import ForNet from data.samplers import RASampler from paths_config import ds_path def prepare_dataset(dataset_name, args, transform=None, train=True, rank=None): """Load a dataset from disk, different formats are used for different datasets. Supported datasets: CIFAR10, ImageNet, ImageNet21k Args: dataset_name (str): name of the dataset args: further arguments transform (list[Module] | str, optional): transformations to use on the data; the list gets composed, or give args.augment_strategy (Default value = None) train (bool, optional): use the training split (or test/validation split) (Default value = True) rank (int, optional): global rank of this process in distributed training (Default value = None) Returns: DataLoader: data loader for the dataset int: number of classes in the dataset int: ignore index for the dataset bool: whether the dataset is multi-label """ compose = tv_transforms.Compose dali_server = None if transform is None: if args.augment_engine == "torchvision": if args.augment_strategy == "3-augment": transform = three_augment(args, as_list=False, test=not train) elif args.augment_strategy == "differentiable-transform": from data.distilled_dataset import differentiable_augment transform = differentiable_augment(args, as_list=False, test=not train) elif args.augment_strategy == "none": transform = [] elif args.augment_strategy == "lm_one_hot": transform = [ tv_transforms.Grayscale(num_output_channels=1), tv_transforms.ToTensor(), ToOneHotSequence(), ] elif args.augment_strategy == "segment-augment": transform = segment_augment(args, test=not train) elif args.augment_strategy == "minimal": transform = minimal_augment(args, test=not train) elif args.augment_strategy == "deit": if train: transform = create_transform( input_size=args.imsize, is_training=True, color_jitter=args.aug_color_jitter_factor, auto_augment=args.auto_augment_strategy, interpolation="bicubic", re_prob=args.aug_random_erase_prob, re_mode=args.aug_random_erase_mode, re_count=args.aug_random_erase_count, ) else: transform = three_augment(args, test=True) # only do resize, centercrop, and normalize else: raise NotImplementedError( f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)." ) elif args.augment_engine == "albumentations": from data import album_transf as ATf compose = ATf.AlbumTorchCompose if args.augment_strategy == "3-augment": transform = ATf.three_augment(args, as_list=False, test=not train) elif args.augment_strategy == "minimal": transform = ATf.minimal_augment(args, test=not train) else: raise NotImplementedError( f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)." ) elif args.augment_engine == "dali": from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy from data import dali_transf as DTf dev_id = int(os.environ.get("LOCAL_RANK", 0)) if args.augment_strategy == "3-augment": pipe = DTf.three_augment( args, test=not train, batch_size=args.batch_size, num_threads=args.num_workers, device_id=dev_id, ) elif args.augment_strategy == "minimal": pipe = DTf.minimal_augment( args, test=not train, batch_size=args.batch_size, num_threads=args.num_workers, device_id=dev_id, ) else: raise NotImplementedError( f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)." ) dali_server = dali_proxy.DALIServer(pipe) transform = dali_server.proxy dataset_name_case_sensitive = dataset_name # keep the original name for AnimalNet folder dataset_name = dataset_name.lower() ignore_index = -100 multi_label = False if isinstance(transform, list): transform = compose(transform) if dataset_name == "cifar10": dataset = CIFAR10(root=ds_path("cifar"), train=train, download=False, transform=transform) n_classes, collate = 10, None elif dataset_name == "stanford-cars": dataset = StanfordCars( root=ds_path("stanford_cars"), split="train" if train else "test", download=False, transform=transform, ) n_classes, collate = 196, None elif dataset_name == "oxford-pet": dataset = OxfordIIITPet( root=ds_path("oxford_pet"), split="trainval" if train else "test", download=False, transform=transform, ) n_classes, collate = 37, None elif dataset_name == "flowers102": dataset = Flowers102( root=ds_path("flowers102"), split="train" if train else "test", download=False, transform=transform, ) n_classes, collate = 102, None elif dataset_name == "food-101": dataset = Food101( root=ds_path("food101"), split="train" if train else "test", download=False, transform=transform, ) n_classes, collate = 101, None elif dataset_name == "fgvc-aircraft": dataset = FGVCAircraft( root=ds_path("aircraft"), split="train" if train else "test", annotation_level="variant", download=False, transform=transform, ) n_classes, collate = 100, None elif dataset_name == "imagenet": dataset = ImageFolder(os.path.join(ds_path("imagenet1k"), "train" if train else "val"), transform=transform) n_classes, collate = 1000, None elif dataset_name == "tinyimagenet": dataset = ImageFolder(os.path.join(ds_path("tinyimagenet"), "train" if train else "val"), transform=transform) n_classes, collate = 200, None elif dataset_name.startswith("fornet"): ds_def = dataset_name.split("/") comb_scheme = ds_def[1] if len(ds_def) > 1 else "same" pruning_ratio = 0.8 if len(ds_def) < 3 else float(ds_def[2]) fg_size_mode = ("range" if train else "max") if len(ds_def) < 4 else ds_def[3] paste_pre_transform = True if len(ds_def) < 5 else ds_def[4] in ["y", "t"] orig_img_prob = ( 0.0 if len(ds_def) < 6 else (ds_def[5] if ds_def[5] in ["linear", "revlinear", "cos"] else float(ds_def[5])) ) mask_smoothing_sigma = 0.0 if len(ds_def) < 7 else float(ds_def[6]) assert len(ds_def) < 5 or ds_def[4] in [ "y", "t", "n", "f", ], f"Invalid dataset definition: {ds_def[4]}; paste pre transform must be 'y'/'t' or 'n'/'f'" orig_ds = ds_path("imagenet1k") dataset = ForNet( ds_path("fornet"), train=train, background_combination=comb_scheme, pruning_ratio=pruning_ratio, transform=transform, fg_transform=( None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True) ), fg_size_mode=fg_size_mode, paste_pre_transform=paste_pre_transform, orig_img_prob=orig_img_prob, orig_ds=orig_ds, mask_smoothing_sigma=mask_smoothing_sigma, epochs=args.epochs, _album_compose=args.augment_engine == "albumentations", ) n_classes, collate = 1000, None elif dataset_name.startswith("tinyfornet"): ds_def = dataset_name.split("/") comb_scheme = ds_def[1] if len(ds_def) > 1 else "same" pruning_ratio = 1.1 if len(ds_def) < 3 else float(ds_def[2]) fg_size_mode = "range" if len(ds_def) < 4 else ds_def[3] fg_bates_n = 1 if len(ds_def) < 5 else int(ds_def[4]) paste_pre_transform = False if len(ds_def) < 6 else ds_def[5] in ["y", "t"] orig_img_prob = ( 0.0 if len(ds_def) < 7 else (ds_def[6] if ds_def[6] in ["linear", "revlinear", "cos"] else float(ds_def[6])) ) mask_smoothing_sigma = 0.0 if len(ds_def) < 8 else float(ds_def[7]) assert len(ds_def) < 6 or ds_def[5] in [ "y", "t", "n", "f", ], f"Invalid dataset definition: {ds_def[5]}; paste pre transform must be 'y'/'t' or 'n'/'f'" assert "-" not in ds_def[0] or len(ds_def[0].split("-")) == 3, f"Invalid dataset definition: {ds_def[0]}" version = "" if "-" not in ds_def[0] else f"_v{ds_def[0].split('-')[1]}_f{ds_def[0].split('-')[2]}" orig_ds = ds_path("tinyimagenet") dataset = ForNet( f"{ds_path('tinyimagenet')}{version}", train=train, background_combination=comb_scheme, pruning_ratio=pruning_ratio, transform=transform, fg_transform=( None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True) ), fg_size_mode=fg_size_mode, fg_bates_n=fg_bates_n, paste_pre_transform=paste_pre_transform, orig_img_prob=orig_img_prob, orig_ds=orig_ds, mask_smoothing_sigma=mask_smoothing_sigma, epochs=args.epochs, _album_compose=args.augment_engine == "albumentations", ) n_classes, collate = 200, None elif dataset_name.startswith("counteranimal/"): mode = dataset_name.split("/")[1] dataset = CounterAnimal(ds_path("counteranimal"), mode=mode, transform=transform, train=train) n_classes, collate = 1000, None elif dataset_name.startswith("imagenet9/"): variant = dataset_name.split("/")[1] assert variant in [ "next", "same", "rand", ], f"ImageNet-9 has possible variants next, same, and rand, but not '{variant}'." dataset = ImageFolder(os.path.join(ds_path("imagenet9"), f"mixed_{variant}", "val"), transform=transform) n_classes, collate = 9, None else: raise NotImplementedError(f"Dataset {dataset_name} is not implemented (yet).") if args.aug_repeated_augment_repeats > 1 and train: # use repeated augment sampler from DeiT sampler = RASampler( dataset, num_replicas=args.world_size, rank=rank, shuffle=args.shuffle, num_repeats=args.aug_repeated_augment_repeats, ) elif args.weighted_sampler: assert hasattr( dataset, "per_sample_weights" ), f"Dataset {type(dataset)} should implement per_sample_weights function, but does not." sampler = WeightedRandomSampler(dataset.per_sample_weights(), num_samples=len(dataset) // args.world_size) elif args.distributed: sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank, shuffle=train and args.shuffle) else: sampler = None loader_batch_size = 1 if dataset_name.startswith("listops") else args.batch_size loader_kwargs = dict( batch_size=loader_batch_size, pin_memory=args.pin_memory, num_workers=args.num_workers, drop_last=train, prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, persistent_workers=False, collate_fn=collate, shuffle=None if sampler else train and args.shuffle, sampler=sampler, ) if args.augment_engine == "dali": data_loader = dali_proxy.DataLoader(dali_server, dataset, **loader_kwargs) else: data_loader = DataLoader(dataset, **loader_kwargs) return data_loader, n_classes, ignore_index, multi_label, dali_server