AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,834 @@
import json
import os
import sys
from datetime import datetime
from functools import partial
from math import isfinite
from time import time
import torch
import torch.nn as nn
import torch.optim as optim
from loguru import logger
from timm.data import Mixup
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm.auto import tqdm
from metrics import calculate_metrics, per_class_counts
from utils import (
NoScaler,
ScalerGradNormReturn,
SchedulerArgs,
log_formatter,
save_model_state,
)
try:
from apex.optimizers import FusedLAMB # noqa: F401
apex_available = True
except ImportError:
logger.error("Nvidia apex not available")
apex_available = False
try:
from lion_pytorch import Lion
lion_available = True
except ImportError:
logger.error("Lion not available")
lion_available = False
WANDB_AVAILABLE = False
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
logger.error("wandb not available")
def wandb_available(turn_off=False):
"""If wandb is available.
Args:
turn_off (bool, optional): set wandb to be unavailble manually.
Returns:
bool: wandb is available
"""
global WANDB_AVAILABLE
if turn_off:
WANDB_AVAILABLE = False
return WANDB_AVAILABLE
tqdm = partial(tqdm, leave=True, position=0) # noqa: F405
def setup_tracking_and_logging(args, rank, append_model_path=None, log_wandb=True):
"""Set up logging and tracking for an experiment.
Args:
args (DotDict): Parsed command-line arguments
rank (int): The rank of the current process
append_model_path (str, optional): Path of an existing model, by default None
log_wandb (bool, optional): Whether to log to wandb, by default True
Returns:
str: folder, where all the run data is saved.
Raises:
AssertionError: If `dataset` or `model` is `None`.
Notes:
This function sets up logger to stdout and file, as well as MLflow tracking for an experiment.
For wandb logger, provide .wandb.apikey in the current directory.
"""
dataset, model, epochs = args.dataset.replace(os.sep, "_").lower(), args.model.replace(os.sep, "_"), args.epochs
_base_folder = (
os.path.join(args.results_folder, args.experiment_name, args.task.replace("-", ""), dataset)
if args.out_dir is None
else args.out_dir
)
run_folder = os.path.join(
_base_folder,
f"{args.run_name.replace(os.sep, '_')}_{model}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}",
)
assert dataset is not None and model is not None
if os.name == "nt":
run_folder = run_folder.replace("@", "_").replace(" ", "_").replace(":", ".")
if append_model_path is not None:
run_folder = os.path.dirname(append_model_path)
if "run_name" not in args or args.run_name is None:
args.run_name = run_folder.split(os.sep)[-1].split("_")[0]
elif args.distributed:
obj_list = [None]
if rank == 0:
obj_list[0] = run_folder
dist.broadcast_object_list(obj_list, src=0)
run_folder = obj_list[0]
if rank == 0:
os.makedirs(run_folder, exist_ok=True)
dist.barrier()
elif rank == 0:
os.makedirs(run_folder, exist_ok=True)
assert "%" not in args.run_name, f"found '%' in run_name '{args.run_name}'. This messes with string formatting..."
if args.debug:
args.log_level = "debug"
# logger to stdout & file
log_name = args.task.replace("-", "")
if args.task not in ["pre-train", "fine-tune", "fine-tune-head"]:
log_name += f"_{dataset}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}"
log_file = os.path.join(run_folder, f"{log_name}.log")
logger.remove()
logger.configure(extra=dict(run_name=args.run_name, rank=rank, world_size=args.world_size))
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.info(f"Run folder '{run_folder}'")
if rank == 0:
logger.info(f"{args.task.replace('-', '').capitalize()} {model} on {dataset} for {epochs} epochs")
global WANDB_AVAILABLE
WANDB_AVAILABLE = WANDB_AVAILABLE and log_wandb and os.path.isfile(".wandb.apikey") and args.wandb
if WANDB_AVAILABLE:
with open(".wandb.apikey", "r") as f:
__wandb_api_key = f.read().strip()
wandb.login(key=__wandb_api_key)
if args.wandb_run_id is not None:
wandb_args = dict(project=args.experiment_name, resume="must", id=args.wandb_run_id)
else:
wandb_args = dict(
project=args.experiment_name,
name=args.run_name.replace("_", "-").replace(" ", "-"),
config={"logfile": log_file, **dict(args)},
job_type=args.task,
tags=[model, dataset],
resume="allow",
id=args.wandb_run_id,
)
wandb.init(**wandb_args)
args["wandb_run_id"] = wandb.run.id
if rank == 0:
logger.info(f"wandb run initialized with id {args['wandb_run_id']}.")
else:
logger.info(
f"Not using wandb. (args.wandb={args.wandb}, .wandb.apikey exists={os.path.isfile('.wandb.apikey')},"
f" function declaration log_wandb={log_wandb})"
)
if args.distributed:
dist.barrier()
if args.debug:
torch.autograd.set_detect_anomaly(True)
logger.warning("torch.autograd anomaly detection enabled. Will slow down model.")
return run_folder
def setup_model_optim_sched_scaler(model, device, epochs, args, head_only=False):
"""Set up model, optimizer, and scheduler with automatic mixed precision (amp) and distributed data parallel (DDP).
Args:
model (nn.Module): the loaded model
device (torch.device): the current device
epochs (int): total number of epochs to learn for (for scheduler)
args: further arguments
head_only (bool, optional): train only the linear head. Default: False
Returns:
tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, ScalerGradNormReturn]: model, optimizer, scheduler, scaler
"""
model = model.to(device)
if head_only:
for param in model.parameters():
param.requires_grad = False
for param in model.head.parameters():
param.requires_grad = True
for name, param in model.named_parameters():
if "head" in name:
param.requires_grad = True
else:
param.requires_grad = False
params = model.head.parameters()
else:
params = model # model.named_parameters() use model itself for now and let timm do the work...
if args.opt == "lion" and not lion_available:
args.opt = "fusedlamb"
logger.warning("Falling back from lion to fusedlamb")
if args.opt == "fusedlamb" and not apex_available:
args.opt = "adamw"
logger.warning("Falling back from fusedlamb to adamw")
if args.opt == "lion":
optimizer = Lion(params, lr=args["lr"], weight_decay=args["weight_decay"])
else:
optimizer = create_optimizer(args, params)
scaler = ScalerGradNormReturn() if args.amp else NoScaler()
# if args.model_ema:
# # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
# ema_model = ModelEma(model, decay=args.model_ema_decay, resume='')
if args.distributed:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[device])
if args.compile_model:
model = torch.compile(model)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer,
# lr_lambda=scheduler_function_factory(**args))
sched_args = SchedulerArgs(args.sched, args.epochs, args.min_lr, args.warmup_lr, args.warmup_epochs)
scheduler, _ = create_scheduler(sched_args, optimizer)
return model, optimizer, scheduler, scaler
def setup_criteria_mixup(args, dataset=None, **criterion_kwargs):
"""Set up further objects that are needed for training.
Args:
args: arguments
dataset (torch.data.Dataset, optional): dataset that implements images_per_class, for class weights (Default value = None)
criterion_kwargs: further arguments for the criterion
**criterion_kwargs:
Returns:
tuple[nn.Module, nn.Module, Mixup]: criterion, val_criterion, mixup
"""
weight = None
if args.loss_weight != "none":
if dataset is not None and hasattr(dataset, "images_per_class"):
ipc = dataset.images_per_class
total_ims = sum(ipc)
if args.loss_weight == "linear":
weight = torch.tensor([total_ims / (ims * args.n_classes) for ims in ipc])
elif args.loss_weight == "log":
p_c = torch.tensor([ims / total_ims for ims in ipc])
log_p_c = torch.where(p_c > 0, p_c.log(), torch.zeros_like(p_c))
entr = -(p_c * log_p_c).sum()
weight = -log_p_c / entr
elif args.loss_weight == "sqrt":
p_c = torch.tensor([ims / total_ims for ims in ipc])
weight = 1 / (p_c.sqrt() * p_c.sqrt().sum())
else:
logger.warning("Could not find images_per_class in dataset. Using uniform weights.")
if args.aug_cutmix or args.multi_label:
# criterion = SoftTargetCrossEntropy()
if args.ignore_index >= 0:
if weight is None:
weight = torch.ones(args.n_classes)
weight[args.ignore_index] = 0
if args.multi_label:
if args.loss == "ce":
criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
val_criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
else:
raise NotImplementedError(
f"Only BCEWithLogitsLoss (ce) is implemented for multi-label classification, not {args.loss}."
)
else:
if args.loss == "ce":
loss_cls = nn.CrossEntropyLoss
elif args.loss == "baikal":
loss_cls = BaikalLoss
else:
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
criterion = loss_cls(weight=weight, **criterion_kwargs)
val_criterion = loss_cls(weight=weight, **criterion_kwargs)
else:
if args.loss == "ce":
loss_cls = nn.CrossEntropyLoss
elif args.loss == "baikal":
loss_cls = BaikalLoss
else:
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
criterion = loss_cls(
label_smoothing=args.label_smoothing,
ignore_index=args.ignore_index if weight is None else -100,
weight=weight,
**criterion_kwargs,
) # LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
# criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
val_criterion = loss_cls(
label_smoothing=args.label_smoothing,
ignore_index=args.ignore_index if weight is None else -100,
weight=weight,
**criterion_kwargs,
) # LabelSmoothingCrossEntropy(smoothing=0.)
mixup_kwargs = dict(
mixup_alpha=args.aug_mixup_alpha,
cutmix_alpha=args.aug_cutmix_alpha,
label_smoothing=args.label_smoothing,
num_classes=args.n_classes,
)
mixup = Mixup(**mixup_kwargs) if abs(args.aug_cutmix_alpha) + abs(args.aug_mixup_alpha) > 0.0 else None
return criterion, val_criterion, mixup
def _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
model_folder,
scaler,
do_metrics_calculation=True,
start_epoch=0,
show_tqdm=True,
topk=(1, 5),
acc_dict_key=None,
train_dali_server=None,
val_dali_server=None,
):
"""Train the model.
Args:
model:
train_loader:
optimizer:
rank:
epochs:
device:
mixup:
criterion:
world_size:
scheduler:
args:
val_loader:
val_criterion:
model_folder:
scaler:
do_metrics_calculation: (Default value = True)
start_epoch: (Default value = 0)
show_tqdm: (Default value = True)
topk: (Default value = (1)
5):
acc_dict_key: (Default value = None)
Returns:
dict: evaluation metrics at the end of training
"""
if acc_dict_key is None:
acc_dict_key = "acc{}"
training_start = time()
topk = tuple(k for k in topk if k <= args.n_classes)
time_spend_training = time_spend_validating = 0
current_best_acc = 0.0
if rank == 0:
logger.info(f"Dataloader has {len(train_loader)} batches")
logger.debug("Starting training with the following settings:")
logger.debug(f"criterion: {criterion}")
logger.debug(f"train_loader: {train_loader}, sampler: {train_loader.sampler}")
logger.debug(f"dataset: {train_loader.dataset}")
logger.debug(f"optimizer: {optimizer}")
logger.debug(f"device: {device}")
logger.debug(f"start epoch: {start_epoch}, epochs: {epochs}")
logger.debug(f"scaler: {scaler}")
logger.debug(f"max_grad_norm: {args.max_grad_norm}")
# logger.debug(f"model_ema:\n{model_ema}\n{model_ema.decay}\n{model_ema.device}")
if mixup:
logger.debug(
f"mixup: {mixup}; mixup_alpha: {mixup.mixup_alpha}, cutmix_alpha: {mixup.cutmix_alpha},"
f" cutmix_minmax: {mixup.cutmix_minmax}, prob: {mixup.mix_prob}, switch_prob: {mixup.switch_prob},"
f" label_smoothing: {mixup.label_smoothing}, num_classes: {mixup.num_classes}, correct_lam:"
f" {mixup.correct_lam}, mixup_enabled: {mixup.mixup_enabled}"
)
else:
logger.debug(f"mixup: {mixup}")
for epoch in range(start_epoch, epochs):
with logger.contextualize(epoch=str(epoch + 1)):
if args.distributed:
train_loader.sampler.set_epoch(epoch)
set_ep_func = getattr(train_loader.dataset, "set_epoch", None)
if callable(set_ep_func):
train_loader.dataset.set_epoch(epoch)
val_loader.dataset.set_epoch(epoch)
if train_dali_server:
train_dali_server.start_thread()
logger.info("started train dali server")
epoch_time, epoch_stats = _train_one_epoch(
model,
train_loader,
optimizer,
rank,
epoch,
device,
mixup,
criterion,
scheduler,
scaler,
args,
topk,
"train/" + acc_dict_key,
show_tqdm,
)
time_spend_training += epoch_time
if train_dali_server:
train_dali_server.stop_thread()
val_time, val_stats = _evaluate(
model,
val_loader,
epoch,
rank,
device,
val_criterion,
args,
topk,
"val/" + acc_dict_key,
dali_server=val_dali_server,
)
time_spend_validating += val_time
if rank == 0:
logger.info(f"total_time={time() - training_start}s")
if rank == 0:
top1_val_acc = val_stats["val/" + acc_dict_key.format(1)]
# print metadata for grafana
metadata = {
"epoch": epoch + 1,
"progress": (epoch + 1) / args.epochs,
**val_stats,
**epoch_stats,
}
# filter out Nan and infinity values
metadata = {k: v for k, v in metadata.items() if isfinite(v)}
print(json.dumps(metadata), flush=True)
logger.debug(f"printed metadata: {json.dumps(metadata)}")
if WANDB_AVAILABLE:
wandb.log(metadata, step=epoch + 1)
# saving current state
if top1_val_acc > current_best_acc or (epoch + 1) % args.save_epochs == 0:
reason = "top" if top1_val_acc > current_best_acc else "" # min(...) will be the top-1 accuracy
if reason == "top":
current_best_acc = top1_val_acc
logger.info(f"found a new best model with acc: {current_best_acc}")
kwargs = dict(
model_state=model.state_dict(),
stats=metadata,
optimizer_state=optimizer.state_dict(),
additional_reason=reason,
regular_save=(epoch + 1) % args.save_epochs == 0,
)
if scheduler:
kwargs["scheduler_state"] = scheduler.state_dict()
save_model_state(
model_folder, epoch + 1, args, **kwargs, max_interm_ep_states=args.keep_interm_states
)
if rank == 0:
end_time = time()
logger.info(
f"training done: total time={end_time - training_start}, "
f"time spend training={time_spend_training}, "
f"time spend validating={time_spend_validating}"
)
results = {**val_stats, **epoch_stats, f"val/best_{acc_dict_key.format(1)}": current_best_acc}
if rank == 0:
save_model_state(
model_folder,
epoch + 1,
args,
model_state=model.state_dict(),
stats=results,
additional_reason="final",
regular_save=False,
max_interm_ep_states=args.keep_interm_states,
)
if do_metrics_calculation:
# Calculate efficiency metrics
inp = next(iter(train_loader))[0].to(device)
metrics = calculate_metrics(
args,
model,
rank=rank,
input=inp,
device=device,
did_training=True,
all_metrics=False,
world_size=world_size,
key_start="eval/",
)
if rank == 0:
logger.info(f"Efficiency metrics: {json.dumps(metrics)}")
return results
def _mask_preds(preds, cls_masks, mask_val=-100):
"""Mask the predictions by the mask.
Args:
preds: model predictions
cls_masks: class masks
mask_val: (Default value = -100)
Returns:
torch.Tensor: masked predictions
"""
if cls_masks is None:
return preds
return torch.where(cls_masks.bool(), mask_val, preds)
def _evaluate(
model,
val_loader,
epoch,
rank,
device,
val_criterion,
args,
topk=(1, 5),
acc_dict_key=None,
dali_server=None,
):
"""Evaluate the model.
Args:
model (nn.Module): the model to evaluate
val_loader (DataLoader): loader for evaluation data
epoch (int): the current epoch (for logger & tracking)
rank (int): this processes rank (don't log n times)
device (torch.device): device to evaluate on
val_criterion (nn.Module): validation loss
args (DotDict): further arguments
topk (tuple[int], optional, optional): top-k accuracy, by default (1, 5)
acc_dict_key (str, optional, optional): key for the accuracy dictionary, by default name of the performance metric. 'val_' will be prepended.
Returns:
tuple[float, float, dict, dict]: validation time, validation loss, validation accuracies, additional information
"""
if not acc_dict_key:
acc_dict_key = "acc{}"
if dali_server:
dali_server.start_thread()
topk = tuple(k for k in topk if k <= args.n_classes)
model.eval()
val_loss = 0
val_accs = {acc_dict_key.format(k): 0.0 for k in topk}
val_start = time()
n_iters = 0
iterator = (
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch + 1}")
if rank == 0 and args.tqdm
else val_loader
)
class_counts = torch.zeros(1 if isinstance(topk, int) else len(topk), args.n_classes, 2)
for batch_data in iterator:
xs, ys = batch_data[:2]
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
if args.debug:
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
xs, ys = xs.to(device, non_blocking=True), ys.to(device, non_blocking=True)
with torch.no_grad(), torch.amp.autocast("cuda", enabled=args.eval_amp):
preds = model(xs)
preds = _mask_preds(preds, cls_masks)
if args.multi_label:
# labels are float for BCELoss
ys = ys.float()
val_loss += val_criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys).item()
class_counts += per_class_counts(preds, ys, args.n_classes, topk=topk, ignore_index=args.ignore_index)
n_iters += 1
if args.distributed:
dist.barrier()
if dali_server:
dali_server.stop_thread()
val_end = time()
iterations = n_iters
if args.distributed:
gather_tensor = torch.Tensor([val_loss]).to(device)
dist.barrier()
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
gather_tensor = gather_tensor.tolist()
val_loss = gather_tensor[0]
class_counts = class_counts.to(device)
dist.all_reduce(class_counts, op=dist.ReduceOp.SUM)
class_counts = class_counts.cpu()
for i, k in enumerate(topk):
key = acc_dict_key.format(k)
mkey = key.replace("acc", "m-acc")
val_accs[key] = class_counts[i].sum(dim=0)[0].item() / class_counts[i].sum(dim=0).sum(dim=-1).item()
val_accs[mkey] = (class_counts[i, :, 0] / class_counts[i].sum(dim=-1)).mean().item()
val_accs["val/loss"] = val_loss
if rank == 0:
log_s = f"val/time={val_end - val_start}s"
for key, val in val_accs.items():
log_s += f", {key}={val:.4f}"
logger.info(log_s)
return val_end - val_start, val_accs
def _train_one_epoch(
model,
train_loader,
optimizer,
rank,
epoch,
device,
mixup,
criterion,
scheduler,
scaler,
args,
topk=(1, 5),
acc_dict_key=None,
show_tqdm=True,
):
"""Train the model for one epoch.
Args:
model:
train_loader:
optimizer:
rank:
epoch:
device:
mixup:
criterion:
scheduler:
scaler:
args:
topk: (Default value = (1, 5)
acc_dict_key: (Default value = None)
show_tqdm: (Default value = True)
Returns:
tuple[float, float, dict]: time spend in training, epoch loss, epoch accuracies
"""
if not acc_dict_key:
acc_dict_key = "acc{}"
model.train()
iterator = (
tqdm(train_loader, total=len(train_loader), desc=f"Training epoch {epoch + 1}")
if rank == 0 and show_tqdm
else train_loader
)
if not args.amp:
scaler = NoScaler()
epoch_loss = 0
epoch_accs = {}
epoch_start = time()
grad_norms = []
n_iters = 0
if hasattr(train_loader.dataset, "epoch"):
train_loader.dataset.epoch = epoch
for i, batch_data in enumerate(iterator):
xs, ys = batch_data[:2]
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
optimizer.zero_grad()
n_iters += 1
xs = xs.to(device, non_blocking=True)
ys = ys.to(device, non_blocking=True)
if args.debug and i == 0:
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
if mixup:
if args.multi_label:
xs, ys = mixup(xs, ys, cls_masks)
else:
xs, ys = mixup(xs, ys)
if args.debug and i == 0:
logger.debug(f"input x: {type(xs)}; {xs.shape}, y: {type(ys)}; {ys.shape}")
with torch.amp.autocast("cuda", enabled=args.amp):
preds = model(xs)
preds = _mask_preds(preds, cls_masks)
if args.multi_label:
# labels are float for BCELoss
ys = ys.float()
loss = criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys) + (
model.get_internal_loss() if hasattr(model, "get_internal_loss") else model.module.get_internal_loss()
)
if not isfinite(loss.item()):
logger.error(f"Got loss value {loss.item()}. Stopping training.")
logger.info(f"input has nan: {xs.isnan().any().item()}")
logger.info(f"target has nan: {ys.isnan().any().item()}")
logger.info(f"output has nan: {preds.isnan().any().item()}")
for name, param in model.named_parameters():
if param.isnan().any().item():
logger.error(f"parameter {name} has a nan value")
if len(grad_norms) > 0:
grad_norms = torch.Tensor(grad_norms)
logger.info(
f"Gradient norms until now: min={grad_norms.min().item()}, 20th"
f" %tile={torch.quantile(grad_norms, .2).item()}, mean={torch.mean(grad_norms)}, 80th"
f" %tile={torch.quantile(grad_norms, .8).item()}, max={grad_norms.max()}"
)
sys.exit(1)
iter_grad_norm = scaler(
loss,
optimizer,
parameters=model.parameters(),
clip_grad=args.max_grad_norm if args.max_grad_norm > 0.0 else None,
).cpu()
if args.gather_stats_during_training and isfinite(iter_grad_norm):
grad_norms.append(iter_grad_norm)
# if args.aug_cutmix:
# ys = ys.argmax(dim=-1) # for accuracy with CutMix, just use the argmax for both
#
epoch_loss += loss.item()
# accuracies = accuracy(preds, ys, topk=topk, dict_key=acc_dict_key, ignore_index=args.ignore_index)
# for key in accuracies:
# epoch_accs[key] += accuracies[key]
if args.distributed:
dist.barrier()
epoch_end = time()
iterations = n_iters
# epoch_accs = {key: val / iterations for key, val in epoch_accs.items()}
epoch_loss = epoch_loss / iterations
grad_norm_avrg = -1
inf_grads = iterations - len(grad_norms)
if len(grad_norms) > 0 and args.gather_stats_during_training:
grad_norm_max = max(grad_norms)
grad_norms = torch.Tensor(grad_norms)
grad_norm_20 = torch.quantile(grad_norms, 0.2).item()
grad_norm_80 = torch.quantile(grad_norms, 0.8).item()
grad_norm_avrg = torch.mean(grad_norms)
if args.distributed:
# grad norm is already synchronized
# gather_tensor = torch.Tensor([epoch_loss, *[epoch_accs[acc_dict_key.format(k)] for k in topk]]).to(device)
gather_tensor = torch.Tensor([epoch_loss]).to(device)
dist.barrier()
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
# gather_tensor = (gather_tensor / world_size).tolist()
epoch_loss = gather_tensor.item()
# for i, k in enumerate(topk):
# epoch_accs[acc_dict_key.format(k)] = gather_tensor[i + 1]
lr = optimizer.param_groups[0]["lr"]
epoch_accs["train/lr"] = lr
epoch_accs["train/loss"] = epoch_loss
if rank == 0:
if args.gather_stats_during_training:
print_s = f"train/time={epoch_end - epoch_start}s"
logger.info(print_s)
if len(grad_norms) > 0:
logger.info(
f"grad norm avrg={grad_norm_avrg}, grad norm max={grad_norm_max}, "
f"inf grad norm={inf_grads}, grad norm 20%={grad_norm_20}, grad norm 80%={grad_norm_80}"
)
else:
logger.warning(f"inf grad norm={inf_grads}")
logger.error("100% of update steps with infinite grad norms!")
else:
logger.info(f"train/time={epoch_end - epoch_start}s")
if scheduler:
if isinstance(scheduler, optim.lr_scheduler.LambdaLR):
scheduler.step()
else:
scheduler.step(epoch)
if args.gather_stats_during_training:
return epoch_end - epoch_start, epoch_accs
return epoch_end - epoch_start, {}