AAAI Version
This commit is contained in:
136
AAAI Supplementary Material/Model Training Code/recover.py
Normal file
136
AAAI Supplementary Material/Model Training Code/recover.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Continue pretraining / finetuning after something went wrong."""
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from engine import _train, setup_criteria_mixup, setup_model_optim_sched_scaler, setup_tracking_and_logging
|
||||
from load_dataset import prepare_dataset
|
||||
from models import load_pretrained
|
||||
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs
|
||||
|
||||
|
||||
def continue_training(model, **kwargs):
|
||||
"""Continue training a model from a saved state.
|
||||
|
||||
Args:
|
||||
model (str): path to saved state.
|
||||
**kwargs: additional keyword arguments.
|
||||
|
||||
"""
|
||||
model_path = model
|
||||
save_state = torch.load(model, map_location="cpu")
|
||||
|
||||
# state is of the form
|
||||
#
|
||||
# state = {'epoch': epochs,
|
||||
# 'model_state': model.state_dict(),
|
||||
# 'optimizer_state': optimizer.state_dict(),
|
||||
# 'scheduler_state': scheduler.state_dict(),
|
||||
# 'args': dict(args),
|
||||
# 'run_name': run_name,
|
||||
# 'stats': metrics}
|
||||
|
||||
args = prep_kwargs(save_state["args"])
|
||||
|
||||
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
if "world_size" in args and args.world_size is not None:
|
||||
global_bs = args.batch_size * args.world_size
|
||||
else:
|
||||
# assume global bs is given in kwargs
|
||||
global_bs = kwargs["batch_size"]
|
||||
args.batch_size = int(global_bs / world_size)
|
||||
args.world_size = world_size
|
||||
|
||||
if "dataset" in args and args.dataset is not None:
|
||||
dataset = args.dataset
|
||||
else:
|
||||
# get default dataset for the task
|
||||
dataset = "ImageNet21k" if args.task == "pre-train" else "ImageNet"
|
||||
args.dataset = dataset
|
||||
|
||||
if "val_dataset" in args and args.val_dataset is not None:
|
||||
val_dataset = args.val_dataset
|
||||
else:
|
||||
val_dataset = dataset
|
||||
args.val_dataset = val_dataset
|
||||
|
||||
start_epoch = save_state["epoch"]
|
||||
if "epochs" in args and args.epochs is not None and args.epochs != start_epoch:
|
||||
epochs = args.epochs
|
||||
else:
|
||||
epochs = kwargs["epochs"]
|
||||
|
||||
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path)
|
||||
logger.info(f"Logging run information to '{run_folder}'")
|
||||
|
||||
# get the datasets & dataloaders
|
||||
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
|
||||
dataset, args, rank=rank
|
||||
)
|
||||
val_loader, _, __, ___, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
|
||||
|
||||
# model_name = args.model
|
||||
|
||||
model, args, _, __ = load_pretrained(model_path, args)
|
||||
|
||||
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
|
||||
|
||||
try:
|
||||
optimizer.load_state_dict(save_state["optimizer_state"])
|
||||
except ValueError as e:
|
||||
logger.error(f"Could not load optimizer state: {e}")
|
||||
logger.error(
|
||||
f"optimizer state: {optimizer.state_dict().keys()}, param groups: {optimizer.state_dict()['param_groups']}"
|
||||
)
|
||||
logger.error(
|
||||
f"saved state: {save_state['optimizer_state'].keys()}, param groups:"
|
||||
f" {save_state['optimizer_state']['param_groups']}"
|
||||
)
|
||||
raise e
|
||||
|
||||
scheduler.load_state_dict(save_state["scheduler_state"])
|
||||
|
||||
# log all devices
|
||||
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
log_args(args)
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
criterion, val_criterion, mixup = setup_criteria_mixup(args)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"start training at epoch {start_epoch}")
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
res = _train(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epochs,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
world_size,
|
||||
scheduler,
|
||||
args,
|
||||
val_loader,
|
||||
val_criterion,
|
||||
run_folder,
|
||||
scaler=scaler,
|
||||
do_metrics_calculation=True,
|
||||
start_epoch=start_epoch,
|
||||
show_tqdm=args.tqdm,
|
||||
train_dali_server=train_dali_server,
|
||||
val_dali_server=val_dali_server,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
|
||||
logger.info(f"Run '{args.run_name}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%")
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
Reference in New Issue
Block a user