347 lines
13 KiB
Python
347 lines
13 KiB
Python
"""Module to load the datasets, using torch and datadings."""
|
|
|
|
import contextlib
|
|
import os
|
|
from functools import partial
|
|
|
|
import torchvision.transforms as tv_transforms
|
|
from datadings.reader import MsgpackReader
|
|
from timm.data import create_transform
|
|
from torch.utils.data import DataLoader, DistributedSampler, WeightedRandomSampler
|
|
from torchvision.datasets import (
|
|
CIFAR10,
|
|
CIFAR100,
|
|
FGVCAircraft,
|
|
Flowers102,
|
|
Food101,
|
|
ImageFolder,
|
|
OxfordIIITPet,
|
|
StanfordCars,
|
|
)
|
|
|
|
from data.counter_animal import CounterAnimal
|
|
from data.data_utils import (
|
|
DDDecodeDataset,
|
|
ToOneHotSequence,
|
|
collate_imnet,
|
|
collate_listops,
|
|
get_hf_transform,
|
|
minimal_augment,
|
|
segment_augment,
|
|
three_augment,
|
|
)
|
|
from data.fornet import ForNet
|
|
from data.samplers import RASampler
|
|
from paths_config import ds_path
|
|
|
|
|
|
def prepare_dataset(dataset_name, args, transform=None, train=True, rank=None):
|
|
"""Load a dataset from disk, different formats are used for different datasets.
|
|
|
|
Supported datasets: CIFAR10, ImageNet, ImageNet21k
|
|
|
|
Args:
|
|
dataset_name (str): name of the dataset
|
|
args: further arguments
|
|
transform (list[Module] | str, optional): transformations to use on the data; the list gets composed, or give args.augment_strategy (Default value = None)
|
|
train (bool, optional): use the training split (or test/validation split) (Default value = True)
|
|
rank (int, optional): global rank of this process in distributed training (Default value = None)
|
|
|
|
Returns:
|
|
DataLoader: data loader for the dataset
|
|
int: number of classes in the dataset
|
|
int: ignore index for the dataset
|
|
bool: whether the dataset is multi-label
|
|
|
|
"""
|
|
compose = tv_transforms.Compose
|
|
dali_server = None
|
|
if transform is None:
|
|
if args.augment_engine == "torchvision":
|
|
if args.augment_strategy == "3-augment":
|
|
transform = three_augment(args, as_list=False, test=not train)
|
|
elif args.augment_strategy == "differentiable-transform":
|
|
from data.distilled_dataset import differentiable_augment
|
|
|
|
transform = differentiable_augment(args, as_list=False, test=not train)
|
|
elif args.augment_strategy == "none":
|
|
transform = []
|
|
elif args.augment_strategy == "lm_one_hot":
|
|
transform = [
|
|
tv_transforms.Grayscale(num_output_channels=1),
|
|
tv_transforms.ToTensor(),
|
|
ToOneHotSequence(),
|
|
]
|
|
elif args.augment_strategy == "segment-augment":
|
|
transform = segment_augment(args, test=not train)
|
|
elif args.augment_strategy == "minimal":
|
|
transform = minimal_augment(args, test=not train)
|
|
elif args.augment_strategy == "deit":
|
|
if train:
|
|
transform = create_transform(
|
|
input_size=args.imsize,
|
|
is_training=True,
|
|
color_jitter=args.aug_color_jitter_factor,
|
|
auto_augment=args.auto_augment_strategy,
|
|
interpolation="bicubic",
|
|
re_prob=args.aug_random_erase_prob,
|
|
re_mode=args.aug_random_erase_mode,
|
|
re_count=args.aug_random_erase_count,
|
|
)
|
|
else:
|
|
transform = three_augment(args, test=True) # only do resize, centercrop, and normalize
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
|
)
|
|
elif args.augment_engine == "albumentations":
|
|
from data import album_transf as ATf
|
|
|
|
compose = ATf.AlbumTorchCompose
|
|
|
|
if args.augment_strategy == "3-augment":
|
|
transform = ATf.three_augment(args, as_list=False, test=not train)
|
|
elif args.augment_strategy == "minimal":
|
|
transform = ATf.minimal_augment(args, test=not train)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
|
)
|
|
elif args.augment_engine == "dali":
|
|
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
|
|
|
|
from data import dali_transf as DTf
|
|
|
|
dev_id = int(os.environ.get("LOCAL_RANK", 0))
|
|
|
|
if args.augment_strategy == "3-augment":
|
|
pipe = DTf.three_augment(
|
|
args,
|
|
test=not train,
|
|
batch_size=args.batch_size,
|
|
num_threads=args.num_workers,
|
|
device_id=dev_id,
|
|
)
|
|
elif args.augment_strategy == "minimal":
|
|
pipe = DTf.minimal_augment(
|
|
args,
|
|
test=not train,
|
|
batch_size=args.batch_size,
|
|
num_threads=args.num_workers,
|
|
device_id=dev_id,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
|
)
|
|
|
|
dali_server = dali_proxy.DALIServer(pipe)
|
|
transform = dali_server.proxy
|
|
|
|
dataset_name_case_sensitive = dataset_name # keep the original name for AnimalNet folder
|
|
dataset_name = dataset_name.lower()
|
|
ignore_index = -100
|
|
multi_label = False
|
|
|
|
if isinstance(transform, list):
|
|
transform = compose(transform)
|
|
|
|
if dataset_name == "cifar10":
|
|
dataset = CIFAR10(root=ds_path("cifar"), train=train, download=False, transform=transform)
|
|
n_classes, collate = 10, None
|
|
|
|
elif dataset_name == "stanford-cars":
|
|
dataset = StanfordCars(
|
|
root=ds_path("stanford_cars"),
|
|
split="train" if train else "test",
|
|
download=False,
|
|
transform=transform,
|
|
)
|
|
n_classes, collate = 196, None
|
|
|
|
elif dataset_name == "oxford-pet":
|
|
dataset = OxfordIIITPet(
|
|
root=ds_path("oxford_pet"),
|
|
split="trainval" if train else "test",
|
|
download=False,
|
|
transform=transform,
|
|
)
|
|
n_classes, collate = 37, None
|
|
|
|
elif dataset_name == "flowers102":
|
|
dataset = Flowers102(
|
|
root=ds_path("flowers102"),
|
|
split="train" if train else "test",
|
|
download=False,
|
|
transform=transform,
|
|
)
|
|
n_classes, collate = 102, None
|
|
|
|
elif dataset_name == "food-101":
|
|
dataset = Food101(
|
|
root=ds_path("food101"),
|
|
split="train" if train else "test",
|
|
download=False,
|
|
transform=transform,
|
|
)
|
|
n_classes, collate = 101, None
|
|
|
|
elif dataset_name == "fgvc-aircraft":
|
|
dataset = FGVCAircraft(
|
|
root=ds_path("aircraft"),
|
|
split="train" if train else "test",
|
|
annotation_level="variant",
|
|
download=False,
|
|
transform=transform,
|
|
)
|
|
n_classes, collate = 100, None
|
|
|
|
elif dataset_name == "imagenet":
|
|
dataset = ImageFolder(os.path.join(ds_path("imagenet1k"), "train" if train else "val"), transform=transform)
|
|
n_classes, collate = 1000, None
|
|
|
|
elif dataset_name == "tinyimagenet":
|
|
dataset = ImageFolder(os.path.join(ds_path("tinyimagenet"), "train" if train else "val"), transform=transform)
|
|
n_classes, collate = 200, None
|
|
|
|
elif dataset_name.startswith("fornet"):
|
|
ds_def = dataset_name.split("/")
|
|
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
|
|
pruning_ratio = 0.8 if len(ds_def) < 3 else float(ds_def[2])
|
|
fg_size_mode = ("range" if train else "max") if len(ds_def) < 4 else ds_def[3]
|
|
paste_pre_transform = True if len(ds_def) < 5 else ds_def[4] in ["y", "t"]
|
|
orig_img_prob = (
|
|
0.0 if len(ds_def) < 6 else (ds_def[5] if ds_def[5] in ["linear", "revlinear", "cos"] else float(ds_def[5]))
|
|
)
|
|
mask_smoothing_sigma = 0.0 if len(ds_def) < 7 else float(ds_def[6])
|
|
assert len(ds_def) < 5 or ds_def[4] in [
|
|
"y",
|
|
"t",
|
|
"n",
|
|
"f",
|
|
], f"Invalid dataset definition: {ds_def[4]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
|
|
|
|
orig_ds = ds_path("imagenet1k")
|
|
|
|
dataset = ForNet(
|
|
ds_path("fornet"),
|
|
train=train,
|
|
background_combination=comb_scheme,
|
|
pruning_ratio=pruning_ratio,
|
|
transform=transform,
|
|
fg_transform=(
|
|
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
|
|
),
|
|
fg_size_mode=fg_size_mode,
|
|
paste_pre_transform=paste_pre_transform,
|
|
orig_img_prob=orig_img_prob,
|
|
orig_ds=orig_ds,
|
|
mask_smoothing_sigma=mask_smoothing_sigma,
|
|
epochs=args.epochs,
|
|
_album_compose=args.augment_engine == "albumentations",
|
|
)
|
|
n_classes, collate = 1000, None
|
|
|
|
elif dataset_name.startswith("tinyfornet"):
|
|
ds_def = dataset_name.split("/")
|
|
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
|
|
pruning_ratio = 1.1 if len(ds_def) < 3 else float(ds_def[2])
|
|
fg_size_mode = "range" if len(ds_def) < 4 else ds_def[3]
|
|
fg_bates_n = 1 if len(ds_def) < 5 else int(ds_def[4])
|
|
paste_pre_transform = False if len(ds_def) < 6 else ds_def[5] in ["y", "t"]
|
|
orig_img_prob = (
|
|
0.0 if len(ds_def) < 7 else (ds_def[6] if ds_def[6] in ["linear", "revlinear", "cos"] else float(ds_def[6]))
|
|
)
|
|
mask_smoothing_sigma = 0.0 if len(ds_def) < 8 else float(ds_def[7])
|
|
assert len(ds_def) < 6 or ds_def[5] in [
|
|
"y",
|
|
"t",
|
|
"n",
|
|
"f",
|
|
], f"Invalid dataset definition: {ds_def[5]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
|
|
assert "-" not in ds_def[0] or len(ds_def[0].split("-")) == 3, f"Invalid dataset definition: {ds_def[0]}"
|
|
version = "" if "-" not in ds_def[0] else f"_v{ds_def[0].split('-')[1]}_f{ds_def[0].split('-')[2]}"
|
|
|
|
orig_ds = ds_path("tinyimagenet")
|
|
|
|
dataset = ForNet(
|
|
f"{ds_path('tinyimagenet')}{version}",
|
|
train=train,
|
|
background_combination=comb_scheme,
|
|
pruning_ratio=pruning_ratio,
|
|
transform=transform,
|
|
fg_transform=(
|
|
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
|
|
),
|
|
fg_size_mode=fg_size_mode,
|
|
fg_bates_n=fg_bates_n,
|
|
paste_pre_transform=paste_pre_transform,
|
|
orig_img_prob=orig_img_prob,
|
|
orig_ds=orig_ds,
|
|
mask_smoothing_sigma=mask_smoothing_sigma,
|
|
epochs=args.epochs,
|
|
_album_compose=args.augment_engine == "albumentations",
|
|
)
|
|
n_classes, collate = 200, None
|
|
|
|
elif dataset_name.startswith("counteranimal/"):
|
|
mode = dataset_name.split("/")[1]
|
|
|
|
dataset = CounterAnimal(ds_path("counteranimal"), mode=mode, transform=transform, train=train)
|
|
n_classes, collate = 1000, None
|
|
|
|
elif dataset_name.startswith("imagenet9/"):
|
|
variant = dataset_name.split("/")[1]
|
|
assert variant in [
|
|
"next",
|
|
"same",
|
|
"rand",
|
|
], f"ImageNet-9 has possible variants next, same, and rand, but not '{variant}'."
|
|
|
|
dataset = ImageFolder(os.path.join(ds_path("imagenet9"), f"mixed_{variant}", "val"), transform=transform)
|
|
n_classes, collate = 9, None
|
|
|
|
else:
|
|
raise NotImplementedError(f"Dataset {dataset_name} is not implemented (yet).")
|
|
|
|
if args.aug_repeated_augment_repeats > 1 and train:
|
|
# use repeated augment sampler from DeiT
|
|
sampler = RASampler(
|
|
dataset,
|
|
num_replicas=args.world_size,
|
|
rank=rank,
|
|
shuffle=args.shuffle,
|
|
num_repeats=args.aug_repeated_augment_repeats,
|
|
)
|
|
elif args.weighted_sampler:
|
|
assert hasattr(
|
|
dataset, "per_sample_weights"
|
|
), f"Dataset {type(dataset)} should implement per_sample_weights function, but does not."
|
|
|
|
sampler = WeightedRandomSampler(dataset.per_sample_weights(), num_samples=len(dataset) // args.world_size)
|
|
elif args.distributed:
|
|
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank, shuffle=train and args.shuffle)
|
|
else:
|
|
sampler = None
|
|
|
|
loader_batch_size = 1 if dataset_name.startswith("listops") else args.batch_size
|
|
|
|
loader_kwargs = dict(
|
|
batch_size=loader_batch_size,
|
|
pin_memory=args.pin_memory,
|
|
num_workers=args.num_workers,
|
|
drop_last=train,
|
|
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
|
persistent_workers=False,
|
|
collate_fn=collate,
|
|
shuffle=None if sampler else train and args.shuffle,
|
|
sampler=sampler,
|
|
)
|
|
|
|
if args.augment_engine == "dali":
|
|
data_loader = dali_proxy.DataLoader(dali_server, dataset, **loader_kwargs)
|
|
else:
|
|
data_loader = DataLoader(dataset, **loader_kwargs)
|
|
|
|
return data_loader, n_classes, ignore_index, multi_label, dali_server
|