Files
ForAug/AAAI Supplementary Material/Model Training Code/train.py
Tobias Christian Nauen ff34712155 AAAI Version
2026-02-24 12:22:44 +01:00

304 lines
11 KiB
Python

import os
import random
import sys
import numpy as np
import timm
import torch
from loguru import logger
from engine import (
_train,
setup_criteria_mixup,
setup_model_optim_sched_scaler,
setup_tracking_and_logging,
wandb_available,
)
from load_dataset import prepare_dataset
from models import load_pretrained, prepare_model
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs, set_filter_warnings
def finetune(model, dataset, epochs, val_dataset=None, head_only=False, **kwargs):
"""Finetune a pretrained model on a given dataset for a specified number of epochs.
Args:
model (str): Path to the pretrained model state file (in .tar format).
dataset (str): Name of the dataset to finetune on.
val_dataset (str, optional): Name of the validation dataset. (Default value = None)
epochs (int): Number of epochs to train for.
head_only (bool, optional): Whether to train only the head of the model. Default: False.
**kwargs (dict): Further arguments for model setup, training, evaluation,...
Notes:
This function assumes that the model was pretrained on a different dataset using a different set of hyperparameters.
It fine-tunes the model on a new dataset by loading the pretrained weights and training for the specified number of
epochs. The function supports distributed training using the PyTorch DistributedDataParallel module.
"""
set_filter_warnings()
# Add defaults & make keys properties
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
args.val_dataset = val_dataset
args.dataset = dataset
args.epochs = epochs
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
args.world_size = world_size
try:
torch.cuda.set_device(device)
except RuntimeError as e:
logger.error(
f"Could not set device {device} as current device; "
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
raise e
args.batch_size = int(args.batch_size / world_size)
if args.seed is not None:
# fix the seed for reproducibility
seed = args.seed + rank
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# get the datasets & dataloaders
# transform only contains resize & crop here; everything else is handled on the GPU / in the training loop
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
dataset, args, rank=rank
)
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
assert (
args.n_classes == _val_classes
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
save_state = torch.load(model, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
parent_folder = os.path.dirname(model)
args.model = old_args.model
run_folder = setup_tracking_and_logging(args, rank)
if rank == 0:
if not os.path.exists(os.path.join(run_folder, "parent_run")):
os.symlink(parent_folder, os.path.join(run_folder, "parent_run"), target_is_directory=True)
logger.info(
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
if args.seed:
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
# model_name = old_args.model
if rank == 0:
logger.info(f"The model was pretrained on {old_args.dataset} for {save_state['epoch']} epochs.")
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(
model, device, epochs, args, head_only=head_only
)
# log all devices
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"timm version {timm.__version__}")
logger.info(f"full set of old arguments: {old_args}")
log_args(args)
if args.seed:
torch.manual_seed(seed)
criterion, val_criterion, mixup = setup_criteria_mixup(args)
if rank == 0:
logger.info(f"Run info at: '{run_folder}'")
res = _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
run_folder,
scaler,
do_metrics_calculation=True,
show_tqdm=args.tqdm,
train_dali_server=train_dali_server,
val_dali_server=val_dali_server,
)
if rank == 0:
best_acc_key = sorted([key for key in res.keys() if key.startswith("val/best_")])[0]
logger.info(
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
)
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)
def pretrain(model, dataset, epochs, val_dataset=None, **kwargs):
"""Train or pretrain a model.
Args:
model (str): Name of the model to train.
dataset (str): Name of the dataset to train the model on.
epochs (int): Number of training epochs.
val_dataset (str, optional, optional): Name of the validation dataset, by default None
**kwargs (dict): Additional keyword arguments.
Notes:
This function sets up logger, prepares the model, and trains the model on the given dataset.
"""
set_filter_warnings()
# Add defaults & make args properties
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
args.val_dataset = val_dataset
args.dataset = dataset
args.model = model
args.epochs = epochs
args.distributed, device, world_size, rank, gpu_id = ddp_setup(args.cuda)
args.world_size = world_size
# sleep(rank * 5)
# logger.debug(f'running environment commands for rank {rank}')
# os.system('env')
# os.system('nvidia-smi')
# sleep((world_size - rank) * 5)
logger.debug(
f"rank params: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; gpu params: "
f"SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}"
)
if args.cuda:
try:
torch.cuda.set_device(device)
except RuntimeError as e:
logger.error(f"Could not set device {device} as current device: {e}")
logger.error(
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
raise e
args.batch_size = int(args.batch_size / world_size)
run_folder = setup_tracking_and_logging(args, rank)
if rank % world_size == 0:
logger.info(
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
if args.seed is not None:
# fix the seed for reproducibility
seed = args.seed + rank
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
# get the datasets & dataloaders
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
dataset, args, rank=rank
)
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
assert (
args.n_classes == _val_classes
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
# setup model with amp & DDP
if isinstance(model, str):
if model.startswith("ViT") and "_" not in model:
model += f"_{args.imsize}"
model_name = model
model = prepare_model(model, args)
if not model_name:
model_name = type(model).__name__
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
# log all devices
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if device != 'cpu' else ''}")
if rank == 0:
logger.info(f"python version {sys.version}")
logger.info(f"torch version {torch.__version__}")
logger.info(f"timm version {timm.__version__}")
log_args(args)
if args.seed:
torch.manual_seed(seed)
criterion, val_criterion, mixup = setup_criteria_mixup(args)
if rank == 0:
logger.info(f"Run info at: '{run_folder}'")
res = _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
run_folder,
scaler,
do_metrics_calculation=True,
show_tqdm=args.tqdm,
train_dali_server=train_dali_server,
val_dali_server=val_dali_server,
)
if rank == 0:
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
logger.info(
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
)
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)