835 lines
29 KiB
Python
835 lines
29 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
from functools import partial
|
|
from math import isfinite
|
|
from time import time
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from loguru import logger
|
|
from timm.data import Mixup
|
|
from timm.optim import create_optimizer
|
|
from timm.scheduler import create_scheduler
|
|
from torch import distributed as dist
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from tqdm.auto import tqdm
|
|
|
|
from metrics import calculate_metrics, per_class_counts
|
|
from utils import (
|
|
NoScaler,
|
|
ScalerGradNormReturn,
|
|
SchedulerArgs,
|
|
log_formatter,
|
|
save_model_state,
|
|
)
|
|
|
|
try:
|
|
from apex.optimizers import FusedLAMB # noqa: F401
|
|
|
|
apex_available = True
|
|
except ImportError:
|
|
logger.error("Nvidia apex not available")
|
|
apex_available = False
|
|
try:
|
|
from lion_pytorch import Lion
|
|
|
|
lion_available = True
|
|
except ImportError:
|
|
logger.error("Lion not available")
|
|
lion_available = False
|
|
|
|
|
|
WANDB_AVAILABLE = False
|
|
try:
|
|
import wandb
|
|
|
|
WANDB_AVAILABLE = True
|
|
except ImportError:
|
|
logger.error("wandb not available")
|
|
|
|
|
|
def wandb_available(turn_off=False):
|
|
"""If wandb is available.
|
|
|
|
Args:
|
|
turn_off (bool, optional): set wandb to be unavailble manually.
|
|
|
|
Returns:
|
|
bool: wandb is available
|
|
"""
|
|
global WANDB_AVAILABLE
|
|
if turn_off:
|
|
WANDB_AVAILABLE = False
|
|
return WANDB_AVAILABLE
|
|
|
|
|
|
tqdm = partial(tqdm, leave=True, position=0) # noqa: F405
|
|
|
|
|
|
def setup_tracking_and_logging(args, rank, append_model_path=None, log_wandb=True):
|
|
"""Set up logging and tracking for an experiment.
|
|
|
|
Args:
|
|
args (DotDict): Parsed command-line arguments
|
|
rank (int): The rank of the current process
|
|
append_model_path (str, optional): Path of an existing model, by default None
|
|
log_wandb (bool, optional): Whether to log to wandb, by default True
|
|
|
|
Returns:
|
|
str: folder, where all the run data is saved.
|
|
|
|
Raises:
|
|
AssertionError: If `dataset` or `model` is `None`.
|
|
|
|
Notes:
|
|
This function sets up logger to stdout and file, as well as MLflow tracking for an experiment.
|
|
For wandb logger, provide .wandb.apikey in the current directory.
|
|
"""
|
|
dataset, model, epochs = args.dataset.replace(os.sep, "_").lower(), args.model.replace(os.sep, "_"), args.epochs
|
|
_base_folder = (
|
|
os.path.join(args.results_folder, args.experiment_name, args.task.replace("-", ""), dataset)
|
|
if args.out_dir is None
|
|
else args.out_dir
|
|
)
|
|
run_folder = os.path.join(
|
|
_base_folder,
|
|
f"{args.run_name.replace(os.sep, '_')}_{model}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}",
|
|
)
|
|
assert dataset is not None and model is not None
|
|
|
|
if os.name == "nt":
|
|
run_folder = run_folder.replace("@", "_").replace(" ", "_").replace(":", ".")
|
|
|
|
if append_model_path is not None:
|
|
run_folder = os.path.dirname(append_model_path)
|
|
if "run_name" not in args or args.run_name is None:
|
|
args.run_name = run_folder.split(os.sep)[-1].split("_")[0]
|
|
elif args.distributed:
|
|
obj_list = [None]
|
|
if rank == 0:
|
|
obj_list[0] = run_folder
|
|
dist.broadcast_object_list(obj_list, src=0)
|
|
run_folder = obj_list[0]
|
|
if rank == 0:
|
|
os.makedirs(run_folder, exist_ok=True)
|
|
dist.barrier()
|
|
elif rank == 0:
|
|
os.makedirs(run_folder, exist_ok=True)
|
|
|
|
assert "%" not in args.run_name, f"found '%' in run_name '{args.run_name}'. This messes with string formatting..."
|
|
|
|
if args.debug:
|
|
args.log_level = "debug"
|
|
|
|
# logger to stdout & file
|
|
log_name = args.task.replace("-", "")
|
|
if args.task not in ["pre-train", "fine-tune", "fine-tune-head"]:
|
|
log_name += f"_{dataset}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}"
|
|
log_file = os.path.join(run_folder, f"{log_name}.log")
|
|
logger.remove()
|
|
logger.configure(extra=dict(run_name=args.run_name, rank=rank, world_size=args.world_size))
|
|
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
|
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
|
logger.info(f"Run folder '{run_folder}'")
|
|
|
|
if rank == 0:
|
|
logger.info(f"{args.task.replace('-', '').capitalize()} {model} on {dataset} for {epochs} epochs")
|
|
|
|
global WANDB_AVAILABLE
|
|
WANDB_AVAILABLE = WANDB_AVAILABLE and log_wandb and os.path.isfile(".wandb.apikey") and args.wandb
|
|
if WANDB_AVAILABLE:
|
|
with open(".wandb.apikey", "r") as f:
|
|
__wandb_api_key = f.read().strip()
|
|
wandb.login(key=__wandb_api_key)
|
|
if args.wandb_run_id is not None:
|
|
wandb_args = dict(project=args.experiment_name, resume="must", id=args.wandb_run_id)
|
|
else:
|
|
wandb_args = dict(
|
|
project=args.experiment_name,
|
|
name=args.run_name.replace("_", "-").replace(" ", "-"),
|
|
config={"logfile": log_file, **dict(args)},
|
|
job_type=args.task,
|
|
tags=[model, dataset],
|
|
resume="allow",
|
|
id=args.wandb_run_id,
|
|
)
|
|
wandb.init(**wandb_args)
|
|
args["wandb_run_id"] = wandb.run.id
|
|
if rank == 0:
|
|
logger.info(f"wandb run initialized with id {args['wandb_run_id']}.")
|
|
else:
|
|
logger.info(
|
|
f"Not using wandb. (args.wandb={args.wandb}, .wandb.apikey exists={os.path.isfile('.wandb.apikey')},"
|
|
f" function declaration log_wandb={log_wandb})"
|
|
)
|
|
|
|
if args.distributed:
|
|
dist.barrier()
|
|
|
|
if args.debug:
|
|
torch.autograd.set_detect_anomaly(True)
|
|
logger.warning("torch.autograd anomaly detection enabled. Will slow down model.")
|
|
|
|
return run_folder
|
|
|
|
|
|
def setup_model_optim_sched_scaler(model, device, epochs, args, head_only=False):
|
|
"""Set up model, optimizer, and scheduler with automatic mixed precision (amp) and distributed data parallel (DDP).
|
|
|
|
Args:
|
|
model (nn.Module): the loaded model
|
|
device (torch.device): the current device
|
|
epochs (int): total number of epochs to learn for (for scheduler)
|
|
args: further arguments
|
|
head_only (bool, optional): train only the linear head. Default: False
|
|
|
|
Returns:
|
|
tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, ScalerGradNormReturn]: model, optimizer, scheduler, scaler
|
|
|
|
"""
|
|
model = model.to(device)
|
|
|
|
if head_only:
|
|
for param in model.parameters():
|
|
param.requires_grad = False
|
|
for param in model.head.parameters():
|
|
param.requires_grad = True
|
|
for name, param in model.named_parameters():
|
|
if "head" in name:
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
params = model.head.parameters()
|
|
else:
|
|
params = model # model.named_parameters() use model itself for now and let timm do the work...
|
|
|
|
if args.opt == "lion" and not lion_available:
|
|
args.opt = "fusedlamb"
|
|
logger.warning("Falling back from lion to fusedlamb")
|
|
if args.opt == "fusedlamb" and not apex_available:
|
|
args.opt = "adamw"
|
|
logger.warning("Falling back from fusedlamb to adamw")
|
|
if args.opt == "lion":
|
|
optimizer = Lion(params, lr=args["lr"], weight_decay=args["weight_decay"])
|
|
else:
|
|
optimizer = create_optimizer(args, params)
|
|
|
|
scaler = ScalerGradNormReturn() if args.amp else NoScaler()
|
|
|
|
# if args.model_ema:
|
|
# # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
|
# ema_model = ModelEma(model, decay=args.model_ema_decay, resume='')
|
|
|
|
if args.distributed:
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
model = DDP(model, device_ids=[device])
|
|
|
|
if args.compile_model:
|
|
model = torch.compile(model)
|
|
|
|
# scheduler = optim.lr_scheduler.LambdaLR(optimizer,
|
|
# lr_lambda=scheduler_function_factory(**args))
|
|
sched_args = SchedulerArgs(args.sched, args.epochs, args.min_lr, args.warmup_lr, args.warmup_epochs)
|
|
scheduler, _ = create_scheduler(sched_args, optimizer)
|
|
|
|
return model, optimizer, scheduler, scaler
|
|
|
|
|
|
def setup_criteria_mixup(args, dataset=None, **criterion_kwargs):
|
|
"""Set up further objects that are needed for training.
|
|
|
|
Args:
|
|
args: arguments
|
|
dataset (torch.data.Dataset, optional): dataset that implements images_per_class, for class weights (Default value = None)
|
|
criterion_kwargs: further arguments for the criterion
|
|
**criterion_kwargs:
|
|
|
|
Returns:
|
|
tuple[nn.Module, nn.Module, Mixup]: criterion, val_criterion, mixup
|
|
|
|
"""
|
|
weight = None
|
|
if args.loss_weight != "none":
|
|
if dataset is not None and hasattr(dataset, "images_per_class"):
|
|
ipc = dataset.images_per_class
|
|
total_ims = sum(ipc)
|
|
|
|
if args.loss_weight == "linear":
|
|
weight = torch.tensor([total_ims / (ims * args.n_classes) for ims in ipc])
|
|
elif args.loss_weight == "log":
|
|
p_c = torch.tensor([ims / total_ims for ims in ipc])
|
|
log_p_c = torch.where(p_c > 0, p_c.log(), torch.zeros_like(p_c))
|
|
entr = -(p_c * log_p_c).sum()
|
|
weight = -log_p_c / entr
|
|
elif args.loss_weight == "sqrt":
|
|
p_c = torch.tensor([ims / total_ims for ims in ipc])
|
|
weight = 1 / (p_c.sqrt() * p_c.sqrt().sum())
|
|
|
|
else:
|
|
logger.warning("Could not find images_per_class in dataset. Using uniform weights.")
|
|
|
|
if args.aug_cutmix or args.multi_label:
|
|
# criterion = SoftTargetCrossEntropy()
|
|
if args.ignore_index >= 0:
|
|
if weight is None:
|
|
weight = torch.ones(args.n_classes)
|
|
weight[args.ignore_index] = 0
|
|
if args.multi_label:
|
|
if args.loss == "ce":
|
|
criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
|
|
val_criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Only BCEWithLogitsLoss (ce) is implemented for multi-label classification, not {args.loss}."
|
|
)
|
|
else:
|
|
if args.loss == "ce":
|
|
loss_cls = nn.CrossEntropyLoss
|
|
elif args.loss == "baikal":
|
|
loss_cls = BaikalLoss
|
|
else:
|
|
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
|
|
criterion = loss_cls(weight=weight, **criterion_kwargs)
|
|
val_criterion = loss_cls(weight=weight, **criterion_kwargs)
|
|
else:
|
|
if args.loss == "ce":
|
|
loss_cls = nn.CrossEntropyLoss
|
|
elif args.loss == "baikal":
|
|
loss_cls = BaikalLoss
|
|
else:
|
|
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
|
|
criterion = loss_cls(
|
|
label_smoothing=args.label_smoothing,
|
|
ignore_index=args.ignore_index if weight is None else -100,
|
|
weight=weight,
|
|
**criterion_kwargs,
|
|
) # LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
|
# criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
|
val_criterion = loss_cls(
|
|
label_smoothing=args.label_smoothing,
|
|
ignore_index=args.ignore_index if weight is None else -100,
|
|
weight=weight,
|
|
**criterion_kwargs,
|
|
) # LabelSmoothingCrossEntropy(smoothing=0.)
|
|
|
|
mixup_kwargs = dict(
|
|
mixup_alpha=args.aug_mixup_alpha,
|
|
cutmix_alpha=args.aug_cutmix_alpha,
|
|
label_smoothing=args.label_smoothing,
|
|
num_classes=args.n_classes,
|
|
)
|
|
mixup = Mixup(**mixup_kwargs) if abs(args.aug_cutmix_alpha) + abs(args.aug_mixup_alpha) > 0.0 else None
|
|
|
|
return criterion, val_criterion, mixup
|
|
|
|
|
|
def _train(
|
|
model,
|
|
train_loader,
|
|
optimizer,
|
|
rank,
|
|
epochs,
|
|
device,
|
|
mixup,
|
|
criterion,
|
|
world_size,
|
|
scheduler,
|
|
args,
|
|
val_loader,
|
|
val_criterion,
|
|
model_folder,
|
|
scaler,
|
|
do_metrics_calculation=True,
|
|
start_epoch=0,
|
|
show_tqdm=True,
|
|
topk=(1, 5),
|
|
acc_dict_key=None,
|
|
train_dali_server=None,
|
|
val_dali_server=None,
|
|
):
|
|
"""Train the model.
|
|
|
|
Args:
|
|
model:
|
|
train_loader:
|
|
optimizer:
|
|
rank:
|
|
epochs:
|
|
device:
|
|
mixup:
|
|
criterion:
|
|
world_size:
|
|
scheduler:
|
|
args:
|
|
val_loader:
|
|
val_criterion:
|
|
model_folder:
|
|
scaler:
|
|
do_metrics_calculation: (Default value = True)
|
|
start_epoch: (Default value = 0)
|
|
show_tqdm: (Default value = True)
|
|
topk: (Default value = (1)
|
|
5):
|
|
acc_dict_key: (Default value = None)
|
|
|
|
Returns:
|
|
dict: evaluation metrics at the end of training
|
|
|
|
"""
|
|
if acc_dict_key is None:
|
|
acc_dict_key = "acc{}"
|
|
training_start = time()
|
|
topk = tuple(k for k in topk if k <= args.n_classes)
|
|
time_spend_training = time_spend_validating = 0
|
|
current_best_acc = 0.0
|
|
if rank == 0:
|
|
logger.info(f"Dataloader has {len(train_loader)} batches")
|
|
|
|
logger.debug("Starting training with the following settings:")
|
|
logger.debug(f"criterion: {criterion}")
|
|
logger.debug(f"train_loader: {train_loader}, sampler: {train_loader.sampler}")
|
|
logger.debug(f"dataset: {train_loader.dataset}")
|
|
logger.debug(f"optimizer: {optimizer}")
|
|
logger.debug(f"device: {device}")
|
|
logger.debug(f"start epoch: {start_epoch}, epochs: {epochs}")
|
|
logger.debug(f"scaler: {scaler}")
|
|
logger.debug(f"max_grad_norm: {args.max_grad_norm}")
|
|
# logger.debug(f"model_ema:\n{model_ema}\n{model_ema.decay}\n{model_ema.device}")
|
|
if mixup:
|
|
logger.debug(
|
|
f"mixup: {mixup}; mixup_alpha: {mixup.mixup_alpha}, cutmix_alpha: {mixup.cutmix_alpha},"
|
|
f" cutmix_minmax: {mixup.cutmix_minmax}, prob: {mixup.mix_prob}, switch_prob: {mixup.switch_prob},"
|
|
f" label_smoothing: {mixup.label_smoothing}, num_classes: {mixup.num_classes}, correct_lam:"
|
|
f" {mixup.correct_lam}, mixup_enabled: {mixup.mixup_enabled}"
|
|
)
|
|
else:
|
|
logger.debug(f"mixup: {mixup}")
|
|
|
|
for epoch in range(start_epoch, epochs):
|
|
with logger.contextualize(epoch=str(epoch + 1)):
|
|
if args.distributed:
|
|
train_loader.sampler.set_epoch(epoch)
|
|
|
|
set_ep_func = getattr(train_loader.dataset, "set_epoch", None)
|
|
if callable(set_ep_func):
|
|
train_loader.dataset.set_epoch(epoch)
|
|
val_loader.dataset.set_epoch(epoch)
|
|
|
|
if train_dali_server:
|
|
train_dali_server.start_thread()
|
|
logger.info("started train dali server")
|
|
|
|
epoch_time, epoch_stats = _train_one_epoch(
|
|
model,
|
|
train_loader,
|
|
optimizer,
|
|
rank,
|
|
epoch,
|
|
device,
|
|
mixup,
|
|
criterion,
|
|
scheduler,
|
|
scaler,
|
|
args,
|
|
topk,
|
|
"train/" + acc_dict_key,
|
|
show_tqdm,
|
|
)
|
|
time_spend_training += epoch_time
|
|
|
|
if train_dali_server:
|
|
train_dali_server.stop_thread()
|
|
|
|
val_time, val_stats = _evaluate(
|
|
model,
|
|
val_loader,
|
|
epoch,
|
|
rank,
|
|
device,
|
|
val_criterion,
|
|
args,
|
|
topk,
|
|
"val/" + acc_dict_key,
|
|
dali_server=val_dali_server,
|
|
)
|
|
time_spend_validating += val_time
|
|
|
|
if rank == 0:
|
|
logger.info(f"total_time={time() - training_start}s")
|
|
|
|
if rank == 0:
|
|
top1_val_acc = val_stats["val/" + acc_dict_key.format(1)]
|
|
# print metadata for grafana
|
|
metadata = {
|
|
"epoch": epoch + 1,
|
|
"progress": (epoch + 1) / args.epochs,
|
|
**val_stats,
|
|
**epoch_stats,
|
|
}
|
|
# filter out Nan and infinity values
|
|
metadata = {k: v for k, v in metadata.items() if isfinite(v)}
|
|
print(json.dumps(metadata), flush=True)
|
|
logger.debug(f"printed metadata: {json.dumps(metadata)}")
|
|
if WANDB_AVAILABLE:
|
|
wandb.log(metadata, step=epoch + 1)
|
|
|
|
# saving current state
|
|
if top1_val_acc > current_best_acc or (epoch + 1) % args.save_epochs == 0:
|
|
reason = "top" if top1_val_acc > current_best_acc else "" # min(...) will be the top-1 accuracy
|
|
if reason == "top":
|
|
current_best_acc = top1_val_acc
|
|
logger.info(f"found a new best model with acc: {current_best_acc}")
|
|
kwargs = dict(
|
|
model_state=model.state_dict(),
|
|
stats=metadata,
|
|
optimizer_state=optimizer.state_dict(),
|
|
additional_reason=reason,
|
|
regular_save=(epoch + 1) % args.save_epochs == 0,
|
|
)
|
|
if scheduler:
|
|
kwargs["scheduler_state"] = scheduler.state_dict()
|
|
save_model_state(
|
|
model_folder, epoch + 1, args, **kwargs, max_interm_ep_states=args.keep_interm_states
|
|
)
|
|
|
|
if rank == 0:
|
|
end_time = time()
|
|
logger.info(
|
|
f"training done: total time={end_time - training_start}, "
|
|
f"time spend training={time_spend_training}, "
|
|
f"time spend validating={time_spend_validating}"
|
|
)
|
|
|
|
results = {**val_stats, **epoch_stats, f"val/best_{acc_dict_key.format(1)}": current_best_acc}
|
|
|
|
if rank == 0:
|
|
save_model_state(
|
|
model_folder,
|
|
epoch + 1,
|
|
args,
|
|
model_state=model.state_dict(),
|
|
stats=results,
|
|
additional_reason="final",
|
|
regular_save=False,
|
|
max_interm_ep_states=args.keep_interm_states,
|
|
)
|
|
|
|
if do_metrics_calculation:
|
|
# Calculate efficiency metrics
|
|
inp = next(iter(train_loader))[0].to(device)
|
|
metrics = calculate_metrics(
|
|
args,
|
|
model,
|
|
rank=rank,
|
|
input=inp,
|
|
device=device,
|
|
did_training=True,
|
|
all_metrics=False,
|
|
world_size=world_size,
|
|
key_start="eval/",
|
|
)
|
|
|
|
if rank == 0:
|
|
logger.info(f"Efficiency metrics: {json.dumps(metrics)}")
|
|
return results
|
|
|
|
|
|
def _mask_preds(preds, cls_masks, mask_val=-100):
|
|
"""Mask the predictions by the mask.
|
|
|
|
Args:
|
|
preds: model predictions
|
|
cls_masks: class masks
|
|
mask_val: (Default value = -100)
|
|
|
|
Returns:
|
|
torch.Tensor: masked predictions
|
|
|
|
"""
|
|
if cls_masks is None:
|
|
return preds
|
|
return torch.where(cls_masks.bool(), mask_val, preds)
|
|
|
|
|
|
def _evaluate(
|
|
model,
|
|
val_loader,
|
|
epoch,
|
|
rank,
|
|
device,
|
|
val_criterion,
|
|
args,
|
|
topk=(1, 5),
|
|
acc_dict_key=None,
|
|
dali_server=None,
|
|
):
|
|
"""Evaluate the model.
|
|
|
|
Args:
|
|
model (nn.Module): the model to evaluate
|
|
val_loader (DataLoader): loader for evaluation data
|
|
epoch (int): the current epoch (for logger & tracking)
|
|
rank (int): this processes rank (don't log n times)
|
|
device (torch.device): device to evaluate on
|
|
val_criterion (nn.Module): validation loss
|
|
args (DotDict): further arguments
|
|
topk (tuple[int], optional, optional): top-k accuracy, by default (1, 5)
|
|
acc_dict_key (str, optional, optional): key for the accuracy dictionary, by default name of the performance metric. 'val_' will be prepended.
|
|
|
|
Returns:
|
|
tuple[float, float, dict, dict]: validation time, validation loss, validation accuracies, additional information
|
|
|
|
"""
|
|
if not acc_dict_key:
|
|
acc_dict_key = "acc{}"
|
|
|
|
if dali_server:
|
|
dali_server.start_thread()
|
|
topk = tuple(k for k in topk if k <= args.n_classes)
|
|
model.eval()
|
|
val_loss = 0
|
|
val_accs = {acc_dict_key.format(k): 0.0 for k in topk}
|
|
val_start = time()
|
|
n_iters = 0
|
|
iterator = (
|
|
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch + 1}")
|
|
if rank == 0 and args.tqdm
|
|
else val_loader
|
|
)
|
|
class_counts = torch.zeros(1 if isinstance(topk, int) else len(topk), args.n_classes, 2)
|
|
for batch_data in iterator:
|
|
xs, ys = batch_data[:2]
|
|
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
|
|
|
|
if args.debug:
|
|
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
|
|
|
|
xs, ys = xs.to(device, non_blocking=True), ys.to(device, non_blocking=True)
|
|
with torch.no_grad(), torch.amp.autocast("cuda", enabled=args.eval_amp):
|
|
preds = model(xs)
|
|
preds = _mask_preds(preds, cls_masks)
|
|
|
|
if args.multi_label:
|
|
# labels are float for BCELoss
|
|
ys = ys.float()
|
|
val_loss += val_criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys).item()
|
|
class_counts += per_class_counts(preds, ys, args.n_classes, topk=topk, ignore_index=args.ignore_index)
|
|
n_iters += 1
|
|
|
|
if args.distributed:
|
|
dist.barrier()
|
|
|
|
if dali_server:
|
|
dali_server.stop_thread()
|
|
val_end = time()
|
|
iterations = n_iters
|
|
|
|
if args.distributed:
|
|
gather_tensor = torch.Tensor([val_loss]).to(device)
|
|
dist.barrier()
|
|
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
|
|
gather_tensor = gather_tensor.tolist()
|
|
val_loss = gather_tensor[0]
|
|
class_counts = class_counts.to(device)
|
|
dist.all_reduce(class_counts, op=dist.ReduceOp.SUM)
|
|
class_counts = class_counts.cpu()
|
|
|
|
for i, k in enumerate(topk):
|
|
key = acc_dict_key.format(k)
|
|
mkey = key.replace("acc", "m-acc")
|
|
val_accs[key] = class_counts[i].sum(dim=0)[0].item() / class_counts[i].sum(dim=0).sum(dim=-1).item()
|
|
val_accs[mkey] = (class_counts[i, :, 0] / class_counts[i].sum(dim=-1)).mean().item()
|
|
|
|
val_accs["val/loss"] = val_loss
|
|
|
|
if rank == 0:
|
|
log_s = f"val/time={val_end - val_start}s"
|
|
for key, val in val_accs.items():
|
|
log_s += f", {key}={val:.4f}"
|
|
logger.info(log_s)
|
|
|
|
return val_end - val_start, val_accs
|
|
|
|
|
|
def _train_one_epoch(
|
|
model,
|
|
train_loader,
|
|
optimizer,
|
|
rank,
|
|
epoch,
|
|
device,
|
|
mixup,
|
|
criterion,
|
|
scheduler,
|
|
scaler,
|
|
args,
|
|
topk=(1, 5),
|
|
acc_dict_key=None,
|
|
show_tqdm=True,
|
|
):
|
|
"""Train the model for one epoch.
|
|
|
|
Args:
|
|
model:
|
|
train_loader:
|
|
optimizer:
|
|
rank:
|
|
epoch:
|
|
device:
|
|
mixup:
|
|
criterion:
|
|
scheduler:
|
|
scaler:
|
|
args:
|
|
topk: (Default value = (1, 5)
|
|
acc_dict_key: (Default value = None)
|
|
show_tqdm: (Default value = True)
|
|
|
|
Returns:
|
|
tuple[float, float, dict]: time spend in training, epoch loss, epoch accuracies
|
|
|
|
"""
|
|
if not acc_dict_key:
|
|
acc_dict_key = "acc{}"
|
|
|
|
model.train()
|
|
iterator = (
|
|
tqdm(train_loader, total=len(train_loader), desc=f"Training epoch {epoch + 1}")
|
|
if rank == 0 and show_tqdm
|
|
else train_loader
|
|
)
|
|
|
|
if not args.amp:
|
|
scaler = NoScaler()
|
|
|
|
epoch_loss = 0
|
|
epoch_accs = {}
|
|
epoch_start = time()
|
|
grad_norms = []
|
|
n_iters = 0
|
|
if hasattr(train_loader.dataset, "epoch"):
|
|
train_loader.dataset.epoch = epoch
|
|
for i, batch_data in enumerate(iterator):
|
|
xs, ys = batch_data[:2]
|
|
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
|
|
optimizer.zero_grad()
|
|
n_iters += 1
|
|
xs = xs.to(device, non_blocking=True)
|
|
ys = ys.to(device, non_blocking=True)
|
|
|
|
if args.debug and i == 0:
|
|
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
|
|
|
|
if mixup:
|
|
if args.multi_label:
|
|
xs, ys = mixup(xs, ys, cls_masks)
|
|
else:
|
|
xs, ys = mixup(xs, ys)
|
|
|
|
if args.debug and i == 0:
|
|
logger.debug(f"input x: {type(xs)}; {xs.shape}, y: {type(ys)}; {ys.shape}")
|
|
|
|
with torch.amp.autocast("cuda", enabled=args.amp):
|
|
preds = model(xs)
|
|
preds = _mask_preds(preds, cls_masks)
|
|
if args.multi_label:
|
|
# labels are float for BCELoss
|
|
ys = ys.float()
|
|
loss = criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys) + (
|
|
model.get_internal_loss() if hasattr(model, "get_internal_loss") else model.module.get_internal_loss()
|
|
)
|
|
|
|
if not isfinite(loss.item()):
|
|
logger.error(f"Got loss value {loss.item()}. Stopping training.")
|
|
logger.info(f"input has nan: {xs.isnan().any().item()}")
|
|
logger.info(f"target has nan: {ys.isnan().any().item()}")
|
|
logger.info(f"output has nan: {preds.isnan().any().item()}")
|
|
for name, param in model.named_parameters():
|
|
if param.isnan().any().item():
|
|
logger.error(f"parameter {name} has a nan value")
|
|
if len(grad_norms) > 0:
|
|
grad_norms = torch.Tensor(grad_norms)
|
|
logger.info(
|
|
f"Gradient norms until now: min={grad_norms.min().item()}, 20th"
|
|
f" %tile={torch.quantile(grad_norms, .2).item()}, mean={torch.mean(grad_norms)}, 80th"
|
|
f" %tile={torch.quantile(grad_norms, .8).item()}, max={grad_norms.max()}"
|
|
)
|
|
sys.exit(1)
|
|
|
|
iter_grad_norm = scaler(
|
|
loss,
|
|
optimizer,
|
|
parameters=model.parameters(),
|
|
clip_grad=args.max_grad_norm if args.max_grad_norm > 0.0 else None,
|
|
).cpu()
|
|
|
|
if args.gather_stats_during_training and isfinite(iter_grad_norm):
|
|
grad_norms.append(iter_grad_norm)
|
|
|
|
# if args.aug_cutmix:
|
|
# ys = ys.argmax(dim=-1) # for accuracy with CutMix, just use the argmax for both
|
|
#
|
|
epoch_loss += loss.item()
|
|
# accuracies = accuracy(preds, ys, topk=topk, dict_key=acc_dict_key, ignore_index=args.ignore_index)
|
|
# for key in accuracies:
|
|
# epoch_accs[key] += accuracies[key]
|
|
|
|
if args.distributed:
|
|
dist.barrier()
|
|
epoch_end = time()
|
|
|
|
iterations = n_iters
|
|
# epoch_accs = {key: val / iterations for key, val in epoch_accs.items()}
|
|
epoch_loss = epoch_loss / iterations
|
|
grad_norm_avrg = -1
|
|
inf_grads = iterations - len(grad_norms)
|
|
if len(grad_norms) > 0 and args.gather_stats_during_training:
|
|
grad_norm_max = max(grad_norms)
|
|
grad_norms = torch.Tensor(grad_norms)
|
|
grad_norm_20 = torch.quantile(grad_norms, 0.2).item()
|
|
grad_norm_80 = torch.quantile(grad_norms, 0.8).item()
|
|
grad_norm_avrg = torch.mean(grad_norms)
|
|
|
|
if args.distributed:
|
|
# grad norm is already synchronized
|
|
# gather_tensor = torch.Tensor([epoch_loss, *[epoch_accs[acc_dict_key.format(k)] for k in topk]]).to(device)
|
|
gather_tensor = torch.Tensor([epoch_loss]).to(device)
|
|
dist.barrier()
|
|
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
|
|
# gather_tensor = (gather_tensor / world_size).tolist()
|
|
epoch_loss = gather_tensor.item()
|
|
# for i, k in enumerate(topk):
|
|
# epoch_accs[acc_dict_key.format(k)] = gather_tensor[i + 1]
|
|
|
|
lr = optimizer.param_groups[0]["lr"]
|
|
epoch_accs["train/lr"] = lr
|
|
epoch_accs["train/loss"] = epoch_loss
|
|
|
|
if rank == 0:
|
|
if args.gather_stats_during_training:
|
|
print_s = f"train/time={epoch_end - epoch_start}s"
|
|
logger.info(print_s)
|
|
if len(grad_norms) > 0:
|
|
logger.info(
|
|
f"grad norm avrg={grad_norm_avrg}, grad norm max={grad_norm_max}, "
|
|
f"inf grad norm={inf_grads}, grad norm 20%={grad_norm_20}, grad norm 80%={grad_norm_80}"
|
|
)
|
|
else:
|
|
logger.warning(f"inf grad norm={inf_grads}")
|
|
logger.error("100% of update steps with infinite grad norms!")
|
|
else:
|
|
logger.info(f"train/time={epoch_end - epoch_start}s")
|
|
|
|
if scheduler:
|
|
if isinstance(scheduler, optim.lr_scheduler.LambdaLR):
|
|
scheduler.step()
|
|
else:
|
|
scheduler.step(epoch)
|
|
|
|
if args.gather_stats_during_training:
|
|
return epoch_end - epoch_start, epoch_accs
|
|
return epoch_end - epoch_start, {}
|