304 lines
11 KiB
Python
304 lines
11 KiB
Python
import os
|
|
import random
|
|
import sys
|
|
|
|
import numpy as np
|
|
import timm
|
|
import torch
|
|
from loguru import logger
|
|
|
|
from engine import (
|
|
_train,
|
|
setup_criteria_mixup,
|
|
setup_model_optim_sched_scaler,
|
|
setup_tracking_and_logging,
|
|
wandb_available,
|
|
)
|
|
from load_dataset import prepare_dataset
|
|
from models import load_pretrained, prepare_model
|
|
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs, set_filter_warnings
|
|
|
|
|
|
def finetune(model, dataset, epochs, val_dataset=None, head_only=False, **kwargs):
|
|
"""Finetune a pretrained model on a given dataset for a specified number of epochs.
|
|
|
|
Args:
|
|
model (str): Path to the pretrained model state file (in .tar format).
|
|
dataset (str): Name of the dataset to finetune on.
|
|
val_dataset (str, optional): Name of the validation dataset. (Default value = None)
|
|
epochs (int): Number of epochs to train for.
|
|
head_only (bool, optional): Whether to train only the head of the model. Default: False.
|
|
**kwargs (dict): Further arguments for model setup, training, evaluation,...
|
|
|
|
Notes:
|
|
This function assumes that the model was pretrained on a different dataset using a different set of hyperparameters.
|
|
It fine-tunes the model on a new dataset by loading the pretrained weights and training for the specified number of
|
|
epochs. The function supports distributed training using the PyTorch DistributedDataParallel module.
|
|
"""
|
|
set_filter_warnings()
|
|
|
|
# Add defaults & make keys properties
|
|
args = prep_kwargs(kwargs)
|
|
|
|
if val_dataset is None:
|
|
val_dataset = dataset
|
|
|
|
args.val_dataset = val_dataset
|
|
args.dataset = dataset
|
|
args.epochs = epochs
|
|
|
|
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
|
|
args.world_size = world_size
|
|
try:
|
|
torch.cuda.set_device(device)
|
|
except RuntimeError as e:
|
|
logger.error(
|
|
f"Could not set device {device} as current device; "
|
|
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
|
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
|
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
|
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
|
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
|
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
|
)
|
|
raise e
|
|
args.batch_size = int(args.batch_size / world_size)
|
|
|
|
if args.seed is not None:
|
|
# fix the seed for reproducibility
|
|
seed = args.seed + rank
|
|
torch.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
|
|
# get the datasets & dataloaders
|
|
# transform only contains resize & crop here; everything else is handled on the GPU / in the training loop
|
|
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
|
|
dataset, args, rank=rank
|
|
)
|
|
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
|
|
assert (
|
|
args.n_classes == _val_classes
|
|
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
|
|
|
|
save_state = torch.load(model, map_location="cpu")
|
|
old_args = prep_kwargs(save_state["args"])
|
|
parent_folder = os.path.dirname(model)
|
|
args.model = old_args.model
|
|
run_folder = setup_tracking_and_logging(args, rank)
|
|
if rank == 0:
|
|
if not os.path.exists(os.path.join(run_folder, "parent_run")):
|
|
os.symlink(parent_folder, os.path.join(run_folder, "parent_run"), target_is_directory=True)
|
|
logger.info(
|
|
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
|
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
|
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
|
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
|
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
|
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
|
)
|
|
|
|
if args.seed:
|
|
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
|
|
|
|
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
|
# model_name = old_args.model
|
|
|
|
if rank == 0:
|
|
logger.info(f"The model was pretrained on {old_args.dataset} for {save_state['epoch']} epochs.")
|
|
|
|
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(
|
|
model, device, epochs, args, head_only=head_only
|
|
)
|
|
|
|
# log all devices
|
|
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
|
|
if rank == 0:
|
|
logger.info(f"torch version {torch.__version__}")
|
|
logger.info(f"timm version {timm.__version__}")
|
|
logger.info(f"full set of old arguments: {old_args}")
|
|
log_args(args)
|
|
|
|
if args.seed:
|
|
torch.manual_seed(seed)
|
|
|
|
criterion, val_criterion, mixup = setup_criteria_mixup(args)
|
|
if rank == 0:
|
|
logger.info(f"Run info at: '{run_folder}'")
|
|
|
|
res = _train(
|
|
model,
|
|
train_loader,
|
|
optimizer,
|
|
rank,
|
|
epochs,
|
|
device,
|
|
mixup,
|
|
criterion,
|
|
world_size,
|
|
scheduler,
|
|
args,
|
|
val_loader,
|
|
val_criterion,
|
|
run_folder,
|
|
scaler,
|
|
do_metrics_calculation=True,
|
|
show_tqdm=args.tqdm,
|
|
train_dali_server=train_dali_server,
|
|
val_dali_server=val_dali_server,
|
|
)
|
|
|
|
if rank == 0:
|
|
best_acc_key = sorted([key for key in res.keys() if key.startswith("val/best_")])[0]
|
|
logger.info(
|
|
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
|
|
)
|
|
|
|
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)
|
|
|
|
|
|
def pretrain(model, dataset, epochs, val_dataset=None, **kwargs):
|
|
"""Train or pretrain a model.
|
|
|
|
Args:
|
|
model (str): Name of the model to train.
|
|
dataset (str): Name of the dataset to train the model on.
|
|
epochs (int): Number of training epochs.
|
|
val_dataset (str, optional, optional): Name of the validation dataset, by default None
|
|
**kwargs (dict): Additional keyword arguments.
|
|
|
|
Notes:
|
|
This function sets up logger, prepares the model, and trains the model on the given dataset.
|
|
"""
|
|
set_filter_warnings()
|
|
|
|
# Add defaults & make args properties
|
|
args = prep_kwargs(kwargs)
|
|
|
|
if val_dataset is None:
|
|
val_dataset = dataset
|
|
|
|
args.val_dataset = val_dataset
|
|
args.dataset = dataset
|
|
args.model = model
|
|
args.epochs = epochs
|
|
|
|
args.distributed, device, world_size, rank, gpu_id = ddp_setup(args.cuda)
|
|
args.world_size = world_size
|
|
|
|
# sleep(rank * 5)
|
|
# logger.debug(f'running environment commands for rank {rank}')
|
|
# os.system('env')
|
|
# os.system('nvidia-smi')
|
|
# sleep((world_size - rank) * 5)
|
|
|
|
logger.debug(
|
|
f"rank params: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
|
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; gpu params: "
|
|
f"SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
|
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
|
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}"
|
|
)
|
|
|
|
if args.cuda:
|
|
try:
|
|
torch.cuda.set_device(device)
|
|
except RuntimeError as e:
|
|
logger.error(f"Could not set device {device} as current device: {e}")
|
|
logger.error(
|
|
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
|
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
|
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
|
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
|
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
|
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
|
)
|
|
raise e
|
|
|
|
args.batch_size = int(args.batch_size / world_size)
|
|
|
|
run_folder = setup_tracking_and_logging(args, rank)
|
|
if rank % world_size == 0:
|
|
logger.info(
|
|
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
|
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
|
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
|
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
|
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
|
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
|
)
|
|
|
|
if args.seed is not None:
|
|
# fix the seed for reproducibility
|
|
seed = args.seed + rank
|
|
torch.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
|
|
|
|
# get the datasets & dataloaders
|
|
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
|
|
dataset, args, rank=rank
|
|
)
|
|
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
|
|
assert (
|
|
args.n_classes == _val_classes
|
|
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
|
|
|
|
# setup model with amp & DDP
|
|
if isinstance(model, str):
|
|
if model.startswith("ViT") and "_" not in model:
|
|
model += f"_{args.imsize}"
|
|
model_name = model
|
|
model = prepare_model(model, args)
|
|
if not model_name:
|
|
model_name = type(model).__name__
|
|
|
|
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
|
|
|
|
# log all devices
|
|
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if device != 'cpu' else ''}")
|
|
if rank == 0:
|
|
logger.info(f"python version {sys.version}")
|
|
logger.info(f"torch version {torch.__version__}")
|
|
logger.info(f"timm version {timm.__version__}")
|
|
log_args(args)
|
|
|
|
if args.seed:
|
|
torch.manual_seed(seed)
|
|
|
|
criterion, val_criterion, mixup = setup_criteria_mixup(args)
|
|
|
|
if rank == 0:
|
|
logger.info(f"Run info at: '{run_folder}'")
|
|
|
|
res = _train(
|
|
model,
|
|
train_loader,
|
|
optimizer,
|
|
rank,
|
|
epochs,
|
|
device,
|
|
mixup,
|
|
criterion,
|
|
world_size,
|
|
scheduler,
|
|
args,
|
|
val_loader,
|
|
val_criterion,
|
|
run_folder,
|
|
scaler,
|
|
do_metrics_calculation=True,
|
|
show_tqdm=args.tqdm,
|
|
train_dali_server=train_dali_server,
|
|
val_dali_server=val_dali_server,
|
|
)
|
|
|
|
if rank == 0:
|
|
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
|
|
logger.info(
|
|
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
|
|
)
|
|
|
|
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)
|