AAAI Version
This commit is contained in:
834
AAAI Supplementary Material/Model Training Code/engine.py
Normal file
834
AAAI Supplementary Material/Model Training Code/engine.py
Normal file
@@ -0,0 +1,834 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from math import isfinite
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from loguru import logger
|
||||
from timm.data import Mixup
|
||||
from timm.optim import create_optimizer
|
||||
from timm.scheduler import create_scheduler
|
||||
from torch import distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from metrics import calculate_metrics, per_class_counts
|
||||
from utils import (
|
||||
NoScaler,
|
||||
ScalerGradNormReturn,
|
||||
SchedulerArgs,
|
||||
log_formatter,
|
||||
save_model_state,
|
||||
)
|
||||
|
||||
try:
|
||||
from apex.optimizers import FusedLAMB # noqa: F401
|
||||
|
||||
apex_available = True
|
||||
except ImportError:
|
||||
logger.error("Nvidia apex not available")
|
||||
apex_available = False
|
||||
try:
|
||||
from lion_pytorch import Lion
|
||||
|
||||
lion_available = True
|
||||
except ImportError:
|
||||
logger.error("Lion not available")
|
||||
lion_available = False
|
||||
|
||||
|
||||
WANDB_AVAILABLE = False
|
||||
try:
|
||||
import wandb
|
||||
|
||||
WANDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.error("wandb not available")
|
||||
|
||||
|
||||
def wandb_available(turn_off=False):
|
||||
"""If wandb is available.
|
||||
|
||||
Args:
|
||||
turn_off (bool, optional): set wandb to be unavailble manually.
|
||||
|
||||
Returns:
|
||||
bool: wandb is available
|
||||
"""
|
||||
global WANDB_AVAILABLE
|
||||
if turn_off:
|
||||
WANDB_AVAILABLE = False
|
||||
return WANDB_AVAILABLE
|
||||
|
||||
|
||||
tqdm = partial(tqdm, leave=True, position=0) # noqa: F405
|
||||
|
||||
|
||||
def setup_tracking_and_logging(args, rank, append_model_path=None, log_wandb=True):
|
||||
"""Set up logging and tracking for an experiment.
|
||||
|
||||
Args:
|
||||
args (DotDict): Parsed command-line arguments
|
||||
rank (int): The rank of the current process
|
||||
append_model_path (str, optional): Path of an existing model, by default None
|
||||
log_wandb (bool, optional): Whether to log to wandb, by default True
|
||||
|
||||
Returns:
|
||||
str: folder, where all the run data is saved.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `dataset` or `model` is `None`.
|
||||
|
||||
Notes:
|
||||
This function sets up logger to stdout and file, as well as MLflow tracking for an experiment.
|
||||
For wandb logger, provide .wandb.apikey in the current directory.
|
||||
"""
|
||||
dataset, model, epochs = args.dataset.replace(os.sep, "_").lower(), args.model.replace(os.sep, "_"), args.epochs
|
||||
_base_folder = (
|
||||
os.path.join(args.results_folder, args.experiment_name, args.task.replace("-", ""), dataset)
|
||||
if args.out_dir is None
|
||||
else args.out_dir
|
||||
)
|
||||
run_folder = os.path.join(
|
||||
_base_folder,
|
||||
f"{args.run_name.replace(os.sep, '_')}_{model}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}",
|
||||
)
|
||||
assert dataset is not None and model is not None
|
||||
|
||||
if os.name == "nt":
|
||||
run_folder = run_folder.replace("@", "_").replace(" ", "_").replace(":", ".")
|
||||
|
||||
if append_model_path is not None:
|
||||
run_folder = os.path.dirname(append_model_path)
|
||||
if "run_name" not in args or args.run_name is None:
|
||||
args.run_name = run_folder.split(os.sep)[-1].split("_")[0]
|
||||
elif args.distributed:
|
||||
obj_list = [None]
|
||||
if rank == 0:
|
||||
obj_list[0] = run_folder
|
||||
dist.broadcast_object_list(obj_list, src=0)
|
||||
run_folder = obj_list[0]
|
||||
if rank == 0:
|
||||
os.makedirs(run_folder, exist_ok=True)
|
||||
dist.barrier()
|
||||
elif rank == 0:
|
||||
os.makedirs(run_folder, exist_ok=True)
|
||||
|
||||
assert "%" not in args.run_name, f"found '%' in run_name '{args.run_name}'. This messes with string formatting..."
|
||||
|
||||
if args.debug:
|
||||
args.log_level = "debug"
|
||||
|
||||
# logger to stdout & file
|
||||
log_name = args.task.replace("-", "")
|
||||
if args.task not in ["pre-train", "fine-tune", "fine-tune-head"]:
|
||||
log_name += f"_{dataset}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}"
|
||||
log_file = os.path.join(run_folder, f"{log_name}.log")
|
||||
logger.remove()
|
||||
logger.configure(extra=dict(run_name=args.run_name, rank=rank, world_size=args.world_size))
|
||||
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.info(f"Run folder '{run_folder}'")
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"{args.task.replace('-', '').capitalize()} {model} on {dataset} for {epochs} epochs")
|
||||
|
||||
global WANDB_AVAILABLE
|
||||
WANDB_AVAILABLE = WANDB_AVAILABLE and log_wandb and os.path.isfile(".wandb.apikey") and args.wandb
|
||||
if WANDB_AVAILABLE:
|
||||
with open(".wandb.apikey", "r") as f:
|
||||
__wandb_api_key = f.read().strip()
|
||||
wandb.login(key=__wandb_api_key)
|
||||
if args.wandb_run_id is not None:
|
||||
wandb_args = dict(project=args.experiment_name, resume="must", id=args.wandb_run_id)
|
||||
else:
|
||||
wandb_args = dict(
|
||||
project=args.experiment_name,
|
||||
name=args.run_name.replace("_", "-").replace(" ", "-"),
|
||||
config={"logfile": log_file, **dict(args)},
|
||||
job_type=args.task,
|
||||
tags=[model, dataset],
|
||||
resume="allow",
|
||||
id=args.wandb_run_id,
|
||||
)
|
||||
wandb.init(**wandb_args)
|
||||
args["wandb_run_id"] = wandb.run.id
|
||||
if rank == 0:
|
||||
logger.info(f"wandb run initialized with id {args['wandb_run_id']}.")
|
||||
else:
|
||||
logger.info(
|
||||
f"Not using wandb. (args.wandb={args.wandb}, .wandb.apikey exists={os.path.isfile('.wandb.apikey')},"
|
||||
f" function declaration log_wandb={log_wandb})"
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
|
||||
if args.debug:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
logger.warning("torch.autograd anomaly detection enabled. Will slow down model.")
|
||||
|
||||
return run_folder
|
||||
|
||||
|
||||
def setup_model_optim_sched_scaler(model, device, epochs, args, head_only=False):
|
||||
"""Set up model, optimizer, and scheduler with automatic mixed precision (amp) and distributed data parallel (DDP).
|
||||
|
||||
Args:
|
||||
model (nn.Module): the loaded model
|
||||
device (torch.device): the current device
|
||||
epochs (int): total number of epochs to learn for (for scheduler)
|
||||
args: further arguments
|
||||
head_only (bool, optional): train only the linear head. Default: False
|
||||
|
||||
Returns:
|
||||
tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, ScalerGradNormReturn]: model, optimizer, scheduler, scaler
|
||||
|
||||
"""
|
||||
model = model.to(device)
|
||||
|
||||
if head_only:
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
for param in model.head.parameters():
|
||||
param.requires_grad = True
|
||||
for name, param in model.named_parameters():
|
||||
if "head" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
params = model.head.parameters()
|
||||
else:
|
||||
params = model # model.named_parameters() use model itself for now and let timm do the work...
|
||||
|
||||
if args.opt == "lion" and not lion_available:
|
||||
args.opt = "fusedlamb"
|
||||
logger.warning("Falling back from lion to fusedlamb")
|
||||
if args.opt == "fusedlamb" and not apex_available:
|
||||
args.opt = "adamw"
|
||||
logger.warning("Falling back from fusedlamb to adamw")
|
||||
if args.opt == "lion":
|
||||
optimizer = Lion(params, lr=args["lr"], weight_decay=args["weight_decay"])
|
||||
else:
|
||||
optimizer = create_optimizer(args, params)
|
||||
|
||||
scaler = ScalerGradNormReturn() if args.amp else NoScaler()
|
||||
|
||||
# if args.model_ema:
|
||||
# # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
||||
# ema_model = ModelEma(model, decay=args.model_ema_decay, resume='')
|
||||
|
||||
if args.distributed:
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = DDP(model, device_ids=[device])
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# scheduler = optim.lr_scheduler.LambdaLR(optimizer,
|
||||
# lr_lambda=scheduler_function_factory(**args))
|
||||
sched_args = SchedulerArgs(args.sched, args.epochs, args.min_lr, args.warmup_lr, args.warmup_epochs)
|
||||
scheduler, _ = create_scheduler(sched_args, optimizer)
|
||||
|
||||
return model, optimizer, scheduler, scaler
|
||||
|
||||
|
||||
def setup_criteria_mixup(args, dataset=None, **criterion_kwargs):
|
||||
"""Set up further objects that are needed for training.
|
||||
|
||||
Args:
|
||||
args: arguments
|
||||
dataset (torch.data.Dataset, optional): dataset that implements images_per_class, for class weights (Default value = None)
|
||||
criterion_kwargs: further arguments for the criterion
|
||||
**criterion_kwargs:
|
||||
|
||||
Returns:
|
||||
tuple[nn.Module, nn.Module, Mixup]: criterion, val_criterion, mixup
|
||||
|
||||
"""
|
||||
weight = None
|
||||
if args.loss_weight != "none":
|
||||
if dataset is not None and hasattr(dataset, "images_per_class"):
|
||||
ipc = dataset.images_per_class
|
||||
total_ims = sum(ipc)
|
||||
|
||||
if args.loss_weight == "linear":
|
||||
weight = torch.tensor([total_ims / (ims * args.n_classes) for ims in ipc])
|
||||
elif args.loss_weight == "log":
|
||||
p_c = torch.tensor([ims / total_ims for ims in ipc])
|
||||
log_p_c = torch.where(p_c > 0, p_c.log(), torch.zeros_like(p_c))
|
||||
entr = -(p_c * log_p_c).sum()
|
||||
weight = -log_p_c / entr
|
||||
elif args.loss_weight == "sqrt":
|
||||
p_c = torch.tensor([ims / total_ims for ims in ipc])
|
||||
weight = 1 / (p_c.sqrt() * p_c.sqrt().sum())
|
||||
|
||||
else:
|
||||
logger.warning("Could not find images_per_class in dataset. Using uniform weights.")
|
||||
|
||||
if args.aug_cutmix or args.multi_label:
|
||||
# criterion = SoftTargetCrossEntropy()
|
||||
if args.ignore_index >= 0:
|
||||
if weight is None:
|
||||
weight = torch.ones(args.n_classes)
|
||||
weight[args.ignore_index] = 0
|
||||
if args.multi_label:
|
||||
if args.loss == "ce":
|
||||
criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
|
||||
val_criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Only BCEWithLogitsLoss (ce) is implemented for multi-label classification, not {args.loss}."
|
||||
)
|
||||
else:
|
||||
if args.loss == "ce":
|
||||
loss_cls = nn.CrossEntropyLoss
|
||||
elif args.loss == "baikal":
|
||||
loss_cls = BaikalLoss
|
||||
else:
|
||||
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
|
||||
criterion = loss_cls(weight=weight, **criterion_kwargs)
|
||||
val_criterion = loss_cls(weight=weight, **criterion_kwargs)
|
||||
else:
|
||||
if args.loss == "ce":
|
||||
loss_cls = nn.CrossEntropyLoss
|
||||
elif args.loss == "baikal":
|
||||
loss_cls = BaikalLoss
|
||||
else:
|
||||
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
|
||||
criterion = loss_cls(
|
||||
label_smoothing=args.label_smoothing,
|
||||
ignore_index=args.ignore_index if weight is None else -100,
|
||||
weight=weight,
|
||||
**criterion_kwargs,
|
||||
) # LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
# criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
val_criterion = loss_cls(
|
||||
label_smoothing=args.label_smoothing,
|
||||
ignore_index=args.ignore_index if weight is None else -100,
|
||||
weight=weight,
|
||||
**criterion_kwargs,
|
||||
) # LabelSmoothingCrossEntropy(smoothing=0.)
|
||||
|
||||
mixup_kwargs = dict(
|
||||
mixup_alpha=args.aug_mixup_alpha,
|
||||
cutmix_alpha=args.aug_cutmix_alpha,
|
||||
label_smoothing=args.label_smoothing,
|
||||
num_classes=args.n_classes,
|
||||
)
|
||||
mixup = Mixup(**mixup_kwargs) if abs(args.aug_cutmix_alpha) + abs(args.aug_mixup_alpha) > 0.0 else None
|
||||
|
||||
return criterion, val_criterion, mixup
|
||||
|
||||
|
||||
def _train(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epochs,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
world_size,
|
||||
scheduler,
|
||||
args,
|
||||
val_loader,
|
||||
val_criterion,
|
||||
model_folder,
|
||||
scaler,
|
||||
do_metrics_calculation=True,
|
||||
start_epoch=0,
|
||||
show_tqdm=True,
|
||||
topk=(1, 5),
|
||||
acc_dict_key=None,
|
||||
train_dali_server=None,
|
||||
val_dali_server=None,
|
||||
):
|
||||
"""Train the model.
|
||||
|
||||
Args:
|
||||
model:
|
||||
train_loader:
|
||||
optimizer:
|
||||
rank:
|
||||
epochs:
|
||||
device:
|
||||
mixup:
|
||||
criterion:
|
||||
world_size:
|
||||
scheduler:
|
||||
args:
|
||||
val_loader:
|
||||
val_criterion:
|
||||
model_folder:
|
||||
scaler:
|
||||
do_metrics_calculation: (Default value = True)
|
||||
start_epoch: (Default value = 0)
|
||||
show_tqdm: (Default value = True)
|
||||
topk: (Default value = (1)
|
||||
5):
|
||||
acc_dict_key: (Default value = None)
|
||||
|
||||
Returns:
|
||||
dict: evaluation metrics at the end of training
|
||||
|
||||
"""
|
||||
if acc_dict_key is None:
|
||||
acc_dict_key = "acc{}"
|
||||
training_start = time()
|
||||
topk = tuple(k for k in topk if k <= args.n_classes)
|
||||
time_spend_training = time_spend_validating = 0
|
||||
current_best_acc = 0.0
|
||||
if rank == 0:
|
||||
logger.info(f"Dataloader has {len(train_loader)} batches")
|
||||
|
||||
logger.debug("Starting training with the following settings:")
|
||||
logger.debug(f"criterion: {criterion}")
|
||||
logger.debug(f"train_loader: {train_loader}, sampler: {train_loader.sampler}")
|
||||
logger.debug(f"dataset: {train_loader.dataset}")
|
||||
logger.debug(f"optimizer: {optimizer}")
|
||||
logger.debug(f"device: {device}")
|
||||
logger.debug(f"start epoch: {start_epoch}, epochs: {epochs}")
|
||||
logger.debug(f"scaler: {scaler}")
|
||||
logger.debug(f"max_grad_norm: {args.max_grad_norm}")
|
||||
# logger.debug(f"model_ema:\n{model_ema}\n{model_ema.decay}\n{model_ema.device}")
|
||||
if mixup:
|
||||
logger.debug(
|
||||
f"mixup: {mixup}; mixup_alpha: {mixup.mixup_alpha}, cutmix_alpha: {mixup.cutmix_alpha},"
|
||||
f" cutmix_minmax: {mixup.cutmix_minmax}, prob: {mixup.mix_prob}, switch_prob: {mixup.switch_prob},"
|
||||
f" label_smoothing: {mixup.label_smoothing}, num_classes: {mixup.num_classes}, correct_lam:"
|
||||
f" {mixup.correct_lam}, mixup_enabled: {mixup.mixup_enabled}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"mixup: {mixup}")
|
||||
|
||||
for epoch in range(start_epoch, epochs):
|
||||
with logger.contextualize(epoch=str(epoch + 1)):
|
||||
if args.distributed:
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
|
||||
set_ep_func = getattr(train_loader.dataset, "set_epoch", None)
|
||||
if callable(set_ep_func):
|
||||
train_loader.dataset.set_epoch(epoch)
|
||||
val_loader.dataset.set_epoch(epoch)
|
||||
|
||||
if train_dali_server:
|
||||
train_dali_server.start_thread()
|
||||
logger.info("started train dali server")
|
||||
|
||||
epoch_time, epoch_stats = _train_one_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epoch,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
scheduler,
|
||||
scaler,
|
||||
args,
|
||||
topk,
|
||||
"train/" + acc_dict_key,
|
||||
show_tqdm,
|
||||
)
|
||||
time_spend_training += epoch_time
|
||||
|
||||
if train_dali_server:
|
||||
train_dali_server.stop_thread()
|
||||
|
||||
val_time, val_stats = _evaluate(
|
||||
model,
|
||||
val_loader,
|
||||
epoch,
|
||||
rank,
|
||||
device,
|
||||
val_criterion,
|
||||
args,
|
||||
topk,
|
||||
"val/" + acc_dict_key,
|
||||
dali_server=val_dali_server,
|
||||
)
|
||||
time_spend_validating += val_time
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"total_time={time() - training_start}s")
|
||||
|
||||
if rank == 0:
|
||||
top1_val_acc = val_stats["val/" + acc_dict_key.format(1)]
|
||||
# print metadata for grafana
|
||||
metadata = {
|
||||
"epoch": epoch + 1,
|
||||
"progress": (epoch + 1) / args.epochs,
|
||||
**val_stats,
|
||||
**epoch_stats,
|
||||
}
|
||||
# filter out Nan and infinity values
|
||||
metadata = {k: v for k, v in metadata.items() if isfinite(v)}
|
||||
print(json.dumps(metadata), flush=True)
|
||||
logger.debug(f"printed metadata: {json.dumps(metadata)}")
|
||||
if WANDB_AVAILABLE:
|
||||
wandb.log(metadata, step=epoch + 1)
|
||||
|
||||
# saving current state
|
||||
if top1_val_acc > current_best_acc or (epoch + 1) % args.save_epochs == 0:
|
||||
reason = "top" if top1_val_acc > current_best_acc else "" # min(...) will be the top-1 accuracy
|
||||
if reason == "top":
|
||||
current_best_acc = top1_val_acc
|
||||
logger.info(f"found a new best model with acc: {current_best_acc}")
|
||||
kwargs = dict(
|
||||
model_state=model.state_dict(),
|
||||
stats=metadata,
|
||||
optimizer_state=optimizer.state_dict(),
|
||||
additional_reason=reason,
|
||||
regular_save=(epoch + 1) % args.save_epochs == 0,
|
||||
)
|
||||
if scheduler:
|
||||
kwargs["scheduler_state"] = scheduler.state_dict()
|
||||
save_model_state(
|
||||
model_folder, epoch + 1, args, **kwargs, max_interm_ep_states=args.keep_interm_states
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
end_time = time()
|
||||
logger.info(
|
||||
f"training done: total time={end_time - training_start}, "
|
||||
f"time spend training={time_spend_training}, "
|
||||
f"time spend validating={time_spend_validating}"
|
||||
)
|
||||
|
||||
results = {**val_stats, **epoch_stats, f"val/best_{acc_dict_key.format(1)}": current_best_acc}
|
||||
|
||||
if rank == 0:
|
||||
save_model_state(
|
||||
model_folder,
|
||||
epoch + 1,
|
||||
args,
|
||||
model_state=model.state_dict(),
|
||||
stats=results,
|
||||
additional_reason="final",
|
||||
regular_save=False,
|
||||
max_interm_ep_states=args.keep_interm_states,
|
||||
)
|
||||
|
||||
if do_metrics_calculation:
|
||||
# Calculate efficiency metrics
|
||||
inp = next(iter(train_loader))[0].to(device)
|
||||
metrics = calculate_metrics(
|
||||
args,
|
||||
model,
|
||||
rank=rank,
|
||||
input=inp,
|
||||
device=device,
|
||||
did_training=True,
|
||||
all_metrics=False,
|
||||
world_size=world_size,
|
||||
key_start="eval/",
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"Efficiency metrics: {json.dumps(metrics)}")
|
||||
return results
|
||||
|
||||
|
||||
def _mask_preds(preds, cls_masks, mask_val=-100):
|
||||
"""Mask the predictions by the mask.
|
||||
|
||||
Args:
|
||||
preds: model predictions
|
||||
cls_masks: class masks
|
||||
mask_val: (Default value = -100)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: masked predictions
|
||||
|
||||
"""
|
||||
if cls_masks is None:
|
||||
return preds
|
||||
return torch.where(cls_masks.bool(), mask_val, preds)
|
||||
|
||||
|
||||
def _evaluate(
|
||||
model,
|
||||
val_loader,
|
||||
epoch,
|
||||
rank,
|
||||
device,
|
||||
val_criterion,
|
||||
args,
|
||||
topk=(1, 5),
|
||||
acc_dict_key=None,
|
||||
dali_server=None,
|
||||
):
|
||||
"""Evaluate the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): the model to evaluate
|
||||
val_loader (DataLoader): loader for evaluation data
|
||||
epoch (int): the current epoch (for logger & tracking)
|
||||
rank (int): this processes rank (don't log n times)
|
||||
device (torch.device): device to evaluate on
|
||||
val_criterion (nn.Module): validation loss
|
||||
args (DotDict): further arguments
|
||||
topk (tuple[int], optional, optional): top-k accuracy, by default (1, 5)
|
||||
acc_dict_key (str, optional, optional): key for the accuracy dictionary, by default name of the performance metric. 'val_' will be prepended.
|
||||
|
||||
Returns:
|
||||
tuple[float, float, dict, dict]: validation time, validation loss, validation accuracies, additional information
|
||||
|
||||
"""
|
||||
if not acc_dict_key:
|
||||
acc_dict_key = "acc{}"
|
||||
|
||||
if dali_server:
|
||||
dali_server.start_thread()
|
||||
topk = tuple(k for k in topk if k <= args.n_classes)
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
val_accs = {acc_dict_key.format(k): 0.0 for k in topk}
|
||||
val_start = time()
|
||||
n_iters = 0
|
||||
iterator = (
|
||||
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch + 1}")
|
||||
if rank == 0 and args.tqdm
|
||||
else val_loader
|
||||
)
|
||||
class_counts = torch.zeros(1 if isinstance(topk, int) else len(topk), args.n_classes, 2)
|
||||
for batch_data in iterator:
|
||||
xs, ys = batch_data[:2]
|
||||
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
|
||||
|
||||
if args.debug:
|
||||
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
|
||||
|
||||
xs, ys = xs.to(device, non_blocking=True), ys.to(device, non_blocking=True)
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", enabled=args.eval_amp):
|
||||
preds = model(xs)
|
||||
preds = _mask_preds(preds, cls_masks)
|
||||
|
||||
if args.multi_label:
|
||||
# labels are float for BCELoss
|
||||
ys = ys.float()
|
||||
val_loss += val_criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys).item()
|
||||
class_counts += per_class_counts(preds, ys, args.n_classes, topk=topk, ignore_index=args.ignore_index)
|
||||
n_iters += 1
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
|
||||
if dali_server:
|
||||
dali_server.stop_thread()
|
||||
val_end = time()
|
||||
iterations = n_iters
|
||||
|
||||
if args.distributed:
|
||||
gather_tensor = torch.Tensor([val_loss]).to(device)
|
||||
dist.barrier()
|
||||
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
|
||||
gather_tensor = gather_tensor.tolist()
|
||||
val_loss = gather_tensor[0]
|
||||
class_counts = class_counts.to(device)
|
||||
dist.all_reduce(class_counts, op=dist.ReduceOp.SUM)
|
||||
class_counts = class_counts.cpu()
|
||||
|
||||
for i, k in enumerate(topk):
|
||||
key = acc_dict_key.format(k)
|
||||
mkey = key.replace("acc", "m-acc")
|
||||
val_accs[key] = class_counts[i].sum(dim=0)[0].item() / class_counts[i].sum(dim=0).sum(dim=-1).item()
|
||||
val_accs[mkey] = (class_counts[i, :, 0] / class_counts[i].sum(dim=-1)).mean().item()
|
||||
|
||||
val_accs["val/loss"] = val_loss
|
||||
|
||||
if rank == 0:
|
||||
log_s = f"val/time={val_end - val_start}s"
|
||||
for key, val in val_accs.items():
|
||||
log_s += f", {key}={val:.4f}"
|
||||
logger.info(log_s)
|
||||
|
||||
return val_end - val_start, val_accs
|
||||
|
||||
|
||||
def _train_one_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epoch,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
scheduler,
|
||||
scaler,
|
||||
args,
|
||||
topk=(1, 5),
|
||||
acc_dict_key=None,
|
||||
show_tqdm=True,
|
||||
):
|
||||
"""Train the model for one epoch.
|
||||
|
||||
Args:
|
||||
model:
|
||||
train_loader:
|
||||
optimizer:
|
||||
rank:
|
||||
epoch:
|
||||
device:
|
||||
mixup:
|
||||
criterion:
|
||||
scheduler:
|
||||
scaler:
|
||||
args:
|
||||
topk: (Default value = (1, 5)
|
||||
acc_dict_key: (Default value = None)
|
||||
show_tqdm: (Default value = True)
|
||||
|
||||
Returns:
|
||||
tuple[float, float, dict]: time spend in training, epoch loss, epoch accuracies
|
||||
|
||||
"""
|
||||
if not acc_dict_key:
|
||||
acc_dict_key = "acc{}"
|
||||
|
||||
model.train()
|
||||
iterator = (
|
||||
tqdm(train_loader, total=len(train_loader), desc=f"Training epoch {epoch + 1}")
|
||||
if rank == 0 and show_tqdm
|
||||
else train_loader
|
||||
)
|
||||
|
||||
if not args.amp:
|
||||
scaler = NoScaler()
|
||||
|
||||
epoch_loss = 0
|
||||
epoch_accs = {}
|
||||
epoch_start = time()
|
||||
grad_norms = []
|
||||
n_iters = 0
|
||||
if hasattr(train_loader.dataset, "epoch"):
|
||||
train_loader.dataset.epoch = epoch
|
||||
for i, batch_data in enumerate(iterator):
|
||||
xs, ys = batch_data[:2]
|
||||
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
|
||||
optimizer.zero_grad()
|
||||
n_iters += 1
|
||||
xs = xs.to(device, non_blocking=True)
|
||||
ys = ys.to(device, non_blocking=True)
|
||||
|
||||
if args.debug and i == 0:
|
||||
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
|
||||
|
||||
if mixup:
|
||||
if args.multi_label:
|
||||
xs, ys = mixup(xs, ys, cls_masks)
|
||||
else:
|
||||
xs, ys = mixup(xs, ys)
|
||||
|
||||
if args.debug and i == 0:
|
||||
logger.debug(f"input x: {type(xs)}; {xs.shape}, y: {type(ys)}; {ys.shape}")
|
||||
|
||||
with torch.amp.autocast("cuda", enabled=args.amp):
|
||||
preds = model(xs)
|
||||
preds = _mask_preds(preds, cls_masks)
|
||||
if args.multi_label:
|
||||
# labels are float for BCELoss
|
||||
ys = ys.float()
|
||||
loss = criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys) + (
|
||||
model.get_internal_loss() if hasattr(model, "get_internal_loss") else model.module.get_internal_loss()
|
||||
)
|
||||
|
||||
if not isfinite(loss.item()):
|
||||
logger.error(f"Got loss value {loss.item()}. Stopping training.")
|
||||
logger.info(f"input has nan: {xs.isnan().any().item()}")
|
||||
logger.info(f"target has nan: {ys.isnan().any().item()}")
|
||||
logger.info(f"output has nan: {preds.isnan().any().item()}")
|
||||
for name, param in model.named_parameters():
|
||||
if param.isnan().any().item():
|
||||
logger.error(f"parameter {name} has a nan value")
|
||||
if len(grad_norms) > 0:
|
||||
grad_norms = torch.Tensor(grad_norms)
|
||||
logger.info(
|
||||
f"Gradient norms until now: min={grad_norms.min().item()}, 20th"
|
||||
f" %tile={torch.quantile(grad_norms, .2).item()}, mean={torch.mean(grad_norms)}, 80th"
|
||||
f" %tile={torch.quantile(grad_norms, .8).item()}, max={grad_norms.max()}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
iter_grad_norm = scaler(
|
||||
loss,
|
||||
optimizer,
|
||||
parameters=model.parameters(),
|
||||
clip_grad=args.max_grad_norm if args.max_grad_norm > 0.0 else None,
|
||||
).cpu()
|
||||
|
||||
if args.gather_stats_during_training and isfinite(iter_grad_norm):
|
||||
grad_norms.append(iter_grad_norm)
|
||||
|
||||
# if args.aug_cutmix:
|
||||
# ys = ys.argmax(dim=-1) # for accuracy with CutMix, just use the argmax for both
|
||||
#
|
||||
epoch_loss += loss.item()
|
||||
# accuracies = accuracy(preds, ys, topk=topk, dict_key=acc_dict_key, ignore_index=args.ignore_index)
|
||||
# for key in accuracies:
|
||||
# epoch_accs[key] += accuracies[key]
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
epoch_end = time()
|
||||
|
||||
iterations = n_iters
|
||||
# epoch_accs = {key: val / iterations for key, val in epoch_accs.items()}
|
||||
epoch_loss = epoch_loss / iterations
|
||||
grad_norm_avrg = -1
|
||||
inf_grads = iterations - len(grad_norms)
|
||||
if len(grad_norms) > 0 and args.gather_stats_during_training:
|
||||
grad_norm_max = max(grad_norms)
|
||||
grad_norms = torch.Tensor(grad_norms)
|
||||
grad_norm_20 = torch.quantile(grad_norms, 0.2).item()
|
||||
grad_norm_80 = torch.quantile(grad_norms, 0.8).item()
|
||||
grad_norm_avrg = torch.mean(grad_norms)
|
||||
|
||||
if args.distributed:
|
||||
# grad norm is already synchronized
|
||||
# gather_tensor = torch.Tensor([epoch_loss, *[epoch_accs[acc_dict_key.format(k)] for k in topk]]).to(device)
|
||||
gather_tensor = torch.Tensor([epoch_loss]).to(device)
|
||||
dist.barrier()
|
||||
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
|
||||
# gather_tensor = (gather_tensor / world_size).tolist()
|
||||
epoch_loss = gather_tensor.item()
|
||||
# for i, k in enumerate(topk):
|
||||
# epoch_accs[acc_dict_key.format(k)] = gather_tensor[i + 1]
|
||||
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
epoch_accs["train/lr"] = lr
|
||||
epoch_accs["train/loss"] = epoch_loss
|
||||
|
||||
if rank == 0:
|
||||
if args.gather_stats_during_training:
|
||||
print_s = f"train/time={epoch_end - epoch_start}s"
|
||||
logger.info(print_s)
|
||||
if len(grad_norms) > 0:
|
||||
logger.info(
|
||||
f"grad norm avrg={grad_norm_avrg}, grad norm max={grad_norm_max}, "
|
||||
f"inf grad norm={inf_grads}, grad norm 20%={grad_norm_20}, grad norm 80%={grad_norm_80}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"inf grad norm={inf_grads}")
|
||||
logger.error("100% of update steps with infinite grad norms!")
|
||||
else:
|
||||
logger.info(f"train/time={epoch_end - epoch_start}s")
|
||||
|
||||
if scheduler:
|
||||
if isinstance(scheduler, optim.lr_scheduler.LambdaLR):
|
||||
scheduler.step()
|
||||
else:
|
||||
scheduler.step(epoch)
|
||||
|
||||
if args.gather_stats_during_training:
|
||||
return epoch_end - epoch_start, epoch_accs
|
||||
return epoch_end - epoch_start, {}
|
||||
Reference in New Issue
Block a user