595 lines
20 KiB
Python
595 lines
20 KiB
Python
"""Utils and small helper functions."""
|
|
|
|
import collections.abc
|
|
import json
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from itertools import repeat
|
|
from math import cos, pi, sqrt
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
from loguru import logger
|
|
from timm.data import Mixup
|
|
from timm.utils import NativeScaler, dispatch_clip_grad
|
|
from torch.nn.modules.loss import _WeightedLoss
|
|
from torch.utils.data import Dataset
|
|
from torchvision.transforms import transforms
|
|
|
|
import paths_config
|
|
from config import default_kwargs, get_default_kwargs # noqa: F401 # is used in prep_kwargs
|
|
|
|
|
|
class RepeatedDataset(Dataset):
|
|
"""Dataset that repeats the given dataset a number of times."""
|
|
|
|
def __init__(self, dataset, num_repeats):
|
|
"""Create repeated dataset.
|
|
|
|
Args:
|
|
dataset (Dataset): dataset to repeat.
|
|
num_repeats (int): number of repeats.
|
|
|
|
"""
|
|
self.dataset = dataset
|
|
self.num_repeats = num_repeats
|
|
|
|
def __getitem__(self, idx):
|
|
return self.dataset[idx // self.num_repeats]
|
|
|
|
def __len__(self):
|
|
return len(self.dataset) * self.num_repeats
|
|
|
|
|
|
@dataclass
|
|
class SchedulerArgs:
|
|
"""Class for scheduler arguments."""
|
|
|
|
sched: str
|
|
epochs: int
|
|
min_lr: float
|
|
warmup_lr: float
|
|
warmup_epochs: int
|
|
cooldown_epochs: int = 0
|
|
|
|
|
|
def scheduler_function_factory(
|
|
epochs, sched, warmup_epochs=0, lr=None, min_lr=0.0, warmup_sched=None, warmup_lr=None, offset=0, **kwargs
|
|
):
|
|
"""Create a scheduler factor function.
|
|
|
|
Args:
|
|
sched (str): the learning rate schedule type
|
|
epochs (int): length of the full schedule
|
|
warmup_epochs (int, optional): number of epochs reserved for warmup (Default value = 0)
|
|
lr (float, optional): learning rate (has to be given, when warmup or min_lr are set) (Default value = None)
|
|
min_lr (float, optional): minimum learning rate (Default value = 0.0)
|
|
warmup_sched (str, optional): the type of schedule during warmup (Default value = None)
|
|
warmup_lr (float, optional): (starting) learning rate during warmup (Default value = None)
|
|
offset (int, optional): offset for the schedule (to be the same as the timm scheduler) (Default value = 0)
|
|
**kwargs: unused
|
|
|
|
Returns:
|
|
function: scheduler function
|
|
|
|
"""
|
|
sched = sched.lower()
|
|
|
|
def warmup_f(ep):
|
|
return 1.0
|
|
|
|
if warmup_epochs > 0:
|
|
assert warmup_lr is not None, "Need warmup_lr, but got None"
|
|
warmup_lr_factor = warmup_lr / lr
|
|
if warmup_sched == "linear":
|
|
|
|
def warmup_f(ep):
|
|
return warmup_lr_factor + (1 - warmup_lr_factor) * max(ep, 0.0) / warmup_epochs
|
|
|
|
elif warmup_sched == "const":
|
|
|
|
def warmup_f(ep):
|
|
return warmup_lr_factor
|
|
|
|
else:
|
|
raise NotImplementedError(f"Warmup schedule {warmup_sched} not implemented")
|
|
|
|
epochs = epochs - warmup_epochs + offset
|
|
if sched == "cosine":
|
|
# cos from 0 to pi
|
|
def main_f(ep):
|
|
return cos(pi * ep / epochs) / 2 + 0.5
|
|
|
|
elif sched == "const":
|
|
|
|
def main_f(ep):
|
|
return 1.0
|
|
|
|
else:
|
|
raise NotImplementedError(f"Schedule {sched} is not implemented.")
|
|
|
|
# rescale and add min_lr
|
|
min_lr_fact = min_lr / lr
|
|
|
|
def main_f_with_min_lr(ep):
|
|
return (1 - min_lr_fact) * main_f(ep) + min_lr_fact
|
|
|
|
return lambda ep: (
|
|
warmup_f(ep + offset) if ep + offset < warmup_epochs else main_f_with_min_lr(ep + offset - warmup_epochs)
|
|
)
|
|
|
|
|
|
class DotDict(dict):
|
|
"""Extension of a Python dictionary to access its keys using dot notation."""
|
|
|
|
__setattr__ = dict.__setitem__
|
|
__delattr__ = dict.__delitem__
|
|
|
|
def __getattr__(self, item, default=None):
|
|
"""Get item from.
|
|
|
|
Args:
|
|
item: key
|
|
default (optional): default value. Defaults to None.
|
|
|
|
Returns:
|
|
value
|
|
|
|
"""
|
|
if item not in self:
|
|
return default
|
|
return self.get(item)
|
|
|
|
|
|
def prep_kwargs(kwargs):
|
|
"""Prepare the arguments and add defaults.
|
|
|
|
Args:
|
|
kwargs (dict[str, Any]): dict of kwargs
|
|
|
|
Returns:
|
|
DotDict: prepared kwargs
|
|
|
|
"""
|
|
if "defaults" not in kwargs:
|
|
kwargs["defaults"] = "DeiTIII"
|
|
defaults = get_default_kwargs(kwargs["defaults"])
|
|
for k, v in defaults.items():
|
|
if k not in kwargs or kwargs[k] is None:
|
|
kwargs[k] = v
|
|
|
|
if "results_folder" not in kwargs:
|
|
kwargs[var_name] = paths_config.results_folder # globals()[var_name]
|
|
|
|
if kwargs["results_folder"].endswith("/"):
|
|
kwargs["results_folder"] = kwargs["results_folder"][:-1]
|
|
|
|
if "val_dataset" not in kwargs and "dataset" in kwargs:
|
|
kwargs["val_dataset"] = kwargs["dataset"]
|
|
|
|
return DotDict(kwargs)
|
|
|
|
|
|
def denormalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
|
"""Invert the normlize operation.
|
|
|
|
Args:
|
|
x (torch.Tensor): images to de-normalize
|
|
mean (tuple, optional): normalization mean. Defaults to (0.485, 0.456, 0.406).
|
|
std (tuple, optional): normalization std. Defaults to (0.229, 0.224, 0.225).
|
|
|
|
Returns:
|
|
torch.Tensor: de-normalized images
|
|
|
|
"""
|
|
operation = transforms.Normalize(
|
|
mean=[-mu / sigma for mu, sigma in zip(mean, std)], std=[1 / sigma for sigma in std]
|
|
)
|
|
return operation(x)
|
|
|
|
|
|
def log_formatter(record):
|
|
if "run_name" not in record["extra"]:
|
|
return (
|
|
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <y>name TBD</y> > <y>?</y>/<y>?</y> <c>|</c> <level>{level:"
|
|
" <8}</level> <c>|</c> {message}\n"
|
|
)
|
|
|
|
epoch_str = "@ epoch <y>{extra[epoch]: >3}</y> " if "epoch" in record["extra"] else ""
|
|
code_loc_str = "<r>{name}</r>.<r>{function}</r>:<r>{line}</r> - " if record["level"].no >= 30 else ""
|
|
|
|
return (
|
|
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <m>{extra[run_name]}</m> >"
|
|
" <m>{extra[rank]}</m>/<m>{extra[world_size]}</m> "
|
|
+ epoch_str
|
|
+ "<c>|</c> <level>{level: <8}</level> <c>|</c> "
|
|
+ code_loc_str
|
|
+ "{message}\n"
|
|
)
|
|
|
|
|
|
def ddp_setup(use_cuda=True):
|
|
"""Set up the distributed environment.
|
|
|
|
Args:
|
|
use_cuda: (Default value = True)
|
|
|
|
Returns:
|
|
tuple: A tuple containing the following elements:
|
|
* bool: Whether the training is distributed.
|
|
* torch.device: The device to use for distributed training.
|
|
* int: The total number of processes in the distributed setup.
|
|
* int: The global rank of the current process in the distributed setup.
|
|
* int: The local rank of the current process on its node.
|
|
|
|
Notes:
|
|
The 'nccl' backend is used.
|
|
|
|
"""
|
|
logger.remove()
|
|
rank = int(os.getenv("RANK", 0))
|
|
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
|
num_gpus = int(os.getenv("WORLD_SIZE", 1))
|
|
distributed = "RANK" in os.environ and num_gpus > 1
|
|
logger.add(sys.stderr, format=log_formatter, colorize=True, enqueue=True)
|
|
if distributed:
|
|
assert use_cuda, "Only use distributed mode with cuda."
|
|
try:
|
|
dist.init_process_group("nccl")
|
|
except ValueError as e:
|
|
logger.critical(f"Value error while setting up nccl process group: {e}")
|
|
logger.info(
|
|
f" CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
|
|
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
|
|
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
|
|
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}. Shutting down now."
|
|
)
|
|
raise e
|
|
|
|
assert torch.cuda.is_available() or not use_cuda, "CUDA is not available"
|
|
assert (
|
|
len(str(os.environ.get("SLURM_STEP_GPUS")).split(","))
|
|
== len(str(os.environ.get("CUDA_VISIBLE_DEVICES")).split(","))
|
|
== len(str(os.environ.get("GPU_DEVICE_ORDINAL")).split(","))
|
|
== num_gpus
|
|
) or not use_cuda, (
|
|
f"SLURM GPU setup is incorrect: CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
|
|
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
|
|
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
|
|
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}"
|
|
)
|
|
return distributed, torch.device(f"cuda:{local_rank}") if use_cuda else "cpu", num_gpus, rank, local_rank
|
|
|
|
|
|
def ddp_cleanup(args, sync_old_wandb=False, rank=0):
|
|
"""Clean the distributed setup after use.
|
|
|
|
Args:
|
|
args (DotDict): arguments
|
|
sync_old_wandb (bool, optional): Whether to sync and remove wandb runs older than 100 hours (>3 days). Defaults to False.
|
|
rank (int, optional): The rank of the current process, so only one process syncs wandb. Defaults to 0.
|
|
|
|
"""
|
|
if sync_old_wandb and rank == 0:
|
|
os.system("wandb sync --clean --clean-old-hours 100 --clean-force")
|
|
|
|
if args.distributed:
|
|
logger.info("waiting for all processes to finish")
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
logger.info("exiting now")
|
|
exit(0)
|
|
|
|
|
|
def set_filter_warnings():
|
|
"""Filter out some warnings to reduce spam."""
|
|
# filter DataLoader number of workers warning
|
|
warnings.filterwarnings(
|
|
"ignore", ".*worker processes in total. Our suggested max number of worker in current system is.*"
|
|
)
|
|
|
|
# Filter datadings only varargs warning
|
|
warnings.filterwarnings("ignore", ".*only accepts varargs so.*")
|
|
|
|
# Filter warnings from calculation of MACs & FLOPs
|
|
# warnings.filterwarnings("ignore", ".*No handlers found:.*")
|
|
|
|
# Filter warnings from gather
|
|
warnings.filterwarnings("ignore", ".*is_namedtuple is deprecated, please use the python checks instead.*")
|
|
|
|
# Filter warnings from meshgrid
|
|
warnings.filterwarnings("ignore", ".*in an upcoming release, it will be required to pass the indexing.*")
|
|
|
|
# Filter warnings from timm when overwriting models
|
|
warnings.filterwarnings("ignore", ".*UserWarning: Overwriting .*")
|
|
|
|
|
|
def remove_prefix(state_dict, prefix="module."):
|
|
"""Remove a prefix from the keys in a state dictionary.
|
|
|
|
Args:
|
|
state_dict (dict[str, Any]): The state dictionary to remove the prefix from.
|
|
prefix (str, optional): The prefix to remove from the keys. Default is 'module.'.
|
|
|
|
Returns:
|
|
dict[str, Any]: A new dictionary with the prefix removed from the keys.
|
|
|
|
Examples:
|
|
>>> state_dict = {'module.layer1.weight': 1, 'module.layer1.bias': 2}
|
|
|
|
>>> remove_prefix(state_dict)
|
|
|
|
{'layer1.weight': 1, 'layer1.bias': 2}
|
|
|
|
"""
|
|
return {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
|
|
|
|
|
|
def prime_factors(n):
|
|
"""Calculate the prime factors of a given integer.
|
|
|
|
Args:
|
|
n (int): The integer to find the prime factors of.
|
|
|
|
Returns:
|
|
list[int]: The prime factors of n.
|
|
|
|
"""
|
|
i = 2
|
|
factors = []
|
|
while i * i <= n:
|
|
if n % i:
|
|
i += 1
|
|
else:
|
|
n //= i
|
|
factors.append(i)
|
|
if n > 1:
|
|
factors.append(n)
|
|
return factors
|
|
|
|
|
|
def linear_regession(points):
|
|
"""Calculate a linear interpolation of the points.
|
|
|
|
Args:
|
|
points (dict[float, float]): points to interpolate in the format points[x] = y
|
|
|
|
Returns:
|
|
function: A function that interpolates the points.
|
|
|
|
"""
|
|
N = len(points)
|
|
x = []
|
|
y = []
|
|
for x_i, y_i in points.items():
|
|
x.append(x_i)
|
|
y.append(y_i)
|
|
x = np.array(x)
|
|
y = np.array(y)
|
|
|
|
a = (N * (x * y).sum() - x.sum() * y.sum()) / (N * (x * x).sum() - x.sum() ** 2)
|
|
b = (y.sum() - a * x.sum()) / N
|
|
return lambda z: a * z + b
|
|
|
|
|
|
def save_model_state(
|
|
model_folder,
|
|
epoch,
|
|
args,
|
|
model_state,
|
|
regular_save=True,
|
|
stats=None,
|
|
val_accs=None,
|
|
epoch_accs=None,
|
|
additional_reason="",
|
|
max_interm_ep_states=2,
|
|
**kwargs,
|
|
):
|
|
"""Save the model state.
|
|
|
|
Args:
|
|
model_folder: Folder to save model in
|
|
epoch: current epoch
|
|
args: arguments to guide and save
|
|
model_state: state of the model
|
|
regular_save: Is this a regular or a special save? (Default value = True)
|
|
stats: model stats to save (Default value = None)
|
|
val_accs: model accuracy to save (Default value = None)
|
|
epoch_accs: training accuracy to save (Default value = None)
|
|
additional_reason: save reason; in case it's not just a regular save interval. Would be "top" or "final", for example. (Default value = "")
|
|
max_interm_ep_states: Number of regular epoch states to keep (Default value = 2)
|
|
**kwargs: Further arguments to save
|
|
|
|
"""
|
|
# make args dict, not DotDict to be able to save it
|
|
state = {"epoch": epoch, "model_state": model_state, "run_name": args.run_name, "args": dict(args)}
|
|
if stats is None:
|
|
stats = {}
|
|
if val_accs is not None:
|
|
stats = {**stats, **val_accs}
|
|
if epoch_accs is not None:
|
|
stats = {**stats, **epoch_accs}
|
|
state["stats"] = stats
|
|
state = {**state, **kwargs}
|
|
logger.info(f"saving model state at epoch {epoch} ({additional_reason})")
|
|
regular_file_name = f"ep_{epoch}.pt"
|
|
save_name = additional_reason + ".pt" if len(additional_reason) > 0 else regular_file_name
|
|
outfile = os.path.join(model_folder, save_name)
|
|
torch.save(state, outfile)
|
|
if len(additional_reason) > 0 and regular_save:
|
|
shutil.copyfile(outfile, os.path.join(model_folder, regular_file_name))
|
|
|
|
# remove intermediate epoch states (all but the last max_interm_ep_states)
|
|
if max_interm_ep_states > 0:
|
|
epoch_states = [f for f in os.listdir(model_folder) if f.startswith("ep_") and f.endswith(".pt")]
|
|
epoch_states = sorted(epoch_states, key=lambda x: int(x.split("_")[1].split(".")[0]))
|
|
if len(epoch_states) > max_interm_ep_states:
|
|
for f in epoch_states[:-max_interm_ep_states]:
|
|
os.remove(os.path.join(model_folder, f))
|
|
logger.debug(f"removed intermediate epoch state {f}")
|
|
|
|
|
|
def log_args(args, rank=0):
|
|
if rank == 0:
|
|
logger.info("full set of arguments: " + json.dumps(dict(args), sort_keys=True))
|
|
# keys = sorted(list(args.keys()))
|
|
# for key in keys:
|
|
# logger.info(f"arg: {key} = {args[key]}")
|
|
|
|
|
|
class ScalerGradNormReturn(NativeScaler):
|
|
"""A wrapper around PyTorch's NativeScaler that returns the gradient norm."""
|
|
|
|
def __str__(self) -> str:
|
|
return f"{type(self).__name__}(_scaler: {self._scaler})"
|
|
|
|
def __call__(self, loss, optimizer, clip_grad=None, clip_mode="norm", parameters=None, create_graph=False):
|
|
"""Scale and backpropagate through the loss tensor, and return the gradient norm of the selected parameters.
|
|
|
|
Does an optimizer step.
|
|
|
|
Args:
|
|
loss (torch.Tensor): The loss tensor to scale and backpropagate through.
|
|
optimizer (torch.optim.Optimizer): The optimizer to use for the optimization step.
|
|
clip_grad (float, optional): The maximum allowed norm of the gradients. If None, no clipping is performed.
|
|
clip_mode (str, optional): The mode used for clipping the gradients. Only used if `clip_grad` is not None. Possible values are 'norm'
|
|
(clipping the norm of the gradients) and 'value' (clipping the value of the gradients). (default='norm')
|
|
parameters (iterable[torch.nn.Parameter], optional): The parameters to compute the gradient norm for. If None, the gradient norm is not computed.
|
|
create_graph (bool, optional): Whether to create a computation graph for computing second-order gradients. (default=False)
|
|
|
|
Returns:
|
|
float: The gradient norm of the selected parameters.
|
|
|
|
"""
|
|
self._scaler.scale(loss).backward(create_graph=create_graph)
|
|
|
|
# always unscale the gradients, since it's being done anyway
|
|
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
|
if parameters is not None:
|
|
grads = [p.grad for p in parameters if p.grad is not None]
|
|
device = grads[0].device
|
|
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
|
else:
|
|
grad_norm = -1
|
|
if clip_grad is not None:
|
|
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
|
|
self._scaler.step(optimizer)
|
|
self._scaler.update()
|
|
return grad_norm
|
|
|
|
|
|
class NoScaler:
|
|
"""Dummy gradient scaler that doesn't scale gradients.
|
|
|
|
This scaler performs a simple backward pass with the given loss, and then updates the model's parameters
|
|
with the given optimizer. The resulting gradient norm is computed and returned.
|
|
|
|
"""
|
|
|
|
def __str__(self) -> str:
|
|
return f"{type(self).__name__}()"
|
|
|
|
def __call__(self, loss, optimizer, parameters=None, **kwargs):
|
|
"""Perform backward pass with the given loss, updates the model's parameters with the given optimizer, and computes the resulting gradient norm.
|
|
|
|
Args:
|
|
loss (torch.Tensor): The loss tensor that the gradients will be computed from.
|
|
optimizer (torch.optim.Optimizer): The optimizer that will be used to update the model's parameters.
|
|
parameters (iterable[torch.Tensor], optional): An iterable of model parameters to compute gradients. If None, returns -1.
|
|
**kwargs: Additional keyword arguments; nothing will be done with these.
|
|
|
|
Returns:
|
|
float: The gradient norm computed after the optimizer step, if parameters is not None.
|
|
|
|
"""
|
|
loss.backward()
|
|
if parameters is not None:
|
|
grads = [p.grad for p in parameters if p.grad is not None]
|
|
device = grads[0].device
|
|
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
|
else:
|
|
grad_norm = -1
|
|
optimizer.step()
|
|
return grad_norm
|
|
|
|
|
|
def get_cpu_name():
|
|
"""Get the name of the CPU."""
|
|
with open("/proc/cpuinfo", "r") as f:
|
|
for line in f:
|
|
if line.startswith("model name"):
|
|
return line.split(":")[1].strip()
|
|
return "unknown"
|
|
|
|
|
|
# From PyTorch internals
|
|
def _ntuple(n):
|
|
"""Make a function to create n-tuples.
|
|
|
|
Args:
|
|
n (int): tuple length
|
|
|
|
Returns:
|
|
function: function to create n-tuples
|
|
|
|
"""
|
|
|
|
def parse(x):
|
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
|
return tuple(x)
|
|
return tuple(repeat(x, n))
|
|
|
|
return parse
|
|
|
|
|
|
to_2tuple = _ntuple(2)
|
|
|
|
|
|
def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
|
|
"""Calculate the smallest number >= v that is divisible by divisor.
|
|
|
|
This function is primarily used to ensure that the output of a layer
|
|
is divisible by a certain number, typically to align with hardware
|
|
optimizations or memory layouts.
|
|
|
|
Args:
|
|
v (int): The input value.
|
|
divisor (int, optional): The divisor. Defaults to 8.
|
|
min_value (int, optional): The minimum value to return. If None, defaults to the divisor.
|
|
round_limit (float, optional): A threshold for rounding down. If the result of rounding down is less than round_limit * v, the next multiple of the divisor is returned instead. Defaults to 0.9.
|
|
|
|
Returns:
|
|
int: The smallest number >= v that is divisible by divisor.
|
|
|
|
"""
|
|
min_value = min_value or divisor
|
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
|
# Make sure that round down does not go down by more than 10%.
|
|
if new_v < round_limit * v:
|
|
new_v += divisor
|
|
return new_v
|
|
|
|
|
|
def grad_cam_reshape_transform(tensor):
|
|
"""Transform the tensor for Grad-CAM calculation.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): input tensor
|
|
|
|
Returns:
|
|
torch.Tensor: reshaped tensor without [CLS] token.
|
|
|
|
"""
|
|
n_squ = tensor.shape[1]
|
|
result = tensor[:, 1:] if int(sqrt(n_squ)) ** 2 != n_squ else tensor
|
|
bs, n, dim = result.shape
|
|
result = result.reshape(bs, int(sqrt(n)), int(sqrt(n)), dim)
|
|
|
|
# Bring the channels to the first dimension,
|
|
# like in CNNs.
|
|
return result.transpose(2, 3).transpose(1, 2)
|