AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,346 @@
"""Module to load the datasets, using torch and datadings."""
import contextlib
import os
from functools import partial
import torchvision.transforms as tv_transforms
from datadings.reader import MsgpackReader
from timm.data import create_transform
from torch.utils.data import DataLoader, DistributedSampler, WeightedRandomSampler
from torchvision.datasets import (
CIFAR10,
CIFAR100,
FGVCAircraft,
Flowers102,
Food101,
ImageFolder,
OxfordIIITPet,
StanfordCars,
)
from data.counter_animal import CounterAnimal
from data.data_utils import (
DDDecodeDataset,
ToOneHotSequence,
collate_imnet,
collate_listops,
get_hf_transform,
minimal_augment,
segment_augment,
three_augment,
)
from data.fornet import ForNet
from data.samplers import RASampler
from paths_config import ds_path
def prepare_dataset(dataset_name, args, transform=None, train=True, rank=None):
"""Load a dataset from disk, different formats are used for different datasets.
Supported datasets: CIFAR10, ImageNet, ImageNet21k
Args:
dataset_name (str): name of the dataset
args: further arguments
transform (list[Module] | str, optional): transformations to use on the data; the list gets composed, or give args.augment_strategy (Default value = None)
train (bool, optional): use the training split (or test/validation split) (Default value = True)
rank (int, optional): global rank of this process in distributed training (Default value = None)
Returns:
DataLoader: data loader for the dataset
int: number of classes in the dataset
int: ignore index for the dataset
bool: whether the dataset is multi-label
"""
compose = tv_transforms.Compose
dali_server = None
if transform is None:
if args.augment_engine == "torchvision":
if args.augment_strategy == "3-augment":
transform = three_augment(args, as_list=False, test=not train)
elif args.augment_strategy == "differentiable-transform":
from data.distilled_dataset import differentiable_augment
transform = differentiable_augment(args, as_list=False, test=not train)
elif args.augment_strategy == "none":
transform = []
elif args.augment_strategy == "lm_one_hot":
transform = [
tv_transforms.Grayscale(num_output_channels=1),
tv_transforms.ToTensor(),
ToOneHotSequence(),
]
elif args.augment_strategy == "segment-augment":
transform = segment_augment(args, test=not train)
elif args.augment_strategy == "minimal":
transform = minimal_augment(args, test=not train)
elif args.augment_strategy == "deit":
if train:
transform = create_transform(
input_size=args.imsize,
is_training=True,
color_jitter=args.aug_color_jitter_factor,
auto_augment=args.auto_augment_strategy,
interpolation="bicubic",
re_prob=args.aug_random_erase_prob,
re_mode=args.aug_random_erase_mode,
re_count=args.aug_random_erase_count,
)
else:
transform = three_augment(args, test=True) # only do resize, centercrop, and normalize
else:
raise NotImplementedError(
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
)
elif args.augment_engine == "albumentations":
from data import album_transf as ATf
compose = ATf.AlbumTorchCompose
if args.augment_strategy == "3-augment":
transform = ATf.three_augment(args, as_list=False, test=not train)
elif args.augment_strategy == "minimal":
transform = ATf.minimal_augment(args, test=not train)
else:
raise NotImplementedError(
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
)
elif args.augment_engine == "dali":
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
from data import dali_transf as DTf
dev_id = int(os.environ.get("LOCAL_RANK", 0))
if args.augment_strategy == "3-augment":
pipe = DTf.three_augment(
args,
test=not train,
batch_size=args.batch_size,
num_threads=args.num_workers,
device_id=dev_id,
)
elif args.augment_strategy == "minimal":
pipe = DTf.minimal_augment(
args,
test=not train,
batch_size=args.batch_size,
num_threads=args.num_workers,
device_id=dev_id,
)
else:
raise NotImplementedError(
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
)
dali_server = dali_proxy.DALIServer(pipe)
transform = dali_server.proxy
dataset_name_case_sensitive = dataset_name # keep the original name for AnimalNet folder
dataset_name = dataset_name.lower()
ignore_index = -100
multi_label = False
if isinstance(transform, list):
transform = compose(transform)
if dataset_name == "cifar10":
dataset = CIFAR10(root=ds_path("cifar"), train=train, download=False, transform=transform)
n_classes, collate = 10, None
elif dataset_name == "stanford-cars":
dataset = StanfordCars(
root=ds_path("stanford_cars"),
split="train" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 196, None
elif dataset_name == "oxford-pet":
dataset = OxfordIIITPet(
root=ds_path("oxford_pet"),
split="trainval" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 37, None
elif dataset_name == "flowers102":
dataset = Flowers102(
root=ds_path("flowers102"),
split="train" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 102, None
elif dataset_name == "food-101":
dataset = Food101(
root=ds_path("food101"),
split="train" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 101, None
elif dataset_name == "fgvc-aircraft":
dataset = FGVCAircraft(
root=ds_path("aircraft"),
split="train" if train else "test",
annotation_level="variant",
download=False,
transform=transform,
)
n_classes, collate = 100, None
elif dataset_name == "imagenet":
dataset = ImageFolder(os.path.join(ds_path("imagenet1k"), "train" if train else "val"), transform=transform)
n_classes, collate = 1000, None
elif dataset_name == "tinyimagenet":
dataset = ImageFolder(os.path.join(ds_path("tinyimagenet"), "train" if train else "val"), transform=transform)
n_classes, collate = 200, None
elif dataset_name.startswith("fornet"):
ds_def = dataset_name.split("/")
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
pruning_ratio = 0.8 if len(ds_def) < 3 else float(ds_def[2])
fg_size_mode = ("range" if train else "max") if len(ds_def) < 4 else ds_def[3]
paste_pre_transform = True if len(ds_def) < 5 else ds_def[4] in ["y", "t"]
orig_img_prob = (
0.0 if len(ds_def) < 6 else (ds_def[5] if ds_def[5] in ["linear", "revlinear", "cos"] else float(ds_def[5]))
)
mask_smoothing_sigma = 0.0 if len(ds_def) < 7 else float(ds_def[6])
assert len(ds_def) < 5 or ds_def[4] in [
"y",
"t",
"n",
"f",
], f"Invalid dataset definition: {ds_def[4]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
orig_ds = ds_path("imagenet1k")
dataset = ForNet(
ds_path("fornet"),
train=train,
background_combination=comb_scheme,
pruning_ratio=pruning_ratio,
transform=transform,
fg_transform=(
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
),
fg_size_mode=fg_size_mode,
paste_pre_transform=paste_pre_transform,
orig_img_prob=orig_img_prob,
orig_ds=orig_ds,
mask_smoothing_sigma=mask_smoothing_sigma,
epochs=args.epochs,
_album_compose=args.augment_engine == "albumentations",
)
n_classes, collate = 1000, None
elif dataset_name.startswith("tinyfornet"):
ds_def = dataset_name.split("/")
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
pruning_ratio = 1.1 if len(ds_def) < 3 else float(ds_def[2])
fg_size_mode = "range" if len(ds_def) < 4 else ds_def[3]
fg_bates_n = 1 if len(ds_def) < 5 else int(ds_def[4])
paste_pre_transform = False if len(ds_def) < 6 else ds_def[5] in ["y", "t"]
orig_img_prob = (
0.0 if len(ds_def) < 7 else (ds_def[6] if ds_def[6] in ["linear", "revlinear", "cos"] else float(ds_def[6]))
)
mask_smoothing_sigma = 0.0 if len(ds_def) < 8 else float(ds_def[7])
assert len(ds_def) < 6 or ds_def[5] in [
"y",
"t",
"n",
"f",
], f"Invalid dataset definition: {ds_def[5]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
assert "-" not in ds_def[0] or len(ds_def[0].split("-")) == 3, f"Invalid dataset definition: {ds_def[0]}"
version = "" if "-" not in ds_def[0] else f"_v{ds_def[0].split('-')[1]}_f{ds_def[0].split('-')[2]}"
orig_ds = ds_path("tinyimagenet")
dataset = ForNet(
f"{ds_path('tinyimagenet')}{version}",
train=train,
background_combination=comb_scheme,
pruning_ratio=pruning_ratio,
transform=transform,
fg_transform=(
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
),
fg_size_mode=fg_size_mode,
fg_bates_n=fg_bates_n,
paste_pre_transform=paste_pre_transform,
orig_img_prob=orig_img_prob,
orig_ds=orig_ds,
mask_smoothing_sigma=mask_smoothing_sigma,
epochs=args.epochs,
_album_compose=args.augment_engine == "albumentations",
)
n_classes, collate = 200, None
elif dataset_name.startswith("counteranimal/"):
mode = dataset_name.split("/")[1]
dataset = CounterAnimal(ds_path("counteranimal"), mode=mode, transform=transform, train=train)
n_classes, collate = 1000, None
elif dataset_name.startswith("imagenet9/"):
variant = dataset_name.split("/")[1]
assert variant in [
"next",
"same",
"rand",
], f"ImageNet-9 has possible variants next, same, and rand, but not '{variant}'."
dataset = ImageFolder(os.path.join(ds_path("imagenet9"), f"mixed_{variant}", "val"), transform=transform)
n_classes, collate = 9, None
else:
raise NotImplementedError(f"Dataset {dataset_name} is not implemented (yet).")
if args.aug_repeated_augment_repeats > 1 and train:
# use repeated augment sampler from DeiT
sampler = RASampler(
dataset,
num_replicas=args.world_size,
rank=rank,
shuffle=args.shuffle,
num_repeats=args.aug_repeated_augment_repeats,
)
elif args.weighted_sampler:
assert hasattr(
dataset, "per_sample_weights"
), f"Dataset {type(dataset)} should implement per_sample_weights function, but does not."
sampler = WeightedRandomSampler(dataset.per_sample_weights(), num_samples=len(dataset) // args.world_size)
elif args.distributed:
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank, shuffle=train and args.shuffle)
else:
sampler = None
loader_batch_size = 1 if dataset_name.startswith("listops") else args.batch_size
loader_kwargs = dict(
batch_size=loader_batch_size,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
drop_last=train,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
persistent_workers=False,
collate_fn=collate,
shuffle=None if sampler else train and args.shuffle,
sampler=sampler,
)
if args.augment_engine == "dali":
data_loader = dali_proxy.DataLoader(dali_server, dataset, **loader_kwargs)
else:
data_loader = DataLoader(dataset, **loader_kwargs)
return data_loader, n_classes, ignore_index, multi_label, dali_server