AAAI Version
This commit is contained in:
154
AAAI Supplementary Material/Model Training Code/models.py
Normal file
154
AAAI Supplementary Material/Model Training Code/models.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Model loading and preparation."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
|
||||
import utils
|
||||
from architectures.vit import TimmViT
|
||||
from resizing_interface import vit_sizes
|
||||
|
||||
_ARCHITECTURES_IMPORTED = False
|
||||
|
||||
|
||||
def _import_architectures():
|
||||
global _ARCHITECTURES_IMPORTED
|
||||
if not _ARCHITECTURES_IMPORTED:
|
||||
model_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
for file in os.listdir(os.path.join(model_file_path, "architectures")):
|
||||
if not file.endswith(".py"):
|
||||
continue
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
importlib.import_module(f"architectures.{file[:-3]}")
|
||||
logger.debug(f"Imported architectures.{file[:-3]}")
|
||||
except Exception as e:
|
||||
logger.error(f"\033[93mCould not import \033[0m\033[91m{file}\033[0m")
|
||||
logger.error(e)
|
||||
_ARCHITECTURES_IMPORTED = True
|
||||
|
||||
|
||||
def prepare_model(model_str, args):
|
||||
"""Prepare a new model.
|
||||
|
||||
If the name is of the format ViT-<size>/<patch_size>, use a *TimmViT*, else fall back to timm model loading.
|
||||
|
||||
Args:
|
||||
model_str (str): model name
|
||||
args (utils.DotDict): further arguments, needs to have keys n_classes, drop_path_rate; key imsize or '_<imsize>' at the end of ViT specification
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: model
|
||||
|
||||
"""
|
||||
_import_architectures()
|
||||
|
||||
kwargs = dict(args)
|
||||
for key in list([key for key, val in kwargs.items() if val is None]):
|
||||
kwargs.pop(key)
|
||||
|
||||
if args.layer_scale_init_values:
|
||||
kwargs["init_values"] = kwargs["init_scale"] = args.layer_scale_init_values
|
||||
if args.dropout and args.dropout > 0.0:
|
||||
kwargs["drop"] = kwargs["drop_rate"] = args.dropout
|
||||
if args.drop_path_rate and args.drop_path_rate > 0.0:
|
||||
kwargs["drop_block_rate"] = args.drop_path_rate
|
||||
kwargs["num_classes"] = args.n_classes
|
||||
kwargs["img_size"] = args.imsize
|
||||
if model_str.startswith("ViT"):
|
||||
# Format: ViT-{Ti,S,B,L}/<patch_size>[_<image_res>]
|
||||
h1, h2 = model_str.split("/")
|
||||
_, model_size = h1.split("-")
|
||||
if "_" in h2:
|
||||
patch_size, image_res = h2.split("_")
|
||||
assert args.imsize is None or args.imsize == int(
|
||||
image_res
|
||||
), f"Got two different image sizes: {args.imsize} vs {image_res}"
|
||||
else:
|
||||
patch_size = h2
|
||||
|
||||
kwargs = {**vit_sizes[model_size], **kwargs}
|
||||
model = TimmViT(patch_size=int(patch_size), in_chans=3, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
|
||||
else:
|
||||
logger.debug(f"Loading model via timm api {model_str} with args {kwargs}")
|
||||
model = timm.create_model(model_str, pretrained=False, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def load_pretrained(model_path, args, new_dataset_params=False):
|
||||
"""Load a pretrained model from .tar file.
|
||||
|
||||
Args:
|
||||
new_dataset_params (bool, optional): change model parameters (imsize, n_classes) to the ones from args. (Default value = False)
|
||||
model_path (str): path to .tar file
|
||||
args: new model parameters
|
||||
|
||||
Returns:
|
||||
tuple: model, args, old_args, save_state
|
||||
|
||||
"""
|
||||
_import_architectures()
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = utils.prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
old_args.cuda = args.cuda
|
||||
|
||||
if old_args.model.startswith("flash_vit"):
|
||||
args.pop("layer_scale_init_values", None)
|
||||
old_args.pop("layer_scale_init_values", None)
|
||||
|
||||
# load the model (the old one first)
|
||||
model = prepare_model(old_args.model, old_args)
|
||||
logger.debug(f"loading model {old_args.model} from {model_path} with args {old_args}")
|
||||
file_save_state = utils.remove_prefix(save_state["model_state"], prefix="_orig_mod.")
|
||||
file_save_state = utils.remove_prefix(file_save_state)
|
||||
try:
|
||||
model.load_state_dict(file_save_state)
|
||||
except (UnboundLocalError, RuntimeError) as e:
|
||||
model_keys = set(model.state_dict().keys())
|
||||
file_keys = set(file_save_state.keys())
|
||||
logger.warning(f"Error loading state dict: {e}")
|
||||
model_minus_file = model_keys.difference(file_keys)
|
||||
file_minus_model = file_keys.difference(model_keys)
|
||||
logger.warning(f"model-file: {model_minus_file}\nfile-model: {file_minus_model}")
|
||||
if len(file_minus_model) == 0 and all([".ls" in key and key.endswith(".gamma") for key in model_minus_file]):
|
||||
logger.info("Old model was without LayerScale -> replicating")
|
||||
try:
|
||||
args.pop("layer_scale_init_values")
|
||||
old_args.pop("layer_scale_init_values")
|
||||
model = prepare_model(old_args.model, old_args)
|
||||
model.load_state_dict(file_save_state)
|
||||
except (UnboundLocalError, RuntimeError) as e:
|
||||
logger.error("Could not resolve conflict")
|
||||
logger.error(f"Still got error {e}")
|
||||
exit(-1)
|
||||
elif any("head.0." in key for key in file_minus_model):
|
||||
logger.info("Old model used nn.Seqeuntial for head. Trying to fix -> nn.Linear")
|
||||
file_save_state = {key.replace("head.0.", "head."): val for key, val in file_save_state.items()}
|
||||
try:
|
||||
model.load_state_dict(file_save_state)
|
||||
except (UnboundLocalError, RuntimeError) as e:
|
||||
logger.error("Could not resolve conflict")
|
||||
logger.error(f"Still got error {e}")
|
||||
exit(-1)
|
||||
else:
|
||||
exit(-1)
|
||||
|
||||
if new_dataset_params:
|
||||
# setup for finetuning parameters
|
||||
model.set_image_res(args.imsize)
|
||||
model.set_num_classes(args.n_classes)
|
||||
|
||||
if args.max_seq_len is not None:
|
||||
model.set_max_seq_len(args.max_seq_len)
|
||||
|
||||
return model, args, old_args, save_state
|
||||
Reference in New Issue
Block a user