AAAI Version
This commit is contained in:
346
AAAI Supplementary Material/Model Training Code/load_dataset.py
Normal file
346
AAAI Supplementary Material/Model Training Code/load_dataset.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Module to load the datasets, using torch and datadings."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import torchvision.transforms as tv_transforms
|
||||
from datadings.reader import MsgpackReader
|
||||
from timm.data import create_transform
|
||||
from torch.utils.data import DataLoader, DistributedSampler, WeightedRandomSampler
|
||||
from torchvision.datasets import (
|
||||
CIFAR10,
|
||||
CIFAR100,
|
||||
FGVCAircraft,
|
||||
Flowers102,
|
||||
Food101,
|
||||
ImageFolder,
|
||||
OxfordIIITPet,
|
||||
StanfordCars,
|
||||
)
|
||||
|
||||
from data.counter_animal import CounterAnimal
|
||||
from data.data_utils import (
|
||||
DDDecodeDataset,
|
||||
ToOneHotSequence,
|
||||
collate_imnet,
|
||||
collate_listops,
|
||||
get_hf_transform,
|
||||
minimal_augment,
|
||||
segment_augment,
|
||||
three_augment,
|
||||
)
|
||||
from data.fornet import ForNet
|
||||
from data.samplers import RASampler
|
||||
from paths_config import ds_path
|
||||
|
||||
|
||||
def prepare_dataset(dataset_name, args, transform=None, train=True, rank=None):
|
||||
"""Load a dataset from disk, different formats are used for different datasets.
|
||||
|
||||
Supported datasets: CIFAR10, ImageNet, ImageNet21k
|
||||
|
||||
Args:
|
||||
dataset_name (str): name of the dataset
|
||||
args: further arguments
|
||||
transform (list[Module] | str, optional): transformations to use on the data; the list gets composed, or give args.augment_strategy (Default value = None)
|
||||
train (bool, optional): use the training split (or test/validation split) (Default value = True)
|
||||
rank (int, optional): global rank of this process in distributed training (Default value = None)
|
||||
|
||||
Returns:
|
||||
DataLoader: data loader for the dataset
|
||||
int: number of classes in the dataset
|
||||
int: ignore index for the dataset
|
||||
bool: whether the dataset is multi-label
|
||||
|
||||
"""
|
||||
compose = tv_transforms.Compose
|
||||
dali_server = None
|
||||
if transform is None:
|
||||
if args.augment_engine == "torchvision":
|
||||
if args.augment_strategy == "3-augment":
|
||||
transform = three_augment(args, as_list=False, test=not train)
|
||||
elif args.augment_strategy == "differentiable-transform":
|
||||
from data.distilled_dataset import differentiable_augment
|
||||
|
||||
transform = differentiable_augment(args, as_list=False, test=not train)
|
||||
elif args.augment_strategy == "none":
|
||||
transform = []
|
||||
elif args.augment_strategy == "lm_one_hot":
|
||||
transform = [
|
||||
tv_transforms.Grayscale(num_output_channels=1),
|
||||
tv_transforms.ToTensor(),
|
||||
ToOneHotSequence(),
|
||||
]
|
||||
elif args.augment_strategy == "segment-augment":
|
||||
transform = segment_augment(args, test=not train)
|
||||
elif args.augment_strategy == "minimal":
|
||||
transform = minimal_augment(args, test=not train)
|
||||
elif args.augment_strategy == "deit":
|
||||
if train:
|
||||
transform = create_transform(
|
||||
input_size=args.imsize,
|
||||
is_training=True,
|
||||
color_jitter=args.aug_color_jitter_factor,
|
||||
auto_augment=args.auto_augment_strategy,
|
||||
interpolation="bicubic",
|
||||
re_prob=args.aug_random_erase_prob,
|
||||
re_mode=args.aug_random_erase_mode,
|
||||
re_count=args.aug_random_erase_count,
|
||||
)
|
||||
else:
|
||||
transform = three_augment(args, test=True) # only do resize, centercrop, and normalize
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
||||
)
|
||||
elif args.augment_engine == "albumentations":
|
||||
from data import album_transf as ATf
|
||||
|
||||
compose = ATf.AlbumTorchCompose
|
||||
|
||||
if args.augment_strategy == "3-augment":
|
||||
transform = ATf.three_augment(args, as_list=False, test=not train)
|
||||
elif args.augment_strategy == "minimal":
|
||||
transform = ATf.minimal_augment(args, test=not train)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
||||
)
|
||||
elif args.augment_engine == "dali":
|
||||
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
|
||||
|
||||
from data import dali_transf as DTf
|
||||
|
||||
dev_id = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if args.augment_strategy == "3-augment":
|
||||
pipe = DTf.three_augment(
|
||||
args,
|
||||
test=not train,
|
||||
batch_size=args.batch_size,
|
||||
num_threads=args.num_workers,
|
||||
device_id=dev_id,
|
||||
)
|
||||
elif args.augment_strategy == "minimal":
|
||||
pipe = DTf.minimal_augment(
|
||||
args,
|
||||
test=not train,
|
||||
batch_size=args.batch_size,
|
||||
num_threads=args.num_workers,
|
||||
device_id=dev_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
||||
)
|
||||
|
||||
dali_server = dali_proxy.DALIServer(pipe)
|
||||
transform = dali_server.proxy
|
||||
|
||||
dataset_name_case_sensitive = dataset_name # keep the original name for AnimalNet folder
|
||||
dataset_name = dataset_name.lower()
|
||||
ignore_index = -100
|
||||
multi_label = False
|
||||
|
||||
if isinstance(transform, list):
|
||||
transform = compose(transform)
|
||||
|
||||
if dataset_name == "cifar10":
|
||||
dataset = CIFAR10(root=ds_path("cifar"), train=train, download=False, transform=transform)
|
||||
n_classes, collate = 10, None
|
||||
|
||||
elif dataset_name == "stanford-cars":
|
||||
dataset = StanfordCars(
|
||||
root=ds_path("stanford_cars"),
|
||||
split="train" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 196, None
|
||||
|
||||
elif dataset_name == "oxford-pet":
|
||||
dataset = OxfordIIITPet(
|
||||
root=ds_path("oxford_pet"),
|
||||
split="trainval" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 37, None
|
||||
|
||||
elif dataset_name == "flowers102":
|
||||
dataset = Flowers102(
|
||||
root=ds_path("flowers102"),
|
||||
split="train" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 102, None
|
||||
|
||||
elif dataset_name == "food-101":
|
||||
dataset = Food101(
|
||||
root=ds_path("food101"),
|
||||
split="train" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 101, None
|
||||
|
||||
elif dataset_name == "fgvc-aircraft":
|
||||
dataset = FGVCAircraft(
|
||||
root=ds_path("aircraft"),
|
||||
split="train" if train else "test",
|
||||
annotation_level="variant",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 100, None
|
||||
|
||||
elif dataset_name == "imagenet":
|
||||
dataset = ImageFolder(os.path.join(ds_path("imagenet1k"), "train" if train else "val"), transform=transform)
|
||||
n_classes, collate = 1000, None
|
||||
|
||||
elif dataset_name == "tinyimagenet":
|
||||
dataset = ImageFolder(os.path.join(ds_path("tinyimagenet"), "train" if train else "val"), transform=transform)
|
||||
n_classes, collate = 200, None
|
||||
|
||||
elif dataset_name.startswith("fornet"):
|
||||
ds_def = dataset_name.split("/")
|
||||
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
|
||||
pruning_ratio = 0.8 if len(ds_def) < 3 else float(ds_def[2])
|
||||
fg_size_mode = ("range" if train else "max") if len(ds_def) < 4 else ds_def[3]
|
||||
paste_pre_transform = True if len(ds_def) < 5 else ds_def[4] in ["y", "t"]
|
||||
orig_img_prob = (
|
||||
0.0 if len(ds_def) < 6 else (ds_def[5] if ds_def[5] in ["linear", "revlinear", "cos"] else float(ds_def[5]))
|
||||
)
|
||||
mask_smoothing_sigma = 0.0 if len(ds_def) < 7 else float(ds_def[6])
|
||||
assert len(ds_def) < 5 or ds_def[4] in [
|
||||
"y",
|
||||
"t",
|
||||
"n",
|
||||
"f",
|
||||
], f"Invalid dataset definition: {ds_def[4]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
|
||||
|
||||
orig_ds = ds_path("imagenet1k")
|
||||
|
||||
dataset = ForNet(
|
||||
ds_path("fornet"),
|
||||
train=train,
|
||||
background_combination=comb_scheme,
|
||||
pruning_ratio=pruning_ratio,
|
||||
transform=transform,
|
||||
fg_transform=(
|
||||
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
|
||||
),
|
||||
fg_size_mode=fg_size_mode,
|
||||
paste_pre_transform=paste_pre_transform,
|
||||
orig_img_prob=orig_img_prob,
|
||||
orig_ds=orig_ds,
|
||||
mask_smoothing_sigma=mask_smoothing_sigma,
|
||||
epochs=args.epochs,
|
||||
_album_compose=args.augment_engine == "albumentations",
|
||||
)
|
||||
n_classes, collate = 1000, None
|
||||
|
||||
elif dataset_name.startswith("tinyfornet"):
|
||||
ds_def = dataset_name.split("/")
|
||||
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
|
||||
pruning_ratio = 1.1 if len(ds_def) < 3 else float(ds_def[2])
|
||||
fg_size_mode = "range" if len(ds_def) < 4 else ds_def[3]
|
||||
fg_bates_n = 1 if len(ds_def) < 5 else int(ds_def[4])
|
||||
paste_pre_transform = False if len(ds_def) < 6 else ds_def[5] in ["y", "t"]
|
||||
orig_img_prob = (
|
||||
0.0 if len(ds_def) < 7 else (ds_def[6] if ds_def[6] in ["linear", "revlinear", "cos"] else float(ds_def[6]))
|
||||
)
|
||||
mask_smoothing_sigma = 0.0 if len(ds_def) < 8 else float(ds_def[7])
|
||||
assert len(ds_def) < 6 or ds_def[5] in [
|
||||
"y",
|
||||
"t",
|
||||
"n",
|
||||
"f",
|
||||
], f"Invalid dataset definition: {ds_def[5]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
|
||||
assert "-" not in ds_def[0] or len(ds_def[0].split("-")) == 3, f"Invalid dataset definition: {ds_def[0]}"
|
||||
version = "" if "-" not in ds_def[0] else f"_v{ds_def[0].split('-')[1]}_f{ds_def[0].split('-')[2]}"
|
||||
|
||||
orig_ds = ds_path("tinyimagenet")
|
||||
|
||||
dataset = ForNet(
|
||||
f"{ds_path('tinyimagenet')}{version}",
|
||||
train=train,
|
||||
background_combination=comb_scheme,
|
||||
pruning_ratio=pruning_ratio,
|
||||
transform=transform,
|
||||
fg_transform=(
|
||||
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
|
||||
),
|
||||
fg_size_mode=fg_size_mode,
|
||||
fg_bates_n=fg_bates_n,
|
||||
paste_pre_transform=paste_pre_transform,
|
||||
orig_img_prob=orig_img_prob,
|
||||
orig_ds=orig_ds,
|
||||
mask_smoothing_sigma=mask_smoothing_sigma,
|
||||
epochs=args.epochs,
|
||||
_album_compose=args.augment_engine == "albumentations",
|
||||
)
|
||||
n_classes, collate = 200, None
|
||||
|
||||
elif dataset_name.startswith("counteranimal/"):
|
||||
mode = dataset_name.split("/")[1]
|
||||
|
||||
dataset = CounterAnimal(ds_path("counteranimal"), mode=mode, transform=transform, train=train)
|
||||
n_classes, collate = 1000, None
|
||||
|
||||
elif dataset_name.startswith("imagenet9/"):
|
||||
variant = dataset_name.split("/")[1]
|
||||
assert variant in [
|
||||
"next",
|
||||
"same",
|
||||
"rand",
|
||||
], f"ImageNet-9 has possible variants next, same, and rand, but not '{variant}'."
|
||||
|
||||
dataset = ImageFolder(os.path.join(ds_path("imagenet9"), f"mixed_{variant}", "val"), transform=transform)
|
||||
n_classes, collate = 9, None
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Dataset {dataset_name} is not implemented (yet).")
|
||||
|
||||
if args.aug_repeated_augment_repeats > 1 and train:
|
||||
# use repeated augment sampler from DeiT
|
||||
sampler = RASampler(
|
||||
dataset,
|
||||
num_replicas=args.world_size,
|
||||
rank=rank,
|
||||
shuffle=args.shuffle,
|
||||
num_repeats=args.aug_repeated_augment_repeats,
|
||||
)
|
||||
elif args.weighted_sampler:
|
||||
assert hasattr(
|
||||
dataset, "per_sample_weights"
|
||||
), f"Dataset {type(dataset)} should implement per_sample_weights function, but does not."
|
||||
|
||||
sampler = WeightedRandomSampler(dataset.per_sample_weights(), num_samples=len(dataset) // args.world_size)
|
||||
elif args.distributed:
|
||||
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank, shuffle=train and args.shuffle)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
loader_batch_size = 1 if dataset_name.startswith("listops") else args.batch_size
|
||||
|
||||
loader_kwargs = dict(
|
||||
batch_size=loader_batch_size,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
drop_last=train,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
persistent_workers=False,
|
||||
collate_fn=collate,
|
||||
shuffle=None if sampler else train and args.shuffle,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
if args.augment_engine == "dali":
|
||||
data_loader = dali_proxy.DataLoader(dali_server, dataset, **loader_kwargs)
|
||||
else:
|
||||
data_loader = DataLoader(dataset, **loader_kwargs)
|
||||
|
||||
return data_loader, n_classes, ignore_index, multi_label, dali_server
|
||||
Reference in New Issue
Block a user