AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,594 @@
"""Utils and small helper functions."""
import collections.abc
import json
import os
import shutil
import sys
import warnings
from dataclasses import dataclass
from itertools import repeat
from math import cos, pi, sqrt
import numpy as np
import torch
import torch.distributed as dist
from loguru import logger
from timm.data import Mixup
from timm.utils import NativeScaler, dispatch_clip_grad
from torch.nn.modules.loss import _WeightedLoss
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import paths_config
from config import default_kwargs, get_default_kwargs # noqa: F401 # is used in prep_kwargs
class RepeatedDataset(Dataset):
"""Dataset that repeats the given dataset a number of times."""
def __init__(self, dataset, num_repeats):
"""Create repeated dataset.
Args:
dataset (Dataset): dataset to repeat.
num_repeats (int): number of repeats.
"""
self.dataset = dataset
self.num_repeats = num_repeats
def __getitem__(self, idx):
return self.dataset[idx // self.num_repeats]
def __len__(self):
return len(self.dataset) * self.num_repeats
@dataclass
class SchedulerArgs:
"""Class for scheduler arguments."""
sched: str
epochs: int
min_lr: float
warmup_lr: float
warmup_epochs: int
cooldown_epochs: int = 0
def scheduler_function_factory(
epochs, sched, warmup_epochs=0, lr=None, min_lr=0.0, warmup_sched=None, warmup_lr=None, offset=0, **kwargs
):
"""Create a scheduler factor function.
Args:
sched (str): the learning rate schedule type
epochs (int): length of the full schedule
warmup_epochs (int, optional): number of epochs reserved for warmup (Default value = 0)
lr (float, optional): learning rate (has to be given, when warmup or min_lr are set) (Default value = None)
min_lr (float, optional): minimum learning rate (Default value = 0.0)
warmup_sched (str, optional): the type of schedule during warmup (Default value = None)
warmup_lr (float, optional): (starting) learning rate during warmup (Default value = None)
offset (int, optional): offset for the schedule (to be the same as the timm scheduler) (Default value = 0)
**kwargs: unused
Returns:
function: scheduler function
"""
sched = sched.lower()
def warmup_f(ep):
return 1.0
if warmup_epochs > 0:
assert warmup_lr is not None, "Need warmup_lr, but got None"
warmup_lr_factor = warmup_lr / lr
if warmup_sched == "linear":
def warmup_f(ep):
return warmup_lr_factor + (1 - warmup_lr_factor) * max(ep, 0.0) / warmup_epochs
elif warmup_sched == "const":
def warmup_f(ep):
return warmup_lr_factor
else:
raise NotImplementedError(f"Warmup schedule {warmup_sched} not implemented")
epochs = epochs - warmup_epochs + offset
if sched == "cosine":
# cos from 0 to pi
def main_f(ep):
return cos(pi * ep / epochs) / 2 + 0.5
elif sched == "const":
def main_f(ep):
return 1.0
else:
raise NotImplementedError(f"Schedule {sched} is not implemented.")
# rescale and add min_lr
min_lr_fact = min_lr / lr
def main_f_with_min_lr(ep):
return (1 - min_lr_fact) * main_f(ep) + min_lr_fact
return lambda ep: (
warmup_f(ep + offset) if ep + offset < warmup_epochs else main_f_with_min_lr(ep + offset - warmup_epochs)
)
class DotDict(dict):
"""Extension of a Python dictionary to access its keys using dot notation."""
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __getattr__(self, item, default=None):
"""Get item from.
Args:
item: key
default (optional): default value. Defaults to None.
Returns:
value
"""
if item not in self:
return default
return self.get(item)
def prep_kwargs(kwargs):
"""Prepare the arguments and add defaults.
Args:
kwargs (dict[str, Any]): dict of kwargs
Returns:
DotDict: prepared kwargs
"""
if "defaults" not in kwargs:
kwargs["defaults"] = "DeiTIII"
defaults = get_default_kwargs(kwargs["defaults"])
for k, v in defaults.items():
if k not in kwargs or kwargs[k] is None:
kwargs[k] = v
if "results_folder" not in kwargs:
kwargs[var_name] = paths_config.results_folder # globals()[var_name]
if kwargs["results_folder"].endswith("/"):
kwargs["results_folder"] = kwargs["results_folder"][:-1]
if "val_dataset" not in kwargs and "dataset" in kwargs:
kwargs["val_dataset"] = kwargs["dataset"]
return DotDict(kwargs)
def denormalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
"""Invert the normlize operation.
Args:
x (torch.Tensor): images to de-normalize
mean (tuple, optional): normalization mean. Defaults to (0.485, 0.456, 0.406).
std (tuple, optional): normalization std. Defaults to (0.229, 0.224, 0.225).
Returns:
torch.Tensor: de-normalized images
"""
operation = transforms.Normalize(
mean=[-mu / sigma for mu, sigma in zip(mean, std)], std=[1 / sigma for sigma in std]
)
return operation(x)
def log_formatter(record):
if "run_name" not in record["extra"]:
return (
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <y>name TBD</y> > <y>?</y>/<y>?</y> <c>|</c> <level>{level:"
" <8}</level> <c>|</c> {message}\n"
)
epoch_str = "@ epoch <y>{extra[epoch]: >3}</y> " if "epoch" in record["extra"] else ""
code_loc_str = "<r>{name}</r>.<r>{function}</r>:<r>{line}</r> - " if record["level"].no >= 30 else ""
return (
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <m>{extra[run_name]}</m> >"
" <m>{extra[rank]}</m>/<m>{extra[world_size]}</m> "
+ epoch_str
+ "<c>|</c> <level>{level: <8}</level> <c>|</c> "
+ code_loc_str
+ "{message}\n"
)
def ddp_setup(use_cuda=True):
"""Set up the distributed environment.
Args:
use_cuda: (Default value = True)
Returns:
tuple: A tuple containing the following elements:
* bool: Whether the training is distributed.
* torch.device: The device to use for distributed training.
* int: The total number of processes in the distributed setup.
* int: The global rank of the current process in the distributed setup.
* int: The local rank of the current process on its node.
Notes:
The 'nccl' backend is used.
"""
logger.remove()
rank = int(os.getenv("RANK", 0))
local_rank = int(os.getenv("LOCAL_RANK", 0))
num_gpus = int(os.getenv("WORLD_SIZE", 1))
distributed = "RANK" in os.environ and num_gpus > 1
logger.add(sys.stderr, format=log_formatter, colorize=True, enqueue=True)
if distributed:
assert use_cuda, "Only use distributed mode with cuda."
try:
dist.init_process_group("nccl")
except ValueError as e:
logger.critical(f"Value error while setting up nccl process group: {e}")
logger.info(
f" CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}. Shutting down now."
)
raise e
assert torch.cuda.is_available() or not use_cuda, "CUDA is not available"
assert (
len(str(os.environ.get("SLURM_STEP_GPUS")).split(","))
== len(str(os.environ.get("CUDA_VISIBLE_DEVICES")).split(","))
== len(str(os.environ.get("GPU_DEVICE_ORDINAL")).split(","))
== num_gpus
) or not use_cuda, (
f"SLURM GPU setup is incorrect: CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}"
)
return distributed, torch.device(f"cuda:{local_rank}") if use_cuda else "cpu", num_gpus, rank, local_rank
def ddp_cleanup(args, sync_old_wandb=False, rank=0):
"""Clean the distributed setup after use.
Args:
args (DotDict): arguments
sync_old_wandb (bool, optional): Whether to sync and remove wandb runs older than 100 hours (>3 days). Defaults to False.
rank (int, optional): The rank of the current process, so only one process syncs wandb. Defaults to 0.
"""
if sync_old_wandb and rank == 0:
os.system("wandb sync --clean --clean-old-hours 100 --clean-force")
if args.distributed:
logger.info("waiting for all processes to finish")
dist.barrier()
dist.destroy_process_group()
logger.info("exiting now")
exit(0)
def set_filter_warnings():
"""Filter out some warnings to reduce spam."""
# filter DataLoader number of workers warning
warnings.filterwarnings(
"ignore", ".*worker processes in total. Our suggested max number of worker in current system is.*"
)
# Filter datadings only varargs warning
warnings.filterwarnings("ignore", ".*only accepts varargs so.*")
# Filter warnings from calculation of MACs & FLOPs
# warnings.filterwarnings("ignore", ".*No handlers found:.*")
# Filter warnings from gather
warnings.filterwarnings("ignore", ".*is_namedtuple is deprecated, please use the python checks instead.*")
# Filter warnings from meshgrid
warnings.filterwarnings("ignore", ".*in an upcoming release, it will be required to pass the indexing.*")
# Filter warnings from timm when overwriting models
warnings.filterwarnings("ignore", ".*UserWarning: Overwriting .*")
def remove_prefix(state_dict, prefix="module."):
"""Remove a prefix from the keys in a state dictionary.
Args:
state_dict (dict[str, Any]): The state dictionary to remove the prefix from.
prefix (str, optional): The prefix to remove from the keys. Default is 'module.'.
Returns:
dict[str, Any]: A new dictionary with the prefix removed from the keys.
Examples:
>>> state_dict = {'module.layer1.weight': 1, 'module.layer1.bias': 2}
>>> remove_prefix(state_dict)
{'layer1.weight': 1, 'layer1.bias': 2}
"""
return {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
def prime_factors(n):
"""Calculate the prime factors of a given integer.
Args:
n (int): The integer to find the prime factors of.
Returns:
list[int]: The prime factors of n.
"""
i = 2
factors = []
while i * i <= n:
if n % i:
i += 1
else:
n //= i
factors.append(i)
if n > 1:
factors.append(n)
return factors
def linear_regession(points):
"""Calculate a linear interpolation of the points.
Args:
points (dict[float, float]): points to interpolate in the format points[x] = y
Returns:
function: A function that interpolates the points.
"""
N = len(points)
x = []
y = []
for x_i, y_i in points.items():
x.append(x_i)
y.append(y_i)
x = np.array(x)
y = np.array(y)
a = (N * (x * y).sum() - x.sum() * y.sum()) / (N * (x * x).sum() - x.sum() ** 2)
b = (y.sum() - a * x.sum()) / N
return lambda z: a * z + b
def save_model_state(
model_folder,
epoch,
args,
model_state,
regular_save=True,
stats=None,
val_accs=None,
epoch_accs=None,
additional_reason="",
max_interm_ep_states=2,
**kwargs,
):
"""Save the model state.
Args:
model_folder: Folder to save model in
epoch: current epoch
args: arguments to guide and save
model_state: state of the model
regular_save: Is this a regular or a special save? (Default value = True)
stats: model stats to save (Default value = None)
val_accs: model accuracy to save (Default value = None)
epoch_accs: training accuracy to save (Default value = None)
additional_reason: save reason; in case it's not just a regular save interval. Would be "top" or "final", for example. (Default value = "")
max_interm_ep_states: Number of regular epoch states to keep (Default value = 2)
**kwargs: Further arguments to save
"""
# make args dict, not DotDict to be able to save it
state = {"epoch": epoch, "model_state": model_state, "run_name": args.run_name, "args": dict(args)}
if stats is None:
stats = {}
if val_accs is not None:
stats = {**stats, **val_accs}
if epoch_accs is not None:
stats = {**stats, **epoch_accs}
state["stats"] = stats
state = {**state, **kwargs}
logger.info(f"saving model state at epoch {epoch} ({additional_reason})")
regular_file_name = f"ep_{epoch}.pt"
save_name = additional_reason + ".pt" if len(additional_reason) > 0 else regular_file_name
outfile = os.path.join(model_folder, save_name)
torch.save(state, outfile)
if len(additional_reason) > 0 and regular_save:
shutil.copyfile(outfile, os.path.join(model_folder, regular_file_name))
# remove intermediate epoch states (all but the last max_interm_ep_states)
if max_interm_ep_states > 0:
epoch_states = [f for f in os.listdir(model_folder) if f.startswith("ep_") and f.endswith(".pt")]
epoch_states = sorted(epoch_states, key=lambda x: int(x.split("_")[1].split(".")[0]))
if len(epoch_states) > max_interm_ep_states:
for f in epoch_states[:-max_interm_ep_states]:
os.remove(os.path.join(model_folder, f))
logger.debug(f"removed intermediate epoch state {f}")
def log_args(args, rank=0):
if rank == 0:
logger.info("full set of arguments: " + json.dumps(dict(args), sort_keys=True))
# keys = sorted(list(args.keys()))
# for key in keys:
# logger.info(f"arg: {key} = {args[key]}")
class ScalerGradNormReturn(NativeScaler):
"""A wrapper around PyTorch's NativeScaler that returns the gradient norm."""
def __str__(self) -> str:
return f"{type(self).__name__}(_scaler: {self._scaler})"
def __call__(self, loss, optimizer, clip_grad=None, clip_mode="norm", parameters=None, create_graph=False):
"""Scale and backpropagate through the loss tensor, and return the gradient norm of the selected parameters.
Does an optimizer step.
Args:
loss (torch.Tensor): The loss tensor to scale and backpropagate through.
optimizer (torch.optim.Optimizer): The optimizer to use for the optimization step.
clip_grad (float, optional): The maximum allowed norm of the gradients. If None, no clipping is performed.
clip_mode (str, optional): The mode used for clipping the gradients. Only used if `clip_grad` is not None. Possible values are 'norm'
(clipping the norm of the gradients) and 'value' (clipping the value of the gradients). (default='norm')
parameters (iterable[torch.nn.Parameter], optional): The parameters to compute the gradient norm for. If None, the gradient norm is not computed.
create_graph (bool, optional): Whether to create a computation graph for computing second-order gradients. (default=False)
Returns:
float: The gradient norm of the selected parameters.
"""
self._scaler.scale(loss).backward(create_graph=create_graph)
# always unscale the gradients, since it's being done anyway
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
if parameters is not None:
grads = [p.grad for p in parameters if p.grad is not None]
device = grads[0].device
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
else:
grad_norm = -1
if clip_grad is not None:
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()
return grad_norm
class NoScaler:
"""Dummy gradient scaler that doesn't scale gradients.
This scaler performs a simple backward pass with the given loss, and then updates the model's parameters
with the given optimizer. The resulting gradient norm is computed and returned.
"""
def __str__(self) -> str:
return f"{type(self).__name__}()"
def __call__(self, loss, optimizer, parameters=None, **kwargs):
"""Perform backward pass with the given loss, updates the model's parameters with the given optimizer, and computes the resulting gradient norm.
Args:
loss (torch.Tensor): The loss tensor that the gradients will be computed from.
optimizer (torch.optim.Optimizer): The optimizer that will be used to update the model's parameters.
parameters (iterable[torch.Tensor], optional): An iterable of model parameters to compute gradients. If None, returns -1.
**kwargs: Additional keyword arguments; nothing will be done with these.
Returns:
float: The gradient norm computed after the optimizer step, if parameters is not None.
"""
loss.backward()
if parameters is not None:
grads = [p.grad for p in parameters if p.grad is not None]
device = grads[0].device
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
else:
grad_norm = -1
optimizer.step()
return grad_norm
def get_cpu_name():
"""Get the name of the CPU."""
with open("/proc/cpuinfo", "r") as f:
for line in f:
if line.startswith("model name"):
return line.split(":")[1].strip()
return "unknown"
# From PyTorch internals
def _ntuple(n):
"""Make a function to create n-tuples.
Args:
n (int): tuple length
Returns:
function: function to create n-tuples
"""
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
"""Calculate the smallest number >= v that is divisible by divisor.
This function is primarily used to ensure that the output of a layer
is divisible by a certain number, typically to align with hardware
optimizations or memory layouts.
Args:
v (int): The input value.
divisor (int, optional): The divisor. Defaults to 8.
min_value (int, optional): The minimum value to return. If None, defaults to the divisor.
round_limit (float, optional): A threshold for rounding down. If the result of rounding down is less than round_limit * v, the next multiple of the divisor is returned instead. Defaults to 0.9.
Returns:
int: The smallest number >= v that is divisible by divisor.
"""
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < round_limit * v:
new_v += divisor
return new_v
def grad_cam_reshape_transform(tensor):
"""Transform the tensor for Grad-CAM calculation.
Args:
tensor (torch.Tensor): input tensor
Returns:
torch.Tensor: reshaped tensor without [CLS] token.
"""
n_squ = tensor.shape[1]
result = tensor[:, 1:] if int(sqrt(n_squ)) ** 2 != n_squ else tensor
bs, n, dim = result.shape
result = result.reshape(bs, int(sqrt(n)), int(sqrt(n)), dim)
# Bring the channels to the first dimension,
# like in CNNs.
return result.transpose(2, 3).transpose(1, 2)