137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
"""Continue pretraining / finetuning after something went wrong."""
|
|
|
|
import torch
|
|
from loguru import logger
|
|
|
|
from engine import _train, setup_criteria_mixup, setup_model_optim_sched_scaler, setup_tracking_and_logging
|
|
from load_dataset import prepare_dataset
|
|
from models import load_pretrained
|
|
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs
|
|
|
|
|
|
def continue_training(model, **kwargs):
|
|
"""Continue training a model from a saved state.
|
|
|
|
Args:
|
|
model (str): path to saved state.
|
|
**kwargs: additional keyword arguments.
|
|
|
|
"""
|
|
model_path = model
|
|
save_state = torch.load(model, map_location="cpu")
|
|
|
|
# state is of the form
|
|
#
|
|
# state = {'epoch': epochs,
|
|
# 'model_state': model.state_dict(),
|
|
# 'optimizer_state': optimizer.state_dict(),
|
|
# 'scheduler_state': scheduler.state_dict(),
|
|
# 'args': dict(args),
|
|
# 'run_name': run_name,
|
|
# 'stats': metrics}
|
|
|
|
args = prep_kwargs(save_state["args"])
|
|
|
|
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
|
|
torch.cuda.set_device(device)
|
|
|
|
if "world_size" in args and args.world_size is not None:
|
|
global_bs = args.batch_size * args.world_size
|
|
else:
|
|
# assume global bs is given in kwargs
|
|
global_bs = kwargs["batch_size"]
|
|
args.batch_size = int(global_bs / world_size)
|
|
args.world_size = world_size
|
|
|
|
if "dataset" in args and args.dataset is not None:
|
|
dataset = args.dataset
|
|
else:
|
|
# get default dataset for the task
|
|
dataset = "ImageNet21k" if args.task == "pre-train" else "ImageNet"
|
|
args.dataset = dataset
|
|
|
|
if "val_dataset" in args and args.val_dataset is not None:
|
|
val_dataset = args.val_dataset
|
|
else:
|
|
val_dataset = dataset
|
|
args.val_dataset = val_dataset
|
|
|
|
start_epoch = save_state["epoch"]
|
|
if "epochs" in args and args.epochs is not None and args.epochs != start_epoch:
|
|
epochs = args.epochs
|
|
else:
|
|
epochs = kwargs["epochs"]
|
|
|
|
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path)
|
|
logger.info(f"Logging run information to '{run_folder}'")
|
|
|
|
# get the datasets & dataloaders
|
|
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
|
|
dataset, args, rank=rank
|
|
)
|
|
val_loader, _, __, ___, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
|
|
|
|
# model_name = args.model
|
|
|
|
model, args, _, __ = load_pretrained(model_path, args)
|
|
|
|
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
|
|
|
|
try:
|
|
optimizer.load_state_dict(save_state["optimizer_state"])
|
|
except ValueError as e:
|
|
logger.error(f"Could not load optimizer state: {e}")
|
|
logger.error(
|
|
f"optimizer state: {optimizer.state_dict().keys()}, param groups: {optimizer.state_dict()['param_groups']}"
|
|
)
|
|
logger.error(
|
|
f"saved state: {save_state['optimizer_state'].keys()}, param groups:"
|
|
f" {save_state['optimizer_state']['param_groups']}"
|
|
)
|
|
raise e
|
|
|
|
scheduler.load_state_dict(save_state["scheduler_state"])
|
|
|
|
# log all devices
|
|
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
|
|
if rank == 0:
|
|
logger.info(f"torch version {torch.__version__}")
|
|
log_args(args)
|
|
|
|
if args.seed:
|
|
torch.manual_seed(args.seed)
|
|
|
|
criterion, val_criterion, mixup = setup_criteria_mixup(args)
|
|
|
|
if rank == 0:
|
|
logger.info(f"start training at epoch {start_epoch}")
|
|
logger.info(f"Run info at: '{run_folder}'")
|
|
|
|
res = _train(
|
|
model,
|
|
train_loader,
|
|
optimizer,
|
|
rank,
|
|
epochs,
|
|
device,
|
|
mixup,
|
|
criterion,
|
|
world_size,
|
|
scheduler,
|
|
args,
|
|
val_loader,
|
|
val_criterion,
|
|
run_folder,
|
|
scaler=scaler,
|
|
do_metrics_calculation=True,
|
|
start_epoch=start_epoch,
|
|
show_tqdm=args.tqdm,
|
|
train_dali_server=train_dali_server,
|
|
val_dali_server=val_dali_server,
|
|
)
|
|
|
|
if rank == 0:
|
|
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
|
|
logger.info(f"Run '{args.run_name}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%")
|
|
ddp_cleanup(args=args, rank=rank)
|