"""Utils and small helper functions.""" import collections.abc import json import os import shutil import sys import warnings from dataclasses import dataclass from itertools import repeat from math import cos, pi, sqrt import numpy as np import torch import torch.distributed as dist from loguru import logger from timm.data import Mixup from timm.utils import NativeScaler, dispatch_clip_grad from torch.nn.modules.loss import _WeightedLoss from torch.utils.data import Dataset from torchvision.transforms import transforms import paths_config from config import default_kwargs, get_default_kwargs # noqa: F401 # is used in prep_kwargs class RepeatedDataset(Dataset): """Dataset that repeats the given dataset a number of times.""" def __init__(self, dataset, num_repeats): """Create repeated dataset. Args: dataset (Dataset): dataset to repeat. num_repeats (int): number of repeats. """ self.dataset = dataset self.num_repeats = num_repeats def __getitem__(self, idx): return self.dataset[idx // self.num_repeats] def __len__(self): return len(self.dataset) * self.num_repeats @dataclass class SchedulerArgs: """Class for scheduler arguments.""" sched: str epochs: int min_lr: float warmup_lr: float warmup_epochs: int cooldown_epochs: int = 0 def scheduler_function_factory( epochs, sched, warmup_epochs=0, lr=None, min_lr=0.0, warmup_sched=None, warmup_lr=None, offset=0, **kwargs ): """Create a scheduler factor function. Args: sched (str): the learning rate schedule type epochs (int): length of the full schedule warmup_epochs (int, optional): number of epochs reserved for warmup (Default value = 0) lr (float, optional): learning rate (has to be given, when warmup or min_lr are set) (Default value = None) min_lr (float, optional): minimum learning rate (Default value = 0.0) warmup_sched (str, optional): the type of schedule during warmup (Default value = None) warmup_lr (float, optional): (starting) learning rate during warmup (Default value = None) offset (int, optional): offset for the schedule (to be the same as the timm scheduler) (Default value = 0) **kwargs: unused Returns: function: scheduler function """ sched = sched.lower() def warmup_f(ep): return 1.0 if warmup_epochs > 0: assert warmup_lr is not None, "Need warmup_lr, but got None" warmup_lr_factor = warmup_lr / lr if warmup_sched == "linear": def warmup_f(ep): return warmup_lr_factor + (1 - warmup_lr_factor) * max(ep, 0.0) / warmup_epochs elif warmup_sched == "const": def warmup_f(ep): return warmup_lr_factor else: raise NotImplementedError(f"Warmup schedule {warmup_sched} not implemented") epochs = epochs - warmup_epochs + offset if sched == "cosine": # cos from 0 to pi def main_f(ep): return cos(pi * ep / epochs) / 2 + 0.5 elif sched == "const": def main_f(ep): return 1.0 else: raise NotImplementedError(f"Schedule {sched} is not implemented.") # rescale and add min_lr min_lr_fact = min_lr / lr def main_f_with_min_lr(ep): return (1 - min_lr_fact) * main_f(ep) + min_lr_fact return lambda ep: ( warmup_f(ep + offset) if ep + offset < warmup_epochs else main_f_with_min_lr(ep + offset - warmup_epochs) ) class DotDict(dict): """Extension of a Python dictionary to access its keys using dot notation.""" __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def __getattr__(self, item, default=None): """Get item from. Args: item: key default (optional): default value. Defaults to None. Returns: value """ if item not in self: return default return self.get(item) def prep_kwargs(kwargs): """Prepare the arguments and add defaults. Args: kwargs (dict[str, Any]): dict of kwargs Returns: DotDict: prepared kwargs """ if "defaults" not in kwargs: kwargs["defaults"] = "DeiTIII" defaults = get_default_kwargs(kwargs["defaults"]) for k, v in defaults.items(): if k not in kwargs or kwargs[k] is None: kwargs[k] = v if "results_folder" not in kwargs: kwargs[var_name] = paths_config.results_folder # globals()[var_name] if kwargs["results_folder"].endswith("/"): kwargs["results_folder"] = kwargs["results_folder"][:-1] if "val_dataset" not in kwargs and "dataset" in kwargs: kwargs["val_dataset"] = kwargs["dataset"] return DotDict(kwargs) def denormalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): """Invert the normlize operation. Args: x (torch.Tensor): images to de-normalize mean (tuple, optional): normalization mean. Defaults to (0.485, 0.456, 0.406). std (tuple, optional): normalization std. Defaults to (0.229, 0.224, 0.225). Returns: torch.Tensor: de-normalized images """ operation = transforms.Normalize( mean=[-mu / sigma for mu, sigma in zip(mean, std)], std=[1 / sigma for sigma in std] ) return operation(x) def log_formatter(record): if "run_name" not in record["extra"]: return ( "{time:YYYY-MM-DD HH:mm:ss.SSS} | name TBD > ?/? | {level:" " <8} | {message}\n" ) epoch_str = "@ epoch {extra[epoch]: >3} " if "epoch" in record["extra"] else "" code_loc_str = "{name}.{function}:{line} - " if record["level"].no >= 30 else "" return ( "{time:YYYY-MM-DD HH:mm:ss.SSS} | {extra[run_name]} >" " {extra[rank]}/{extra[world_size]} " + epoch_str + "| {level: <8} | " + code_loc_str + "{message}\n" ) def ddp_setup(use_cuda=True): """Set up the distributed environment. Args: use_cuda: (Default value = True) Returns: tuple: A tuple containing the following elements: * bool: Whether the training is distributed. * torch.device: The device to use for distributed training. * int: The total number of processes in the distributed setup. * int: The global rank of the current process in the distributed setup. * int: The local rank of the current process on its node. Notes: The 'nccl' backend is used. """ logger.remove() rank = int(os.getenv("RANK", 0)) local_rank = int(os.getenv("LOCAL_RANK", 0)) num_gpus = int(os.getenv("WORLD_SIZE", 1)) distributed = "RANK" in os.environ and num_gpus > 1 logger.add(sys.stderr, format=log_formatter, colorize=True, enqueue=True) if distributed: assert use_cuda, "Only use distributed mode with cuda." try: dist.init_process_group("nccl") except ValueError as e: logger.critical(f"Value error while setting up nccl process group: {e}") logger.info( f" CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}," f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}," f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:" f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}. Shutting down now." ) raise e assert torch.cuda.is_available() or not use_cuda, "CUDA is not available" assert ( len(str(os.environ.get("SLURM_STEP_GPUS")).split(",")) == len(str(os.environ.get("CUDA_VISIBLE_DEVICES")).split(",")) == len(str(os.environ.get("GPU_DEVICE_ORDINAL")).split(",")) == num_gpus ) or not use_cuda, ( f"SLURM GPU setup is incorrect: CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}," f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}," f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:" f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}" ) return distributed, torch.device(f"cuda:{local_rank}") if use_cuda else "cpu", num_gpus, rank, local_rank def ddp_cleanup(args, sync_old_wandb=False, rank=0): """Clean the distributed setup after use. Args: args (DotDict): arguments sync_old_wandb (bool, optional): Whether to sync and remove wandb runs older than 100 hours (>3 days). Defaults to False. rank (int, optional): The rank of the current process, so only one process syncs wandb. Defaults to 0. """ if sync_old_wandb and rank == 0: os.system("wandb sync --clean --clean-old-hours 100 --clean-force") if args.distributed: logger.info("waiting for all processes to finish") dist.barrier() dist.destroy_process_group() logger.info("exiting now") exit(0) def set_filter_warnings(): """Filter out some warnings to reduce spam.""" # filter DataLoader number of workers warning warnings.filterwarnings( "ignore", ".*worker processes in total. Our suggested max number of worker in current system is.*" ) # Filter datadings only varargs warning warnings.filterwarnings("ignore", ".*only accepts varargs so.*") # Filter warnings from calculation of MACs & FLOPs # warnings.filterwarnings("ignore", ".*No handlers found:.*") # Filter warnings from gather warnings.filterwarnings("ignore", ".*is_namedtuple is deprecated, please use the python checks instead.*") # Filter warnings from meshgrid warnings.filterwarnings("ignore", ".*in an upcoming release, it will be required to pass the indexing.*") # Filter warnings from timm when overwriting models warnings.filterwarnings("ignore", ".*UserWarning: Overwriting .*") def remove_prefix(state_dict, prefix="module."): """Remove a prefix from the keys in a state dictionary. Args: state_dict (dict[str, Any]): The state dictionary to remove the prefix from. prefix (str, optional): The prefix to remove from the keys. Default is 'module.'. Returns: dict[str, Any]: A new dictionary with the prefix removed from the keys. Examples: >>> state_dict = {'module.layer1.weight': 1, 'module.layer1.bias': 2} >>> remove_prefix(state_dict) {'layer1.weight': 1, 'layer1.bias': 2} """ return {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()} def prime_factors(n): """Calculate the prime factors of a given integer. Args: n (int): The integer to find the prime factors of. Returns: list[int]: The prime factors of n. """ i = 2 factors = [] while i * i <= n: if n % i: i += 1 else: n //= i factors.append(i) if n > 1: factors.append(n) return factors def linear_regession(points): """Calculate a linear interpolation of the points. Args: points (dict[float, float]): points to interpolate in the format points[x] = y Returns: function: A function that interpolates the points. """ N = len(points) x = [] y = [] for x_i, y_i in points.items(): x.append(x_i) y.append(y_i) x = np.array(x) y = np.array(y) a = (N * (x * y).sum() - x.sum() * y.sum()) / (N * (x * x).sum() - x.sum() ** 2) b = (y.sum() - a * x.sum()) / N return lambda z: a * z + b def save_model_state( model_folder, epoch, args, model_state, regular_save=True, stats=None, val_accs=None, epoch_accs=None, additional_reason="", max_interm_ep_states=2, **kwargs, ): """Save the model state. Args: model_folder: Folder to save model in epoch: current epoch args: arguments to guide and save model_state: state of the model regular_save: Is this a regular or a special save? (Default value = True) stats: model stats to save (Default value = None) val_accs: model accuracy to save (Default value = None) epoch_accs: training accuracy to save (Default value = None) additional_reason: save reason; in case it's not just a regular save interval. Would be "top" or "final", for example. (Default value = "") max_interm_ep_states: Number of regular epoch states to keep (Default value = 2) **kwargs: Further arguments to save """ # make args dict, not DotDict to be able to save it state = {"epoch": epoch, "model_state": model_state, "run_name": args.run_name, "args": dict(args)} if stats is None: stats = {} if val_accs is not None: stats = {**stats, **val_accs} if epoch_accs is not None: stats = {**stats, **epoch_accs} state["stats"] = stats state = {**state, **kwargs} logger.info(f"saving model state at epoch {epoch} ({additional_reason})") regular_file_name = f"ep_{epoch}.pt" save_name = additional_reason + ".pt" if len(additional_reason) > 0 else regular_file_name outfile = os.path.join(model_folder, save_name) torch.save(state, outfile) if len(additional_reason) > 0 and regular_save: shutil.copyfile(outfile, os.path.join(model_folder, regular_file_name)) # remove intermediate epoch states (all but the last max_interm_ep_states) if max_interm_ep_states > 0: epoch_states = [f for f in os.listdir(model_folder) if f.startswith("ep_") and f.endswith(".pt")] epoch_states = sorted(epoch_states, key=lambda x: int(x.split("_")[1].split(".")[0])) if len(epoch_states) > max_interm_ep_states: for f in epoch_states[:-max_interm_ep_states]: os.remove(os.path.join(model_folder, f)) logger.debug(f"removed intermediate epoch state {f}") def log_args(args, rank=0): if rank == 0: logger.info("full set of arguments: " + json.dumps(dict(args), sort_keys=True)) # keys = sorted(list(args.keys())) # for key in keys: # logger.info(f"arg: {key} = {args[key]}") class ScalerGradNormReturn(NativeScaler): """A wrapper around PyTorch's NativeScaler that returns the gradient norm.""" def __str__(self) -> str: return f"{type(self).__name__}(_scaler: {self._scaler})" def __call__(self, loss, optimizer, clip_grad=None, clip_mode="norm", parameters=None, create_graph=False): """Scale and backpropagate through the loss tensor, and return the gradient norm of the selected parameters. Does an optimizer step. Args: loss (torch.Tensor): The loss tensor to scale and backpropagate through. optimizer (torch.optim.Optimizer): The optimizer to use for the optimization step. clip_grad (float, optional): The maximum allowed norm of the gradients. If None, no clipping is performed. clip_mode (str, optional): The mode used for clipping the gradients. Only used if `clip_grad` is not None. Possible values are 'norm' (clipping the norm of the gradients) and 'value' (clipping the value of the gradients). (default='norm') parameters (iterable[torch.nn.Parameter], optional): The parameters to compute the gradient norm for. If None, the gradient norm is not computed. create_graph (bool, optional): Whether to create a computation graph for computing second-order gradients. (default=False) Returns: float: The gradient norm of the selected parameters. """ self._scaler.scale(loss).backward(create_graph=create_graph) # always unscale the gradients, since it's being done anyway self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place if parameters is not None: grads = [p.grad for p in parameters if p.grad is not None] device = grads[0].device grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) else: grad_norm = -1 if clip_grad is not None: dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) self._scaler.step(optimizer) self._scaler.update() return grad_norm class NoScaler: """Dummy gradient scaler that doesn't scale gradients. This scaler performs a simple backward pass with the given loss, and then updates the model's parameters with the given optimizer. The resulting gradient norm is computed and returned. """ def __str__(self) -> str: return f"{type(self).__name__}()" def __call__(self, loss, optimizer, parameters=None, **kwargs): """Perform backward pass with the given loss, updates the model's parameters with the given optimizer, and computes the resulting gradient norm. Args: loss (torch.Tensor): The loss tensor that the gradients will be computed from. optimizer (torch.optim.Optimizer): The optimizer that will be used to update the model's parameters. parameters (iterable[torch.Tensor], optional): An iterable of model parameters to compute gradients. If None, returns -1. **kwargs: Additional keyword arguments; nothing will be done with these. Returns: float: The gradient norm computed after the optimizer step, if parameters is not None. """ loss.backward() if parameters is not None: grads = [p.grad for p in parameters if p.grad is not None] device = grads[0].device grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) else: grad_norm = -1 optimizer.step() return grad_norm def get_cpu_name(): """Get the name of the CPU.""" with open("/proc/cpuinfo", "r") as f: for line in f: if line.startswith("model name"): return line.split(":")[1].strip() return "unknown" # From PyTorch internals def _ntuple(n): """Make a function to create n-tuples. Args: n (int): tuple length Returns: function: function to create n-tuples """ def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse to_2tuple = _ntuple(2) def make_divisible(v, divisor=8, min_value=None, round_limit=0.9): """Calculate the smallest number >= v that is divisible by divisor. This function is primarily used to ensure that the output of a layer is divisible by a certain number, typically to align with hardware optimizations or memory layouts. Args: v (int): The input value. divisor (int, optional): The divisor. Defaults to 8. min_value (int, optional): The minimum value to return. If None, defaults to the divisor. round_limit (float, optional): A threshold for rounding down. If the result of rounding down is less than round_limit * v, the next multiple of the divisor is returned instead. Defaults to 0.9. Returns: int: The smallest number >= v that is divisible by divisor. """ min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < round_limit * v: new_v += divisor return new_v def grad_cam_reshape_transform(tensor): """Transform the tensor for Grad-CAM calculation. Args: tensor (torch.Tensor): input tensor Returns: torch.Tensor: reshaped tensor without [CLS] token. """ n_squ = tensor.shape[1] result = tensor[:, 1:] if int(sqrt(n_squ)) ** 2 != n_squ else tensor bs, n, dim = result.shape result = result.reshape(bs, int(sqrt(n)), int(sqrt(n)), dim) # Bring the channels to the first dimension, # like in CNNs. return result.transpose(2, 3).transpose(1, 2)