Files
ForAug/AAAI Supplementary Material/Model Training Code/resizing_interface.py
Tobias Christian Nauen ff34712155 AAAI Version
2026-02-24 12:22:44 +01:00

109 lines
3.9 KiB
Python

from copy import copy
import torch
from loguru import logger
from torch import nn
vit_sizes = {
"na": dict(embed_dim=96, depth=2, num_heads=2),
"mu": dict(embed_dim=144, depth=6, num_heads=3),
"Ti": dict(embed_dim=192, depth=12, num_heads=3),
"S": dict(embed_dim=384, depth=12, num_heads=6),
"B": dict(embed_dim=768, depth=12, num_heads=12),
"L": dict(embed_dim=1024, depth=24, num_heads=16),
"LRA_CIFAR": dict(embed_dim=256, depth=1, num_heads=4, mlp_ratio=1.0),
"LRA_IMDB": dict(embed_dim=256, depth=4, num_heads=4, mlp_ratio=4.0),
"LRA_ListOps": dict(embed_dim=512, depth=4, num_heads=8, mlp_ratio=4.0),
}
class ResizingInterface:
"""Interface for resizing parts of a Vision Transformer model."""
def get_internal_loss(self):
"""Add a term to the loss."""
return 0.0
def set_image_res(self, res):
"""Set a new image resolution.
Resets the (learned) patch embedding.
Args:
res (int): new image resolution
"""
self._set_input_strand(res=res)
def _set_input_strand(self, res=None, patch_size=None):
"""Set a new image resolution and patch size.
Args:
res (int): (Default value = None)
patch_size (int): (Default value = None)
"""
if res is None:
res = self.img_size
if patch_size is None:
patch_size = self.patch_size
else:
# TODO: implement interpolation of patch_embed weights to new patch size/input shape
raise NotImplementedError("Interpolation of patch_embed weights to new patch size not implemented yet.")
if res == self.img_size and patch_size == self.patch_size:
return # nothing to do here
logger.info(f"Resizing input from {self.img_size} to {res} with patch size {self.patch_size} to {patch_size}.")
old_patch_embed_state = copy(self.patch_embed.state_dict())
self.patch_embed = self.embed_layer(
img_size=res,
patch_size=patch_size,
in_chans=self.in_chans,
embed_dim=self.embed_dim,
bias=not self.pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
)
self.patch_embed.load_state_dict(old_patch_embed_state)
num_extra_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
orig_size = int((self.pos_embed.shape[-2] - num_extra_tokens) ** 0.5)
new_size = int(self.patch_embed.num_patches**0.5)
extra_tokens = self.pos_embed[:, :num_extra_tokens]
pos_tokens = self.pos_embed[:, num_extra_tokens:]
# make it shape rest x embed_dim x orig_size x orig_size
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
pos_tokens = nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
# make it shape rest x new_size^2 x embed_dim
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
if num_extra_tokens > 0:
pos_tokens = torch.cat((extra_tokens, pos_tokens), dim=1)
self.pos_embed = nn.Parameter(pos_tokens.contiguous())
self.img_size = res
self.patch_size = patch_size
def set_num_classes(self, n_classes):
"""Reset the classification head with a new number of classes.
Args:
n_classes (int): new number of classes
"""
if n_classes == self.num_classes:
return
logger.info(f"Resizing classification head from {self.num_classes} to {n_classes}.")
self.head = nn.Linear(self.embed_dim, n_classes) if n_classes > 0 else nn.Identity()
self.num_classes = n_classes
# init weight + bias
# nn.init.zeros_(self.head.weight)
# nn.init.constant_(self.head.bias, -log(self.num_classes))
nn.init.trunc_normal_(self.head.weight, std=0.02)
nn.init.constant_(self.head.bias, 0)