AAAI Version
This commit is contained in:
@@ -0,0 +1,225 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from timm.models.vision_transformer import (
|
||||
Attention,
|
||||
Block,
|
||||
PatchEmbed,
|
||||
VisionTransformer,
|
||||
)
|
||||
|
||||
from resizing_interface import ResizingInterface
|
||||
|
||||
|
||||
class _MatrixSaveAttn(Attention):
|
||||
attn_mat = None
|
||||
|
||||
@classmethod
|
||||
def cast(cls, attn: Attention):
|
||||
assert isinstance(attn, Attention), "Can only save attention from Timms attention class"
|
||||
attn.__class__ = cls
|
||||
assert isinstance(attn, _MatrixSaveAttn)
|
||||
return attn
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
self.attn_mat = attn.detach()
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
def _index_picker(tensor, idx=-1):
|
||||
"""Pick a specific index from a tensor.
|
||||
|
||||
tensor: B x N x D -> B x D, by picking idx from N.
|
||||
|
||||
Args:
|
||||
tensor (toch.tensor): tensor to pick from.
|
||||
idx (int, optional): index to pick. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
torch.tensor: index from tensor
|
||||
|
||||
"""
|
||||
return tensor[..., idx, :] # B x N x D -> B x D
|
||||
|
||||
|
||||
class TimmViT(VisionTransformer, ResizingInterface):
|
||||
"""Wrapper for *VisionTransformer* from *timm* library (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool="token",
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
init_values=None,
|
||||
class_token=True,
|
||||
no_embed_class=True,
|
||||
pre_norm=False,
|
||||
fc_norm=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
weight_init="",
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
block_fn=Block,
|
||||
save_attention_maps=False,
|
||||
fused_attn=True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a Vision Transformer model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
img_size : int
|
||||
image dimensions -> img_size x img_size
|
||||
patch_size : int
|
||||
patch_size
|
||||
in_chans : int
|
||||
number of image channels
|
||||
num_classes : int
|
||||
number of classes for classification head
|
||||
global_pool : str,int
|
||||
type of global pooling for final sequence (default: 'token'), or index for token to be taken
|
||||
embed_dim : int
|
||||
embedding dimension
|
||||
depth : int
|
||||
number of transformer layers
|
||||
num_heads : int
|
||||
number of transformer heads
|
||||
mlp_ratio : float
|
||||
ratio of feed forward (mlp) hidden dimension to embedding dimension
|
||||
qkv_bias : bool
|
||||
enable bias for query, key, and value (qkv) embeddings
|
||||
qk_norm : bool
|
||||
normalize query and key embeddings
|
||||
init_values : float
|
||||
layer scale initial values
|
||||
class_token : bool
|
||||
use a class token [CLS]
|
||||
no_embed_class : bool
|
||||
no positional embedding for the class token
|
||||
pre_norm : bool
|
||||
use pre-norm architecture (norm before the blocks, not after)
|
||||
fc_norm : bool
|
||||
norm after pool (used when global_pool == 'avg')
|
||||
drop_rate : float
|
||||
dropout rate
|
||||
attn_drop_rate : float
|
||||
dropout rate in the attention module
|
||||
drop_path_rate : float
|
||||
drop path rate (stochastic depth)
|
||||
weight_init : str
|
||||
scheme for weight initialization
|
||||
embed_layer : nn.Module
|
||||
patch embedding layer
|
||||
norm_layer : nn.Module
|
||||
normalization layer
|
||||
act_layer : nn.Module
|
||||
activation function
|
||||
block_fn : nn.Module
|
||||
which block structure to use; for parallel attention layers, ...
|
||||
save_attention_maps : bool
|
||||
save attention maps for each block
|
||||
fused_attn : bool
|
||||
use fused attention
|
||||
kwargs : dict
|
||||
additional arguments (will be ignored)
|
||||
|
||||
"""
|
||||
init_kwargs = dict(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
num_classes=num_classes,
|
||||
global_pool=global_pool if isinstance(global_pool, str) else "avg",
|
||||
embed_dim=embed_dim,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
init_values=init_values,
|
||||
class_token=class_token,
|
||||
no_embed_class=no_embed_class,
|
||||
pre_norm=pre_norm,
|
||||
fc_norm=fc_norm,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
weight_init=weight_init,
|
||||
embed_layer=embed_layer,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
block_fn=block_fn,
|
||||
)
|
||||
|
||||
self.new_version = False # TODO: check based on the timm version
|
||||
self.global_pool = global_pool
|
||||
if isinstance(global_pool, int):
|
||||
self.attn_pool = partial(_index_picker, idx=global_pool)
|
||||
|
||||
if self.new_version:
|
||||
init_kwargs["qk_norm"] = qk_norm
|
||||
init_kwargs["proj_drop_rate"] = drop_rate
|
||||
super(TimmViT, self).__init__(**init_kwargs)
|
||||
self.embed_layer = embed_layer
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.in_chans = in_chans
|
||||
self.pre_norm = pre_norm
|
||||
self.class_token = class_token
|
||||
self.no_embed_class = no_embed_class
|
||||
self.num_classes = num_classes
|
||||
self.num_heads = num_heads
|
||||
self.depth = depth
|
||||
self.save_attention_maps = save_attention_maps
|
||||
if save_attention_maps:
|
||||
self.do_save_attention_maps()
|
||||
try:
|
||||
for block in self.blocks:
|
||||
block.attn.fused_attn = fused_attn and self.blocks[0].attn.fused_attn
|
||||
use_fused = self.blocks[0].attn.fused_attn
|
||||
logger.info(f"Use fused attention: {use_fused}")
|
||||
except: # I'm lazy for now # noqa: E722
|
||||
pass
|
||||
|
||||
def do_save_attention_maps(self):
|
||||
self.save_attention_maps = True
|
||||
for block in self.blocks:
|
||||
block.attn = _MatrixSaveAttn.cast(block.attn)
|
||||
|
||||
def attention_maps(self):
|
||||
assert self.save_attention_maps, "Have to save attention maps first"
|
||||
return [getattr(block.attn, "attn_mat", None) for block in self.blocks]
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
if self.new_version:
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
x = self.blocks(x)
|
||||
return self.norm(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
return self.forward_head(x)
|
||||
Reference in New Issue
Block a user