AAAI Version
This commit is contained in:
@@ -0,0 +1,108 @@
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch import nn
|
||||
|
||||
vit_sizes = {
|
||||
"na": dict(embed_dim=96, depth=2, num_heads=2),
|
||||
"mu": dict(embed_dim=144, depth=6, num_heads=3),
|
||||
"Ti": dict(embed_dim=192, depth=12, num_heads=3),
|
||||
"S": dict(embed_dim=384, depth=12, num_heads=6),
|
||||
"B": dict(embed_dim=768, depth=12, num_heads=12),
|
||||
"L": dict(embed_dim=1024, depth=24, num_heads=16),
|
||||
"LRA_CIFAR": dict(embed_dim=256, depth=1, num_heads=4, mlp_ratio=1.0),
|
||||
"LRA_IMDB": dict(embed_dim=256, depth=4, num_heads=4, mlp_ratio=4.0),
|
||||
"LRA_ListOps": dict(embed_dim=512, depth=4, num_heads=8, mlp_ratio=4.0),
|
||||
}
|
||||
|
||||
|
||||
class ResizingInterface:
|
||||
"""Interface for resizing parts of a Vision Transformer model."""
|
||||
|
||||
def get_internal_loss(self):
|
||||
"""Add a term to the loss."""
|
||||
return 0.0
|
||||
|
||||
def set_image_res(self, res):
|
||||
"""Set a new image resolution.
|
||||
|
||||
Resets the (learned) patch embedding.
|
||||
|
||||
Args:
|
||||
res (int): new image resolution
|
||||
|
||||
"""
|
||||
self._set_input_strand(res=res)
|
||||
|
||||
def _set_input_strand(self, res=None, patch_size=None):
|
||||
"""Set a new image resolution and patch size.
|
||||
|
||||
Args:
|
||||
res (int): (Default value = None)
|
||||
patch_size (int): (Default value = None)
|
||||
|
||||
"""
|
||||
if res is None:
|
||||
res = self.img_size
|
||||
|
||||
if patch_size is None:
|
||||
patch_size = self.patch_size
|
||||
else:
|
||||
# TODO: implement interpolation of patch_embed weights to new patch size/input shape
|
||||
raise NotImplementedError("Interpolation of patch_embed weights to new patch size not implemented yet.")
|
||||
|
||||
if res == self.img_size and patch_size == self.patch_size:
|
||||
return # nothing to do here
|
||||
|
||||
logger.info(f"Resizing input from {self.img_size} to {res} with patch size {self.patch_size} to {patch_size}.")
|
||||
|
||||
old_patch_embed_state = copy(self.patch_embed.state_dict())
|
||||
self.patch_embed = self.embed_layer(
|
||||
img_size=res,
|
||||
patch_size=patch_size,
|
||||
in_chans=self.in_chans,
|
||||
embed_dim=self.embed_dim,
|
||||
bias=not self.pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
)
|
||||
|
||||
self.patch_embed.load_state_dict(old_patch_embed_state)
|
||||
|
||||
num_extra_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
|
||||
orig_size = int((self.pos_embed.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
new_size = int(self.patch_embed.num_patches**0.5)
|
||||
extra_tokens = self.pos_embed[:, :num_extra_tokens]
|
||||
pos_tokens = self.pos_embed[:, num_extra_tokens:]
|
||||
# make it shape rest x embed_dim x orig_size x orig_size
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
|
||||
pos_tokens = nn.functional.interpolate(
|
||||
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
|
||||
)
|
||||
# make it shape rest x new_size^2 x embed_dim
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
if num_extra_tokens > 0:
|
||||
pos_tokens = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
self.pos_embed = nn.Parameter(pos_tokens.contiguous())
|
||||
|
||||
self.img_size = res
|
||||
self.patch_size = patch_size
|
||||
|
||||
def set_num_classes(self, n_classes):
|
||||
"""Reset the classification head with a new number of classes.
|
||||
|
||||
Args:
|
||||
n_classes (int): new number of classes
|
||||
|
||||
"""
|
||||
if n_classes == self.num_classes:
|
||||
return
|
||||
logger.info(f"Resizing classification head from {self.num_classes} to {n_classes}.")
|
||||
self.head = nn.Linear(self.embed_dim, n_classes) if n_classes > 0 else nn.Identity()
|
||||
self.num_classes = n_classes
|
||||
|
||||
# init weight + bias
|
||||
# nn.init.zeros_(self.head.weight)
|
||||
# nn.init.constant_(self.head.bias, -log(self.num_classes))
|
||||
|
||||
nn.init.trunc_normal_(self.head.weight, std=0.02)
|
||||
nn.init.constant_(self.head.bias, 0)
|
||||
Reference in New Issue
Block a user