AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,238 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
from architectures.vit import TimmViT
__all__ = [
"deit_tiny_patch16_224",
"deit_small_patch16_224",
"deit_base_patch16_224",
"deit_tiny_distilled_patch16_224",
"deit_small_distilled_patch16_224",
"deit_base_distilled_patch16_224",
"deit_base_patch16_384",
"deit_base_distilled_patch16_384",
]
class DistilledVisionTransformer(TimmViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
trunc_normal_(self.dist_token, std=0.02)
trunc_normal_(self.pos_embed, std=0.02)
self.head_dist.apply(self._init_weights)
def forward_features(self, x):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add the dist_token
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0], x[:, 1]
def forward(self, x):
x, x_dist = self.forward_features(x)
x = self.head(x)
x_dist = self.head_dist(x_dist)
if self.training:
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
def _clean_kwargs(kwargs):
allowed_keys = {key for key in kwargs.keys() if not key.startswith("pretrain")}
allowed_keys = {key for key in allowed_keys if not key.startswith("cache")}
return {key: kwargs[key] for key in allowed_keys}
@register_model
def deit_tiny_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
kwargs = _clean_kwargs(kwargs)
model = TimmViT(
patch_size=16,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_small_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
kwargs = _clean_kwargs(kwargs)
model = TimmViT(
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
kwargs = _clean_kwargs(kwargs)
model = TimmViT(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_tiny_distilled_patch16(
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
):
kwargs = _clean_kwargs(kwargs)
model = DistilledVisionTransformer(
patch_size=16,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_small_distilled_patch16(
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
):
kwargs = _clean_kwargs(kwargs)
model = DistilledVisionTransformer(
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_distilled_patch16(
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
):
kwargs = _clean_kwargs(kwargs)
model = DistilledVisionTransformer(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model

View File

@@ -0,0 +1,850 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Taken from https://github.com/facebookresearch/deit with slight modifications
from loguru import logger
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.vision_transformer import Mlp, PatchEmbed, _cfg
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from resizing_interface import ResizingInterface
class Attention(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Layer_scale_init_Block(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class Layer_scale_init_Block_paralx2(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm11 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.attn1 = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm21 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.mlp1 = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, x):
x = (
x
+ self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
+ self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
)
x = (
x
+ self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
)
return x
class Block_paralx2(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm11 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.attn1 = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm21 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.mlp1 = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.attn1(self.norm11(x)))
x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.mlp1(self.norm21(x)))
return x
class hMLP_stem(nn.Module):
"""hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=nn.SyncBatchNorm,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = torch.nn.Sequential(
*[
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=4, stride=4),
norm_layer(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim // 4, kernel_size=2, stride=2),
norm_layer(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),
norm_layer(embed_dim),
]
)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class vit_models(nn.Module, ResizingInterface):
"""Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
global_pool=None,
block_layers=Block,
Patch_layer=PatchEmbed,
act_layer=nn.GELU,
Attention_block=Attention,
Mlp_block=Mlp,
dpr_constant=True,
init_scale=1e-4,
mlp_ratio_clstk=4.0,
**kwargs,
):
super().__init__()
self.dropout_rate = drop_rate
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.img_size = img_size
self.patch_embed = Patch_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.embed_layer = Patch_layer
self.pre_norm = False
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.no_embed_class = True
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList(
[
block_layers(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=0.0,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
Attention_block=Attention_block,
Mlp_block=Mlp_block,
init_values=init_scale,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")]
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
def set_num_classes(self, n_classes):
super().set_num_classes(n_classes)
self._init_weights(self.head)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def get_classifier(self):
return self.head
def get_num_layers(self):
return len(self.blocks)
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, test=False):
B = x.shape[0]
x = self.patch_embed(x)
if test and x.isnan().any().item():
logger.error("patch embedded input has nan value")
cls_tokens = self.cls_token.expand(B, -1, -1)
x = x + self.pos_embed
if test and x.isnan().any().item():
logger.error("position embedded input has a nan value")
x = torch.cat((cls_tokens, x), dim=1)
if test and x.isnan().any().item():
logger.error("input with [CLS] has a nan value")
for i, blk in enumerate(self.blocks):
x = blk(x)
if test and x.isnan().any().item():
logger.error(f"output of block {i} has a nan value")
x = self.norm(x)
return x[:, 0]
def forward(self, x, test=False):
x = self.forward_features(x, test=test)
if self.dropout_rate:
x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
x = self.head(x)
return x
# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
@register_model
def deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
model.default_cfg = _cfg()
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_small_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
patch_size=16,
embed_dim=512,
depth=12,
num_heads=8,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
model.default_cfg = _cfg()
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_medium_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_base_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_large_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_huge_" + str(img_size) + "_"
if pretrained_21k:
name += "21k_v1.pth"
else:
name += "1k_v1.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1280,
depth=52,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1280,
depth=26,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block_paralx2,
**kwargs,
)
return model
# @register_model
# def deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
# model = vit_models(
# img_size=img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4,
# norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
#
# return model
# @register_model
# def deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
# model = vit_models(
# img_size=img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4,
# norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
# return model
@register_model
def deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1664,
depth=48,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1408,
depth=40,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
# model.default_cfg = _cfg()
return model
# Models from Three things everyone should know about Vision Transformers (https://arxiv.org/pdf/2203.09795.pdf)
@register_model
def deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=36,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=36,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
@register_model
def deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=18,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=18,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=18,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=18,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=36,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=36,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model

View File

@@ -0,0 +1,111 @@
from timm.models import register_model
from timm.models.resnet import BasicBlock, Bottleneck
from timm.models.resnet import ResNet as ResNetTimm
from torch import nn
from resizing_interface import ResizingInterface
class ResNet(ResNetTimm, ResizingInterface):
"""The popular ResNet model with a ResizingInterface."""
def __init__(self, *args, global_pool="avg", **kwargs):
"""Create ResNet model with resizing capabilities.
Args:
*args: Arguments for the ResNet model.
global_pool (str, optional): _description_. Defaults to "avg".
**kwargs: Keyword arguments for the ResNet model (from Timm).
Keyword Args:
block, layers, num_classes, in_chans, output_stride, cardinality, base_width, stem_width, stem_type, replace_stem_pool, block_reduce_first, down_kernel_size, avg_down, act_layer, norm_layer, aa_layer, drop_rate, drop_path_rate, drop_block_rate, zero_init_last, block_args
"""
admissible_kwargs = [
"block",
"layers",
"num_classes",
"in_chans",
"output_stride",
"cardinality",
"base_width",
"stem_width",
"stem_type",
"replace_stem_pool",
"block_reduce_first",
"down_kernel_size",
"avg_down",
"act_layer",
"norm_layer",
"aa_layer",
"drop_rate",
"drop_path_rate",
"drop_block_rate",
"zero_init_last",
"block_args",
]
for key in list(kwargs.keys()):
if key not in admissible_kwargs:
kwargs.pop(key)
super().__init__(*args, global_pool=global_pool, **kwargs)
self.global_pool_str = global_pool
def set_image_res(self, res):
# resizing not needed for CNNs with pooling
return
def set_num_classes(self, n_classes):
if self.num_classes == n_classes:
return
if n_classes > 0:
self.reset_classifier(num_classes=n_classes, global_pool=self.global_pool_str)
else:
self.fc = nn.Identity()
@register_model
def resnet18(pretrained=False, **kwargs):
"""Construct a ResNet-18 model."""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
return ResNet(**model_args, **kwargs)
@register_model
def resnet34(pretrained=False, **kwargs):
"""Construct a ResNet-34 model."""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
return ResNet(**model_args, **kwargs)
@register_model
def resnet26(pretrained=False, **kwargs):
"""Construct a ResNet-26 model."""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
return ResNet(**model_args, **kwargs)
@register_model
def resnet50(pretrained=False, **kwargs):
"""Construct a ResNet-50 model."""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3])
return ResNet(**model_args, **kwargs)
@register_model
def resnet101(pretrained=False, **kwargs):
"""Construct a ResNet-101 model."""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
return ResNet(**model_args, **kwargs)
@register_model
def wide_resnet50_2(pretrained=False, **kwargs):
"""Construct a Wide ResNet-50-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
return ResNet(**model_args, **kwargs)

View File

@@ -0,0 +1,893 @@
# Taken from https://github.com/microsoft/Swin-Transformer with slight modifications
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
from copy import copy
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from loguru import logger
from timm.models import register_model
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from resizing_interface import ResizingInterface
try:
import os
import sys
kernel_path = os.path.abspath(os.path.join(".."))
sys.path.append(kernel_path)
from kernels.window_process.window_process import (
WindowProcess,
WindowProcessReverse,
)
except ImportError:
WindowProcess = None
WindowProcessReverse = None
logger.warning("Fused window process have not been installed. Please refer to get_started.md for installation.")
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
fused_window_process=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.fused_window_process = fused_window_process
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
)
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
fused_window_process=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=(drop_path[i] if isinstance(drop_path, list) else drop_path),
norm_layer=norm_layer,
fused_window_process=fused_window_process,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
r"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module, ResizingInterface):
r"""Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
fused_window_process=False,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
self.img_size = img_size
self.fused_window_process = fused_window_process
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
self.embed_layer = PatchEmbed
self.patch_size = patch_size
self.in_chans = in_chans
self.norm_layer = norm_layer
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
input_resolution=(
patches_resolution[0] // (2**i_layer),
patches_resolution[1] // (2**i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
fused_window_process=fused_window_process,
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def set_num_classes(self, n_classes):
"""Reset the classification head with a new number of classes.
Parameters
----------
n_classes : int
new number of classes
"""
if n_classes == self.num_classes:
return
self.head = nn.Linear(self.num_features, n_classes) if n_classes > 0 else nn.Identity()
self.num_classes = n_classes
nn.init.trunc_normal_(self.head.weight, std=0.02)
nn.init.constant_(self.head.bias, 0)
def set_image_res(self, res):
if res == self.img_size:
return
old_patch_embed_state = copy(self.patch_embed.state_dict())
self.patch_embed = self.embed_layer(
img_size=res,
patch_size=self.patch_size,
in_chans=self.in_chans,
embed_dim=self.embed_dim,
norm_layer=self.norm_layer if self.patch_norm else None,
)
self.patch_embed.load_state_dict(old_patch_embed_state)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
for i_layer, layer in enumerate(self.layers):
input_resolution = (
patches_resolution[0] // (2**i_layer),
patches_resolution[1] // (2**i_layer),
)
layer.input_resolution = input_resolution
downsample = PatchMerging if (i_layer < self.num_layers - 1) else None
if downsample is not None:
layer.downsample = downsample(input_resolution, dim=layer.dim, norm_layer=self.norm_layer)
for block in layer.blocks:
block.input_resolution = input_resolution
if min(input_resolution) <= block.window_size:
# if window size is larger than input resolution, we don't partition windows
block.shift_size = 0
block.window_size = min(block.input_resolution)
assert 0 <= block.shift_size < block.window_size, "shift_size must in 0-window_size"
if block.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = block.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -block.window_size),
slice(-block.window_size, -block.shift_size),
slice(-block.shift_size, None),
)
w_slices = (
slice(0, -block.window_size),
slice(-block.window_size, -block.shift_size),
slice(-block.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, block.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, block.window_size * block.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
else:
attn_mask = None
block.register_buffer("attn_mask", attn_mask)
if self.ape:
orig_size = int((self.absolute_pos_embed.shape[-2]) ** 0.5)
new_size = int(self.patch_embed.num_patches**0.5)
pos_tokens = self.absolute_pos_embed[:, :]
# make it shape rest x embed_dim x orig_size x orig_size
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
pos_tokens = nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
# make it shape rest x new_size^2 x embed_dim
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
self.absolute_pos_embed = nn.Parameter(pos_tokens.contiguous())
self.img_size = res
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {"absolute_pos_embed"}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"relative_position_bias_table"}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
return x
def forward(self, x):
x = self.forward_features(x)
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
x = self.head(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2**self.num_layers)
flops += self.num_features * self.num_classes
return flops
swin_sizes = {
"Ti": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]),
"S": dict(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24]),
"B": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
"L": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48]),
}
@register_model
def swin_tiny_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["Ti"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_small_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["S"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_base_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["B"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_large_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["L"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_ashim(pretrained=False, img_size=112, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = dict(embed_dim=384, depths=[12], num_heads=[12])
if "num_heads" in kwargs:
kwargs["num_heads"] = [kwargs["num_heads"]]
return SwinTransformer(img_size=img_size, in_chans=3, patch_size=2, window_size=7, **{**size, **kwargs})

View File

@@ -0,0 +1,225 @@
from functools import partial
import torch
from loguru import logger
from timm.models.vision_transformer import (
Attention,
Block,
PatchEmbed,
VisionTransformer,
)
from resizing_interface import ResizingInterface
class _MatrixSaveAttn(Attention):
attn_mat = None
@classmethod
def cast(cls, attn: Attention):
assert isinstance(attn, Attention), "Can only save attention from Timms attention class"
attn.__class__ = cls
assert isinstance(attn, _MatrixSaveAttn)
return attn
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
self.attn_mat = attn.detach()
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return self.proj_drop(x)
def _index_picker(tensor, idx=-1):
"""Pick a specific index from a tensor.
tensor: B x N x D -> B x D, by picking idx from N.
Args:
tensor (toch.tensor): tensor to pick from.
idx (int, optional): index to pick. Defaults to -1.
Returns:
torch.tensor: index from tensor
"""
return tensor[..., idx, :] # B x N x D -> B x D
class TimmViT(VisionTransformer, ResizingInterface):
"""Wrapper for *VisionTransformer* from *timm* library (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)."""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool="token",
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
qk_norm=False,
init_values=None,
class_token=True,
no_embed_class=True,
pre_norm=False,
fc_norm=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
weight_init="",
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
save_attention_maps=False,
fused_attn=True,
**kwargs,
):
"""Create a Vision Transformer model.
Parameters
----------
img_size : int
image dimensions -> img_size x img_size
patch_size : int
patch_size
in_chans : int
number of image channels
num_classes : int
number of classes for classification head
global_pool : str,int
type of global pooling for final sequence (default: 'token'), or index for token to be taken
embed_dim : int
embedding dimension
depth : int
number of transformer layers
num_heads : int
number of transformer heads
mlp_ratio : float
ratio of feed forward (mlp) hidden dimension to embedding dimension
qkv_bias : bool
enable bias for query, key, and value (qkv) embeddings
qk_norm : bool
normalize query and key embeddings
init_values : float
layer scale initial values
class_token : bool
use a class token [CLS]
no_embed_class : bool
no positional embedding for the class token
pre_norm : bool
use pre-norm architecture (norm before the blocks, not after)
fc_norm : bool
norm after pool (used when global_pool == 'avg')
drop_rate : float
dropout rate
attn_drop_rate : float
dropout rate in the attention module
drop_path_rate : float
drop path rate (stochastic depth)
weight_init : str
scheme for weight initialization
embed_layer : nn.Module
patch embedding layer
norm_layer : nn.Module
normalization layer
act_layer : nn.Module
activation function
block_fn : nn.Module
which block structure to use; for parallel attention layers, ...
save_attention_maps : bool
save attention maps for each block
fused_attn : bool
use fused attention
kwargs : dict
additional arguments (will be ignored)
"""
init_kwargs = dict(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
num_classes=num_classes,
global_pool=global_pool if isinstance(global_pool, str) else "avg",
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
class_token=class_token,
no_embed_class=no_embed_class,
pre_norm=pre_norm,
fc_norm=fc_norm,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
weight_init=weight_init,
embed_layer=embed_layer,
norm_layer=norm_layer,
act_layer=act_layer,
block_fn=block_fn,
)
self.new_version = False # TODO: check based on the timm version
self.global_pool = global_pool
if isinstance(global_pool, int):
self.attn_pool = partial(_index_picker, idx=global_pool)
if self.new_version:
init_kwargs["qk_norm"] = qk_norm
init_kwargs["proj_drop_rate"] = drop_rate
super(TimmViT, self).__init__(**init_kwargs)
self.embed_layer = embed_layer
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.pre_norm = pre_norm
self.class_token = class_token
self.no_embed_class = no_embed_class
self.num_classes = num_classes
self.num_heads = num_heads
self.depth = depth
self.save_attention_maps = save_attention_maps
if save_attention_maps:
self.do_save_attention_maps()
try:
for block in self.blocks:
block.attn.fused_attn = fused_attn and self.blocks[0].attn.fused_attn
use_fused = self.blocks[0].attn.fused_attn
logger.info(f"Use fused attention: {use_fused}")
except: # I'm lazy for now # noqa: E722
pass
def do_save_attention_maps(self):
self.save_attention_maps = True
for block in self.blocks:
block.attn = _MatrixSaveAttn.cast(block.attn)
def attention_maps(self):
assert self.save_attention_maps, "Have to save attention maps first"
return [getattr(block.attn, "attn_mat", None) for block in self.blocks]
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
if self.new_version:
x = self.patch_drop(x)
x = self.norm_pre(x)
x = self.blocks(x)
return self.norm(x)
def forward(self, x):
x = self.forward_features(x)
return self.forward_head(x)