AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,76 @@
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]
# Same as Black.
line-length = 120
indent-width = 4
[lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["E", "F", "D", "B", "NPY", "PD", "TD005", "TD006", "TD007", "SIM", "RET", "Q", "ICN", "I"]
ignore = ["D203", "D213", "E501", "D100", "NPY002", "D102", "B008", "PD011", "D105", "SIM118", "D417"]
# Allow fix for all enabled rules (when `--fix`) is provided.
unfixable = ["B"]
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[lint.pydocstyle]
convention = "google"
[format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = true
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
# Enable auto-formatting of code examples in docstrings. Markdown,
# reStructuredText code/literal blocks and doctests are all supported.
#
# This is currently disabled by default, but it is planned for this
# to be opt-out in the future.
docstring-code-format = true
# Set the line length limit used when formatting code snippets in
# docstrings.
#
# This only has an effect when the `docstring-code-format` setting is
# enabled.
docstring-code-line-length = "dynamic"

View File

@@ -0,0 +1,107 @@
# ForNet
This is the training code for the ForNet paper.
All our experiments and evaluations were run using this codebase.
## Requirements
This project heavily builds on [timm](https://github.com/huggingface/pytorch-image-models) and open source implementations of the models that are tested.
All requirements are listed in [requirements.txt](./requirements.txt).
To install those, run
```commandline
pip install -r requirements.txt
```
## Usage
After **cloning this repository**, you can train and test a lot of different models.
By default, a `srun` command is executed to run the code on a slurm cluster.
To run on the local machine, append the `-local` flag to the command.
### General Preparation
After cloning the repository on a slurm cluster, make sure main.py is executable (by using "chmod a+x main.py").
To run the project on a slurm cluster, you need to create a docker image from the requirements file.
You will also want to adapt the default slurm parameters in `config.py`.
Next, adjust the paths in paths_config.py for your system, specifically results_folder, slurm_output_folder and dataset folders.
Finally, if you want to use Weights and Biases for Tracking, create the file ".wandb.apikey" in this folder and paste your API Key into it.
### Training
#### Pretraining
To pretrain a `ViT-S` on a given dataset, run
```commandline
./main.py --task pre-train --model ViT-S/16 --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
```
This will save a checkpoint (`.pt` file) every `<save_epochs>` epochs (the default is 10), which contains all the model weights, along with the optimizer and scheduler state, and the current training stats.
#### Finetuning
A model (checkpoint) can be finetuned on another dataset using the following command:
```commandline
./main.py --task fine-tune --model <model_checkpoint.pt> --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
```
This will also save new checkpoints during training.
### Evaluation
It is also possible to evaluate the models.
To evaluate the model's accuracy on a specific dataset, run
```commandline
./main.py -t eval -ds <dataset name> -m <model_checkpoint.pt> --ntasks 1 -bs 512 --num-workers 10 --cpus-per-task 10 --time 10:00 (--local)
```
You can run our center-bias, size-bias, and foreground-focus evaluations using the `eval-attr`, `eval-center-bias`, and `eval-size-bias` tasks (`-t` or `--task` argument).
### Further Arguments
There can be multiple further arguments and flags given to the scripts.
The most important ones are
| Arg | Description |
| :------------------------------ | :----------------------------------------------------- |
| `--model <model>` | Model name or checkpoint. |
| `--run_name <name for the run>` | Name or description of this training run. |
| `--dataset <dataset>` | Specifies a dataset to use. |
| `--task <task>` | Specifies a task. The default is `pre-train`. |
| `--local` | Run on the local machine, not on a slurm cluster. |
| `--epochs <epochs>` | Epochs to train. |
| `--lr <lr>` | Learning rate. Default is 3e-3. |
| `--batch_size <bs>` | Batch size. Default is 2048. |
| `--weight_decay <wd>` | Weight decay. Default is 0.02. |
| `--imsize <image resolution>` | Resulution of the image to train with. Default is 224. |
For a list of all arguments, run
```commandline
./main.py --help
```
## Supported Models
These are the models we support. Links are to original code sources. If no link is provided, we implemented the architecture from scratch, following the specific paper.
| Architecture | Versions |
| :----------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [DeiT](https://github.com/facebookresearch/deit) | `deit_tiny_patch16_LS`, `deit_small_patch16_LS`, `deit_medium_patch16_LS`, `deit_base_patch16_LS`, `deit_large_patch16_LS`, `deit_huge_patch14_LS`, `deit_huge_patch14_52_LS`, `deit_huge_patch14_26x2_LS`, `deit_Giant_48_patch14_LS`, `deit_giant_40_patch14_LS`, `deit_small_patch16_36_LS`, `deit_small_patch16_36`, `deit_small_patch16_18x2_LS`, `deit_small_patch16_18x2`, `deit_base_patch16_18x2_LS`, `deit_base_patch16_18x2`, `deit_base_patch16_36x1_LS`, `deit_base_patch16_36x1` |
| [ResNet](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnet.py) | `resnet18`, `resnet34`, `resnet26`, `resnet50`, `resnet101`, `wide_resnet50_2` |
| [Swin](https://github.com/microsoft/Swin-Transformer) | `swin_tiny_patch4_window7`, `swin_small_patch4_window7`, `swin_base_patch4_window7`, `swin_large_patch4_window7` |
| [ViT](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py) | `ViT-{Ti,S,B,L}/<patch_size>` |
## License
We release this code under the [MIT license](./LICENSE).
## Citation
If you use this codebase in your project, please cite:

View File

@@ -0,0 +1,3 @@
albumentations==2.0.5
datasets==3.5.0
nvidia-dali-cuda120==1.47.0

View File

@@ -0,0 +1,238 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
from architectures.vit import TimmViT
__all__ = [
"deit_tiny_patch16_224",
"deit_small_patch16_224",
"deit_base_patch16_224",
"deit_tiny_distilled_patch16_224",
"deit_small_distilled_patch16_224",
"deit_base_distilled_patch16_224",
"deit_base_patch16_384",
"deit_base_distilled_patch16_384",
]
class DistilledVisionTransformer(TimmViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
trunc_normal_(self.dist_token, std=0.02)
trunc_normal_(self.pos_embed, std=0.02)
self.head_dist.apply(self._init_weights)
def forward_features(self, x):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add the dist_token
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0], x[:, 1]
def forward(self, x):
x, x_dist = self.forward_features(x)
x = self.head(x)
x_dist = self.head_dist(x_dist)
if self.training:
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
def _clean_kwargs(kwargs):
allowed_keys = {key for key in kwargs.keys() if not key.startswith("pretrain")}
allowed_keys = {key for key in allowed_keys if not key.startswith("cache")}
return {key: kwargs[key] for key in allowed_keys}
@register_model
def deit_tiny_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
kwargs = _clean_kwargs(kwargs)
model = TimmViT(
patch_size=16,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_small_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
kwargs = _clean_kwargs(kwargs)
model = TimmViT(
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
kwargs = _clean_kwargs(kwargs)
model = TimmViT(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_tiny_distilled_patch16(
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
):
kwargs = _clean_kwargs(kwargs)
model = DistilledVisionTransformer(
patch_size=16,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_small_distilled_patch16(
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
):
kwargs = _clean_kwargs(kwargs)
model = DistilledVisionTransformer(
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_distilled_patch16(
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
):
kwargs = _clean_kwargs(kwargs)
model = DistilledVisionTransformer(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
img_size=img_size,
drop_path_rate=drop_path_rate,
num_classes=num_classes,
drop_rate=drop_rate,
)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
return model

View File

@@ -0,0 +1,850 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Taken from https://github.com/facebookresearch/deit with slight modifications
from loguru import logger
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.vision_transformer import Mlp, PatchEmbed, _cfg
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from resizing_interface import ResizingInterface
class Attention(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Layer_scale_init_Block(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class Layer_scale_init_Block_paralx2(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm11 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.attn1 = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm21 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.mlp1 = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, x):
x = (
x
+ self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
+ self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
)
x = (
x
+ self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
)
return x
class Block_paralx2(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Attention_block=Attention,
Mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm11 = norm_layer(dim)
self.attn = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.attn1 = Attention_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm21 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.mlp1 = Mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.attn1(self.norm11(x)))
x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.mlp1(self.norm21(x)))
return x
class hMLP_stem(nn.Module):
"""hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=nn.SyncBatchNorm,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = torch.nn.Sequential(
*[
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=4, stride=4),
norm_layer(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim // 4, kernel_size=2, stride=2),
norm_layer(embed_dim // 4),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),
norm_layer(embed_dim),
]
)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class vit_models(nn.Module, ResizingInterface):
"""Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
with slight modifications
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
global_pool=None,
block_layers=Block,
Patch_layer=PatchEmbed,
act_layer=nn.GELU,
Attention_block=Attention,
Mlp_block=Mlp,
dpr_constant=True,
init_scale=1e-4,
mlp_ratio_clstk=4.0,
**kwargs,
):
super().__init__()
self.dropout_rate = drop_rate
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.img_size = img_size
self.patch_embed = Patch_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.embed_layer = Patch_layer
self.pre_norm = False
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.no_embed_class = True
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList(
[
block_layers(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=0.0,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
Attention_block=Attention_block,
Mlp_block=Mlp_block,
init_values=init_scale,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")]
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=0.02)
trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
def set_num_classes(self, n_classes):
super().set_num_classes(n_classes)
self._init_weights(self.head)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def get_classifier(self):
return self.head
def get_num_layers(self):
return len(self.blocks)
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, test=False):
B = x.shape[0]
x = self.patch_embed(x)
if test and x.isnan().any().item():
logger.error("patch embedded input has nan value")
cls_tokens = self.cls_token.expand(B, -1, -1)
x = x + self.pos_embed
if test and x.isnan().any().item():
logger.error("position embedded input has a nan value")
x = torch.cat((cls_tokens, x), dim=1)
if test and x.isnan().any().item():
logger.error("input with [CLS] has a nan value")
for i, blk in enumerate(self.blocks):
x = blk(x)
if test and x.isnan().any().item():
logger.error(f"output of block {i} has a nan value")
x = self.norm(x)
return x[:, 0]
def forward(self, x, test=False):
x = self.forward_features(x, test=test)
if self.dropout_rate:
x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
x = self.head(x)
return x
# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
@register_model
def deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
model.default_cfg = _cfg()
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_small_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
patch_size=16,
embed_dim=512,
depth=12,
num_heads=8,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
model.default_cfg = _cfg()
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_medium_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_base_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_large_" + str(img_size) + "_"
if pretrained_21k:
name += "21k.pth"
else:
name += "1k.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
if pretrained:
name = "https://dl.fbaipublicfiles.com/deit/deit_3_huge_" + str(img_size) + "_"
if pretrained_21k:
name += "21k_v1.pth"
else:
name += "1k_v1.pth"
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1280,
depth=52,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1280,
depth=26,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block_paralx2,
**kwargs,
)
return model
# @register_model
# def deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
# model = vit_models(
# img_size=img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4,
# norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
#
# return model
# @register_model
# def deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
# model = vit_models(
# img_size=img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4,
# norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
# return model
@register_model
def deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1664,
depth=48,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=14,
embed_dim=1408,
depth=40,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
# model.default_cfg = _cfg()
return model
# Models from Three things everyone should know about Vision Transformers (https://arxiv.org/pdf/2203.09795.pdf)
@register_model
def deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=36,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=36,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
@register_model
def deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=18,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=18,
num_heads=6,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=18,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=18,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Block_paralx2,
**kwargs,
)
return model
@register_model
def deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=36,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers=Layer_scale_init_Block,
**kwargs,
)
return model
@register_model
def deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
model = vit_models(
img_size=img_size,
patch_size=16,
embed_dim=768,
depth=36,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model

View File

@@ -0,0 +1,111 @@
from timm.models import register_model
from timm.models.resnet import BasicBlock, Bottleneck
from timm.models.resnet import ResNet as ResNetTimm
from torch import nn
from resizing_interface import ResizingInterface
class ResNet(ResNetTimm, ResizingInterface):
"""The popular ResNet model with a ResizingInterface."""
def __init__(self, *args, global_pool="avg", **kwargs):
"""Create ResNet model with resizing capabilities.
Args:
*args: Arguments for the ResNet model.
global_pool (str, optional): _description_. Defaults to "avg".
**kwargs: Keyword arguments for the ResNet model (from Timm).
Keyword Args:
block, layers, num_classes, in_chans, output_stride, cardinality, base_width, stem_width, stem_type, replace_stem_pool, block_reduce_first, down_kernel_size, avg_down, act_layer, norm_layer, aa_layer, drop_rate, drop_path_rate, drop_block_rate, zero_init_last, block_args
"""
admissible_kwargs = [
"block",
"layers",
"num_classes",
"in_chans",
"output_stride",
"cardinality",
"base_width",
"stem_width",
"stem_type",
"replace_stem_pool",
"block_reduce_first",
"down_kernel_size",
"avg_down",
"act_layer",
"norm_layer",
"aa_layer",
"drop_rate",
"drop_path_rate",
"drop_block_rate",
"zero_init_last",
"block_args",
]
for key in list(kwargs.keys()):
if key not in admissible_kwargs:
kwargs.pop(key)
super().__init__(*args, global_pool=global_pool, **kwargs)
self.global_pool_str = global_pool
def set_image_res(self, res):
# resizing not needed for CNNs with pooling
return
def set_num_classes(self, n_classes):
if self.num_classes == n_classes:
return
if n_classes > 0:
self.reset_classifier(num_classes=n_classes, global_pool=self.global_pool_str)
else:
self.fc = nn.Identity()
@register_model
def resnet18(pretrained=False, **kwargs):
"""Construct a ResNet-18 model."""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
return ResNet(**model_args, **kwargs)
@register_model
def resnet34(pretrained=False, **kwargs):
"""Construct a ResNet-34 model."""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
return ResNet(**model_args, **kwargs)
@register_model
def resnet26(pretrained=False, **kwargs):
"""Construct a ResNet-26 model."""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
return ResNet(**model_args, **kwargs)
@register_model
def resnet50(pretrained=False, **kwargs):
"""Construct a ResNet-50 model."""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3])
return ResNet(**model_args, **kwargs)
@register_model
def resnet101(pretrained=False, **kwargs):
"""Construct a ResNet-101 model."""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
return ResNet(**model_args, **kwargs)
@register_model
def wide_resnet50_2(pretrained=False, **kwargs):
"""Construct a Wide ResNet-50-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
return ResNet(**model_args, **kwargs)

View File

@@ -0,0 +1,893 @@
# Taken from https://github.com/microsoft/Swin-Transformer with slight modifications
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
from copy import copy
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from loguru import logger
from timm.models import register_model
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from resizing_interface import ResizingInterface
try:
import os
import sys
kernel_path = os.path.abspath(os.path.join(".."))
sys.path.append(kernel_path)
from kernels.window_process.window_process import (
WindowProcess,
WindowProcessReverse,
)
except ImportError:
WindowProcess = None
WindowProcessReverse = None
logger.warning("Fused window process have not been installed. Please refer to get_started.md for installation.")
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
fused_window_process=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.fused_window_process = fused_window_process
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
)
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
fused_window_process=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=(drop_path[i] if isinstance(drop_path, list) else drop_path),
norm_layer=norm_layer,
fused_window_process=fused_window_process,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
r"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module, ResizingInterface):
r"""Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
fused_window_process=False,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
self.img_size = img_size
self.fused_window_process = fused_window_process
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
self.embed_layer = PatchEmbed
self.patch_size = patch_size
self.in_chans = in_chans
self.norm_layer = norm_layer
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
input_resolution=(
patches_resolution[0] // (2**i_layer),
patches_resolution[1] // (2**i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
fused_window_process=fused_window_process,
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def set_num_classes(self, n_classes):
"""Reset the classification head with a new number of classes.
Parameters
----------
n_classes : int
new number of classes
"""
if n_classes == self.num_classes:
return
self.head = nn.Linear(self.num_features, n_classes) if n_classes > 0 else nn.Identity()
self.num_classes = n_classes
nn.init.trunc_normal_(self.head.weight, std=0.02)
nn.init.constant_(self.head.bias, 0)
def set_image_res(self, res):
if res == self.img_size:
return
old_patch_embed_state = copy(self.patch_embed.state_dict())
self.patch_embed = self.embed_layer(
img_size=res,
patch_size=self.patch_size,
in_chans=self.in_chans,
embed_dim=self.embed_dim,
norm_layer=self.norm_layer if self.patch_norm else None,
)
self.patch_embed.load_state_dict(old_patch_embed_state)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
for i_layer, layer in enumerate(self.layers):
input_resolution = (
patches_resolution[0] // (2**i_layer),
patches_resolution[1] // (2**i_layer),
)
layer.input_resolution = input_resolution
downsample = PatchMerging if (i_layer < self.num_layers - 1) else None
if downsample is not None:
layer.downsample = downsample(input_resolution, dim=layer.dim, norm_layer=self.norm_layer)
for block in layer.blocks:
block.input_resolution = input_resolution
if min(input_resolution) <= block.window_size:
# if window size is larger than input resolution, we don't partition windows
block.shift_size = 0
block.window_size = min(block.input_resolution)
assert 0 <= block.shift_size < block.window_size, "shift_size must in 0-window_size"
if block.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = block.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -block.window_size),
slice(-block.window_size, -block.shift_size),
slice(-block.shift_size, None),
)
w_slices = (
slice(0, -block.window_size),
slice(-block.window_size, -block.shift_size),
slice(-block.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, block.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, block.window_size * block.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
else:
attn_mask = None
block.register_buffer("attn_mask", attn_mask)
if self.ape:
orig_size = int((self.absolute_pos_embed.shape[-2]) ** 0.5)
new_size = int(self.patch_embed.num_patches**0.5)
pos_tokens = self.absolute_pos_embed[:, :]
# make it shape rest x embed_dim x orig_size x orig_size
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
pos_tokens = nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
# make it shape rest x new_size^2 x embed_dim
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
self.absolute_pos_embed = nn.Parameter(pos_tokens.contiguous())
self.img_size = res
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {"absolute_pos_embed"}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"relative_position_bias_table"}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
return x
def forward(self, x):
x = self.forward_features(x)
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
x = self.head(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2**self.num_layers)
flops += self.num_features * self.num_classes
return flops
swin_sizes = {
"Ti": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]),
"S": dict(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24]),
"B": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
"L": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48]),
}
@register_model
def swin_tiny_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["Ti"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_small_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["S"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_base_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["B"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_large_patch4_window7(pretrained=False, img_size=224, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = swin_sizes["L"]
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
return model
@register_model
def swin_ashim(pretrained=False, img_size=112, **kwargs):
if "pretrained_cfg" in kwargs:
kwargs.pop("pretrained_cfg")
size = dict(embed_dim=384, depths=[12], num_heads=[12])
if "num_heads" in kwargs:
kwargs["num_heads"] = [kwargs["num_heads"]]
return SwinTransformer(img_size=img_size, in_chans=3, patch_size=2, window_size=7, **{**size, **kwargs})

View File

@@ -0,0 +1,225 @@
from functools import partial
import torch
from loguru import logger
from timm.models.vision_transformer import (
Attention,
Block,
PatchEmbed,
VisionTransformer,
)
from resizing_interface import ResizingInterface
class _MatrixSaveAttn(Attention):
attn_mat = None
@classmethod
def cast(cls, attn: Attention):
assert isinstance(attn, Attention), "Can only save attention from Timms attention class"
attn.__class__ = cls
assert isinstance(attn, _MatrixSaveAttn)
return attn
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
self.attn_mat = attn.detach()
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return self.proj_drop(x)
def _index_picker(tensor, idx=-1):
"""Pick a specific index from a tensor.
tensor: B x N x D -> B x D, by picking idx from N.
Args:
tensor (toch.tensor): tensor to pick from.
idx (int, optional): index to pick. Defaults to -1.
Returns:
torch.tensor: index from tensor
"""
return tensor[..., idx, :] # B x N x D -> B x D
class TimmViT(VisionTransformer, ResizingInterface):
"""Wrapper for *VisionTransformer* from *timm* library (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)."""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool="token",
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
qk_norm=False,
init_values=None,
class_token=True,
no_embed_class=True,
pre_norm=False,
fc_norm=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
weight_init="",
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
save_attention_maps=False,
fused_attn=True,
**kwargs,
):
"""Create a Vision Transformer model.
Parameters
----------
img_size : int
image dimensions -> img_size x img_size
patch_size : int
patch_size
in_chans : int
number of image channels
num_classes : int
number of classes for classification head
global_pool : str,int
type of global pooling for final sequence (default: 'token'), or index for token to be taken
embed_dim : int
embedding dimension
depth : int
number of transformer layers
num_heads : int
number of transformer heads
mlp_ratio : float
ratio of feed forward (mlp) hidden dimension to embedding dimension
qkv_bias : bool
enable bias for query, key, and value (qkv) embeddings
qk_norm : bool
normalize query and key embeddings
init_values : float
layer scale initial values
class_token : bool
use a class token [CLS]
no_embed_class : bool
no positional embedding for the class token
pre_norm : bool
use pre-norm architecture (norm before the blocks, not after)
fc_norm : bool
norm after pool (used when global_pool == 'avg')
drop_rate : float
dropout rate
attn_drop_rate : float
dropout rate in the attention module
drop_path_rate : float
drop path rate (stochastic depth)
weight_init : str
scheme for weight initialization
embed_layer : nn.Module
patch embedding layer
norm_layer : nn.Module
normalization layer
act_layer : nn.Module
activation function
block_fn : nn.Module
which block structure to use; for parallel attention layers, ...
save_attention_maps : bool
save attention maps for each block
fused_attn : bool
use fused attention
kwargs : dict
additional arguments (will be ignored)
"""
init_kwargs = dict(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
num_classes=num_classes,
global_pool=global_pool if isinstance(global_pool, str) else "avg",
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
class_token=class_token,
no_embed_class=no_embed_class,
pre_norm=pre_norm,
fc_norm=fc_norm,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
weight_init=weight_init,
embed_layer=embed_layer,
norm_layer=norm_layer,
act_layer=act_layer,
block_fn=block_fn,
)
self.new_version = False # TODO: check based on the timm version
self.global_pool = global_pool
if isinstance(global_pool, int):
self.attn_pool = partial(_index_picker, idx=global_pool)
if self.new_version:
init_kwargs["qk_norm"] = qk_norm
init_kwargs["proj_drop_rate"] = drop_rate
super(TimmViT, self).__init__(**init_kwargs)
self.embed_layer = embed_layer
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.pre_norm = pre_norm
self.class_token = class_token
self.no_embed_class = no_embed_class
self.num_classes = num_classes
self.num_heads = num_heads
self.depth = depth
self.save_attention_maps = save_attention_maps
if save_attention_maps:
self.do_save_attention_maps()
try:
for block in self.blocks:
block.attn.fused_attn = fused_attn and self.blocks[0].attn.fused_attn
use_fused = self.blocks[0].attn.fused_attn
logger.info(f"Use fused attention: {use_fused}")
except: # I'm lazy for now # noqa: E722
pass
def do_save_attention_maps(self):
self.save_attention_maps = True
for block in self.blocks:
block.attn = _MatrixSaveAttn.cast(block.attn)
def attention_maps(self):
assert self.save_attention_maps, "Have to save attention maps first"
return [getattr(block.attn, "attn_mat", None) for block in self.blocks]
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
if self.new_version:
x = self.patch_drop(x)
x = self.norm_pre(x)
x = self.blocks(x)
return self.norm(x)
def forward(self, x):
x = self.forward_features(x)
return self.forward_head(x)

View File

@@ -0,0 +1,118 @@
"""Default configuration parameters.
Attributes:
default_kwargs (dict): Default hyperparameters for the training process.
slurm_defaults (dict): Default values for SLURM batch job settings.
"""
from paths_config import user
default_kwargs = {
"amp": True,
"aug_color_jitter_factor": 0.3,
"aug_crop": True,
"aug_cutmix_alpha": 1.0,
"aug_flip": True,
"aug_gauss_blur": True,
"aug_grayscale": True,
"aug_mixup_alpha": 0.0,
"aug_normalize": True,
"aug_rand_rot": 0,
"aug_random_erase_count": 1,
"aug_random_erase_mode": "pixel",
"aug_random_erase_prob": 0.0,
"aug_repeated_augment_repeats": 1,
"aug_resize": True,
"aug_solarize": True,
"augment_engine": "torchvision",
"augment_strategy": "3-augment",
"auto_augment_strategy": "rand-m9-mstd0.5-inc1",
"batch_size": 2048,
"compile_model": False,
"cuda": True,
"custom_dataset_path": None,
"debug": False,
"drop_path_rate": 0.05,
"dropout": 0.0,
"eval_amp": True,
"experiment_name": "none",
"fused_attn": True,
"gather_stats_during_training": True,
"imsize": 224,
"input_dim": None,
"keep_interm_states": 2,
"label_smoothing": 0.1,
"layer_scale": True,
"layer_scale_init_values": 1e-4,
"log_level": "info",
"loss": "ce",
"loss_weight": "none",
"lr": 3e-3,
"max_grad_norm": 1.0,
"max_seq_len": None,
"min_lr": 1e-5,
"momentum": 0.0,
"num_heads": None,
"num_workers": 44,
"opt": "fusedlamb",
"opt_eps": 1e-7,
"pin_memory": False,
"pre_norm": False,
"prefetch_factor": 2,
"qkv_bias": True,
"run_name": None,
"save_epochs": 10,
"sched": "cosine",
"seed": None,
"shuffle": True,
"tqdm": True,
"wandb": True,
"warmup_epochs": 5,
"warmup_lr": 1e-6,
"warmup_sched": "linear",
"weight_decay": 0.02,
"weighted_sampler": False,
}
# , 'model_ema': True, 'model_ema_decay': 0.99996}
deit_kwargs = {
"aug_mixup_alpha": 0.8,
"aug_repeated_augment_repeats": 3,
"augment_strategy": "deit",
"aug_random_erase_prob": 0.25,
"batch_size": 1024,
"lr": 1e-3,
"max_grad_norm": 0.0,
"num_workers": 10,
"opt": "adamw",
"opt_eps": 1e-8,
"weight_decay": 0.05,
}
def get_default_kwargs(settings="deitiii"):
if settings.lower() == "deitiii":
return default_kwargs
if settings.lower() == "deit":
return {**default_kwargs, **deit_kwargs}
raise NotImplementedError(f"No such defaults setting: {settings}")
slurm_defaults = {
"after_job": None,
"container_image": f"PATH/TO/ENROOT/IMAGE",
"container_mounts": f'MOUNT_ALL_IMPORTANT_STORAGE_SERVERS_HERE,"`pwd`":"`pwd`"',
"container_workdir": '"`pwd`"',
"cpus_per_task": 24,
"exclude": None,
"export": "ALL,TQDM_DISABLE=1",
"job_name": None,
"mem_per_gpu": 90,
"nodes": 1,
"ntasks": 4,
"partition": ["A100", "H100", "H200"],
"task_prolog": None,
"time": "1-0",
}

View File

@@ -0,0 +1,142 @@
import albumentations as A
import cv2
import numpy as np
from albumentations.pytorch import ToTensorV2
from datadings.torch import CompressedToPIL
class AlbumTorchCompose(A.Compose):
"""Compose albumentation augmentations in a way that works with PIL images and datadings."""
def __init__(self, *args, **kwargs):
"""Pass to A.Compose."""
super().__init__(*args, **kwargs)
self.to_pil = CompressedToPIL()
def __call__(self, image, mask=None, **kwargs):
if isinstance(image, bytes):
image = self.to_pil(image)
if mask is not None and len(mask) == 0:
mask = None
if not isinstance(image, np.ndarray):
image = np.array(image)
if mask is not None and not isinstance(mask, np.ndarray):
mask = np.array(mask)
if mask is None:
return super().__call__(image=image, **kwargs)["image"]
return super().__call__(image=image, mask=mask, **kwargs)
class PILToNP(A.DualTransform):
"""Convert PIL image to numpy array."""
def apply(self, image, **params):
return np.array(image)
def apply_to_mask(self, image, **params):
return np.array(image)
def get_transform_init_args_names(self):
return ()
class AlbumCompressedToPIL(A.DualTransform):
"""Convert compressed image to PIL image."""
def apply(self, img, **params):
return self.to_pil(img)
def apply_to_mask(self, img, **params):
return self.to_pil(img)
def get_transform_init_args_names(self):
return ()
def minimal_augment(args, test=False):
"""Get minimal augmentations for training or testing.
Args:
args (argparse.Namespace): arguments
test (bool, optional): if True, return test augmentations. Defaults to False.
Returns:
List: Augmentation list
"""
augs = []
if args.aug_resize:
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
if test and args.aug_crop:
augs.append(A.CenterCrop(args.imsize, args.imsize))
elif args.aug_crop:
augs.append(A.RandomCrop(args.imsize, args.imsize))
if not test and args.aug_flip:
augs.append(A.HorizontalFlip(p=0.5))
if args.aug_normalize:
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
augs.append(ToTensorV2())
return augs
def three_augment(args, as_list=False, test=False):
"""Create the data augmentation.
Args:
args (Namespace): arguments
as_list (bool): return list of transformations, not composed transformation
test (bool): In eval mode? If False => train mode
Returns:
torch.nn.Module | list[torch.nn.Module]: composed transformation or list of transformations
"""
augs = []
if args.aug_resize:
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
if test and args.aug_crop:
augs.append(A.CenterCrop(args.imsize, args.imsize))
elif args.aug_crop:
augs.append(A.RandomCrop(args.imsize, args.imsize, pad_if_needed=True, border_mode=cv2.BORDER_REFLECT))
if not test:
if args.aug_flip:
augs.append(A.HorizontalFlip(p=0.5))
augs_choice = []
if args.aug_grayscale:
augs_choice.append(A.ToGray(p=1, num_output_channels=3))
if args.aug_solarize:
augs_choice.append(A.Solarize(p=1, threshold_range=(0.5, 0.5)))
if args.aug_gauss_blur:
augs_choice.append(A.GaussianBlur(p=1, sigma_limit=(0.2, 2.0), blur_limit=(7, 7)))
if len(augs_choice) > 0:
augs.append(A.OneOf(augs_choice))
if args.aug_color_jitter_factor > 0.0:
augs.append(
A.ColorJitter(
brightness=args.aug_color_jitter_factor,
contrast=args.aug_color_jitter_factor,
saturation=args.aug_color_jitter_factor,
hue=0.0,
)
)
if args.aug_normalize:
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
augs.append(ToTensorV2())
if as_list:
return augs
return AlbumTorchCompose(augs)

View File

@@ -0,0 +1,63 @@
import os
from loguru import logger
from PIL import Image
from torch.utils.data import Dataset
class CounterAnimal(Dataset):
"""Dataset to load the CounterAnimal dataset with ImageNet labels."""
def __init__(self, base_path, mode, transform=None, target_transform=None, train=False):
"""Create the dataset.
Args:
base_path (str): path to the base folder (the one where the class folders are in)
mode (str): mode/variant of the dataset (common/counter)
transform: Image augmentation
target_transform: label augmentation
train: train or test set. Train set is not supported
"""
super().__init__()
self.base = base_path
assert mode in ["counter", "common"], f"Supported modes are counter and common, but got '{mode}'"
assert not train, "CounterAnimal only consists of test data, not training data."
self.transform = transform
self.target_transform = target_transform
self.index = []
for class_folder in os.listdir(self.base):
if not os.path.isdir(os.path.join(self.base, class_folder)):
continue
# print(f"looking in folder {class_folder}")
class_idx = int(class_folder.split(" ")[0])
for variant_folder in os.listdir(os.path.join(self.base, class_folder)):
# print(f"\tlooking in variant {variant_folder}")
if not variant_folder.startswith(mode):
# print("\t\tskip")
continue
_folder = os.path.join(self.base, class_folder, variant_folder)
# print(f"\t\tadding {len(os.listdir(_folder))} files to index")
for file in os.listdir(_folder):
if file.lower().split(".")[-1] in ["jpg", "jpeg", "png"]:
self.index.append((os.path.join(_folder, file), class_idx))
# print(f"loaded {len(self.index)} images into the index: {self.index[:5]}...")
assert len(self.index) > 0, "did not find any images :("
def __len__(self):
return len(self.index)
def __getitem__(self, idx):
path, label = self.index[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
if self.target_transform:
label = self.target_transform(label)
return img, label

View File

@@ -0,0 +1,145 @@
from nvidia.dali import fn, pipeline_def, types
# see https://docs.nvidia.com/deeplearning/dali/user-guide/docs/plugins/pytorch_dali_proxy.html
@pipeline_def
def minimal_augment(args, test=False):
"""Minimal Augmentation set for images.
Contains only resize, crop, flip, to tensor and normalize.
Args:
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
test (bool, optional): On the test set? Defaults to False.
Returns:
images: augmented images.
"""
images = fn.external_source(name="images", no_copy=True)
if args.aug_resize:
images = fn.resize(images, size=args.imsize, mode="not_smaller")
if test and args.aug_crop:
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
elif args.aug_crop:
images = fn.crop(
images,
crop=(args.imsize, args.imsize),
crop_pos_x=fn.random.uniform(range=(0, 1)),
crop_pos_y=fn.random.uniform(range=(0, 1)),
)
# if not test and args.aug_flip:
# images = fn.flip(images, horizontal=fn.random.coin_flip())
# if args.aug_normalize:
# images = fn.normalize(
# images,
# mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
# std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
# dtype=types.FLOAT,
# )
return fn.crop_mirror_normalize(
images,
dtype=types.FLOAT,
output_layout="CHW",
crop=(args.imsize, args.imsize),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
)
def dali_solarize(images, threshold=128):
"""Solarize implementation for nvidia DALI.
Args:
images (DALI Tensor): Images to solarize.
threshold (int, optional): Threshold for solarization. Defaults to 128.
Returns:
images: solarized images.
"""
inv_images = types.Constant(255).uint8() - images
mask = (images >= threshold) * types.Constant(1).uint8()
return mask * inv_images + (types.Constant(1).uint8() ^ mask) * images
@pipeline_def(enable_conditionals=True)
def three_augment(args, test=False):
"""3-augment data augmentation pipeline for nvidia DALI.
Args:
args (namespace): augmentation arguments.
test (bool, optional): Test (or train) split. Defaults to False.
Returns:
images: augmented images.
"""
images = fn.external_source(name="images", no_copy=True)
if args.aug_resize:
images = fn.resize(images, size=args.imsize, mode="not_smaller")
if test and args.aug_crop:
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
elif args.aug_crop:
images = fn.crop(
images,
crop=(args.imsize, args.imsize),
crop_pos_x=fn.random.uniform(range=(0, 1)),
crop_pos_y=fn.random.uniform(range=(0, 1)),
)
if not test:
choices = []
# choice = fn.random.choice(3)
# print(images.layout())
choice_ps = [1 * args.aug_grayscale, 1 * args.aug_solarize, 1 * args.aug_gauss_blur]
choice_ps = [c / sum(choice_ps) for c in choice_ps]
choice = fn.random.choice(
[0, 1, 2],
p=choice_ps,
)
if choice == 0:
images = fn.color_space_conversion(
fn.color_space_conversion(images, image_type=types.RGB, output_type=types.GRAY),
image_type=types.GRAY,
output_type=types.RGB,
)
elif choice == 1:
images = dali_solarize(images, threshold=128)
elif choice == 2:
images = fn.gaussian_blur(images, window_size=7, sigma=fn.random.uniform(range=(0.2, 2.0)))
if len(choices) > 0:
images = fn.random.choice(choices)
if args.aug_color_jitter_factor > 0.0:
images = fn.color_twist(
images,
brightness=fn.random.uniform(
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
),
contrast=fn.random.uniform(range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)),
saturation=fn.random.uniform(
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
),
)
return fn.crop_mirror_normalize(
images,
dtype=types.FLOAT,
output_layout="CHW",
crop=(args.imsize, args.imsize),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
)

View File

@@ -0,0 +1,407 @@
from random import uniform
import msgpack
import torch
import torchvision
from datadings.torch import CompressedToPIL
from datadings.torch import Dataset as DDDataset
from PIL import ImageFilter
from torchvision.transforms import (
CenterCrop,
ColorJitter,
Compose,
GaussianBlur,
Grayscale,
InterpolationMode,
Normalize,
RandomChoice,
RandomCrop,
RandomHorizontalFlip,
RandomResizedCrop,
RandomSolarize,
Resize,
ToTensor,
)
from torchvision.transforms import functional as F
_image_and_target_transforms = [
torchvision.transforms.RandomCrop,
torchvision.transforms.RandomHorizontalFlip,
torchvision.transforms.CenterCrop,
torchvision.transforms.RandomRotation,
torchvision.transforms.RandomAffine,
torchvision.transforms.RandomResizedCrop,
torchvision.transforms.RandomRotation,
]
def apply_dense_transforms(x, y, transforms: torchvision.transforms.transforms.Compose):
"""Apply some transfomations to both image and target.
Args:
x (torch.Tensor): image
y (torch.Tensor): target (image)
transforms (torchvision.transforms.transforms.Compose): transformations to apply
Returns:
tuple[torch.Tensor, torch.Tensor]: (x, y) with applyed transformations
"""
for trans in transforms.transforms:
if isinstance(trans, torchvision.transforms.RandomResizedCrop):
params = trans.get_params(x, trans.scale, trans.ratio)
x = F.resized_crop(x, *params, trans.size, trans.interpolation, antialias=trans.antialias)
y = F.resized_crop(y.unsqueeze(0), *params, trans.size, 0).squeeze(0) # nearest neighbor interpolation
elif isinstance(trans, Resize):
pre_shape = x.shape
x = trans(x)
if x.shape != pre_shape:
y = F.resize(y.unsqueeze(0), trans.size, 0, trans.max_size, trans.antialias).squeeze(
0
) # nearest neighbor interpolation
elif any(isinstance(trans, simul_transform) for simul_transform in _image_and_target_transforms):
xy = torch.cat([x, y.unsqueeze(0).float()], dim=0)
xy = trans(xy)
x, y = xy[:-1], xy[-1].long()
elif isinstance(trans, torchvision.transforms.ToTensor):
if not isinstance(x, torch.Tensor):
x = trans(x)
else:
x = trans(x)
return x, y
def get_hf_transform(transform_f, trgt_transform_f=None, image_key="image"):
"""Convert the transform function to a huggingface compatible transform function.
Args:
transform_f (callable): Image transform.
trgt_transform (callable, optional): Target transform. Defaults to None.
image_key (str, optional): Key for the image in the hf ds return dict. Defaults to "image".
"""
def _transform(samples):
try:
samples[image_key] = [transform_f(im) for im in samples[image_key]]
if trgt_transform_f is not None:
samples["label"] = [trgt_transform_f(tgt) for tgt in samples["label"]]
except TypeError as e:
print(f"Type error when transforming samples: {samples}")
raise e
return samples
return _transform
class DDDecodeDataset(DDDataset):
"""Datadings dataset with image decoding before transform."""
def __init__(self, *args, transform=None, target_transform=None, transforms=None, **kwargs):
"""Create datadings dataset.
Args:
transform (callable, optional): Image transform. Overrides transforms['image']. Defaults to None.
target_transform (callable, optional): Label transform. Overrides transforms['label']. Defaults to None.
transforms (dict[str, callable], optional): Dict of transforms for each key. Defaults to None.
"""
super().__init__(*args, **kwargs)
if transforms is None:
transforms = {}
self._decode_transform = transform if transform is not None else transforms.get("image", None)
self._decode_target_transform = (
target_transform if target_transform is not None else transforms.get("label", None)
)
self.ctp = CompressedToPIL()
def __getitem__(self, idx):
sample = super().__getitem__(idx)
img, lbl = sample["image"], sample["label"]
if isinstance(img, bytes):
img = self.ctp(img)
if self._decode_transform is not None:
img = self._decode_transform(img)
if self._decode_target_transform is not None:
lbl = self._decode_target_transform(lbl)
return img, lbl
def minimal_augment(args, test=False):
"""Minimal Augmentation set for images.
Contains only resize, crop, flip, to tensor and normalize.
Args:
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
test (bool, optional): On the test set? Defaults to False.
Returns:
List: Augmentation list
"""
augs = []
augs.append(ToTensor())
if args.aug_resize:
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
if test and args.aug_crop:
augs.append(CenterCrop(args.imsize))
elif args.aug_crop:
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
if not test and args.aug_flip:
augs.append(RandomHorizontalFlip(p=0.5))
if args.aug_normalize:
augs.append(
Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225]),
)
)
return augs
def three_augment(args, as_list=False, test=False):
"""Create the data augmentation.
Parameters
----------
Args:
arguments
as_list : bool
return list of transformations, not composed transformation
test : bool
In eval mode? If False => train mode
Returns:
-------
torch.nn.Module | list[torch.nn.Module]
composed transformation of list of transformations
"""
augs = []
augs.append(ToTensor())
if args.aug_resize:
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
if test and args.aug_crop:
augs.append(CenterCrop(args.imsize))
elif args.aug_crop:
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
if not test:
if args.aug_flip:
augs.append(RandomHorizontalFlip(p=0.5))
augs_choice = []
if args.aug_grayscale:
augs_choice.append(Grayscale(num_output_channels=3))
if args.aug_solarize:
augs_choice.append(RandomSolarize(threshold=0.5, p=1.0))
if args.aug_gauss_blur:
# TODO: check kernel size?
augs_choice.append(GaussianBlur(kernel_size=7, sigma=(0.2, 2.0)))
# augs_choice.append(QuickGaussBlur())
if len(augs_choice) > 0:
augs.append(RandomChoice(augs_choice))
if args.aug_color_jitter_factor > 0.0:
augs.append(
ColorJitter(
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
)
)
if args.aug_normalize:
augs.append(
Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225]),
)
)
if as_list:
return augs
return Compose(augs)
def segment_augment(args, test=False):
"""Create the data augmentation for segmentation.
No cropping in this part, as cropping has to be done for the image and labels simultaneously.
Args:
args (DotDict): arguments
test (bool, optional): In eval mode? If False => train mode. Defaults to False.
Returns:
list[torch.nn.Module]: list of transformations
"""
augs = []
if test:
augs.append(ResizeUp(args.imsize))
augs.append(CenterCrop(args.imsize))
else:
augs.append(RandomResizedCrop(args.imsize, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)))
if not test and args.aug_flip:
augs.append(RandomHorizontalFlip(p=0.5))
if args.aug_color_jitter_factor > 0.0:
augs.append(
ColorJitter(
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
args.aug_color_jitter_factor,
)
)
if args.aug_normalize:
augs.append(
Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225]),
)
)
return augs
class QuickGaussBlur:
"""Gaussian blur transformation using PIL ImageFilter."""
def __init__(self, sigma=(0.2, 2.0)):
"""Create Gaussian blur operator.
Args:
-----
sigma : tuple[float, float]
range of sigma for blur
"""
self.sigma_min, self.sigma_max = sigma
def __call__(self, img):
return img.filter(ImageFilter.GaussianBlur(radius=uniform(self.sigma_min, self.sigma_max)))
class RemoveTransform:
"""Remove data from transformation.
To use with default collate function.
"""
def __call__(self, x, y=None):
if y is None:
return [1]
return [1], y
def __repr__(self):
return f"{self.__class__.__name__}()"
def collate_imnet(data, image_key="image"):
"""Collate function for imagenet(1k / 21k) with datadings.
Args:
----
data : list[dict[str, Any]]
images for a batch
Returns:
-------
tuple[torch.Tensor, torch.Tensor]
images, labels
"""
if isinstance(data[0][image_key], torch.Tensor):
ims = torch.stack([d[image_key] for d in data], dim=0)
else:
ims = [d[image_key] for d in data]
labels = torch.tensor([d["label"] for d in data])
# keys = [d['key'] for d in data]
return ims, labels # , keys
def collate_listops(data):
"""Collate function for ListOps with datadings.
Args:
----
data : list[tuple[torch.Tensor, torch.Tensor]]
list of samples
Returns:
-------
tuple[torch.Tensor, torch.Tensor]
images, labels
"""
return data[0][0], data[0][1]
def no_param_transf(self, sample):
"""Call transformation without extra parameter.
To use with datadings QuasiShuffler.
Args:
----
self : object
use this as a method ( <obj>.<method_name> = MethodType(no_param_transf, <obj>) )
sample : Any
sample to transform
Returns:
-------
Any
transformed sample
"""
if isinstance(sample, tuple):
# sample of type (name (str), data (bytes encoded))
sample = sample[1]
if isinstance(sample, bytes):
# decode msgpack bytes
sample = msgpack.loads(sample)
params = self._rng(sample)
for k, f in self._transforms.items():
sample[k] = f(sample[k], params)
return sample
class ToOneHotSequence:
"""Convert a sequence of grayscale values (range 0 to 1) to a one-hot encoded sequence based on 8-bit rounded values."""
def __call__(self, x, y=None):
# x is 1 x 32 x 32
x = (255 * x).round().to(torch.int64).view(-1)
assert x.max() < 256, f"Found max value {x.max()} in {x}."
x = torch.nn.functional.one_hot(x, num_classes=256).float()
if y is None:
return x
return x, y
def __repr__(self):
return f"{self.__class__.__name__}()"
class ResizeUp(Resize):
"""Resize up if image is smaller than target size."""
def forward(self, img):
w, h = img.shape[-2], img.shape[-1]
if w < self.size or h < self.size:
img = super().forward(img)
return img

View File

@@ -0,0 +1,484 @@
import json
import os
import zipfile
from io import BytesIO
from math import floor
import numpy as np
import PIL
from datadings.torch import Compose
from loguru import logger
from PIL import Image, ImageFilter
from torch.utils.data import Dataset, get_worker_info
from torchvision import transforms as T
from tqdm.auto import tqdm
from data.data_utils import apply_dense_transforms
class ForNet(Dataset):
"""Recombine ImageNet forgrounds and backgrounds.
Note:
This dataset has exactly the ImageNet classes.
"""
_back_combs = ["same", "all", "original"]
_bg_transforms = {T.RandomCrop, T.CenterCrop, T.Resize, T.RandomResizedCrop}
def __init__(
self,
root,
transform=None,
train=True,
target_transform=None,
background_combination="all",
fg_scale_jitter=0.3,
fg_transform=None,
pruning_ratio=0.8,
return_fg_masks=False,
fg_size_mode="range",
fg_bates_n=1,
paste_pre_transform=True,
mask_smoothing_sigma=4.0,
rel_jut_out=0.0,
fg_in_nonant=None,
size_fact=1.0,
orig_img_prob=0.0,
orig_ds=None,
_orig_ds_file_type="JPEG",
epochs=0,
_album_compose=False,
):
"""Create RecombinationNet dataset.
Args:
root (str): Root folder for the dataset.
transform (T.Collate | list, optional): Transform to apply to the image. Defaults to None.
train (bool, optional): On the train set (False -> val set). Defaults to True.
target_transform (T.Collate | list, optional): Transform to apply to the target values. Defaults to None.
background_combination (str, optional): Which backgrounds to combine with foregrounds. Defaults to "same".
fg_scale_jitter (tuple, optional): How much should the size of the foreground be changed (random ratio). Defaults to (0.1, 0.8).
fg_transform (_type_, optional): Transform to apply to the foreground before applying to the background. This is supposed to be a random rotation, mainly. Defaults to None.
pruning_ratio (float, optional): For pruning backgrounds, with (foreground size/background size) >= <pruning_ratio>. Backgrounds from images that contain very large foreground objects are mostly computer generated and therefore relatively unnatural. Defaults to full dataset.
return_fg_masks (bool, optional): Return the foreground masks. Defaults to False.
fg_size_mode (str, optional): How to determine the size of the foreground, based on the foreground sizes of the foreground and background images. Defaults to "max".
fg_bates_n (int, optional): Bates parameter for the distribution of the object position in the foreground. Defaults to 1 (uniform distribution). The higher the value, the more likely the object is in the center. For fg_bates_n = 0, the object is always in the center.
paste_pre_transform (bool, optional): Paste the foreground onto the background before applying the transform. If false, the background will be cropped and resized before pasting the foreground. Defaults to False.
mask_smoothing_sigma (float, optional): Sigma for the Gaussian blur of the mask edge. Defaults to 0.0. Try 2.0 or 4.0?
rel_jut_out (float, optional): How much is the foreground allowed to stand/jut out of the background (and then cut off). Defaults to 0.0.
fg_in_nonant (int, optional): If not None, the foreground will be placed in a specific nonant (0-8) of the image. Defaults to None.
size_fact (float, optional): Factor to multiply the size of the foreground with. Defaults to 1.0.
orig_img_prob (float | str, optional): Probability to use the original image, instead of the fg-bg recombinations. Defaults to 0.0.
orig_ds (Dataset, optional): Original dataset (without transforms) to use for the original images. Defaults to None.
_orig_ds_file_type (str, optional): File type of the original dataset. Defaults to "JPEG".
epochs (int, optional): Number of epochs to train on. Used for linear increase of orig_img_prob.
Note:
For more information on the bates distribution, see https://en.wikipedia.org/wiki/Bates_distribution.
For fg_bats_n < 0, we take extend the bates dirstribution to focus more and more on the edges. This is done by sampling B ~ Bates(|fg_bates_n|) and then passing through f(x) = x + 0.5 - floor(x + 0.5).
For the list of transformations that will be applied to the background only (if paste_pre_transform=False), see RecombinationNet._bg_transforms.
A nonant in this case refers to a square in a 3x3 grid dividing the image.
"""
assert (
background_combination in self._back_combs
), f"background_combination={background_combination} is not supported. Use one of {self._back_combs}"
if (not os.path.exists(f"{root}/backgrounds_{'train' if train else 'val'}.zip")) and os.path.exists(
os.path.join(root, "train" if train else "val", "backgrounds")
):
self._mode = "folder"
else:
self._mode = "zip"
if self._mode == "zip":
try:
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip:
self.backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
with zipfile.ZipFile(f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r") as fg_zip:
self.foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
except FileNotFoundError as e:
logger.error(
f"RecombinationNet: {e}. Make sure to have the background and foreground zips in the root"
f" directory: found {os.listdir(root)}"
)
raise e
classes = set([f.split("/")[-2] for f in self.foregrounds])
else:
logger.info("ForNet folder mode: loading classes")
classes = set(os.listdir(os.path.join(root, "train" if train else "val", "foregrounds")))
foregrounds = []
backgrounds = []
for cls in tqdm(classes, desc="Loading files"):
foregrounds.extend(
[
f"{cls}/{f}"
for f in os.listdir(os.path.join(root, "train" if train else "val", "foregrounds", cls))
]
)
backgrounds.extend(
[
f"{cls}/{f}"
for f in os.listdir(os.path.join(root, "train" if train else "val", "backgrounds", cls))
]
)
self.foregrounds = foregrounds
self.backgrounds = backgrounds
self.classes = sorted(list(classes), key=lambda x: int(x[1:]))
assert os.path.exists(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json"), (
f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json not found, provide the information or set"
" pruning_ratio=1.0"
)
with open(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json", "r") as f:
self.fg_bg_ratios = json.load(f)
if self._mode == "folder":
self.fg_bg_ratios = {"/".join(key.split("/")[-2:]): val for key, val in self.fg_bg_ratios.items()}
logger.info(f"Renamed fg_bg_ratios keys to {list(self.fg_bg_ratios.keys())[:3]}...")
if pruning_ratio <= 1.0:
backup_backgrounds = {}
for bg_file in self.backgrounds:
bg_cls = bg_file.split("/")[-2]
backup_backgrounds[bg_cls] = bg_file
self.backgrounds = [
bg for bg in self.backgrounds if bg in self.fg_bg_ratios and self.fg_bg_ratios[bg] < pruning_ratio
]
# logger.info(
# f"RecombinationNet: keep {len(self.backgrounds)} of {len(self.fg_bg_ratios)} backgrounds with pr {pruning_ratio}"
# )
self.root = root
self.train = train
self.background_combination = background_combination
self.fg_scale_jitter = fg_scale_jitter
self.fg_transform = fg_transform
self.return_fg_masks = return_fg_masks
self.paste_pre_transform = paste_pre_transform
self.mask_smoothing_sigma = mask_smoothing_sigma
self.rel_jut_out = rel_jut_out
self.size_fact = size_fact
self.fg_in_nonant = fg_in_nonant
assert fg_in_nonant is None or -1 <= fg_in_nonant < 9, f"fg_in_nonant={fg_in_nonant} not in [0, 8] or None"
self.orig_img_prob = orig_img_prob
if orig_img_prob != 0.0:
assert (isinstance(orig_img_prob, float) and orig_img_prob > 0.0) or orig_img_prob in [
"linear",
"cos",
"revlinear",
]
assert orig_ds is not None, "orig_ds must be provided if orig_img_prob > 0.0"
assert not return_fg_masks, "can't provide fg masks for original images (yet)"
assert os.path.exists(os.path.join(root, f"{'train' if train else 'val'}_indices.json")) or isinstance(
orig_ds, str
), f"{root}/{'train' if train else 'val'}_indices.json must be provided if orig_ds is a dataset"
if not isinstance(orig_ds, str):
with open(os.path.join(root, f"{'train' if train else 'val'}_indices.json"), "r") as f:
self.key_to_orig_idx = json.load(f)
else:
if not (
orig_ds.endswith("train" if train else "val") or orig_ds.endswith("train/" if train else "val/")
):
orig_ds = f"{orig_ds}/{'train' if train else 'val'}"
self.key_to_orig_idx = None
self._orig_ds_file_type = _orig_ds_file_type
self.orig_ds = orig_ds
self.epochs = epochs
self._epoch = 0
assert fg_size_mode in [
"max",
"min",
"mean",
"range",
], f"fg_size_mode={fg_size_mode} not supported; use one of ['max', 'min', 'mean', 'range']"
self.fg_size_mode = fg_size_mode
self.fg_bates_n = fg_bates_n
if not paste_pre_transform:
if isinstance(transform, (T.Compose, Compose)):
transform = transform.transforms
# do cropping and resizing mainly on background; paste foreground on top later
self.bg_transform = []
self.join_transform = []
for tf in transform:
if isinstance(tf, tuple(self._bg_transforms)):
self.bg_transform.append(tf)
else:
self.join_transform.append(tf)
if _album_compose:
from data.album_transf import AlbumTorchCompose
self.bg_transform = AlbumTorchCompose(self.bg_transform)
self.join_transform = AlbumTorchCompose(self.join_transform)
else:
self.bg_transform = T.Compose(self.bg_transform)
self.join_transform = T.Compose(self.join_transform)
else:
if isinstance(transform, list):
if _album_compose:
from data.album_transf import AlbumTorchCompose
self.join_transform = AlbumTorchCompose(transform)
else:
self.join_transform = T.Compose(transform)
else:
self.join_transform = transform
self.bg_transform = None
self.trgt_map = {cls: i for i, cls in enumerate(self.classes)}
self.target_transform = target_transform
self.cls_to_allowed_bg = {}
for bg_file in self.backgrounds:
if background_combination == "same":
bg_cls = bg_file.split("/")[-2]
if bg_cls not in self.cls_to_allowed_bg:
self.cls_to_allowed_bg[bg_cls] = []
self.cls_to_allowed_bg[bg_cls].append(bg_file)
if background_combination == "same":
for cls_code in classes:
if cls_code not in self.cls_to_allowed_bg or len(self.cls_to_allowed_bg[cls_code]) == 0:
self.cls_to_allowed_bg[cls_code] = [backup_backgrounds[cls_code]]
logger.warning(f"No background for class {cls_code}, using {backup_backgrounds[cls_code]}")
self._zf = {}
@property
def epoch(self):
return self._epoch
@epoch.setter
def epoch(self, value):
self._epoch = value
def __len__(self):
"""Size of the dataset.
Returns:
int: number of foregrounds
"""
return len(self.foregrounds)
def num_classes(self):
return len(self.classes)
def _get_fg(self, idx):
worker_id = self._wrkr_info()
fg_file = self.foregrounds[idx]
with self._zf[worker_id]["fg"].open(fg_file) as f:
fg_data = BytesIO(f.read())
return Image.open(fg_data)
def _wrkr_info(self):
worker_id = get_worker_info().id if get_worker_info() else 0
if worker_id not in self._zf and self._mode == "zip":
self._zf[worker_id] = {
"bg": zipfile.ZipFile(f"{self.root}/backgrounds_{'train' if self.train else 'val'}.zip", "r"),
"fg": zipfile.ZipFile(f"{self.root}/foregrounds_{'train' if self.train else 'val'}.zip", "r"),
}
return worker_id
def __getitem__(self, idx):
"""Get the foreground at index idx and combine it with a (random) background.
Args:
idx (int): foreground index
Returns:
torch.Tensor, torch.Tensor: image, target
"""
worker_id = self._wrkr_info()
fg_file = self.foregrounds[idx]
trgt_cls = fg_file.split("/")[-2]
if (
(self.orig_img_prob == "linear" and np.random.rand() < self._epoch / self.epochs)
or (self.orig_img_prob == "revlinear" and np.random.rand() < (self._epoch - self.epochs) / self.epochs)
or (self.orig_img_prob == "cos" and np.random.rand() > np.cos(np.pi * self._epoch / (2 * self.epochs)))
or (
isinstance(self.orig_img_prob, float)
and self.orig_img_prob > 0.0
and np.random.rand() < self.orig_img_prob
)
):
data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
if isinstance(self.orig_ds, str):
image_file = os.path.join(self.orig_ds, f"{data_key}.{self._orig_ds_file_type}")
orig_img = Image.open(image_file).convert("RGB")
else:
orig_data = self.orig_ds[self.key_to_orig_idx[data_key]]
orig_img = orig_data["image"] if isinstance(orig_data, dict) else orig_data[0]
if self.bg_transform:
orig_img = self.bg_transform(orig_img)
if self.join_transform:
orig_img = self.join_transform(orig_img)
trgt = self.trgt_map[trgt_cls]
if self.target_transform:
trgt = self.target_transform(trgt)
return orig_img, trgt
if self._mode == "zip":
with self._zf[worker_id]["fg"].open(fg_file) as f:
fg_data = BytesIO(f.read())
try:
fg_img = Image.open(fg_data).convert("RGBA")
except PIL.UnidentifiedImageError as e:
logger.error(f"Error with idx={idx}, file={self.foregrounds[idx]}")
raise e
else:
# data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
fg_img = Image.open(
os.path.join(self.root, "train" if self.train else "val", "foregrounds", fg_file)
).convert("RGBA")
if self.fg_transform:
fg_img = self.fg_transform(fg_img)
fg_size_factor = T.ToTensor()(fg_img.split()[-1]).mean().item()
if self.background_combination == "all":
bg_idx = np.random.randint(len(self.backgrounds))
bg_file = self.backgrounds[bg_idx]
elif self.background_combination == "original":
bg_file = fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")
else:
bg_idx = np.random.randint(len(self.cls_to_allowed_bg[trgt_cls]))
bg_file = self.cls_to_allowed_bg[trgt_cls][bg_idx]
if self._mode == "zip":
with self._zf[worker_id]["bg"].open(bg_file) as f:
bg_data = BytesIO(f.read())
bg_img = Image.open(bg_data).convert("RGB")
else:
bg_img = Image.open(
os.path.join(self.root, "train" if self.train else "val", "backgrounds", bg_file)
).convert("RGB")
if not self.paste_pre_transform:
bg_img = self.bg_transform(bg_img)
bg_size = bg_img.size
# choose scale factor, such that relative area is in fg_scale
bg_area = bg_size[0] * bg_size[1]
if self.fg_in_nonant is not None:
bg_area = bg_area / 9
# logger.info(f"background: size={bg_size} area={bg_area}")
# logger.info(f"fg_file={fg_file}, fg_bg_ratio_keys={list(self.fg_bg_ratios.keys())[:3]}...")
orig_fg_ratio = self.fg_bg_ratios[fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")]
bg_fg_ratio = self.fg_bg_ratios[bg_file]
if self.fg_size_mode == "max":
goal_fg_ratio_lower = goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
elif self.fg_size_mode == "min":
goal_fg_ratio_lower = goal_fg_ratio_upper = min(orig_fg_ratio, bg_fg_ratio)
elif self.fg_size_mode == "mean":
goal_fg_ratio_lower = goal_fg_ratio_upper = (orig_fg_ratio + bg_fg_ratio) / 2
else:
# range
goal_fg_ratio_lower = min(orig_fg_ratio, bg_fg_ratio)
goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
# logger.info(f"fg_bg_ratio={orig_fg_ratio}")
fg_scale = (
np.random.uniform(
goal_fg_ratio_lower * (1 - self.fg_scale_jitter), goal_fg_ratio_upper * (1 + self.fg_scale_jitter)
)
/ fg_size_factor
* self.size_fact
)
goal_shape_y = round(np.sqrt(bg_area * fg_scale * fg_img.size[1] / fg_img.size[0]))
goal_shape_x = round(np.sqrt(bg_area * fg_scale * fg_img.size[0] / fg_img.size[1]))
fg_img = fg_img.resize((goal_shape_x, goal_shape_y))
if fg_img.size[0] > bg_size[0] or fg_img.size[1] > bg_size[1]:
# random crop to fit
goal_w, goal_h = (min(fg_img.size[0], bg_size[0]), min(fg_img.size[1], bg_size[1]))
fg_img = T.RandomCrop((goal_h, goal_w))(fg_img) if self.train else T.CenterCrop((goal_h, goal_w))(fg_img)
# paste fg on bg
z1, z2 = (
(
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(), # bates distribution n=1 => uniform
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(),
)
if self.fg_bates_n != 0
else (0.5, 0.5)
)
if self.fg_bates_n < 0:
z1 = z1 + 0.5 - floor(z1 + 0.5)
z2 = z2 + 0.5 - floor(z2 + 0.5)
x_min = -self.rel_jut_out * fg_img.size[0]
x_max = bg_size[0] - fg_img.size[0] * (1 - self.rel_jut_out)
y_min = -self.rel_jut_out * fg_img.size[1]
y_max = bg_size[1] - fg_img.size[1] * (1 - self.rel_jut_out)
if self.fg_in_nonant is not None and self.fg_in_nonant >= 0:
x_min = (self.fg_in_nonant % 3) * bg_size[0] / 3
x_max = ((self.fg_in_nonant % 3) + 1) * bg_size[0] / 3 - fg_img.size[0]
y_min = (self.fg_in_nonant // 3) * bg_size[1] / 3
y_max = ((self.fg_in_nonant // 3) + 1) * bg_size[1] / 3 - fg_img.size[1]
if x_min > x_max:
x_min = x_max = (x_min + x_max) / 2
if y_min > y_max:
y_min = y_max = (y_min + y_max) / 2
offs_x = round(z1 * (x_max - x_min) + x_min)
offs_y = round(z2 * (y_max - y_min) + y_min)
paste_mask = fg_img.split()[-1]
if self.mask_smoothing_sigma > 0.0:
sigma = (np.random.rand() * 0.9 + 0.1) * self.mask_smoothing_sigma
paste_mask = paste_mask.filter(ImageFilter.GaussianBlur(radius=sigma))
paste_mask = paste_mask.point(lambda p: 2 * p - 255 if p > 128 else 0)
bg_img.paste(fg_img.convert("RGB"), (offs_x, offs_y), paste_mask)
bg_img = bg_img.convert("RGB")
if self.return_fg_masks:
fg_mask = Image.new("L", bg_size, 0)
fg_mask.paste(paste_mask, (offs_x, offs_y))
fg_mask = T.ToTensor()(fg_mask)[0]
bg_img = T.ToTensor()(bg_img)
if self.join_transform:
# img_mask_stack = torch.cat([bg_img, fg_mask.unsqueeze(0)], dim=0)
# img_mask_stack = self.join_transform(img_mask_stack)
# bg_img, fg_mask = img_mask_stack[:-1], img_mask_stack[-1]
bg_img, fg_mask = apply_dense_transforms(bg_img, fg_mask, self.join_transform)
else:
bg_img = self.join_transform(bg_img)
if trgt_cls not in self.trgt_map:
raise ValueError(f"trgt_cls={trgt_cls} not in trgt_map: {self.trgt_map}")
trgt = self.trgt_map[trgt_cls]
if self.target_transform:
trgt = self.target_transform(trgt)
if self.return_fg_masks:
return bg_img, trgt, fg_mask
return bg_img, trgt

View File

@@ -0,0 +1,151 @@
import argparse
import shutil
import zipfile
from os import listdir, makedirs, path
from random import choice
from datadings.reader import MsgpackReader
from datadings.writer import FileWriter
from tqdm.auto import tqdm
parser = argparse.ArgumentParser()
parser.add_argument("-tiny_imagenet_zip", type=str, required=True, help="Path to the Tiny ImageNet zip file")
parser.add_argument("-output_dir", type=str, required=True, help="Directory to extract the image names to")
parser.add_argument("-in_segment_dir", type=str, required=True, help="Directory that holds the segmented ImageNet")
parser.add_argument(
"-imagenet_path", type=str, nargs="?", required=True, help="Path to the original ImageNet dataset (datadings)"
)
args = parser.parse_args()
images = {"train": set(), "val": set()}
with zipfile.ZipFile(args.tiny_imagenet_zip, "r") as zip_ref:
for info in tqdm(zip_ref.infolist(), desc="Gathering Images"):
if info.filename.endswith(".JPEG"):
if "/val/" in info.filename:
images["val"].add(info.filename.split("/")[-1])
elif "/train/" in info.filename:
images["train"].add(info.filename.split("/")[-1])
with open(path.join(args.output_dir, "tiny_imagenet_train_images.txt"), "w+") as f:
f.write("\n".join(images["train"]))
with open(path.join(args.output_dir, "tiny_imagenet_val_images.txt"), "w+") as f:
f.write("\n".join(images["val"]))
print(f"Found {len(images['train'])} training images and {len(images['val'])} validation images")
classes = {img_name.split("_")[0] for img_name in images["train"]}
classes = sorted(list(classes), key=lambda x: int(x[1:]))
assert len(classes) == 200, f"Expected 200 classes, found {len(classes)}"
assert len(images["train"]) == len(classes) * 500, f"Expected 100000 training images, found {len(images['train'])}"
assert len(images["val"]) == len(classes) * 50, f"Expected 10000 validation images, found {len(images['val'])}"
with open(path.join(args.output_dir, "tiny_imagenet_classes.txt"), "w+") as f:
f.write("\n".join(classes))
# copy over the relevant images
for split in ["train", "val"]:
ipc = 500 if split == "train" else 50
part = "foregrounds_WEBP"
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
for synset in classes:
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
tqdm.write(
f"skip {split} > {part} > {synset} with"
f" {len(listdir(path.join(args.output_dir, split, part, synset)))} ims"
)
pbar.update(ipc)
continue
for img in listdir(path.join(args.in_segment_dir, split, part, synset)):
orig_name = (
img.split(".")[0] + ".JPEG"
if split == "train"
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
)
if orig_name in images[split]:
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, img),
path.join(args.output_dir, split, part, synset, img),
)
pbar.update(1)
while len(listdir(path.join(args.output_dir, split, part, synset))) < min(
ipc, len(listdir(path.join(args.in_segment_dir, split, part, synset)))
):
# copy over more random images
image_names = [
(
img,
(
img.split(".")[0] + ".JPEG"
if split == "train"
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
),
)
for img in listdir(path.join(args.in_segment_dir, split, part, synset))
]
image_names = [
img for img in image_names if img[1] not in listdir(path.join(args.output_dir, split, part, synset))
]
img = choice(image_names)[0]
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, img),
path.join(args.output_dir, split, part, synset, img),
)
pbar.update(1)
tqdm.write(f"Extra image: {orig_name} to {split}/{part}/{synset}")
# copy over the background images corresponding to those foregrounds
part = "backgrounds_JPEG"
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
for synset in classes:
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
tqdm.write(f"skip {split} > {part} > {synset}")
pbar.update(ipc)
continue
for img in listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset)):
bg_name = img.replace(".WEBP", ".JPEG")
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
shutil.copy(
path.join(args.in_segment_dir, split, part, synset, bg_name),
path.join(args.output_dir, split, part, synset, bg_name),
)
pbar.update(1)
assert len(listdir(path.join(args.output_dir, split, part, synset))) == len(
listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset))
), (
f"Expected {len(listdir(path.join(args.output_dir, split, 'foregrounds_WEBP', synset)))} background"
f" images, found {len(listdir(path.join(args.output_dir, split, part, synset)))}"
)
# write the original dataset to datadings
for part in ["train", "val"]:
reader = MsgpackReader(path.join(args.imagenet_path, f"{part}.msgpack"))
with FileWriter(path.join(args.output_dir, f"TinyIN_{part}.msgpack")) as writer:
for data in tqdm(reader, desc=f"Writing {part} to datadings"):
key = data["key"].split("/")[-1]
allowed_synsets = [key.split("_")[0]] if part == "train" else classes
if part == "train" and allowed_synsets[0] not in classes:
continue
keep_image = False
label_synset = None
for synset in allowed_synsets:
for img in listdir(path.join(args.output_dir, part, "foregrounds_WEBP", synset)):
if img.split(".")[0] == key.split(".")[0]:
keep_image = True
label_synset = synset
break
if not keep_image:
continue
data["label"] = classes.index(label_synset)
writer.write(data)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,48 @@
import argparse
import os
import json
import torch
parser = argparse.ArgumentParser("Script to convert ImageNet trained models to ImageNet-9")
parser.add_argument("-m", "--model", type=str, required=True, help="Model weights (.pt file).")
parser.add_argument(
"--in_to_in9", type=str, default="/ds-sds/images/ImageNet-9/in_to_in9.json", help="Path to in_to_in9.json"
)
args = parser.parse_args()
checkpoint = torch.load(args.model, map_location="cpu")
model_state = checkpoint["model_state"]
head_keys = [k for k in model_state.keys() if ".head." in k or ".fc." in k]
print("weights that will be modified:", head_keys)
assert len(head_keys) > 0, "no head keys found :("
with open(args.in_to_in9, "r") as f:
in_to_in9_classes = json.load(f)
print(f"{len([k for k, v in in_to_in9_classes.items() if v == -1])} classes get mapped to -1")
print("map", len(in_to_in9_classes), " classes to", set(in_to_in9_classes.values()))
print("Building conversion matrix")
conversion_matrix = torch.zeros((9, 1000))
for in_idx, in9_idx in in_to_in9_classes.items():
if in9_idx == -1:
continue
in_idx = int(in_idx)
conversion_matrix[in9_idx, in_idx] = 1
print(f"Conversion matrix ({conversion_matrix.shape}) has {int(conversion_matrix.sum().item())} non-zero values")
for head_key in head_keys:
print(f"converting {head_key} of shape {model_state[head_key].shape}", end=" ")
model_state[head_key] = conversion_matrix @ model_state[head_key]
print(f"\tto {model_state[head_key].shape}")
checkpoint["model_state"] = model_state
checkpoint["args"]["n_classes"] = 9
save_folder = os.path.dirname(args.model)
orig_model_name = args.model.split(os.sep)[-1]
new_model_name = ".".join(orig_model_name.split(".")[:-1]) + "_to_in9." + orig_model_name.split(".")[-1]
print(f"saving model as {new_model_name} in {save_folder}")
torch.save(checkpoint, os.path.join(save_folder, new_model_name))

View File

@@ -0,0 +1,200 @@
n07695742: pretzel
n03902125: pay-phone, pay-station
n03980874: poncho
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
n02730930: apron
n02699494: altar
n03201208: dining table, board
n02056570: king penguin, Aptenodytes patagonica
n04099969: rocking chair, rocker
n04366367: suspension bridge
n04067472: reel
n02808440: bathtub, bathing tub, bath, tub
n04540053: volleyball
n02403003: ox
n03100240: convertible
n04562935: water tower
n02788148: bannister, banister, balustrade, balusters, handrail
n02988304: CD player
n02423022: gazelle
n03637318: lampshade, lamp shade
n01774384: black widow, Latrodectus mactans
n01768244: trilobite
n07614500: ice cream, icecream
n04254777: sock
n02085620: Chihuahua
n01443537: goldfish, Carassius auratus
n01629819: European fire salamander, Salamandra salamandra
n02099601: golden retriever
n02321529: sea cucumber, holothurian
n03837869: obelisk
n02002724: black stork, Ciconia nigra
n02841315: binoculars, field glasses, opera glasses
n04560804: water jug
n02364673: guinea pig, Cavia cobaya
n03706229: magnetic compass
n09256479: coral reef
n09332890: lakeside, lakeshore
n03544143: hourglass
n02124075: Egyptian cat
n02948072: candle, taper, wax light
n01950731: sea slug, nudibranch
n02791270: barbershop
n03179701: desk
n02190166: fly
n04275548: spider web, spider's web
n04417672: thatch, thatched roof
n03930313: picket fence, paling
n02236044: mantis, mantid
n03976657: pole
n01774750: tarantula
n04376876: syringe
n04133789: sandal
n02099712: Labrador retriever
n04532670: viaduct
n04487081: trolleybus, trolley coach, trackless trolley
n09428293: seashore, coast, seacoast, sea-coast
n03160309: dam, dike, dyke
n03250847: drumstick
n02843684: birdhouse
n07768694: pomegranate
n03670208: limousine, limo
n03085013: computer keyboard, keypad
n02892201: brass, memorial tablet, plaque
n02233338: cockroach, roach
n03649909: lawn mower, mower
n03388043: fountain
n02917067: bullet train, bullet
n02486410: baboon
n04596742: wok
n03255030: dumbbell
n03937543: pill bottle
n02113799: standard poodle
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
n02906734: broom
n07920052: espresso
n01698640: American alligator, Alligator mississipiensis
n02123394: Persian cat
n03424325: gasmask, respirator, gas helmet
n02129165: lion, king of beasts, Panthera leo
n04008634: projectile, missile
n03042490: cliff dwelling
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
n02815834: beaker
n02395406: hog, pig, grunter, squealer, Sus scrofa
n01784675: centipede
n03126707: crane
n04399382: teddy, teddy bear
n07875152: potpie
n03733131: maypole
n02802426: basketball
n03891332: parking meter
n01910747: jellyfish
n03838899: oboe, hautboy, hautbois
n03770439: miniskirt, mini
n02281406: sulphur butterfly, sulfur butterfly
n03970156: plunger, plumber's helper
n09246464: cliff, drop, drop-off
n02206856: bee
n02074367: dugong, Dugong dugon
n03584254: iPod
n04179913: sewing machine
n04328186: stopwatch, stop watch
n07583066: guacamole
n01917289: brain coral
n03447447: gondola
n02823428: beer bottle
n03854065: organ, pipe organ
n02793495: barn
n04285008: sports car, sport car
n02231487: walking stick, walkingstick, stick insect
n04465501: tractor
n02814860: beacon, lighthouse, beacon light, pharos
n02883205: bow tie, bow-tie, bowtie
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
n04149813: scoreboard
n04023962: punching bag, punch bag, punching ball, punchball
n02226429: grasshopper, hopper
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
n02669723: academic gown, academic robe, judge's robe
n04486054: triumphal arch
n04070727: refrigerator, icebox
n03444034: go-kart
n02666196: abacus
n01945685: slug
n04251144: snorkel
n03617480: kimono
n03599486: jinrikisha, ricksha, rickshaw
n02437312: Arabian camel, dromedary, Camelus dromedarius
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
n04118538: rugby ball
n01770393: scorpion
n04356056: sunglasses, dark glasses, shades
n03804744: nail
n02132136: brown bear, bruin, Ursus arctos
n03400231: frying pan, frypan, skillet
n03983396: pop bottle, soda bottle
n07734744: mushroom
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
n02410509: bison
n03404251: fur coat
n04456115: torch
n02123045: tabby, tabby cat
n03026506: Christmas stocking
n07715103: cauliflower
n04398044: teapot
n02927161: butcher shop, meat market
n07749582: lemon
n07615774: ice lolly, lolly, lollipop, popsicle
n02795169: barrel, cask
n04532106: vestment
n02837789: bikini, two-piece
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
n04265275: space heater
n02481823: chimpanzee, chimp, Pan troglodytes
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
n06596364: comic book
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
n02504458: African elephant, Loxodonta africana
n03014705: chest
n01944390: snail
n04146614: school bus
n01641577: bullfrog, Rana catesbeiana
n07720875: bell pepper
n02999410: chain
n01855672: goose
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
n07753592: banana
n07871810: meat loaf, meatloaf
n04501370: turnstile
n04311004: steel arch bridge
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
n04074963: remote control, remote
n03662601: lifeboat
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
n03089624: confectionery, confectionary, candy store
n04259630: sombrero
n03393912: freight car
n04597913: wooden spoon
n07711569: mashed potato
n03355925: flagpole, flagstaff
n02963159: cardigan
n07579787: plate
n02950826: cannon
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
n02094433: Yorkshire terrier
n02909870: bucket, pail
n02058221: albatross, mollymawk
n01742172: boa constrictor, Constrictor constrictor
n09193705: alp
n04371430: swimming trunks, bathing trunks
n07747607: orange
n03814639: neck brace
n04507155: umbrella
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
n03763968: military uniform
n07873807: pizza, pizza pie
n03992509: potter's wheel
n03796401: moving van
n12267677: acorn

View File

@@ -0,0 +1,73 @@
# Repeat Augment sampler taken from DeiT: https://github.com/facebookresearch/deit/blob/main/samplers.py
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import math
import torch
import torch.distributed as dist
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU)
Heavily based on torch.utils.data.DistributedSampler
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if num_repeats < 1:
raise ValueError("num_repeats should be greater than 0")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_repeats = num_repeats
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g)
else:
indices = torch.arange(start=0, end=len(self.dataset))
# add extra samples to make it evenly divisible
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
padding_size: int = self.total_size - len(indices)
if padding_size > 0:
indices += indices[:padding_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[: self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch
def __str__(self) -> str:
return (
f"{type(self).__name__}(num_replicas: {self.num_replicas}, rank: {self.rank}, num_repeats:"
f" {self.num_repeats}, epoch: {self.epoch}, num_samples: {self.num_samples}, total_size: {self.total_size},"
f" num_selected_samples: {self.num_selected_samples}, shuffle: {self.shuffle})"
)

View File

@@ -0,0 +1,381 @@
import argparse
import json
import os
from copy import copy
from loguru import logger
from nltk.corpus import wordnet as wn
class bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def _lemmas_str(synset):
return ", ".join([lemma.name() for lemma in synset.lemmas()])
class WNEntry:
def __init__(
self,
name: str,
id: int,
lemmas: str,
parent_id: int,
depth: int = None,
in_image_net: bool = False,
child_ids: list = None,
in_main_tree: bool = True,
_n_images: int = 0,
_description: str = None,
_name: str = None,
_pruned: bool = False,
):
self.name = name
self.id = id
self.lemmas = lemmas
self.parent_id = parent_id
self.depth = depth
self.in_image_net = in_image_net
self.child_ids = child_ids
self.in_main_tree = in_main_tree
self._n_images = _n_images
self._description = _description
self._name = _name
self._pruned = _pruned
def __str__(self, tree=None, accumulate=True):
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}" if self.in_image_net else f"{bcolors.FAIL}-{bcolors.ENDC}"
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
if self.child_ids is None or tree is None:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
else:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
["\n ".join(tree.nodes[child_id].__str__(tree).split("\n")) for child_id in self.child_ids]
)
def tree_diff(self, tree_1, tree_2):
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
else:
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
n_ims = (
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
)
if self.child_ids is None:
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
)
def prune(self, tree):
if self._pruned:
return
if self.child_ids is not None:
for child_id in self.child_ids:
tree[child_id].prune(tree)
self._pruned = True
parent_node = tree.nodes[self.parent_id]
try:
parent_node.child_ids.remove(self.id)
except ValueError as e:
print(
f"Error removing {self.name} from"
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
)
while parent_node._pruned:
parent_node = tree.nodes[parent_node.parent_id]
parent_node._n_images += self._n_images
self._n_images = 0
@property
def description(self):
if not self._description:
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
return self._description
@property
def print_name(self):
return self.name.split(".")[0]
def get_branch(self, tree=None):
if self.parent_id is None or tree is None:
return self.print_name
parent = tree.nodes[self.parent_id]
return parent.get_branch(tree) + " > " + self.print_name
def get_branch_list(self, tree):
if self.parent_id is None:
return [self]
parent = tree.nodes[self.parent_id]
return parent.get_branch_list(tree) + [self]
def to_dict(self):
return {
"name": self.name,
"id": self.id,
"lemmas": self.lemmas,
"parent_id": self.parent_id,
"depth": self.depth,
"in_image_net": self.in_image_net,
"child_ids": self.child_ids,
"in_main_tree": self.in_main_tree,
"_n_images": self._n_images,
"_description": self._description,
"_name": self._name,
"_pruned": self._pruned,
}
@classmethod
def from_dict(cls, d):
return cls(**d)
def n_images(self, tree=None):
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
return self._n_images
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
def n_children(self, tree=None):
if self.child_ids is None:
return 0
if tree is None or len(self.child_ids) == 0:
return len(self.child_ids)
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
def get_examples(self, tree, n_examples=3):
if self.child_ids is None or len(self.child_ids) == 0:
return ""
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
max_images = max(child_images.values())
if max_images == 0:
# go on number of child nodes
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
# sorted childids by number of images
top_children = [
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
]
top_children = top_children[: min(n_examples, len(top_children))]
return ", ".join(
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
)
class WNTree:
def __init__(self, root=1740, nodes=None):
if isinstance(root, int):
root_id = root
root_synset = wn.synset_from_pos_and_offset("n", root)
root_node = WNEntry(
root_synset.name(),
root_id,
_lemmas_str(root_synset),
parent_id=None,
depth=0,
)
else:
assert isinstance(root, WNEntry)
root_id = root.id
root_node = root
self.root = root_node
self.nodes = {root_id: self.root} if nodes is None else nodes
self.parentless = []
self.label_index = None
self.pruned = set()
def to_dict(self):
return {
"root": self.root.to_dict(),
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
"parentless": self.parentless,
"pruned": list(self.pruned),
}
def prune(self, min_images):
pruned_nodes = set()
# prune all nodes that have fewer than min_images below them
for node_id, node in self.nodes.items():
if node.n_images(self) < min_images:
pruned_nodes.add(node_id)
node.prune(self)
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
node_stack = [self.root]
node_idx = 0
while node_idx < len(node_stack):
node = node_stack[node_idx]
if node.child_ids is not None:
for child_id in node.child_ids:
child = self.nodes[child_id]
node_stack.append(child)
node_idx += 1
# now prune the stack from the bottom up
for node in node_stack[::-1]:
# only look at images of that class, not of additional children
if node.n_images() < min_images:
pruned_nodes.add(node.id)
node.prune(self)
self.pruned = pruned_nodes
return pruned_nodes
@classmethod
def from_dict(cls, d):
tree = cls()
tree.root = WNEntry.from_dict(d["root"])
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
tree.parentless = d["parentless"]
if "pruned" in d:
tree.pruned = set(d["pruned"])
return tree
def add_node(self, node_id, in_in=True):
if node_id in self.nodes:
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
return
synset = wn.synset_from_pos_and_offset("n", node_id)
# print(f"adding node {synset.name()} with id {node_id}")
hypernyms = synset.hypernyms()
if len(hypernyms) == 0:
parent_id = None
self.parentless.append(node_id)
main_tree = False
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
else:
parent_id = synset.hypernyms()[0].offset()
if parent_id not in self.nodes:
self.add_node(parent_id, in_in=False)
parent = self.nodes[parent_id]
if parent.child_ids is None:
parent.child_ids = []
parent.child_ids.append(node_id)
main_tree = parent.in_main_tree
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
node = WNEntry(
synset.name(),
node_id,
_lemmas_str(synset),
parent_id=parent_id,
in_image_net=in_in,
depth=depth,
in_main_tree=main_tree,
)
self.nodes[node_id] = node
def __len__(self):
return len(self.nodes)
def image_net_len(self, only_main_tree=False):
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def max_depth(self, only_main_tree=False):
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
def __str__(self):
return (
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self)}\nParentless:\n"
+ "\n".join([self.nodes[node_id].__str__(tree=self) for node_id in self.parentless])
)
def save(self, path):
with open(path, "w") as f:
json.dump(self.to_dict(), f)
@classmethod
def load(cls, path):
with open(path, "r") as f:
tree_dict = json.load(f)
return cls.from_dict(tree_dict)
def subtree(self, node_id):
if node_id not in self.nodes:
return None
node_queue = [self.nodes[node_id]]
subtree_ids = set()
while len(node_queue) > 0:
node = node_queue.pop(0)
subtree_ids.add(node.id)
if node.child_ids is not None:
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
subtree_root = subtree_nodes[node_id]
subtree_root.parent_id = None
depth_diff = subtree_root.depth
for node in subtree_nodes.values():
node.depth -= depth_diff
return WNTree(root=subtree_root, nodes=subtree_nodes)
def _make_label_index(self, include_merged=False):
self.label_index = sorted(
[
node_id
for node_id, node in self.nodes.items()
if node.n_images(self if include_merged else None) > 0 and not node._pruned
]
)
def get_label(self, node_id):
if self.label_index is None:
self._make_label_index()
while self.nodes[node_id]._pruned:
node_id = self.nodes[node_id].parent_id
return self.label_index.index(node_id)
def n_labels(self):
if self.label_index is None:
self._make_label_index()
return len(self.label_index)
def __contains__(self, item):
if isinstance(item, str):
if item[0] == "n":
item = int(item[1:])
else:
return False
if isinstance(item, int):
return item in self.nodes
if isinstance(item, WNEntry):
return item.id in self.nodes
return False
def __getitem__(self, item):
if isinstance(item, str) and item[0].startswith("n"):
try:
item = int(item[1:])
except ValueError:
pass
if isinstance(item, str) and ".n." in item:
for node in self.nodes.values():
if item == node.name:
return node
raise KeyError(f"Item {item} not found in tree")
if isinstance(item, int):
return self.nodes[item]
if isinstance(item, WNEntry):
return self.nodes[item.id]
raise KeyError(f"Item {item} not found in tree")

View File

@@ -0,0 +1,834 @@
import json
import os
import sys
from datetime import datetime
from functools import partial
from math import isfinite
from time import time
import torch
import torch.nn as nn
import torch.optim as optim
from loguru import logger
from timm.data import Mixup
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm.auto import tqdm
from metrics import calculate_metrics, per_class_counts
from utils import (
NoScaler,
ScalerGradNormReturn,
SchedulerArgs,
log_formatter,
save_model_state,
)
try:
from apex.optimizers import FusedLAMB # noqa: F401
apex_available = True
except ImportError:
logger.error("Nvidia apex not available")
apex_available = False
try:
from lion_pytorch import Lion
lion_available = True
except ImportError:
logger.error("Lion not available")
lion_available = False
WANDB_AVAILABLE = False
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
logger.error("wandb not available")
def wandb_available(turn_off=False):
"""If wandb is available.
Args:
turn_off (bool, optional): set wandb to be unavailble manually.
Returns:
bool: wandb is available
"""
global WANDB_AVAILABLE
if turn_off:
WANDB_AVAILABLE = False
return WANDB_AVAILABLE
tqdm = partial(tqdm, leave=True, position=0) # noqa: F405
def setup_tracking_and_logging(args, rank, append_model_path=None, log_wandb=True):
"""Set up logging and tracking for an experiment.
Args:
args (DotDict): Parsed command-line arguments
rank (int): The rank of the current process
append_model_path (str, optional): Path of an existing model, by default None
log_wandb (bool, optional): Whether to log to wandb, by default True
Returns:
str: folder, where all the run data is saved.
Raises:
AssertionError: If `dataset` or `model` is `None`.
Notes:
This function sets up logger to stdout and file, as well as MLflow tracking for an experiment.
For wandb logger, provide .wandb.apikey in the current directory.
"""
dataset, model, epochs = args.dataset.replace(os.sep, "_").lower(), args.model.replace(os.sep, "_"), args.epochs
_base_folder = (
os.path.join(args.results_folder, args.experiment_name, args.task.replace("-", ""), dataset)
if args.out_dir is None
else args.out_dir
)
run_folder = os.path.join(
_base_folder,
f"{args.run_name.replace(os.sep, '_')}_{model}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}",
)
assert dataset is not None and model is not None
if os.name == "nt":
run_folder = run_folder.replace("@", "_").replace(" ", "_").replace(":", ".")
if append_model_path is not None:
run_folder = os.path.dirname(append_model_path)
if "run_name" not in args or args.run_name is None:
args.run_name = run_folder.split(os.sep)[-1].split("_")[0]
elif args.distributed:
obj_list = [None]
if rank == 0:
obj_list[0] = run_folder
dist.broadcast_object_list(obj_list, src=0)
run_folder = obj_list[0]
if rank == 0:
os.makedirs(run_folder, exist_ok=True)
dist.barrier()
elif rank == 0:
os.makedirs(run_folder, exist_ok=True)
assert "%" not in args.run_name, f"found '%' in run_name '{args.run_name}'. This messes with string formatting..."
if args.debug:
args.log_level = "debug"
# logger to stdout & file
log_name = args.task.replace("-", "")
if args.task not in ["pre-train", "fine-tune", "fine-tune-head"]:
log_name += f"_{dataset}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}"
log_file = os.path.join(run_folder, f"{log_name}.log")
logger.remove()
logger.configure(extra=dict(run_name=args.run_name, rank=rank, world_size=args.world_size))
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.info(f"Run folder '{run_folder}'")
if rank == 0:
logger.info(f"{args.task.replace('-', '').capitalize()} {model} on {dataset} for {epochs} epochs")
global WANDB_AVAILABLE
WANDB_AVAILABLE = WANDB_AVAILABLE and log_wandb and os.path.isfile(".wandb.apikey") and args.wandb
if WANDB_AVAILABLE:
with open(".wandb.apikey", "r") as f:
__wandb_api_key = f.read().strip()
wandb.login(key=__wandb_api_key)
if args.wandb_run_id is not None:
wandb_args = dict(project=args.experiment_name, resume="must", id=args.wandb_run_id)
else:
wandb_args = dict(
project=args.experiment_name,
name=args.run_name.replace("_", "-").replace(" ", "-"),
config={"logfile": log_file, **dict(args)},
job_type=args.task,
tags=[model, dataset],
resume="allow",
id=args.wandb_run_id,
)
wandb.init(**wandb_args)
args["wandb_run_id"] = wandb.run.id
if rank == 0:
logger.info(f"wandb run initialized with id {args['wandb_run_id']}.")
else:
logger.info(
f"Not using wandb. (args.wandb={args.wandb}, .wandb.apikey exists={os.path.isfile('.wandb.apikey')},"
f" function declaration log_wandb={log_wandb})"
)
if args.distributed:
dist.barrier()
if args.debug:
torch.autograd.set_detect_anomaly(True)
logger.warning("torch.autograd anomaly detection enabled. Will slow down model.")
return run_folder
def setup_model_optim_sched_scaler(model, device, epochs, args, head_only=False):
"""Set up model, optimizer, and scheduler with automatic mixed precision (amp) and distributed data parallel (DDP).
Args:
model (nn.Module): the loaded model
device (torch.device): the current device
epochs (int): total number of epochs to learn for (for scheduler)
args: further arguments
head_only (bool, optional): train only the linear head. Default: False
Returns:
tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, ScalerGradNormReturn]: model, optimizer, scheduler, scaler
"""
model = model.to(device)
if head_only:
for param in model.parameters():
param.requires_grad = False
for param in model.head.parameters():
param.requires_grad = True
for name, param in model.named_parameters():
if "head" in name:
param.requires_grad = True
else:
param.requires_grad = False
params = model.head.parameters()
else:
params = model # model.named_parameters() use model itself for now and let timm do the work...
if args.opt == "lion" and not lion_available:
args.opt = "fusedlamb"
logger.warning("Falling back from lion to fusedlamb")
if args.opt == "fusedlamb" and not apex_available:
args.opt = "adamw"
logger.warning("Falling back from fusedlamb to adamw")
if args.opt == "lion":
optimizer = Lion(params, lr=args["lr"], weight_decay=args["weight_decay"])
else:
optimizer = create_optimizer(args, params)
scaler = ScalerGradNormReturn() if args.amp else NoScaler()
# if args.model_ema:
# # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
# ema_model = ModelEma(model, decay=args.model_ema_decay, resume='')
if args.distributed:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[device])
if args.compile_model:
model = torch.compile(model)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer,
# lr_lambda=scheduler_function_factory(**args))
sched_args = SchedulerArgs(args.sched, args.epochs, args.min_lr, args.warmup_lr, args.warmup_epochs)
scheduler, _ = create_scheduler(sched_args, optimizer)
return model, optimizer, scheduler, scaler
def setup_criteria_mixup(args, dataset=None, **criterion_kwargs):
"""Set up further objects that are needed for training.
Args:
args: arguments
dataset (torch.data.Dataset, optional): dataset that implements images_per_class, for class weights (Default value = None)
criterion_kwargs: further arguments for the criterion
**criterion_kwargs:
Returns:
tuple[nn.Module, nn.Module, Mixup]: criterion, val_criterion, mixup
"""
weight = None
if args.loss_weight != "none":
if dataset is not None and hasattr(dataset, "images_per_class"):
ipc = dataset.images_per_class
total_ims = sum(ipc)
if args.loss_weight == "linear":
weight = torch.tensor([total_ims / (ims * args.n_classes) for ims in ipc])
elif args.loss_weight == "log":
p_c = torch.tensor([ims / total_ims for ims in ipc])
log_p_c = torch.where(p_c > 0, p_c.log(), torch.zeros_like(p_c))
entr = -(p_c * log_p_c).sum()
weight = -log_p_c / entr
elif args.loss_weight == "sqrt":
p_c = torch.tensor([ims / total_ims for ims in ipc])
weight = 1 / (p_c.sqrt() * p_c.sqrt().sum())
else:
logger.warning("Could not find images_per_class in dataset. Using uniform weights.")
if args.aug_cutmix or args.multi_label:
# criterion = SoftTargetCrossEntropy()
if args.ignore_index >= 0:
if weight is None:
weight = torch.ones(args.n_classes)
weight[args.ignore_index] = 0
if args.multi_label:
if args.loss == "ce":
criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
val_criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
else:
raise NotImplementedError(
f"Only BCEWithLogitsLoss (ce) is implemented for multi-label classification, not {args.loss}."
)
else:
if args.loss == "ce":
loss_cls = nn.CrossEntropyLoss
elif args.loss == "baikal":
loss_cls = BaikalLoss
else:
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
criterion = loss_cls(weight=weight, **criterion_kwargs)
val_criterion = loss_cls(weight=weight, **criterion_kwargs)
else:
if args.loss == "ce":
loss_cls = nn.CrossEntropyLoss
elif args.loss == "baikal":
loss_cls = BaikalLoss
else:
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
criterion = loss_cls(
label_smoothing=args.label_smoothing,
ignore_index=args.ignore_index if weight is None else -100,
weight=weight,
**criterion_kwargs,
) # LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
# criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
val_criterion = loss_cls(
label_smoothing=args.label_smoothing,
ignore_index=args.ignore_index if weight is None else -100,
weight=weight,
**criterion_kwargs,
) # LabelSmoothingCrossEntropy(smoothing=0.)
mixup_kwargs = dict(
mixup_alpha=args.aug_mixup_alpha,
cutmix_alpha=args.aug_cutmix_alpha,
label_smoothing=args.label_smoothing,
num_classes=args.n_classes,
)
mixup = Mixup(**mixup_kwargs) if abs(args.aug_cutmix_alpha) + abs(args.aug_mixup_alpha) > 0.0 else None
return criterion, val_criterion, mixup
def _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
model_folder,
scaler,
do_metrics_calculation=True,
start_epoch=0,
show_tqdm=True,
topk=(1, 5),
acc_dict_key=None,
train_dali_server=None,
val_dali_server=None,
):
"""Train the model.
Args:
model:
train_loader:
optimizer:
rank:
epochs:
device:
mixup:
criterion:
world_size:
scheduler:
args:
val_loader:
val_criterion:
model_folder:
scaler:
do_metrics_calculation: (Default value = True)
start_epoch: (Default value = 0)
show_tqdm: (Default value = True)
topk: (Default value = (1)
5):
acc_dict_key: (Default value = None)
Returns:
dict: evaluation metrics at the end of training
"""
if acc_dict_key is None:
acc_dict_key = "acc{}"
training_start = time()
topk = tuple(k for k in topk if k <= args.n_classes)
time_spend_training = time_spend_validating = 0
current_best_acc = 0.0
if rank == 0:
logger.info(f"Dataloader has {len(train_loader)} batches")
logger.debug("Starting training with the following settings:")
logger.debug(f"criterion: {criterion}")
logger.debug(f"train_loader: {train_loader}, sampler: {train_loader.sampler}")
logger.debug(f"dataset: {train_loader.dataset}")
logger.debug(f"optimizer: {optimizer}")
logger.debug(f"device: {device}")
logger.debug(f"start epoch: {start_epoch}, epochs: {epochs}")
logger.debug(f"scaler: {scaler}")
logger.debug(f"max_grad_norm: {args.max_grad_norm}")
# logger.debug(f"model_ema:\n{model_ema}\n{model_ema.decay}\n{model_ema.device}")
if mixup:
logger.debug(
f"mixup: {mixup}; mixup_alpha: {mixup.mixup_alpha}, cutmix_alpha: {mixup.cutmix_alpha},"
f" cutmix_minmax: {mixup.cutmix_minmax}, prob: {mixup.mix_prob}, switch_prob: {mixup.switch_prob},"
f" label_smoothing: {mixup.label_smoothing}, num_classes: {mixup.num_classes}, correct_lam:"
f" {mixup.correct_lam}, mixup_enabled: {mixup.mixup_enabled}"
)
else:
logger.debug(f"mixup: {mixup}")
for epoch in range(start_epoch, epochs):
with logger.contextualize(epoch=str(epoch + 1)):
if args.distributed:
train_loader.sampler.set_epoch(epoch)
set_ep_func = getattr(train_loader.dataset, "set_epoch", None)
if callable(set_ep_func):
train_loader.dataset.set_epoch(epoch)
val_loader.dataset.set_epoch(epoch)
if train_dali_server:
train_dali_server.start_thread()
logger.info("started train dali server")
epoch_time, epoch_stats = _train_one_epoch(
model,
train_loader,
optimizer,
rank,
epoch,
device,
mixup,
criterion,
scheduler,
scaler,
args,
topk,
"train/" + acc_dict_key,
show_tqdm,
)
time_spend_training += epoch_time
if train_dali_server:
train_dali_server.stop_thread()
val_time, val_stats = _evaluate(
model,
val_loader,
epoch,
rank,
device,
val_criterion,
args,
topk,
"val/" + acc_dict_key,
dali_server=val_dali_server,
)
time_spend_validating += val_time
if rank == 0:
logger.info(f"total_time={time() - training_start}s")
if rank == 0:
top1_val_acc = val_stats["val/" + acc_dict_key.format(1)]
# print metadata for grafana
metadata = {
"epoch": epoch + 1,
"progress": (epoch + 1) / args.epochs,
**val_stats,
**epoch_stats,
}
# filter out Nan and infinity values
metadata = {k: v for k, v in metadata.items() if isfinite(v)}
print(json.dumps(metadata), flush=True)
logger.debug(f"printed metadata: {json.dumps(metadata)}")
if WANDB_AVAILABLE:
wandb.log(metadata, step=epoch + 1)
# saving current state
if top1_val_acc > current_best_acc or (epoch + 1) % args.save_epochs == 0:
reason = "top" if top1_val_acc > current_best_acc else "" # min(...) will be the top-1 accuracy
if reason == "top":
current_best_acc = top1_val_acc
logger.info(f"found a new best model with acc: {current_best_acc}")
kwargs = dict(
model_state=model.state_dict(),
stats=metadata,
optimizer_state=optimizer.state_dict(),
additional_reason=reason,
regular_save=(epoch + 1) % args.save_epochs == 0,
)
if scheduler:
kwargs["scheduler_state"] = scheduler.state_dict()
save_model_state(
model_folder, epoch + 1, args, **kwargs, max_interm_ep_states=args.keep_interm_states
)
if rank == 0:
end_time = time()
logger.info(
f"training done: total time={end_time - training_start}, "
f"time spend training={time_spend_training}, "
f"time spend validating={time_spend_validating}"
)
results = {**val_stats, **epoch_stats, f"val/best_{acc_dict_key.format(1)}": current_best_acc}
if rank == 0:
save_model_state(
model_folder,
epoch + 1,
args,
model_state=model.state_dict(),
stats=results,
additional_reason="final",
regular_save=False,
max_interm_ep_states=args.keep_interm_states,
)
if do_metrics_calculation:
# Calculate efficiency metrics
inp = next(iter(train_loader))[0].to(device)
metrics = calculate_metrics(
args,
model,
rank=rank,
input=inp,
device=device,
did_training=True,
all_metrics=False,
world_size=world_size,
key_start="eval/",
)
if rank == 0:
logger.info(f"Efficiency metrics: {json.dumps(metrics)}")
return results
def _mask_preds(preds, cls_masks, mask_val=-100):
"""Mask the predictions by the mask.
Args:
preds: model predictions
cls_masks: class masks
mask_val: (Default value = -100)
Returns:
torch.Tensor: masked predictions
"""
if cls_masks is None:
return preds
return torch.where(cls_masks.bool(), mask_val, preds)
def _evaluate(
model,
val_loader,
epoch,
rank,
device,
val_criterion,
args,
topk=(1, 5),
acc_dict_key=None,
dali_server=None,
):
"""Evaluate the model.
Args:
model (nn.Module): the model to evaluate
val_loader (DataLoader): loader for evaluation data
epoch (int): the current epoch (for logger & tracking)
rank (int): this processes rank (don't log n times)
device (torch.device): device to evaluate on
val_criterion (nn.Module): validation loss
args (DotDict): further arguments
topk (tuple[int], optional, optional): top-k accuracy, by default (1, 5)
acc_dict_key (str, optional, optional): key for the accuracy dictionary, by default name of the performance metric. 'val_' will be prepended.
Returns:
tuple[float, float, dict, dict]: validation time, validation loss, validation accuracies, additional information
"""
if not acc_dict_key:
acc_dict_key = "acc{}"
if dali_server:
dali_server.start_thread()
topk = tuple(k for k in topk if k <= args.n_classes)
model.eval()
val_loss = 0
val_accs = {acc_dict_key.format(k): 0.0 for k in topk}
val_start = time()
n_iters = 0
iterator = (
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch + 1}")
if rank == 0 and args.tqdm
else val_loader
)
class_counts = torch.zeros(1 if isinstance(topk, int) else len(topk), args.n_classes, 2)
for batch_data in iterator:
xs, ys = batch_data[:2]
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
if args.debug:
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
xs, ys = xs.to(device, non_blocking=True), ys.to(device, non_blocking=True)
with torch.no_grad(), torch.amp.autocast("cuda", enabled=args.eval_amp):
preds = model(xs)
preds = _mask_preds(preds, cls_masks)
if args.multi_label:
# labels are float for BCELoss
ys = ys.float()
val_loss += val_criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys).item()
class_counts += per_class_counts(preds, ys, args.n_classes, topk=topk, ignore_index=args.ignore_index)
n_iters += 1
if args.distributed:
dist.barrier()
if dali_server:
dali_server.stop_thread()
val_end = time()
iterations = n_iters
if args.distributed:
gather_tensor = torch.Tensor([val_loss]).to(device)
dist.barrier()
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
gather_tensor = gather_tensor.tolist()
val_loss = gather_tensor[0]
class_counts = class_counts.to(device)
dist.all_reduce(class_counts, op=dist.ReduceOp.SUM)
class_counts = class_counts.cpu()
for i, k in enumerate(topk):
key = acc_dict_key.format(k)
mkey = key.replace("acc", "m-acc")
val_accs[key] = class_counts[i].sum(dim=0)[0].item() / class_counts[i].sum(dim=0).sum(dim=-1).item()
val_accs[mkey] = (class_counts[i, :, 0] / class_counts[i].sum(dim=-1)).mean().item()
val_accs["val/loss"] = val_loss
if rank == 0:
log_s = f"val/time={val_end - val_start}s"
for key, val in val_accs.items():
log_s += f", {key}={val:.4f}"
logger.info(log_s)
return val_end - val_start, val_accs
def _train_one_epoch(
model,
train_loader,
optimizer,
rank,
epoch,
device,
mixup,
criterion,
scheduler,
scaler,
args,
topk=(1, 5),
acc_dict_key=None,
show_tqdm=True,
):
"""Train the model for one epoch.
Args:
model:
train_loader:
optimizer:
rank:
epoch:
device:
mixup:
criterion:
scheduler:
scaler:
args:
topk: (Default value = (1, 5)
acc_dict_key: (Default value = None)
show_tqdm: (Default value = True)
Returns:
tuple[float, float, dict]: time spend in training, epoch loss, epoch accuracies
"""
if not acc_dict_key:
acc_dict_key = "acc{}"
model.train()
iterator = (
tqdm(train_loader, total=len(train_loader), desc=f"Training epoch {epoch + 1}")
if rank == 0 and show_tqdm
else train_loader
)
if not args.amp:
scaler = NoScaler()
epoch_loss = 0
epoch_accs = {}
epoch_start = time()
grad_norms = []
n_iters = 0
if hasattr(train_loader.dataset, "epoch"):
train_loader.dataset.epoch = epoch
for i, batch_data in enumerate(iterator):
xs, ys = batch_data[:2]
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
optimizer.zero_grad()
n_iters += 1
xs = xs.to(device, non_blocking=True)
ys = ys.to(device, non_blocking=True)
if args.debug and i == 0:
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
if mixup:
if args.multi_label:
xs, ys = mixup(xs, ys, cls_masks)
else:
xs, ys = mixup(xs, ys)
if args.debug and i == 0:
logger.debug(f"input x: {type(xs)}; {xs.shape}, y: {type(ys)}; {ys.shape}")
with torch.amp.autocast("cuda", enabled=args.amp):
preds = model(xs)
preds = _mask_preds(preds, cls_masks)
if args.multi_label:
# labels are float for BCELoss
ys = ys.float()
loss = criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys) + (
model.get_internal_loss() if hasattr(model, "get_internal_loss") else model.module.get_internal_loss()
)
if not isfinite(loss.item()):
logger.error(f"Got loss value {loss.item()}. Stopping training.")
logger.info(f"input has nan: {xs.isnan().any().item()}")
logger.info(f"target has nan: {ys.isnan().any().item()}")
logger.info(f"output has nan: {preds.isnan().any().item()}")
for name, param in model.named_parameters():
if param.isnan().any().item():
logger.error(f"parameter {name} has a nan value")
if len(grad_norms) > 0:
grad_norms = torch.Tensor(grad_norms)
logger.info(
f"Gradient norms until now: min={grad_norms.min().item()}, 20th"
f" %tile={torch.quantile(grad_norms, .2).item()}, mean={torch.mean(grad_norms)}, 80th"
f" %tile={torch.quantile(grad_norms, .8).item()}, max={grad_norms.max()}"
)
sys.exit(1)
iter_grad_norm = scaler(
loss,
optimizer,
parameters=model.parameters(),
clip_grad=args.max_grad_norm if args.max_grad_norm > 0.0 else None,
).cpu()
if args.gather_stats_during_training and isfinite(iter_grad_norm):
grad_norms.append(iter_grad_norm)
# if args.aug_cutmix:
# ys = ys.argmax(dim=-1) # for accuracy with CutMix, just use the argmax for both
#
epoch_loss += loss.item()
# accuracies = accuracy(preds, ys, topk=topk, dict_key=acc_dict_key, ignore_index=args.ignore_index)
# for key in accuracies:
# epoch_accs[key] += accuracies[key]
if args.distributed:
dist.barrier()
epoch_end = time()
iterations = n_iters
# epoch_accs = {key: val / iterations for key, val in epoch_accs.items()}
epoch_loss = epoch_loss / iterations
grad_norm_avrg = -1
inf_grads = iterations - len(grad_norms)
if len(grad_norms) > 0 and args.gather_stats_during_training:
grad_norm_max = max(grad_norms)
grad_norms = torch.Tensor(grad_norms)
grad_norm_20 = torch.quantile(grad_norms, 0.2).item()
grad_norm_80 = torch.quantile(grad_norms, 0.8).item()
grad_norm_avrg = torch.mean(grad_norms)
if args.distributed:
# grad norm is already synchronized
# gather_tensor = torch.Tensor([epoch_loss, *[epoch_accs[acc_dict_key.format(k)] for k in topk]]).to(device)
gather_tensor = torch.Tensor([epoch_loss]).to(device)
dist.barrier()
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
# gather_tensor = (gather_tensor / world_size).tolist()
epoch_loss = gather_tensor.item()
# for i, k in enumerate(topk):
# epoch_accs[acc_dict_key.format(k)] = gather_tensor[i + 1]
lr = optimizer.param_groups[0]["lr"]
epoch_accs["train/lr"] = lr
epoch_accs["train/loss"] = epoch_loss
if rank == 0:
if args.gather_stats_during_training:
print_s = f"train/time={epoch_end - epoch_start}s"
logger.info(print_s)
if len(grad_norms) > 0:
logger.info(
f"grad norm avrg={grad_norm_avrg}, grad norm max={grad_norm_max}, "
f"inf grad norm={inf_grads}, grad norm 20%={grad_norm_20}, grad norm 80%={grad_norm_80}"
)
else:
logger.warning(f"inf grad norm={inf_grads}")
logger.error("100% of update steps with infinite grad norms!")
else:
logger.info(f"train/time={epoch_end - epoch_start}s")
if scheduler:
if isinstance(scheduler, optim.lr_scheduler.LambdaLR):
scheduler.step()
else:
scheduler.step(epoch)
if args.gather_stats_during_training:
return epoch_end - epoch_start, epoch_accs
return epoch_end - epoch_start, {}

View File

@@ -0,0 +1,710 @@
"""Module to evaluate trained models."""
import os
from contextlib import nullcontext
from datetime import datetime
from math import sqrt
from time import time
import torch
from loguru import logger
from timm.loss import LabelSmoothingCrossEntropy
from timm.models.resnet import ResNet as TimmResNet
from torch import distributed as dist
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from tqdm import tqdm
from engine import (
_evaluate,
setup_criteria_mixup,
setup_model_optim_sched_scaler,
setup_tracking_and_logging,
wandb_available,
)
from load_dataset import prepare_dataset
from metrics import calculate_metrics
from models import load_pretrained
from utils import (
RepeatedDataset,
ddp_cleanup,
ddp_setup,
denormalize,
get_cpu_name,
grad_cam_reshape_transform,
prep_kwargs,
set_filter_warnings,
)
def evaluate_metrics(model, dataset, **kwargs):
"""Evaluate efficiency metrics for a given model.
Args:
model (str): path to model state .tar
dataset (str): name of the dataset to evaluate on
**kwargs: further arguments
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
rank = 0
args.compile_model = False
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
setup_tracking_and_logging(args, rank=rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None)
train_loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
old_args["eval_imsize"] = args.imsize
args.model = model_name = old_args.model
args.dataset = dataset
args.epochs = 5
model, optim, _, scaler = setup_model_optim_sched_scaler(model, device, epochs=10, args=args, head_only=False)
if rank == 0:
logger.info(
f"Evaluate metrics for model {model_name} on {dataset}. "
f"It was {old_args.task.replace('-','')}d on {old_args.dataset} for {save_state['epoch']} "
"epochs."
)
# logger.info(f"full set of arguments: {args}")
logger.info(f"full set of training arguments: {old_args}")
logger.info(f"full set of eval-metrics arguments: {args}")
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
metrics = calculate_metrics(
args, model, rank=rank, device=device, optim=optim, scaler=scaler, train_loader=train_loader, key_start="eval/"
)
if rank == 0:
logger.info(f"Metrics: {metrics}")
if wandb_available():
import wandb
wandb.log(metrics)
def evaluate(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model accuracy.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
args.dataset = dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
)
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
val_dataset, args, train=False
)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
if rank == 0:
logger.info(
f"Evaluate model {model_name} on {val_dataset}. "
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
if rank == 0:
logger.info("start evaluation")
logger.info(f"Run info at: '{run_folder}'")
if rank == 0:
val_time, val_stats = _evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
)
log_s = f"Evaluation done in {val_time}s"
for key, val in val_stats.items():
log_s += f", {key}={val:.4f}"
logger.info(log_s)
if wandb_available():
import wandb
wandb.log(val_stats)
else:
_evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
)
ddp_cleanup(args=args, rank=rank)
def evaluate_center_bias(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model accuracy in different nonants.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
if dataset is None:
dataset = val_dataset
assert dataset is not None, "Specify validation dataset (-valds) or dataset (-ds)."
args.dataset = dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
)
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
if rank == 0:
logger.info(
f"Evaluate model {model_name} on {val_dataset}. "
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
if rank == 0:
logger.info("start evaluation")
logger.info(f"Run info at: '{run_folder}'")
if rank == 0:
nonant_accs = []
for nonant in range(-1, 9):
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
val_loader.dataset.fg_in_nonant = nonant
logger.info(f"Evaluate nonant {nonant} for 5 rounds.")
round_accs = []
for _ in range(5):
val_time, val_stats = _evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
)
round_accs.append(val_stats["acc1"])
nonant_accs.append(sum(round_accs) / len(round_accs))
log_s = f"Evaluation done in {val_time}s: "
for nonant, val in enumerate(nonant_accs[1:]):
log_s += f", nonant {nonant}={val}% acc ({val / nonant_accs[0]} rel acc)"
center_bias_val = 1 - (
min([nonant_accs[1], nonant_accs[3], nonant_accs[7], nonant_accs[9]])
+ min([nonant_accs[2], nonant_accs[4], nonant_accs[6], nonant_accs[8]])
) / (2 * nonant_accs[5])
log_s += f", center_bias={center_bias_val:.4f}"
logger.info(log_s)
if wandb_available():
import wandb
wandb.log({f"eval_{args.val_dataset}/center_bias": center_bias_val})
else:
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
ddp_cleanup(args=args, rank=rank)
def evaluate_size_bias(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model accuracy for differently scaled foregrounds.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
"""
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
if dataset is None:
dataset = val_dataset
assert val_dataset is not None and dataset is not None
args.dataset = dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path, log_wandb=False)
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
if rank == 0:
logger.info(
f"Evaluate model {model_name} on {val_dataset}. "
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
if rank == 0:
logger.info("start evaluation")
logger.info(f"Run info at: '{run_folder}'")
sizes = [0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 1.75, 2.0]
if rank == 0:
size_accs = []
val_times = 0
for size in sizes:
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
val_loader.dataset.size_fact = size
val_loader.dataset.fg_scale_jitter = 0.0
logger.info(f"Evaluate size factor {size} for 5 rounds.")
round_accs = []
for _ in range(5):
val_time, val_stats = _evaluate(
model.to(device),
val_loader,
epoch=save_state["epoch"] - 1,
rank=rank,
device=device,
val_criterion=val_criterion,
args=args,
dali_server=dali_server,
)
round_accs.append(val_stats["acc1"])
val_times += val_time
size_accs.append(sum(round_accs) / len(round_accs))
log_s = f"Evaluation done in {val_times}s: "
for size, val in zip(sizes, size_accs):
log_s += f", rel_size {size}={val}% acc ({val / size_accs[sizes.index(1.0)]} rel acc)"
logger.info(log_s)
else:
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
ddp_cleanup(args=args, rank=rank)
def evaluate_attributions(model, dataset=None, val_dataset=None, **kwargs):
"""Evaluate model attributions using captum.
Args:
model (str): path to model state .tar
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
**kwargs: further arguments
Note:
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
The `captum` package is required.
"""
from captum.attr import IntegratedGradients
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
set_filter_warnings()
model_path = model
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
assert val_dataset is not None, "Please set dataset (-ds) or validation dataset (-valds)"
args.dataset = val_dataset
args.val_dataset = val_dataset
if args.cuda:
args.distributed, device, world_size, rank, _ = ddp_setup()
torch.cuda.set_device(device)
else:
args.distributed = False
device = torch.device("cpu")
world_size = 1
rank = 0
args.compile_model = False
args.batch_size = int(args.batch_size / world_size)
save_state = torch.load(model_path, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
args.model = old_args.model
args.dataset = val_dataset
args.run_name = old_args.run_name
args.experiment_name = old_args.experiment_name
args.wandb_run_id = old_args.wandb_run_id
run_folder = setup_tracking_and_logging(
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
)
assert "fornet" in val_dataset.lower(), "Only ForNet supported for attribution evaluation."
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
val_dataset, args, train=False
)
val_loader.dataset.return_fg_masks = True
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
model = model.to(device)
args.model = model_name = old_args.model
args.dataset = dataset
# assert (
# args.imsize == old_args.imsize
# ), f"Model was trained on {old_args.imsize}x{old_args.imsize} images. Not {args.imsize}x{args.imsize}."
epoch = save_state["epoch"]
if rank == 0:
logger.info(
f"Evaluate attributions of model {model_name} on {dataset}. "
f"It was pretrained on {old_args.dataset} for {epoch} epochs."
)
if args.distributed:
model = DDP(model)
if args.compile_model:
model = torch.compile(model)
# log all devices
logger.info(
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
)
if rank == 0:
logger.info(f"torch version {torch.__version__}")
if args.new_log:
logger.info(f"full set of arguments: {args}")
logger.info(f"full set of old arguments: {old_args}")
else:
logger.info(f"full set of attribution evaluation arguments: {old_args}")
if args.seed:
torch.manual_seed(args.seed)
if rank == 0:
logger.info(f"Run info at: '{run_folder}'")
iterator = (
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch}")
if rank == 0 and args.tqdm
else val_loader
)
if args.debug:
from matplotlib import pyplot as plt
eval_attn_importance = False
if isinstance(model, TimmResNet):
reshape_transform = None
target_layers = [model.layer4[-1]]
elif model_name.lower().startswith("vit-"):
reshape_transform = grad_cam_reshape_transform
target_layers = [model.blocks[-1].norm1]
eval_attn_importance = True
from architectures.vit import _MatrixSaveAttn
model.blocks[-1].attn = _MatrixSaveAttn.cast(model.blocks[-1].attn)
elif model_name.lower().startswith("swin_"):
reshape_transform = grad_cam_reshape_transform
target_layers = [model.layers[-1].blocks[-1].norm1]
else:
raise NotImplementedError(f"Model {model_name} not supported for attribution evaluation.")
model.eval()
val_start = time()
rel_ig_weights = 0.0
rel_attn_weights = 0.0
rel_cam_weights = {"GradCAM": 0.0, "GradCAM++": 0.0}
if rank == 0:
logger.info("Start attribution evaluation")
if dali_server:
dali_server.start_thread()
for batch_data in iterator:
xs, ys, fg_masks = batch_data
xs, ys, fg_masks = (
xs.to(device, non_blocking=True),
ys.to(device, non_blocking=True),
fg_masks.float().to(device, non_blocking=True),
)
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
model.zero_grad()
ig = IntegratedGradients(model)
# we use attention temperature of 10 to make differences more apparent after exp
attr_ig = (
ig.attribute(xs, target=ys, baselines=0.0, internal_batch_size=args.batch_size * 4).sum(dim=1) * 10
) # B x W x H
attr_probs = attr_ig.view(xs.shape[0], -1).softmax(dim=-1).view(xs.shape[0], *xs.shape[2:])
fg_masks = fg_masks.view(attr_probs.shape)
fg_attrs = (attr_probs * fg_masks).sum(dim=(-1, -2))
rel_attr_weight = fg_attrs / fg_masks.mean(dim=(-1, -2))
rel_attr_weight = torch.where(fg_masks.mean(dim=(-1, -2)) > 0, rel_attr_weight, 1.0)
if rel_attr_weight.isnan().any():
logger.error(f"NaNs in rel_attr_weight: {rel_attr_weight}, fg_mask_weights: {fg_masks.mean(dim=(-1, -2))}")
break
rel_ig_weights += rel_attr_weight.mean().item()
cam_targets = [ClassifierOutputTarget(int(trgt)) for trgt in ys.tolist()]
for method, name in zip([GradCAM, GradCAMPlusPlus], ["GradCAM", "GradCAM++"]):
with method(model=model, target_layers=target_layers, reshape_transform=reshape_transform) as cam, (
torch.amp.autocast("cuda") if args.eval_amp else nullcontext()
):
cam_attr = cam(input_tensor=xs, targets=cam_targets)
cam_attr = torch.from_numpy(cam_attr).to(device)
rel_cam_attr = (cam_attr * fg_masks).sum(dim=(-1, -2)) / cam_attr.sum(dim=(-1, -2))
cam_attr_weight = rel_cam_attr / fg_masks.mean(dim=(-1, -2))
cam_attr_weight = torch.where(
(fg_masks.mean(dim=(-1, -2)) > 0) & (cam_attr.sum(dim=(-1, -2)) > 0), cam_attr_weight, 1.0
)
rel_cam_weights[name] += cam_attr_weight.mean().item()
if cam_attr_weight.isnan().any():
logger.error(
f"NaNs in cam_attr_weight ({name}): {cam_attr_weight}, fg_mask_weights:"
f" {fg_masks.mean(dim=(-1, -2))}"
)
break
if eval_attn_importance:
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
pred = model(xs) # noqa: F841
last_attn_mat = model.blocks[-1].attn.attn_mat
cls_tkn_attn = last_attn_mat[:, :, 0, 1:].mean(dim=1).squeeze(dim=1) # B x H x 1(CLS Token) X N -> B x N
B, N = cls_tkn_attn.shape
att_HW = int(sqrt(N))
cls_tkn_attn = cls_tkn_attn.view(B, 1, att_HW, att_HW)
attn_attr = F.interpolate(
cls_tkn_attn, size=(xs.shape[-2], xs.shape[-1]), mode="bilinear", align_corners=False
).view(B, xs.shape[-2], xs.shape[-1])
rel_attn_attr = (attn_attr * fg_masks).sum(dim=(-1, -2)) / attn_attr.sum(dim=(-1, -2))
attn_attr_weight = rel_attn_attr / fg_masks.mean(dim=(-1, -2))
attn_attr_weight = torch.where(
(fg_masks.mean(dim=(-1, -2)) > 0) & (attn_attr.sum(dim=(-1, -2)) > 0), attn_attr_weight, 1.0
)
rel_attn_weights += attn_attr_weight.mean().item()
if args.debug:
logger.debug(f"Attribution scores: IG: {rel_attr_weight[:5]}, GradCAM(++): {cam_attr_weight[:5]}")
num_subplots = 5 if eval_attn_importance else 4
fig, axs = plt.subplots(num_subplots, 4)
for plt_i in range(4):
axs[0][plt_i].imshow(denormalize(xs[plt_i]).permute(1, 2, 0).cpu().numpy())
axs[1][plt_i].imshow(fg_masks[plt_i].cpu().numpy())
axs[2][plt_i].imshow(attr_probs[plt_i].cpu().numpy())
axs[3][plt_i].imshow(cam_attr[plt_i].cpu().numpy())
if eval_attn_importance:
axs[4][plt_i].imshow(attn_attr[plt_i].cpu().numpy())
plt.show()
iterator_desc = (
f"IG weights: {rel_ig_weights / (iterator.n + 1):.4f}, GradCAM weights:"
f" {rel_cam_weights['GradCAM'] / (iterator.n + 1):.4f}, GradCAM++ weights:"
f" {rel_cam_weights['GradCAM++'] / (iterator.n + 1):.4f}"
)
if eval_attn_importance:
iterator_desc += f", Attn weights: {rel_attn_weights / (iterator.n + 1):.4f}"
iterator.set_description(iterator_desc)
if args.distributed:
dist.barrier()
val_end = time()
rel_ig_weights /= len(iterator)
rel_grad_cam = rel_cam_weights["GradCAM"] / len(iterator)
rel_grad_cam_pp = rel_cam_weights["GradCAM++"] / len(iterator)
rel_attn_weights /= len(iterator)
if dali_server:
dali_server.stop_thread()
if args.distributed:
gather_tensor = torch.Tensor([rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights]).to(device)
dist.barrier()
dist.all_reduce(gather_tensor)
gather_tensor = (gather_tensor / world_size).tolist()
rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights = gather_tensor
if rank == 0:
output_text = (
f"epoch {epoch}: eval_{args.val_dataset}/rel_ig_weights={rel_ig_weights},"
f" eval_{args.val_dataset}/rel_grad_cam={rel_grad_cam},"
f" eval_{args.val_dataset}/rel_grad_cam_pp={rel_grad_cam_pp}"
)
if eval_attn_importance:
output_text += f", eval_{args.val_dataset}/rel_attn_weights={rel_attn_weights}"
output_text += f", eval_{args.val_dataset}/attribution_eval_time={val_end - val_start}s"
logger.info(output_text)
if wandb_available():
import wandb
wandb_data = {
f"eval_{args.val_dataset}/importance_ig": rel_ig_weights,
f"eval_{args.val_dataset}/importance_grad_cam": rel_grad_cam,
f"eval_{args.val_dataset}/importance_grad_cam_pp": rel_grad_cam_pp,
}
if eval_attn_importance:
wandb_data[f"eval_{args.val_dataset}/importance_attn"] = rel_attn_weights
wandb.log(wandb_data)
ddp_cleanup(args=args, rank=rank)

View File

@@ -0,0 +1,211 @@
import argparse
import os
from functools import partial
from math import log
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import Compose, RandomCrop, Resize, ToTensor
from tqdm.auto import tqdm
try:
from models import load_pretrained
except ModuleNotFoundError:
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from models import load_pretrained
from utils import prep_kwargs
def score_5( # noqa: D103
idx,
bg_probs,
mean_probs,
fg_ratio,
max_idx,
fg_ratio_max=0.9,
fg_ratio_min=0.002,
fg_ratio_exp=0.4, # learned: 0.487
idx_exp=0.01, # learned: 0.043
bg_probs_exp=0.2, # learned: 0.24
opt_fg_ratio=0.1, # learned: 0.1
mean_probs_exp=0.2, # learned: 0.2
fg_ratio_penalty=1, # learned: -0.446 ???
):
return (
log(mean_probs) * mean_probs_exp
+ log(1 - bg_probs) * bg_probs_exp
+ log(1 - abs(fg_ratio - opt_fg_ratio)) * fg_ratio_exp
+ log(1 - idx / (max_idx + 1)) * idx_exp
+ (fg_ratio_min < fg_ratio < fg_ratio_max) * fg_ratio_penalty
) # TODO: KEEP IT LIKE IT IS
parser = argparse.ArgumentParser(description="Inspect image versions")
parser.add_argument("-f", "--base_folder", type=str, required=True, help="Base folder to inspect")
parser.add_argument(
"-batch_size",
type=int,
default=32,
help="Batch size for model inspection. Will be 1 background and batch_size - 1 foregrounds",
)
parser.add_argument("-imsize", type=int, default=224, help="Image size")
parser.add_argument(
"-score_f_weights", choices=["manual", "automatic"], default="manual", help="Score function hyperparameters"
)
parser.add_argument("-auto_fg_pen_val", type=float, default=-0.446, help="Automatic foreground penalty value")
parser.add_argument("-d", "--dataset", choices=["tinyimagenet", "imagenet"], default="tinyimagenet", help="Dataset")
args = parser.parse_args()
score_f = (
score_5
if args.score_f_weights == "manual"
else partial(
score_5, **dict(fg_ratio_exp=0.487, idx_exp=0.043, bg_probs_exp=0.24, fg_ratio_penalty=args.auto_fg_pen_val)
)
)
bg_folder = os.path.join(args.base_folder, "backgrounds")
fg_folder = os.path.join(args.base_folder, "foregrounds")
classes = os.listdir(fg_folder)
classes = sorted(classes, key=lambda x: int(x[1:]))
assert len(classes) in [200, 1_000], f"Expected 200 or 1_000 classes, got {len(classes)}"
total_images = set()
for in_cls in classes:
cls_images = {
os.path.join(in_cls, "_".join(img.split(".")[0].split("_")[:-1]))
for img in os.listdir(os.path.join(fg_folder, in_cls))
if img.split(".")[0].split("_")[-1].startswith("v")
}
total_images.update(cls_images)
total_images = list(total_images)
# base_folder = os.path.join(*(os.path.dirname(__file__).split("/")[:-1]))
# in_cls to print name/lemma
with open(os.path.join("data", "misc_dataset_files", "tinyimagenet_synset_names.txt"), "r") as f:
in_cls_to_name = {line.split(":")[0].strip(): line.split(":")[1].strip() for line in f.readlines() if len(line) > 2}
if args.dataset == "tinyimagenet":
inspection_model_paths = [] # PATHS TO MODEL WEIGHTS (.pt) PRETRAINED ON TinyImageNet
elif args.dataset == "imagenet":
inspection_model_paths = [] # PATHS TO MODEL WEIGHTS (.pt) PRETRAINED ON ImageNet
else:
raise ValueError(f"Unknown dataset {args.dataset}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inspection_models = [
load_pretrained(path, prep_kwargs({}), new_dataset_params=False)[0].to(device) for path in inspection_model_paths
]
img_transform = Compose([Resize((args.imsize, args.imsize)), RandomCrop(args.imsize), ToTensor()])
total_versions = []
for img_name in tqdm(total_images, desc="Image version computation"):
in_cls, img_name = img_name.split("/")
versions = set()
for img in os.listdir(os.path.join(fg_folder, in_cls)):
if "_".join(img.split("_")[: len(img_name.split("_"))]) == img_name:
versions.add(img)
if len(versions) == 1:
version = list(versions)[0]
if version.split(".")[0].split("_")[-1].startswith("v"):
tqdm.write(f"renaming single version image {version} to {img_name}.WEBP")
os.rename(os.path.join(fg_folder, in_cls, version), os.path.join(fg_folder, in_cls, f"{img_name}.WEBP"))
os.rename(
os.path.join(bg_folder, in_cls, version.replace(".WEBP", ".JPEG")),
os.path.join(bg_folder, in_cls, f"{img_name}.JPEG"),
)
continue
elif len(versions) == 0:
tqdm.write(f"Image {img_name} has no versions")
continue
versions = sorted(list(versions))
assert all(
[version.split(".")[0].split("_")[-1].startswith("v") for version in versions]
), f"Weird Versions: {versions} for image {img_name}"
assert len(versions) <= 3, f"Too many versions for image {img_name}: {versions}"
version_scores = []
for v_idx, version in enumerate(versions):
img = Image.open(os.path.join(fg_folder, in_cls, version))
bg_img = Image.open(os.path.join(bg_folder, in_cls, f"{version.split('.')[0]}.JPEG"))
img_mask = np.array(img.convert("RGBA").split()[-1])
fg_ratio = np.sum(img_mask) / (255 * bg_img.size[0] * bg_img.size[1])
fg_size = img.size
monochrome_backgrounds = [
Image.new(
"RGB",
(max(args.imsize, fg_size[0]), max(args.imsize, fg_size[1])),
(255 * i // (args.batch_size - 2), 255 * i // (args.batch_size - 2), 255 * i // (args.batch_size - 2)),
)
for i in range(args.batch_size - 1)
]
pasting_error = False
for mc_bg in monochrome_backgrounds:
try:
mc_bg.paste(img, ((args.imsize - fg_size[0]) // 2, (args.imsize - fg_size[1]) // 2), img)
except ValueError as e:
tqdm.write(f"Image {img_name} could not be pasted into background: {e}")
pasting_error = True
break
inp_batch = torch.stack(
[img_transform(bg_img)] + [img_transform(mc_bg) for mc_bg in monochrome_backgrounds], dim=0
).to(device)
cls_idx = classes.index(in_cls)
bg_probs = []
mean_probs = []
for model in inspection_models:
model.eval()
with torch.no_grad():
out_probs = model(inp_batch).softmax(dim=-1)[:, cls_idx].cpu().numpy()
bg_probs.append(out_probs[0])
mean_probs.append(np.mean(out_probs[1:]))
# average the lists
bg_probs = np.mean(bg_probs)
mean_probs = np.mean(mean_probs)
version_score = (
score_f(
idx=v_idx,
bg_probs=float(bg_probs),
mean_probs=float(mean_probs),
fg_ratio=float(fg_ratio),
max_idx=len(versions) - 1,
)
if not pasting_error
else -100
)
version_scores.append(version_score)
assert len(versions) == len(version_scores), f"Expected {len(versions)} scores, got {len(version_scores)}"
if max(version_scores) > min(version_scores):
# find best version
best_version_idx = int(np.argmax(version_scores))
best_version = versions[best_version_idx]
# delete all other versions
for version in versions:
if version != best_version:
os.remove(os.path.join(fg_folder, in_cls, version))
os.remove(os.path.join(bg_folder, in_cls, f"{version.split('.')[0]}.JPEG"))
# remove version tag in name
new_version_name = "_".join(best_version.split("_")[:-1]) + "." + best_version.split(".")[-1]
os.rename(os.path.join(fg_folder, in_cls, best_version), os.path.join(fg_folder, in_cls, new_version_name))
os.rename(
os.path.join(bg_folder, in_cls, f"{best_version.split('.')[0]}.JPEG"),
os.path.join(bg_folder, in_cls, f"{new_version_name.split('.')[0]}.JPEG"),
)
else:
tqdm.write(f"All versions have the same score for image {img_name}")

View File

@@ -0,0 +1,16 @@
#!/bin/bash
srun -K \
--container-image=PATH/TO/SLURM/IMAGE \
--container-workdir="$(pwd)" \
--container-mounts=/ALL/IMPORTANT/MOUNTS,"$(pwd)":"$(pwd)" \
--partition=RTXA6000,RTX3090,A100-40GB,A100-80GB,H100,H200 \
--job-name="python" \
--nodes=1 \
--gpus=1 \
--ntasks=1 \
--cpus-per-task=24 \
--mem=64G \
--time=1-0 \
--export="NLTK_DATA=/PATH/TO/NLTK_DATA" \
python3 "$@"

View File

@@ -0,0 +1,346 @@
"""Module to load the datasets, using torch and datadings."""
import contextlib
import os
from functools import partial
import torchvision.transforms as tv_transforms
from datadings.reader import MsgpackReader
from timm.data import create_transform
from torch.utils.data import DataLoader, DistributedSampler, WeightedRandomSampler
from torchvision.datasets import (
CIFAR10,
CIFAR100,
FGVCAircraft,
Flowers102,
Food101,
ImageFolder,
OxfordIIITPet,
StanfordCars,
)
from data.counter_animal import CounterAnimal
from data.data_utils import (
DDDecodeDataset,
ToOneHotSequence,
collate_imnet,
collate_listops,
get_hf_transform,
minimal_augment,
segment_augment,
three_augment,
)
from data.fornet import ForNet
from data.samplers import RASampler
from paths_config import ds_path
def prepare_dataset(dataset_name, args, transform=None, train=True, rank=None):
"""Load a dataset from disk, different formats are used for different datasets.
Supported datasets: CIFAR10, ImageNet, ImageNet21k
Args:
dataset_name (str): name of the dataset
args: further arguments
transform (list[Module] | str, optional): transformations to use on the data; the list gets composed, or give args.augment_strategy (Default value = None)
train (bool, optional): use the training split (or test/validation split) (Default value = True)
rank (int, optional): global rank of this process in distributed training (Default value = None)
Returns:
DataLoader: data loader for the dataset
int: number of classes in the dataset
int: ignore index for the dataset
bool: whether the dataset is multi-label
"""
compose = tv_transforms.Compose
dali_server = None
if transform is None:
if args.augment_engine == "torchvision":
if args.augment_strategy == "3-augment":
transform = three_augment(args, as_list=False, test=not train)
elif args.augment_strategy == "differentiable-transform":
from data.distilled_dataset import differentiable_augment
transform = differentiable_augment(args, as_list=False, test=not train)
elif args.augment_strategy == "none":
transform = []
elif args.augment_strategy == "lm_one_hot":
transform = [
tv_transforms.Grayscale(num_output_channels=1),
tv_transforms.ToTensor(),
ToOneHotSequence(),
]
elif args.augment_strategy == "segment-augment":
transform = segment_augment(args, test=not train)
elif args.augment_strategy == "minimal":
transform = minimal_augment(args, test=not train)
elif args.augment_strategy == "deit":
if train:
transform = create_transform(
input_size=args.imsize,
is_training=True,
color_jitter=args.aug_color_jitter_factor,
auto_augment=args.auto_augment_strategy,
interpolation="bicubic",
re_prob=args.aug_random_erase_prob,
re_mode=args.aug_random_erase_mode,
re_count=args.aug_random_erase_count,
)
else:
transform = three_augment(args, test=True) # only do resize, centercrop, and normalize
else:
raise NotImplementedError(
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
)
elif args.augment_engine == "albumentations":
from data import album_transf as ATf
compose = ATf.AlbumTorchCompose
if args.augment_strategy == "3-augment":
transform = ATf.three_augment(args, as_list=False, test=not train)
elif args.augment_strategy == "minimal":
transform = ATf.minimal_augment(args, test=not train)
else:
raise NotImplementedError(
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
)
elif args.augment_engine == "dali":
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
from data import dali_transf as DTf
dev_id = int(os.environ.get("LOCAL_RANK", 0))
if args.augment_strategy == "3-augment":
pipe = DTf.three_augment(
args,
test=not train,
batch_size=args.batch_size,
num_threads=args.num_workers,
device_id=dev_id,
)
elif args.augment_strategy == "minimal":
pipe = DTf.minimal_augment(
args,
test=not train,
batch_size=args.batch_size,
num_threads=args.num_workers,
device_id=dev_id,
)
else:
raise NotImplementedError(
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
)
dali_server = dali_proxy.DALIServer(pipe)
transform = dali_server.proxy
dataset_name_case_sensitive = dataset_name # keep the original name for AnimalNet folder
dataset_name = dataset_name.lower()
ignore_index = -100
multi_label = False
if isinstance(transform, list):
transform = compose(transform)
if dataset_name == "cifar10":
dataset = CIFAR10(root=ds_path("cifar"), train=train, download=False, transform=transform)
n_classes, collate = 10, None
elif dataset_name == "stanford-cars":
dataset = StanfordCars(
root=ds_path("stanford_cars"),
split="train" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 196, None
elif dataset_name == "oxford-pet":
dataset = OxfordIIITPet(
root=ds_path("oxford_pet"),
split="trainval" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 37, None
elif dataset_name == "flowers102":
dataset = Flowers102(
root=ds_path("flowers102"),
split="train" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 102, None
elif dataset_name == "food-101":
dataset = Food101(
root=ds_path("food101"),
split="train" if train else "test",
download=False,
transform=transform,
)
n_classes, collate = 101, None
elif dataset_name == "fgvc-aircraft":
dataset = FGVCAircraft(
root=ds_path("aircraft"),
split="train" if train else "test",
annotation_level="variant",
download=False,
transform=transform,
)
n_classes, collate = 100, None
elif dataset_name == "imagenet":
dataset = ImageFolder(os.path.join(ds_path("imagenet1k"), "train" if train else "val"), transform=transform)
n_classes, collate = 1000, None
elif dataset_name == "tinyimagenet":
dataset = ImageFolder(os.path.join(ds_path("tinyimagenet"), "train" if train else "val"), transform=transform)
n_classes, collate = 200, None
elif dataset_name.startswith("fornet"):
ds_def = dataset_name.split("/")
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
pruning_ratio = 0.8 if len(ds_def) < 3 else float(ds_def[2])
fg_size_mode = ("range" if train else "max") if len(ds_def) < 4 else ds_def[3]
paste_pre_transform = True if len(ds_def) < 5 else ds_def[4] in ["y", "t"]
orig_img_prob = (
0.0 if len(ds_def) < 6 else (ds_def[5] if ds_def[5] in ["linear", "revlinear", "cos"] else float(ds_def[5]))
)
mask_smoothing_sigma = 0.0 if len(ds_def) < 7 else float(ds_def[6])
assert len(ds_def) < 5 or ds_def[4] in [
"y",
"t",
"n",
"f",
], f"Invalid dataset definition: {ds_def[4]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
orig_ds = ds_path("imagenet1k")
dataset = ForNet(
ds_path("fornet"),
train=train,
background_combination=comb_scheme,
pruning_ratio=pruning_ratio,
transform=transform,
fg_transform=(
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
),
fg_size_mode=fg_size_mode,
paste_pre_transform=paste_pre_transform,
orig_img_prob=orig_img_prob,
orig_ds=orig_ds,
mask_smoothing_sigma=mask_smoothing_sigma,
epochs=args.epochs,
_album_compose=args.augment_engine == "albumentations",
)
n_classes, collate = 1000, None
elif dataset_name.startswith("tinyfornet"):
ds_def = dataset_name.split("/")
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
pruning_ratio = 1.1 if len(ds_def) < 3 else float(ds_def[2])
fg_size_mode = "range" if len(ds_def) < 4 else ds_def[3]
fg_bates_n = 1 if len(ds_def) < 5 else int(ds_def[4])
paste_pre_transform = False if len(ds_def) < 6 else ds_def[5] in ["y", "t"]
orig_img_prob = (
0.0 if len(ds_def) < 7 else (ds_def[6] if ds_def[6] in ["linear", "revlinear", "cos"] else float(ds_def[6]))
)
mask_smoothing_sigma = 0.0 if len(ds_def) < 8 else float(ds_def[7])
assert len(ds_def) < 6 or ds_def[5] in [
"y",
"t",
"n",
"f",
], f"Invalid dataset definition: {ds_def[5]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
assert "-" not in ds_def[0] or len(ds_def[0].split("-")) == 3, f"Invalid dataset definition: {ds_def[0]}"
version = "" if "-" not in ds_def[0] else f"_v{ds_def[0].split('-')[1]}_f{ds_def[0].split('-')[2]}"
orig_ds = ds_path("tinyimagenet")
dataset = ForNet(
f"{ds_path('tinyimagenet')}{version}",
train=train,
background_combination=comb_scheme,
pruning_ratio=pruning_ratio,
transform=transform,
fg_transform=(
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
),
fg_size_mode=fg_size_mode,
fg_bates_n=fg_bates_n,
paste_pre_transform=paste_pre_transform,
orig_img_prob=orig_img_prob,
orig_ds=orig_ds,
mask_smoothing_sigma=mask_smoothing_sigma,
epochs=args.epochs,
_album_compose=args.augment_engine == "albumentations",
)
n_classes, collate = 200, None
elif dataset_name.startswith("counteranimal/"):
mode = dataset_name.split("/")[1]
dataset = CounterAnimal(ds_path("counteranimal"), mode=mode, transform=transform, train=train)
n_classes, collate = 1000, None
elif dataset_name.startswith("imagenet9/"):
variant = dataset_name.split("/")[1]
assert variant in [
"next",
"same",
"rand",
], f"ImageNet-9 has possible variants next, same, and rand, but not '{variant}'."
dataset = ImageFolder(os.path.join(ds_path("imagenet9"), f"mixed_{variant}", "val"), transform=transform)
n_classes, collate = 9, None
else:
raise NotImplementedError(f"Dataset {dataset_name} is not implemented (yet).")
if args.aug_repeated_augment_repeats > 1 and train:
# use repeated augment sampler from DeiT
sampler = RASampler(
dataset,
num_replicas=args.world_size,
rank=rank,
shuffle=args.shuffle,
num_repeats=args.aug_repeated_augment_repeats,
)
elif args.weighted_sampler:
assert hasattr(
dataset, "per_sample_weights"
), f"Dataset {type(dataset)} should implement per_sample_weights function, but does not."
sampler = WeightedRandomSampler(dataset.per_sample_weights(), num_samples=len(dataset) // args.world_size)
elif args.distributed:
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank, shuffle=train and args.shuffle)
else:
sampler = None
loader_batch_size = 1 if dataset_name.startswith("listops") else args.batch_size
loader_kwargs = dict(
batch_size=loader_batch_size,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
drop_last=train,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
persistent_workers=False,
collate_fn=collate,
shuffle=None if sampler else train and args.shuffle,
sampler=sampler,
)
if args.augment_engine == "dali":
data_loader = dali_proxy.DataLoader(dali_server, dataset, **loader_kwargs)
else:
data_loader = DataLoader(dataset, **loader_kwargs)
return data_loader, n_classes, ignore_index, multi_label, dali_server

View File

@@ -0,0 +1,679 @@
#!/usr/bin/env python3
"""Parse args and call the correct script inside slurm container.
Outside the container, on the head-node, create and call the correct srun command.
"""
import argparse
import os
import subprocess
from datetime import datetime
from config import default_kwargs, slurm_defaults
from paths_config import results_folder, slurm_output_folder
_EXPNAMES = ["EfficientCVBench", "test", "recombine_imagenet"]
def base_parser():
"""Create the argument parser with all the choices for the training / evaluation scripts."""
parser = argparse.ArgumentParser("Transformer training and evaluation.")
# Main
group = parser.add_argument_group("Main")
group.add_argument(
"-t",
"--task",
nargs="?",
choices=[
"pre-train",
"fine-tune",
"fine-tune-head",
"eval",
"parser-test",
"eval-metrics",
"eval-attr",
"continue",
"eval-center-bias",
"eval-size-bias",
"load-images",
"save-images",
],
required=True,
help="Task to perform.",
)
group.add_argument(
"-m",
"--model",
nargs="?",
type=str,
required=True,
help="Model to use. Either model name for a new model or weights and dicts to load for fine-tuning.",
)
group.add_argument("-ds", "--dataset", nargs="?", type=str, help="Dataset to train on.")
group.add_argument(
"-valds", "--val-dataset", nargs="?", type=str, help="Validation dataset. Defaults to same as training."
)
group.add_argument("-ep", "--epochs", nargs="?", type=int, help="Number of epochs to train.")
group.add_argument(
"-run",
"--run-name",
nargs="?",
type=str,
help="A name for the run. If not give, the model name is used instead.",
)
group.add_argument(
"--defaults", nargs="?", choices=["DeiT", "DeiTIII"], default="DeiTIII", help="Default settings to use."
)
# Further model parameters
group = parser.add_argument_group("Further model parameters")
group.add_argument("--drop-path-rate", nargs="?", type=float, help="Drop path rate for ViT models.")
group.add_argument("--layer-scale-init-values", nargs="?", type=float, help="LayerScale initial values.")
group.add_argument("--layer-scale", action=argparse.BooleanOptionalAction, help="Use layer scale?")
group.add_argument(
"--qkv-bias",
action=argparse.BooleanOptionalAction,
help="Use bias in linear transformation to queries, keys, and values?",
)
group.add_argument("--pre-norm", action=argparse.BooleanOptionalAction, help="Use norm first architecture?")
group.add_argument("--dropout", nargs="?", type=float, help="Model dropout.")
group.add_argument("-heads", "--num-heads", nargs="?", type=int, help="Number of parallel attention heads.")
group.add_argument("--input-dim", nargs="?", type=int, help="Dimensionality of text encoding.")
group.add_argument("--max-seq-len", nargs="?", type=int, help="Maximum sequence length for text data.")
group.add_argument(
"--fused-attn",
action=argparse.BooleanOptionalAction,
help="Use fused attention (for ViT with Timm's attention only)?",
)
# group.add_argument(
# "--perf-metric", nargs="?", choices=["acc", "mIoU"], help="Performance metric to use for evaluation."
# )
# group.add_argument("-no_model_ema", action="store_true",
# help="Don't use an exponential moving average for model parameters")
# group.add_argument("-model_ema_decay", nargs='?', type=float, default=default_kwargs["model_ema_decay"],
# help="Decay rate for exponential moving average of model parameters")
# Experiment management
group = parser.add_argument_group("Experiment management")
group.add_argument("--seed", nargs="?", type=int, help="Manual RNG seed.")
group.add_argument(
"-exp",
"--experiment-name",
nargs="?",
choices=_EXPNAMES,
help="Name for the experiment. Is used for grouping of runs.",
)
group.add_argument(
"--save-epochs", nargs="?", type=int, help="Number of epochs after which to save the full training state."
)
group.add_argument(
"--keep-interm-states",
nargs="?",
type=int,
help="Number of intermediate states to keep. All others (earlier ones) will be deleted automatically.",
)
group.add_argument(
"--custom-dataset-path", nargs="?", type=str, help="Overwrite the path to any dataset to this path."
)
group.add_argument(
"--results-folder",
nargs="?",
default=results_folder,
type=str,
help="Folder to put script results (mlflow data, models, etc.).",
)
group.add_argument(
"--gather-stats-during-training",
action=argparse.BooleanOptionalAction,
help="Gather training statistics from all GPUs?",
)
group.add_argument("--tqdm", action=argparse.BooleanOptionalAction, help="Show tqdm for every epoch?")
group.add_argument(
"--debug", action=argparse.BooleanOptionalAction, help="Debug mode: lots of intermediate prints."
)
group.add_argument("--wandb", action=argparse.BooleanOptionalAction, help="Use external logging via Wandb?")
group.add_argument("--log-level", choices=["info", "debug"], help="Log level", metavar="LEVEL")
group.add_argument("-out", "--out-dir", type=str, help="Output directory for additional outputs.")
# Speedup
group = parser.add_argument_group("Speedup")
group.add_argument("--amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision?")
group.add_argument(
"--eval-amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision during evaluation?"
)
group.add_argument("--compile-model", action=argparse.BooleanOptionalAction, help="Use torch.compile?")
group.add_argument("--cuda", action=argparse.BooleanOptionalAction, help="Use cuda?")
# Data loading
group = parser.add_argument_group("Data loading")
group.add_argument("-bs", "--batch-size", nargs="?", type=int, help="Batch size over all graphics cards (togeter).")
group.add_argument("--num-workers", nargs="?", type=int, help="Number of dataloader worker threads. Should be >0.")
group.add_argument(
"--pin-memory", action=argparse.BooleanOptionalAction, help="Use pin_memory of torch Dataloader?"
)
group.add_argument(
"--prefetch-factor",
nargs="?",
type=int,
help="Prefetch factor for dataloader workers (how many batches to fetch)",
)
group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, help="Shuffle the training data?")
group.add_argument(
"--weighted-sampler",
action=argparse.BooleanOptionalAction,
help="Use a class-weighted sampler to sample evenly from all classes (train and val)?",
)
group.add_argument("--ipc", type=int, help="How many images per class to load and save.")
# Optimizer
group = parser.add_argument_group("Optimizer")
group.add_argument("--opt", nargs="?", type=str, help="Optimizer to use.")
group.add_argument("--weight-decay", nargs="?", type=float, help="Weight decay factor for use in optimizer.")
group.add_argument("-lr", "--lr", nargs="?", type=float, help="Initial learning rate.")
group.add_argument(
"--max-grad-norm", nargs="?", type=float, help="Maximum norm for the gradients (used for cutoff)."
)
group.add_argument("--warmup-epochs", nargs="?", type=int, help="Number of epochs of linear warmup.")
group.add_argument("--label-smoothing", nargs="?", type=float, help="Label smoothing factor.")
group.add_argument("--loss", nargs="?", choices=["ce", "baikal"], type=str, help="Loss function to use.")
group.add_argument(
"--loss-weight", nargs="?", type=str, choices=["none", "linear", "log", "sqrt"], help="Per class loss weight."
)
group.add_argument("--sched", nargs="?", choices=["cosine", "const"], help="Learning rate schedule.")
group.add_argument("--min-lr", nargs="?", type=float, help="Minimum learning rate to be hit by scheduler.")
group.add_argument("--warmup-lr", nargs="?", type=float, help="Warmup learning rate.")
group.add_argument("--warmup-sched", nargs="?", choices=["linear", "const"], help="Schedule for warmup")
group.add_argument(
"--opt-eps", nargs="?", type=float, help="Epsilon value added in the optimizer to stabilize training."
)
group.add_argument("--momentum", nargs="?", type=float, help="Optimizer momentum.")
# Data augmentation
group = parser.add_argument_group("Data augmentation")
group.add_argument("--augment-strategy", nargs="?", type=str, help="Data augmentation strategy.")
group.add_argument("--aug-rand-rot", nargs="?", type=int, help="Random rotation limit.")
group.add_argument(
"--aug-flip", action=argparse.BooleanOptionalAction, help="Use data augmentation: horizontal flip?"
)
group.add_argument(
"--aug-crop",
action=argparse.BooleanOptionalAction,
help="Use data augmentation: cropping. This may break the skript?",
)
group.add_argument("--aug-resize", action=argparse.BooleanOptionalAction, help="Use data augmentation: resize?")
group.add_argument(
"--aug-grayscale", action=argparse.BooleanOptionalAction, help="Use data augmentation: grayscale?"
)
group.add_argument("--aug-solarize", action=argparse.BooleanOptionalAction, help="Use data augmentation: solarize?")
group.add_argument(
"--aug-gauss-blur", action=argparse.BooleanOptionalAction, help="Use data augmentation: gaussian blur?"
)
group.add_argument(
"--aug-cutmix-alpha",
type=float,
help="Alpha value for using CutMix. CutMix is active when aug_cutmix_alpha > 0.",
)
group.add_argument(
"--aug-mixup-alpha", type=float, help="Alpha value for using Mixup. Mixup is active when aug_mixup_alpha > 0."
)
group.add_argument(
"--aug-color-jitter-factor",
nargs="?",
type=float,
help="Factor to use for the data augmentation: color jitter.",
)
group.add_argument(
"--aug-normalize", action=argparse.BooleanOptionalAction, help="Use data augmentation: Normalization?"
)
group.add_argument(
"--aug-repeated-augment-repeats",
type=int,
help="Number of image repeats with repeat-augment from DeiT. 1 is not using repeat-augment.",
)
group.add_argument(
"--aug-random-erase-prob", type=float, help="For DeiT augment: Probabiliy of RandomErase augmentation."
)
group.add_argument("--auto-augment-strategy", type=str, help="For DeiT augment: AutoAugment Policy to use.")
group.add_argument("--imsize", nargs="?", type=int, help="Image size given to the model -> imsize x imsize.")
group.add_argument(
"--augment-engine",
nargs="?",
choices=["torchvision", "albumentations", "dali"],
help="Which data augmentation engine to use.",
)
return parser
def partition_choices():
"""Automatically create a list of all possible slurm partitions."""
potential = list(set([l.split(" ")[0] for l in os.popen("sinfo")])) # noqa: E741
if len(potential) <= 2:
return slurm_defaults["partition"]
return [p[:-1] if "*" in p else p for p in potential if p != "PARTITION"]
def slurm_parser(parser=None):
"""Add srun arguments to the given parser.
Args:
parser (argparse.ArgumentParser, optional): base parser to extend; default is parser from *base_parser*
Returns:
parser (argparse.ArgumentParser): extended parser
"""
if parser is None:
parser = base_parser()
group = parser.add_argument_group("Slurm arguments")
group.add_argument(
"--partition",
nargs="*",
default=slurm_defaults["partition"],
choices=partition_choices(),
help="Slurm partition to use",
)
group.add_argument(
"--container-image",
nargs="?",
default=slurm_defaults["container_image"],
type=str,
help="Path to slurm container image (.sqsh)",
)
group.add_argument(
"--container-workdir",
nargs="?",
default=slurm_defaults["container_workdir"],
type=str,
help="Working directory in container",
)
group.add_argument(
"--container-mounts",
nargs="?",
default=slurm_defaults["container_mounts"],
type=str,
help="All slurm mounts separated by ','.",
)
group.add_argument(
"--job-name",
nargs="?",
default=slurm_defaults["job_name"],
type=str,
help="Slurm job name. Will default to '<model> <task> <dataset>'.",
)
group.add_argument(
"--nodes", nargs="?", default=slurm_defaults["nodes"], type=int, help="Number of cluster nodes to use."
)
group.add_argument(
"--ntasks", nargs="?", default=slurm_defaults["ntasks"], type=int, help="Number of GPUs to use for the job."
)
group.add_argument("--gpus", action=argparse.BooleanOptionalAction, default=True, help="Use gpus for this job?")
group.add_argument(
"-cpus",
"--cpus-per-task",
"--cpus-per-gpu",
nargs="?",
default=slurm_defaults["cpus_per_task"],
type=int,
help="Number of CPUs per task/GPU.",
)
group.add_argument(
"-mem",
"--mem-per-gpu",
"--mem-per-task",
nargs="?",
default=slurm_defaults["mem_per_gpu"],
type=int,
help="Ram per GPU (in Gb) to use. Will be given as total mem in srun command.",
)
group.add_argument(
"--task-prolog",
nargs="?",
default=slurm_defaults["task_prolog"],
type=str,
help="Shell script for task prolog (installing packages, etc.).",
)
group.add_argument("--time", nargs="?", default=slurm_defaults["time"], type=str, help="Slurm time limit.")
group.add_argument(
"--export",
nargs="?",
default=slurm_defaults["export"],
type=str,
help="Additional environment variables to export.",
)
group.add_argument("--exclude", nargs="?", default=slurm_defaults["exclude"], type=str, help="Nodes to exclude.")
group.add_argument(
"--after-job", nargs="?", default=slurm_defaults["after_job"], type=int, help="Job ID to wait for."
)
group.add_argument(
"--interactive",
action="store_true",
help=(
"Run using srun instead of sbatch. This will print the output into the terminal, not the slurm output file."
" The logfile will still be created as usual."
),
default=False,
)
group = parser.add_argument_group("Run locally")
group.add_argument("--local", action="store_true", help="Run locally; not in slurm", default=False)
return parser
def parse_args(args=None, parser=None):
"""Parse args from *base_parser* and insert defaults.
Args:
args: (Default value = None)
parser: (Default value = None)
Returns:
dict: parsed arguments
"""
if args is None:
parser = base_parser()
args = parser.parse_args()
args = dict(vars(args))
check_arg_completeness(args, parser)
return args
def check_arg_completeness(args, parser):
"""Check completeness of arguments.
Args:
args (dict): arguments to check
parser (argparse.ArgumentParser): for raising the parser error
Note:
will raise a parser error if the arguments are not complete.
"""
if args["task"] in ["pre-train", "fine-tune", "fine-tune-head"]:
if "run_name" not in args or args["run_name"] is None or len(args["run_name"]) == 0:
parser.error(f"-run_name is required for task {args['task']}")
if "experiment_name" not in args or args["experiment_name"] is None or len(args["experiment_name"]) == 0:
parser.error(f"-experiment_name is required for task {args['task']}. Choose from {_EXPNAMES}")
if "epochs" not in args or args["epochs"] is None:
parser.error(f"-epochs is required for task {args['task']}")
if ("dataset" not in args or args["dataset"] is None) and args["task"] in [
"pre-train",
"fine-tune",
"fine-tune-head",
"eval-metrics",
]:
parser.error(f"-dataset is required for task {args['task']}")
if (
("val_dataset" not in args or args["val_dataset"] is None)
and ("dataset" not in args or args["dataset"] is None)
and args["task"] in ["eval"]
):
parser.error(f"-dataset or -val_dataset is required for task {args['task']}")
if args["aug_repeated_augment_repeats"] is not None and args["aug_repeated_augment_repeats"] < 1:
parser.error(
"number of repeats for repeated augment has to be >= 1, but got -aug_repeated_augment_repeats ="
f" {args['aug_repeated_augment_repeats']}"
)
if args["task"] == "save-images" and ("out_dir" not in args or args["out_dir"] is None):
parser.error("Need to set save directory (--out-dir) to save the images in.")
def inside_slurm():
"""Test for being inside a slurm container.
Works by testing for environment variable 'RANK'.
"""
return "RANK" in os.environ
# TODO: fix ./runscript.tmp: 18: Syntax error: Unterminated quoted string
def create_runscript(args, file_name=None):
"""Create a run script for a distributed training job using SLURM.
Args:
args (dict): A dictionary containing various arguments for the job, including parameters for SLURM and for training.
file_name (str, optional, optional): The name of the file to create. Defaults to "runscript.tmp".
Returns:
str: The name of the created file.
str: Additional command line arguments for sbatch.
Example:
>>> args = {"model": "vit_large_patch16_384", "task": "pre-train", "batch_size": 256, ...}
>>> file_name = "my_run_script.sh"
>>> create_runscript(args, file_name)
"""
for key, val in slurm_defaults.items():
if key not in args and val is not None:
args[key] = val
if "run_name" not in args or args["run_name"] is None:
model_str = args["model"]
if model_str.endswith(".pt"):
model_str = os.path.dirname(model_str)
run_name = args["task"] + " " + model_str.split(os.sep)[-1].split("_")[0]
else:
run_name = args["run_name"]
job_name = run_name.replace(" ", "_").replace("/", "_").replace(">", "_").replace("<", "_")
if file_name is None:
file_name = (
f"experiments/sbatch/run_{args['task']}_{job_name}_at_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.sbatch"
)
task_args = ""
# slurm_command = "echo run distributed:\necho python3 main.py {0}\n\nsrun -K \\\n" # " --gpus-per-task=1 \\\n --gpu-bind=none \\\n"
srun_command = "\nsrun -K \\\n"
sbatch_commands = ( # outfile name is job name, date, job id, node name
"#!/bin/bash\n\n#SBATCH"
f" --output={slurm_output_folder}/%x-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-%j-%N.out\n"
)
sbatch_cmd_args = "" # additional command line arguments for sbatch
python_command = " python3 main.py {0}\n"
for key, val in args.items():
if key == "local":
continue
if key == "interactive":
continue
if key == "gpus":
continue
if key in slurm_defaults:
# it's a parameter for srun
# slurm has - instead of _
key = key.replace("_", "-")
if key == "mem-per-gpu":
# convert mem-per-gpu to mem
# slurm_command += f" --mem={val * args['ntasks'] // args['nodes']}G \\\n" # that amount of memory is assigned on each node
key = "mem"
val = f"{val * args['ntasks'] // args['nodes']}G" # that amount of memory is assigned on each node
# continue
if key == "job-name" and val is None:
# # default jobname is '<task> <model> <dataset>'
# model_str = args["model"]
# task = args["task"]
# if task == "pre-train":
# # it's just the model name...
# model = model_str.split("_")[0]
# else:
# # it's a path to the tar file
# if not model_str.startswith(res_folder):
# model = "<vit model>"
# else:
# model = model_str[len(res_folder) :].split("_")[1].split(" ")[0]
# if "dataset" in args and args["dataset"] is not None:
# dataset = args["dataset"]
# else:
# dataset = ""
val = run_name
if key == "job-name" and not val.startswith('"'):
val = f'"{val}"'
if key in ["task-prolog", "nodes", "exclude", "after-job"] and val is None:
continue
if key == "task-prolog":
srun_command += f' --{key}="{val}" \\\n'
continue
if key == "after-job":
sbatch_cmd_args += f"--dependency=afterany:{val} "
continue
if key == "partition" and isinstance(val, list):
val = ",".join(val)
if key == "ntasks":
if args["nodes"] == 1:
gpus = val if args["gpus"] else 0
# slurm_command += f" --gpus={val} \\\n"
sbatch_commands += f"#SBATCH --gpus={gpus}\n"
else:
assert (
val % args["nodes"] == 0
), f"Number of tasks ({val}) must be a multiple of the number of nodes ({args['nodes']})."
# slurm_command += f" --gpus-per-node={val // args['nodes']} \\\n"
sbatch_commands += f"#SBATCH --gpus-per-node={val // args['nodes']}\n"
sbatch_commands += "#SBATCH --ntasks-per-node=8\n"
if "container" in key:
srun_command += f" --{key}={val} \\\n"
else:
sbatch_commands += f"#SBATCH --{key}={val}\n"
# slurm_command += f" --{key}={val} \\\n"
else:
# it's a parameter for the training
if val is None:
continue
if key in ["results_folder"] and val == globals()[key]:
continue
key = key.replace("_", "-")
if isinstance(val, bool):
if val:
task_args += f"--{key} "
else:
task_args += f"--no-{key} "
continue
if isinstance(val, str):
task_args += f'--{key} "{val}" '
else:
task_args += f"--{key} {val} "
# slurm_command += "python3 main.py {0}\n"
# os.umask(0) # make it possible to create an executable file
# with open(file_name, "w+", opener=lambda pth, flgs: os.open(pth, flgs, 0o777)) as f:
# f.write(slurm_command.format(task_args))
with open(file_name, "w+") as f:
f.write(sbatch_commands + srun_command + python_command.format(task_args))
# delete all runscripts older than a month
n_old_files = int(os.popen("find experiments/sbatch/ -type f -mtime +30 | wc -l").read())
if n_old_files > 0:
print(f"Deleting {n_old_files} old runscripts.")
os.system("find experiments/sbatch/ -type f -mtime +30 -delete")
return file_name, sbatch_cmd_args
if __name__ == "__main__":
if not inside_slurm():
# Make execution script and execute it
parser = slurm_parser()
args = vars(parser.parse_args())
if not args["local"]:
script_name, cmd_args = create_runscript(args)
# os.system("./" + script_name) # run srun to execute this script in slurm cluster
# -> the following lines will be executed there
if args["interactive"]:
os.system(f"python3 srun-sbatch.py {script_name}")
else:
os.system(f"sbatch {cmd_args} {script_name}") # sbatch to queue the job on the cluster
exit(0)
# local execution is wanted
for key in list(args.keys()):
if key.replace("_", "-") in slurm_defaults:
args.pop(key)
args = parse_args(args, parser)
else:
args = parse_args()
args["branch"] = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode("utf-8")
args["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
if args["task"] == "pre-train":
from train import pretrain
pretrain(**args)
elif args["task"] == "fine-tune":
from train import finetune
finetune(**args)
elif args["task"] == "fine-tune-head":
from train import finetune
finetune(**args, head_only=True)
elif args["task"] == "parser-test":
from copy import copy
from utils import prep_kwargs, log_args
kwargs = prep_kwargs(copy(args))
log_args(kwargs)
# keys = sorted(list(args.keys()))
# fill_len = max(len(k) for k in keys)
# for key in keys:
# print(f"{key + ' ' * (fill_len - len(key))} = {args[key]} -> {kwargs[key]}")
elif args["task"] == "eval-metrics":
from evaluate import evaluate_metrics
evaluate_metrics(**args)
elif args["task"] == "eval":
from evaluate import evaluate
evaluate(**args)
elif args["task"] == "eval-attr":
from evaluate import evaluate_attributions
evaluate_attributions(**args)
elif args["task"] == "continue":
from recover import continue_training
continue_training(**args)
elif args["task"] == "eval-center-bias":
from evaluate import evaluate_center_bias
evaluate_center_bias(**args)
elif args["task"] == "eval-size-bias":
from evaluate import evaluate_size_bias
evaluate_size_bias(**args)
elif args["task"] == "load-images":
from test import load_images
load_images(**args)
elif args["task"] == "save-images":
from test import save_images
save_images(**args)
else:
raise NotImplementedError(f"Task {args['task']} is not implemented.")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,154 @@
"""Model loading and preparation."""
import importlib
import os
import warnings
from functools import partial
import timm
import torch
import torch.nn as nn
from loguru import logger
import utils
from architectures.vit import TimmViT
from resizing_interface import vit_sizes
_ARCHITECTURES_IMPORTED = False
def _import_architectures():
global _ARCHITECTURES_IMPORTED
if not _ARCHITECTURES_IMPORTED:
model_file_path = os.path.dirname(os.path.abspath(__file__))
for file in os.listdir(os.path.join(model_file_path, "architectures")):
if not file.endswith(".py"):
continue
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
importlib.import_module(f"architectures.{file[:-3]}")
logger.debug(f"Imported architectures.{file[:-3]}")
except Exception as e:
logger.error(f"\033[93mCould not import \033[0m\033[91m{file}\033[0m")
logger.error(e)
_ARCHITECTURES_IMPORTED = True
def prepare_model(model_str, args):
"""Prepare a new model.
If the name is of the format ViT-<size>/<patch_size>, use a *TimmViT*, else fall back to timm model loading.
Args:
model_str (str): model name
args (utils.DotDict): further arguments, needs to have keys n_classes, drop_path_rate; key imsize or '_<imsize>' at the end of ViT specification
Returns:
torch.nn.Module: model
"""
_import_architectures()
kwargs = dict(args)
for key in list([key for key, val in kwargs.items() if val is None]):
kwargs.pop(key)
if args.layer_scale_init_values:
kwargs["init_values"] = kwargs["init_scale"] = args.layer_scale_init_values
if args.dropout and args.dropout > 0.0:
kwargs["drop"] = kwargs["drop_rate"] = args.dropout
if args.drop_path_rate and args.drop_path_rate > 0.0:
kwargs["drop_block_rate"] = args.drop_path_rate
kwargs["num_classes"] = args.n_classes
kwargs["img_size"] = args.imsize
if model_str.startswith("ViT"):
# Format: ViT-{Ti,S,B,L}/<patch_size>[_<image_res>]
h1, h2 = model_str.split("/")
_, model_size = h1.split("-")
if "_" in h2:
patch_size, image_res = h2.split("_")
assert args.imsize is None or args.imsize == int(
image_res
), f"Got two different image sizes: {args.imsize} vs {image_res}"
else:
patch_size = h2
kwargs = {**vit_sizes[model_size], **kwargs}
model = TimmViT(patch_size=int(patch_size), in_chans=3, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
else:
logger.debug(f"Loading model via timm api {model_str} with args {kwargs}")
model = timm.create_model(model_str, pretrained=False, **kwargs)
return model
def load_pretrained(model_path, args, new_dataset_params=False):
"""Load a pretrained model from .tar file.
Args:
new_dataset_params (bool, optional): change model parameters (imsize, n_classes) to the ones from args. (Default value = False)
model_path (str): path to .tar file
args: new model parameters
Returns:
tuple: model, args, old_args, save_state
"""
_import_architectures()
save_state = torch.load(model_path, map_location="cpu")
old_args = utils.prep_kwargs(save_state["args"])
args.model = old_args.model
old_args.cuda = args.cuda
if old_args.model.startswith("flash_vit"):
args.pop("layer_scale_init_values", None)
old_args.pop("layer_scale_init_values", None)
# load the model (the old one first)
model = prepare_model(old_args.model, old_args)
logger.debug(f"loading model {old_args.model} from {model_path} with args {old_args}")
file_save_state = utils.remove_prefix(save_state["model_state"], prefix="_orig_mod.")
file_save_state = utils.remove_prefix(file_save_state)
try:
model.load_state_dict(file_save_state)
except (UnboundLocalError, RuntimeError) as e:
model_keys = set(model.state_dict().keys())
file_keys = set(file_save_state.keys())
logger.warning(f"Error loading state dict: {e}")
model_minus_file = model_keys.difference(file_keys)
file_minus_model = file_keys.difference(model_keys)
logger.warning(f"model-file: {model_minus_file}\nfile-model: {file_minus_model}")
if len(file_minus_model) == 0 and all([".ls" in key and key.endswith(".gamma") for key in model_minus_file]):
logger.info("Old model was without LayerScale -> replicating")
try:
args.pop("layer_scale_init_values")
old_args.pop("layer_scale_init_values")
model = prepare_model(old_args.model, old_args)
model.load_state_dict(file_save_state)
except (UnboundLocalError, RuntimeError) as e:
logger.error("Could not resolve conflict")
logger.error(f"Still got error {e}")
exit(-1)
elif any("head.0." in key for key in file_minus_model):
logger.info("Old model used nn.Seqeuntial for head. Trying to fix -> nn.Linear")
file_save_state = {key.replace("head.0.", "head."): val for key, val in file_save_state.items()}
try:
model.load_state_dict(file_save_state)
except (UnboundLocalError, RuntimeError) as e:
logger.error("Could not resolve conflict")
logger.error(f"Still got error {e}")
exit(-1)
else:
exit(-1)
if new_dataset_params:
# setup for finetuning parameters
model.set_image_res(args.imsize)
model.set_num_classes(args.n_classes)
if args.max_seq_len is not None:
model.set_max_seq_len(args.max_seq_len)
return model, args, old_args, save_state

View File

@@ -0,0 +1,36 @@
import os
user = os.environ.get("USER")
results_folder = os.path.join("/BASE/FOLDER/TO/STORE/WEIGHTS/AND/LOGS", "EfficientCVBench")
# PATH: /netscratch/<user>/slurm
slurm_output_folder = os.path.join("/FOLDER/FOR/SLURM/TO/WRITE/LOGS/TO", "slurm")
_ds_paths = {
"cifar": "/PATH/TO/CIFAT",
"tinyimagenet": "/PATH/TO/TINYIMAGENET",
"stanford_cars": "/PATH/TO/CARS/SUPERFOLDER",
"oxford_pet": "/PATH/TO/PET/SUPERFOLDER",
"flowers102": "/PATH/TO/FLOWERS/SUPERFOLDER",
"food101": "/PATH/TO/FOOD/SUPERFOLDER",
"aircraft": "/PATH/TO/AIRCRAFT/SUPERFOLDER",
"fornet": "/PATH/TO/FORNET",
"counteranimal": "/PATH/TO/CounterAnimal/LAION-final",
}
def ds_path(dataset, args=None):
"""Get the (base) path for any dataset.
Args:
-----
dataset (str): The dataset I'm looking for.
args (DotDict, optional): Run args. If args.custom_dataset_path is set, this one is always returned.
Returns:
--------
str: Path to the dataset root folder.
"""
if args is not None and "custom_dataset_path" in args and args.custom_dataset_path is not None:
return args.custom_dataset_path
return _ds_paths[dataset]

View File

@@ -0,0 +1,136 @@
"""Continue pretraining / finetuning after something went wrong."""
import torch
from loguru import logger
from engine import _train, setup_criteria_mixup, setup_model_optim_sched_scaler, setup_tracking_and_logging
from load_dataset import prepare_dataset
from models import load_pretrained
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs
def continue_training(model, **kwargs):
"""Continue training a model from a saved state.
Args:
model (str): path to saved state.
**kwargs: additional keyword arguments.
"""
model_path = model
save_state = torch.load(model, map_location="cpu")
# state is of the form
#
# state = {'epoch': epochs,
# 'model_state': model.state_dict(),
# 'optimizer_state': optimizer.state_dict(),
# 'scheduler_state': scheduler.state_dict(),
# 'args': dict(args),
# 'run_name': run_name,
# 'stats': metrics}
args = prep_kwargs(save_state["args"])
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
torch.cuda.set_device(device)
if "world_size" in args and args.world_size is not None:
global_bs = args.batch_size * args.world_size
else:
# assume global bs is given in kwargs
global_bs = kwargs["batch_size"]
args.batch_size = int(global_bs / world_size)
args.world_size = world_size
if "dataset" in args and args.dataset is not None:
dataset = args.dataset
else:
# get default dataset for the task
dataset = "ImageNet21k" if args.task == "pre-train" else "ImageNet"
args.dataset = dataset
if "val_dataset" in args and args.val_dataset is not None:
val_dataset = args.val_dataset
else:
val_dataset = dataset
args.val_dataset = val_dataset
start_epoch = save_state["epoch"]
if "epochs" in args and args.epochs is not None and args.epochs != start_epoch:
epochs = args.epochs
else:
epochs = kwargs["epochs"]
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path)
logger.info(f"Logging run information to '{run_folder}'")
# get the datasets & dataloaders
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
dataset, args, rank=rank
)
val_loader, _, __, ___, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
# model_name = args.model
model, args, _, __ = load_pretrained(model_path, args)
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
try:
optimizer.load_state_dict(save_state["optimizer_state"])
except ValueError as e:
logger.error(f"Could not load optimizer state: {e}")
logger.error(
f"optimizer state: {optimizer.state_dict().keys()}, param groups: {optimizer.state_dict()['param_groups']}"
)
logger.error(
f"saved state: {save_state['optimizer_state'].keys()}, param groups:"
f" {save_state['optimizer_state']['param_groups']}"
)
raise e
scheduler.load_state_dict(save_state["scheduler_state"])
# log all devices
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
if rank == 0:
logger.info(f"torch version {torch.__version__}")
log_args(args)
if args.seed:
torch.manual_seed(args.seed)
criterion, val_criterion, mixup = setup_criteria_mixup(args)
if rank == 0:
logger.info(f"start training at epoch {start_epoch}")
logger.info(f"Run info at: '{run_folder}'")
res = _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
run_folder,
scaler=scaler,
do_metrics_calculation=True,
start_epoch=start_epoch,
show_tqdm=args.tqdm,
train_dali_server=train_dali_server,
val_dali_server=val_dali_server,
)
if rank == 0:
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
logger.info(f"Run '{args.run_name}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%")
ddp_cleanup(args=args, rank=rank)

View File

@@ -0,0 +1,25 @@
captum==0.8.0
# datadings==3.4.6
einops==0.8.1
fvcore==0.1.5.post20221221
grad-cam==1.5.4
halonet-pytorch==0.0.4
numpy==1.26.4
opencv-python==4.11.0.86
pytorch_wavelets==1.3.0
pywavelets==1.8.0
reformer-pytorch==1.4.4
routing-transformer==1.6.1
sinkhorn-transformer==0.11.4
timm==1.0.15
torch==2.6.0
torcheval==0.0.7
torchprofile==0.0.4
torchvision==0.21.0
tqdm==4.67.1
nltk==3.9.1
numpy==1.26.4
Pillow==11.1.0
psutils==3.3.9
wandb==0.19.9
psutil==7.0.0

View File

@@ -0,0 +1,108 @@
from copy import copy
import torch
from loguru import logger
from torch import nn
vit_sizes = {
"na": dict(embed_dim=96, depth=2, num_heads=2),
"mu": dict(embed_dim=144, depth=6, num_heads=3),
"Ti": dict(embed_dim=192, depth=12, num_heads=3),
"S": dict(embed_dim=384, depth=12, num_heads=6),
"B": dict(embed_dim=768, depth=12, num_heads=12),
"L": dict(embed_dim=1024, depth=24, num_heads=16),
"LRA_CIFAR": dict(embed_dim=256, depth=1, num_heads=4, mlp_ratio=1.0),
"LRA_IMDB": dict(embed_dim=256, depth=4, num_heads=4, mlp_ratio=4.0),
"LRA_ListOps": dict(embed_dim=512, depth=4, num_heads=8, mlp_ratio=4.0),
}
class ResizingInterface:
"""Interface for resizing parts of a Vision Transformer model."""
def get_internal_loss(self):
"""Add a term to the loss."""
return 0.0
def set_image_res(self, res):
"""Set a new image resolution.
Resets the (learned) patch embedding.
Args:
res (int): new image resolution
"""
self._set_input_strand(res=res)
def _set_input_strand(self, res=None, patch_size=None):
"""Set a new image resolution and patch size.
Args:
res (int): (Default value = None)
patch_size (int): (Default value = None)
"""
if res is None:
res = self.img_size
if patch_size is None:
patch_size = self.patch_size
else:
# TODO: implement interpolation of patch_embed weights to new patch size/input shape
raise NotImplementedError("Interpolation of patch_embed weights to new patch size not implemented yet.")
if res == self.img_size and patch_size == self.patch_size:
return # nothing to do here
logger.info(f"Resizing input from {self.img_size} to {res} with patch size {self.patch_size} to {patch_size}.")
old_patch_embed_state = copy(self.patch_embed.state_dict())
self.patch_embed = self.embed_layer(
img_size=res,
patch_size=patch_size,
in_chans=self.in_chans,
embed_dim=self.embed_dim,
bias=not self.pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
)
self.patch_embed.load_state_dict(old_patch_embed_state)
num_extra_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
orig_size = int((self.pos_embed.shape[-2] - num_extra_tokens) ** 0.5)
new_size = int(self.patch_embed.num_patches**0.5)
extra_tokens = self.pos_embed[:, :num_extra_tokens]
pos_tokens = self.pos_embed[:, num_extra_tokens:]
# make it shape rest x embed_dim x orig_size x orig_size
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
pos_tokens = nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
# make it shape rest x new_size^2 x embed_dim
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
if num_extra_tokens > 0:
pos_tokens = torch.cat((extra_tokens, pos_tokens), dim=1)
self.pos_embed = nn.Parameter(pos_tokens.contiguous())
self.img_size = res
self.patch_size = patch_size
def set_num_classes(self, n_classes):
"""Reset the classification head with a new number of classes.
Args:
n_classes (int): new number of classes
"""
if n_classes == self.num_classes:
return
logger.info(f"Resizing classification head from {self.num_classes} to {n_classes}.")
self.head = nn.Linear(self.embed_dim, n_classes) if n_classes > 0 else nn.Identity()
self.num_classes = n_classes
# init weight + bias
# nn.init.zeros_(self.head.weight)
# nn.init.constant_(self.head.bias, -log(self.num_classes))
nn.init.trunc_normal_(self.head.weight, std=0.02)
nn.init.constant_(self.head.bias, 0)

View File

@@ -0,0 +1,46 @@
import os
import sys
assert len(sys.argv) >= 2, f"Add a sbatch script as the first argument"
assert os.path.isfile(
sys.argv[1]
), f"First argument has to be an executable script (a file that exists), but got '{sys.argv[1]}'"
with open(sys.argv[1], "r") as f:
script = f.readlines()
script = [l.strip() for l in script if len(l.strip()) > 0]
# join lines ending with \
joined_script = []
current_line = ""
for line in script:
current_line += line
if current_line.endswith("\\"):
current_line = current_line[:-1]
else:
joined_script.append(current_line)
current_line = ""
script = joined_script
additional_srun_params = []
srun_lines = []
for line in script:
if line.upper().startswith("#SBATCH "):
param = line[len("#SBATCH ") :]
if param.startswith("--output="):
continue
additional_srun_params.append(param)
elif line.startswith("srun "):
srun_lines.append(line)
further_args = " ".join(sys.argv[2:])
sruns = [
line.replace("srun ", "srun " + " ".join(additional_srun_params) + " ").replace('"$@"', further_args)
for line in srun_lines
]
for srun_line in sruns:
print(f"I will run:\n{srun_line}", flush=True)
os.system(srun_line)

View File

@@ -0,0 +1,97 @@
import math
import os
import sys
import numpy as np
from loguru import logger
from matplotlib import pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from load_dataset import prepare_dataset
from utils import log_args, log_formatter, prep_kwargs
def load_images(dataset, **kwargs):
args = prep_kwargs(kwargs)
args.dataset = dataset
args.aug_normalize = False
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
images = next(iter(loader))[0]
images = images.permute(0, 2, 3, 1).numpy()
images = [images[i] for i in range(images.shape[0])]
rows = math.ceil(math.sqrt(len(images) / 2))
ims_per_row = len(images) // rows
fig, axs = plt.subplots(rows, ims_per_row)
axs = [ax for row in axs for ax in row]
for img, ax in zip(images, axs):
ax.imshow(img)
fig.suptitle(f"Examples from {dataset}")
fig.tight_layout(pad=0)
plt.show()
def save_images(dataset, out_dir, ipc=None, **kwargs):
args = prep_kwargs(kwargs)
args.dataset = dataset
args.out_dir = out_dir
args.ipc = ipc
args.aug_normalize = False
log_file = os.path.join(out_dir, "save_images.log")
logger.remove()
logger.configure(extra=dict(run_name=f"Save images of {dataset}", rank=0, world_size=-1))
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
logger.info(f"Out dir '{out_dir}'")
log_args(args)
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
n_ims = [0 for i in range(args.n_classes)]
if args.n_classes == 1000:
# assume its ImageNet classes
logger.info("1000 classes => assuming ImageNet class names")
with open("data/misc_dataset_files/imagenet_labels.txt", "r") as f:
lines = f.readlines()
labels = [l.strip().split(" ")[0] for l in lines]
lbl_to_cls_name = sorted(labels)
else:
lbl_to_cls_name = [i for i in range(args.n_classes)]
for cls_name in lbl_to_cls_name:
os.makedirs(os.path.join(args.out_dir, cls_name), exist_ok=True)
skipped_ims = 0
tqdm_is_disabled = (not args.tqdm) or os.environ.get("TQDM_DISABLE", 0) != 0
for i, (images, labels) in (
pbar := tqdm(enumerate(loader), desc="Loading and saving images", disable=tqdm_is_disabled, total=len(loader))
):
images = (images.permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
images = [images[i] for i in range(images.shape[0])]
labels = labels.tolist()
for img, lbl in zip(images, labels):
if ipc is not None and n_ims[lbl] >= ipc:
skipped_ims += 1
continue
img = Image.fromarray(img).save(
os.path.join(args.out_dir, lbl_to_cls_name[lbl], f"{lbl_to_cls_name[lbl]}_{n_ims[lbl]}.JPEG")
)
n_ims[lbl] += 1
if ipc is not None and sum(n_ims) >= args.n_classes * ipc:
break
if tqdm_is_disabled:
if i % 1000 == 0:
logger.info(f"Batch [{i+1}/{len(loader)}]: Saved {sum(n_ims)}, skipped {skipped_ims}")
else:
pbar.set_description(f"Loading and saving (saved {sum(n_ims)}, skipped {skipped_ims})")
logger.success(f"Extracted {sum(n_ims)} images.")

View File

@@ -0,0 +1,303 @@
import os
import random
import sys
import numpy as np
import timm
import torch
from loguru import logger
from engine import (
_train,
setup_criteria_mixup,
setup_model_optim_sched_scaler,
setup_tracking_and_logging,
wandb_available,
)
from load_dataset import prepare_dataset
from models import load_pretrained, prepare_model
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs, set_filter_warnings
def finetune(model, dataset, epochs, val_dataset=None, head_only=False, **kwargs):
"""Finetune a pretrained model on a given dataset for a specified number of epochs.
Args:
model (str): Path to the pretrained model state file (in .tar format).
dataset (str): Name of the dataset to finetune on.
val_dataset (str, optional): Name of the validation dataset. (Default value = None)
epochs (int): Number of epochs to train for.
head_only (bool, optional): Whether to train only the head of the model. Default: False.
**kwargs (dict): Further arguments for model setup, training, evaluation,...
Notes:
This function assumes that the model was pretrained on a different dataset using a different set of hyperparameters.
It fine-tunes the model on a new dataset by loading the pretrained weights and training for the specified number of
epochs. The function supports distributed training using the PyTorch DistributedDataParallel module.
"""
set_filter_warnings()
# Add defaults & make keys properties
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
args.val_dataset = val_dataset
args.dataset = dataset
args.epochs = epochs
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
args.world_size = world_size
try:
torch.cuda.set_device(device)
except RuntimeError as e:
logger.error(
f"Could not set device {device} as current device; "
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
raise e
args.batch_size = int(args.batch_size / world_size)
if args.seed is not None:
# fix the seed for reproducibility
seed = args.seed + rank
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# get the datasets & dataloaders
# transform only contains resize & crop here; everything else is handled on the GPU / in the training loop
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
dataset, args, rank=rank
)
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
assert (
args.n_classes == _val_classes
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
save_state = torch.load(model, map_location="cpu")
old_args = prep_kwargs(save_state["args"])
parent_folder = os.path.dirname(model)
args.model = old_args.model
run_folder = setup_tracking_and_logging(args, rank)
if rank == 0:
if not os.path.exists(os.path.join(run_folder, "parent_run")):
os.symlink(parent_folder, os.path.join(run_folder, "parent_run"), target_is_directory=True)
logger.info(
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
if args.seed:
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
# model_name = old_args.model
if rank == 0:
logger.info(f"The model was pretrained on {old_args.dataset} for {save_state['epoch']} epochs.")
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(
model, device, epochs, args, head_only=head_only
)
# log all devices
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
if rank == 0:
logger.info(f"torch version {torch.__version__}")
logger.info(f"timm version {timm.__version__}")
logger.info(f"full set of old arguments: {old_args}")
log_args(args)
if args.seed:
torch.manual_seed(seed)
criterion, val_criterion, mixup = setup_criteria_mixup(args)
if rank == 0:
logger.info(f"Run info at: '{run_folder}'")
res = _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
run_folder,
scaler,
do_metrics_calculation=True,
show_tqdm=args.tqdm,
train_dali_server=train_dali_server,
val_dali_server=val_dali_server,
)
if rank == 0:
best_acc_key = sorted([key for key in res.keys() if key.startswith("val/best_")])[0]
logger.info(
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
)
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)
def pretrain(model, dataset, epochs, val_dataset=None, **kwargs):
"""Train or pretrain a model.
Args:
model (str): Name of the model to train.
dataset (str): Name of the dataset to train the model on.
epochs (int): Number of training epochs.
val_dataset (str, optional, optional): Name of the validation dataset, by default None
**kwargs (dict): Additional keyword arguments.
Notes:
This function sets up logger, prepares the model, and trains the model on the given dataset.
"""
set_filter_warnings()
# Add defaults & make args properties
args = prep_kwargs(kwargs)
if val_dataset is None:
val_dataset = dataset
args.val_dataset = val_dataset
args.dataset = dataset
args.model = model
args.epochs = epochs
args.distributed, device, world_size, rank, gpu_id = ddp_setup(args.cuda)
args.world_size = world_size
# sleep(rank * 5)
# logger.debug(f'running environment commands for rank {rank}')
# os.system('env')
# os.system('nvidia-smi')
# sleep((world_size - rank) * 5)
logger.debug(
f"rank params: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; gpu params: "
f"SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}"
)
if args.cuda:
try:
torch.cuda.set_device(device)
except RuntimeError as e:
logger.error(f"Could not set device {device} as current device: {e}")
logger.error(
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
raise e
args.batch_size = int(args.batch_size / world_size)
run_folder = setup_tracking_and_logging(args, rank)
if rank % world_size == 0:
logger.info(
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
)
if args.seed is not None:
# fix the seed for reproducibility
seed = args.seed + rank
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
# get the datasets & dataloaders
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
dataset, args, rank=rank
)
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
assert (
args.n_classes == _val_classes
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
# setup model with amp & DDP
if isinstance(model, str):
if model.startswith("ViT") and "_" not in model:
model += f"_{args.imsize}"
model_name = model
model = prepare_model(model, args)
if not model_name:
model_name = type(model).__name__
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
# log all devices
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if device != 'cpu' else ''}")
if rank == 0:
logger.info(f"python version {sys.version}")
logger.info(f"torch version {torch.__version__}")
logger.info(f"timm version {timm.__version__}")
log_args(args)
if args.seed:
torch.manual_seed(seed)
criterion, val_criterion, mixup = setup_criteria_mixup(args)
if rank == 0:
logger.info(f"Run info at: '{run_folder}'")
res = _train(
model,
train_loader,
optimizer,
rank,
epochs,
device,
mixup,
criterion,
world_size,
scheduler,
args,
val_loader,
val_criterion,
run_folder,
scaler,
do_metrics_calculation=True,
show_tqdm=args.tqdm,
train_dali_server=train_dali_server,
val_dali_server=val_dali_server,
)
if rank == 0:
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
logger.info(
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
)
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)

View File

@@ -0,0 +1,594 @@
"""Utils and small helper functions."""
import collections.abc
import json
import os
import shutil
import sys
import warnings
from dataclasses import dataclass
from itertools import repeat
from math import cos, pi, sqrt
import numpy as np
import torch
import torch.distributed as dist
from loguru import logger
from timm.data import Mixup
from timm.utils import NativeScaler, dispatch_clip_grad
from torch.nn.modules.loss import _WeightedLoss
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import paths_config
from config import default_kwargs, get_default_kwargs # noqa: F401 # is used in prep_kwargs
class RepeatedDataset(Dataset):
"""Dataset that repeats the given dataset a number of times."""
def __init__(self, dataset, num_repeats):
"""Create repeated dataset.
Args:
dataset (Dataset): dataset to repeat.
num_repeats (int): number of repeats.
"""
self.dataset = dataset
self.num_repeats = num_repeats
def __getitem__(self, idx):
return self.dataset[idx // self.num_repeats]
def __len__(self):
return len(self.dataset) * self.num_repeats
@dataclass
class SchedulerArgs:
"""Class for scheduler arguments."""
sched: str
epochs: int
min_lr: float
warmup_lr: float
warmup_epochs: int
cooldown_epochs: int = 0
def scheduler_function_factory(
epochs, sched, warmup_epochs=0, lr=None, min_lr=0.0, warmup_sched=None, warmup_lr=None, offset=0, **kwargs
):
"""Create a scheduler factor function.
Args:
sched (str): the learning rate schedule type
epochs (int): length of the full schedule
warmup_epochs (int, optional): number of epochs reserved for warmup (Default value = 0)
lr (float, optional): learning rate (has to be given, when warmup or min_lr are set) (Default value = None)
min_lr (float, optional): minimum learning rate (Default value = 0.0)
warmup_sched (str, optional): the type of schedule during warmup (Default value = None)
warmup_lr (float, optional): (starting) learning rate during warmup (Default value = None)
offset (int, optional): offset for the schedule (to be the same as the timm scheduler) (Default value = 0)
**kwargs: unused
Returns:
function: scheduler function
"""
sched = sched.lower()
def warmup_f(ep):
return 1.0
if warmup_epochs > 0:
assert warmup_lr is not None, "Need warmup_lr, but got None"
warmup_lr_factor = warmup_lr / lr
if warmup_sched == "linear":
def warmup_f(ep):
return warmup_lr_factor + (1 - warmup_lr_factor) * max(ep, 0.0) / warmup_epochs
elif warmup_sched == "const":
def warmup_f(ep):
return warmup_lr_factor
else:
raise NotImplementedError(f"Warmup schedule {warmup_sched} not implemented")
epochs = epochs - warmup_epochs + offset
if sched == "cosine":
# cos from 0 to pi
def main_f(ep):
return cos(pi * ep / epochs) / 2 + 0.5
elif sched == "const":
def main_f(ep):
return 1.0
else:
raise NotImplementedError(f"Schedule {sched} is not implemented.")
# rescale and add min_lr
min_lr_fact = min_lr / lr
def main_f_with_min_lr(ep):
return (1 - min_lr_fact) * main_f(ep) + min_lr_fact
return lambda ep: (
warmup_f(ep + offset) if ep + offset < warmup_epochs else main_f_with_min_lr(ep + offset - warmup_epochs)
)
class DotDict(dict):
"""Extension of a Python dictionary to access its keys using dot notation."""
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __getattr__(self, item, default=None):
"""Get item from.
Args:
item: key
default (optional): default value. Defaults to None.
Returns:
value
"""
if item not in self:
return default
return self.get(item)
def prep_kwargs(kwargs):
"""Prepare the arguments and add defaults.
Args:
kwargs (dict[str, Any]): dict of kwargs
Returns:
DotDict: prepared kwargs
"""
if "defaults" not in kwargs:
kwargs["defaults"] = "DeiTIII"
defaults = get_default_kwargs(kwargs["defaults"])
for k, v in defaults.items():
if k not in kwargs or kwargs[k] is None:
kwargs[k] = v
if "results_folder" not in kwargs:
kwargs[var_name] = paths_config.results_folder # globals()[var_name]
if kwargs["results_folder"].endswith("/"):
kwargs["results_folder"] = kwargs["results_folder"][:-1]
if "val_dataset" not in kwargs and "dataset" in kwargs:
kwargs["val_dataset"] = kwargs["dataset"]
return DotDict(kwargs)
def denormalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
"""Invert the normlize operation.
Args:
x (torch.Tensor): images to de-normalize
mean (tuple, optional): normalization mean. Defaults to (0.485, 0.456, 0.406).
std (tuple, optional): normalization std. Defaults to (0.229, 0.224, 0.225).
Returns:
torch.Tensor: de-normalized images
"""
operation = transforms.Normalize(
mean=[-mu / sigma for mu, sigma in zip(mean, std)], std=[1 / sigma for sigma in std]
)
return operation(x)
def log_formatter(record):
if "run_name" not in record["extra"]:
return (
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <y>name TBD</y> > <y>?</y>/<y>?</y> <c>|</c> <level>{level:"
" <8}</level> <c>|</c> {message}\n"
)
epoch_str = "@ epoch <y>{extra[epoch]: >3}</y> " if "epoch" in record["extra"] else ""
code_loc_str = "<r>{name}</r>.<r>{function}</r>:<r>{line}</r> - " if record["level"].no >= 30 else ""
return (
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <m>{extra[run_name]}</m> >"
" <m>{extra[rank]}</m>/<m>{extra[world_size]}</m> "
+ epoch_str
+ "<c>|</c> <level>{level: <8}</level> <c>|</c> "
+ code_loc_str
+ "{message}\n"
)
def ddp_setup(use_cuda=True):
"""Set up the distributed environment.
Args:
use_cuda: (Default value = True)
Returns:
tuple: A tuple containing the following elements:
* bool: Whether the training is distributed.
* torch.device: The device to use for distributed training.
* int: The total number of processes in the distributed setup.
* int: The global rank of the current process in the distributed setup.
* int: The local rank of the current process on its node.
Notes:
The 'nccl' backend is used.
"""
logger.remove()
rank = int(os.getenv("RANK", 0))
local_rank = int(os.getenv("LOCAL_RANK", 0))
num_gpus = int(os.getenv("WORLD_SIZE", 1))
distributed = "RANK" in os.environ and num_gpus > 1
logger.add(sys.stderr, format=log_formatter, colorize=True, enqueue=True)
if distributed:
assert use_cuda, "Only use distributed mode with cuda."
try:
dist.init_process_group("nccl")
except ValueError as e:
logger.critical(f"Value error while setting up nccl process group: {e}")
logger.info(
f" CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}. Shutting down now."
)
raise e
assert torch.cuda.is_available() or not use_cuda, "CUDA is not available"
assert (
len(str(os.environ.get("SLURM_STEP_GPUS")).split(","))
== len(str(os.environ.get("CUDA_VISIBLE_DEVICES")).split(","))
== len(str(os.environ.get("GPU_DEVICE_ORDINAL")).split(","))
== num_gpus
) or not use_cuda, (
f"SLURM GPU setup is incorrect: CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}"
)
return distributed, torch.device(f"cuda:{local_rank}") if use_cuda else "cpu", num_gpus, rank, local_rank
def ddp_cleanup(args, sync_old_wandb=False, rank=0):
"""Clean the distributed setup after use.
Args:
args (DotDict): arguments
sync_old_wandb (bool, optional): Whether to sync and remove wandb runs older than 100 hours (>3 days). Defaults to False.
rank (int, optional): The rank of the current process, so only one process syncs wandb. Defaults to 0.
"""
if sync_old_wandb and rank == 0:
os.system("wandb sync --clean --clean-old-hours 100 --clean-force")
if args.distributed:
logger.info("waiting for all processes to finish")
dist.barrier()
dist.destroy_process_group()
logger.info("exiting now")
exit(0)
def set_filter_warnings():
"""Filter out some warnings to reduce spam."""
# filter DataLoader number of workers warning
warnings.filterwarnings(
"ignore", ".*worker processes in total. Our suggested max number of worker in current system is.*"
)
# Filter datadings only varargs warning
warnings.filterwarnings("ignore", ".*only accepts varargs so.*")
# Filter warnings from calculation of MACs & FLOPs
# warnings.filterwarnings("ignore", ".*No handlers found:.*")
# Filter warnings from gather
warnings.filterwarnings("ignore", ".*is_namedtuple is deprecated, please use the python checks instead.*")
# Filter warnings from meshgrid
warnings.filterwarnings("ignore", ".*in an upcoming release, it will be required to pass the indexing.*")
# Filter warnings from timm when overwriting models
warnings.filterwarnings("ignore", ".*UserWarning: Overwriting .*")
def remove_prefix(state_dict, prefix="module."):
"""Remove a prefix from the keys in a state dictionary.
Args:
state_dict (dict[str, Any]): The state dictionary to remove the prefix from.
prefix (str, optional): The prefix to remove from the keys. Default is 'module.'.
Returns:
dict[str, Any]: A new dictionary with the prefix removed from the keys.
Examples:
>>> state_dict = {'module.layer1.weight': 1, 'module.layer1.bias': 2}
>>> remove_prefix(state_dict)
{'layer1.weight': 1, 'layer1.bias': 2}
"""
return {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
def prime_factors(n):
"""Calculate the prime factors of a given integer.
Args:
n (int): The integer to find the prime factors of.
Returns:
list[int]: The prime factors of n.
"""
i = 2
factors = []
while i * i <= n:
if n % i:
i += 1
else:
n //= i
factors.append(i)
if n > 1:
factors.append(n)
return factors
def linear_regession(points):
"""Calculate a linear interpolation of the points.
Args:
points (dict[float, float]): points to interpolate in the format points[x] = y
Returns:
function: A function that interpolates the points.
"""
N = len(points)
x = []
y = []
for x_i, y_i in points.items():
x.append(x_i)
y.append(y_i)
x = np.array(x)
y = np.array(y)
a = (N * (x * y).sum() - x.sum() * y.sum()) / (N * (x * x).sum() - x.sum() ** 2)
b = (y.sum() - a * x.sum()) / N
return lambda z: a * z + b
def save_model_state(
model_folder,
epoch,
args,
model_state,
regular_save=True,
stats=None,
val_accs=None,
epoch_accs=None,
additional_reason="",
max_interm_ep_states=2,
**kwargs,
):
"""Save the model state.
Args:
model_folder: Folder to save model in
epoch: current epoch
args: arguments to guide and save
model_state: state of the model
regular_save: Is this a regular or a special save? (Default value = True)
stats: model stats to save (Default value = None)
val_accs: model accuracy to save (Default value = None)
epoch_accs: training accuracy to save (Default value = None)
additional_reason: save reason; in case it's not just a regular save interval. Would be "top" or "final", for example. (Default value = "")
max_interm_ep_states: Number of regular epoch states to keep (Default value = 2)
**kwargs: Further arguments to save
"""
# make args dict, not DotDict to be able to save it
state = {"epoch": epoch, "model_state": model_state, "run_name": args.run_name, "args": dict(args)}
if stats is None:
stats = {}
if val_accs is not None:
stats = {**stats, **val_accs}
if epoch_accs is not None:
stats = {**stats, **epoch_accs}
state["stats"] = stats
state = {**state, **kwargs}
logger.info(f"saving model state at epoch {epoch} ({additional_reason})")
regular_file_name = f"ep_{epoch}.pt"
save_name = additional_reason + ".pt" if len(additional_reason) > 0 else regular_file_name
outfile = os.path.join(model_folder, save_name)
torch.save(state, outfile)
if len(additional_reason) > 0 and regular_save:
shutil.copyfile(outfile, os.path.join(model_folder, regular_file_name))
# remove intermediate epoch states (all but the last max_interm_ep_states)
if max_interm_ep_states > 0:
epoch_states = [f for f in os.listdir(model_folder) if f.startswith("ep_") and f.endswith(".pt")]
epoch_states = sorted(epoch_states, key=lambda x: int(x.split("_")[1].split(".")[0]))
if len(epoch_states) > max_interm_ep_states:
for f in epoch_states[:-max_interm_ep_states]:
os.remove(os.path.join(model_folder, f))
logger.debug(f"removed intermediate epoch state {f}")
def log_args(args, rank=0):
if rank == 0:
logger.info("full set of arguments: " + json.dumps(dict(args), sort_keys=True))
# keys = sorted(list(args.keys()))
# for key in keys:
# logger.info(f"arg: {key} = {args[key]}")
class ScalerGradNormReturn(NativeScaler):
"""A wrapper around PyTorch's NativeScaler that returns the gradient norm."""
def __str__(self) -> str:
return f"{type(self).__name__}(_scaler: {self._scaler})"
def __call__(self, loss, optimizer, clip_grad=None, clip_mode="norm", parameters=None, create_graph=False):
"""Scale and backpropagate through the loss tensor, and return the gradient norm of the selected parameters.
Does an optimizer step.
Args:
loss (torch.Tensor): The loss tensor to scale and backpropagate through.
optimizer (torch.optim.Optimizer): The optimizer to use for the optimization step.
clip_grad (float, optional): The maximum allowed norm of the gradients. If None, no clipping is performed.
clip_mode (str, optional): The mode used for clipping the gradients. Only used if `clip_grad` is not None. Possible values are 'norm'
(clipping the norm of the gradients) and 'value' (clipping the value of the gradients). (default='norm')
parameters (iterable[torch.nn.Parameter], optional): The parameters to compute the gradient norm for. If None, the gradient norm is not computed.
create_graph (bool, optional): Whether to create a computation graph for computing second-order gradients. (default=False)
Returns:
float: The gradient norm of the selected parameters.
"""
self._scaler.scale(loss).backward(create_graph=create_graph)
# always unscale the gradients, since it's being done anyway
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
if parameters is not None:
grads = [p.grad for p in parameters if p.grad is not None]
device = grads[0].device
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
else:
grad_norm = -1
if clip_grad is not None:
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()
return grad_norm
class NoScaler:
"""Dummy gradient scaler that doesn't scale gradients.
This scaler performs a simple backward pass with the given loss, and then updates the model's parameters
with the given optimizer. The resulting gradient norm is computed and returned.
"""
def __str__(self) -> str:
return f"{type(self).__name__}()"
def __call__(self, loss, optimizer, parameters=None, **kwargs):
"""Perform backward pass with the given loss, updates the model's parameters with the given optimizer, and computes the resulting gradient norm.
Args:
loss (torch.Tensor): The loss tensor that the gradients will be computed from.
optimizer (torch.optim.Optimizer): The optimizer that will be used to update the model's parameters.
parameters (iterable[torch.Tensor], optional): An iterable of model parameters to compute gradients. If None, returns -1.
**kwargs: Additional keyword arguments; nothing will be done with these.
Returns:
float: The gradient norm computed after the optimizer step, if parameters is not None.
"""
loss.backward()
if parameters is not None:
grads = [p.grad for p in parameters if p.grad is not None]
device = grads[0].device
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
else:
grad_norm = -1
optimizer.step()
return grad_norm
def get_cpu_name():
"""Get the name of the CPU."""
with open("/proc/cpuinfo", "r") as f:
for line in f:
if line.startswith("model name"):
return line.split(":")[1].strip()
return "unknown"
# From PyTorch internals
def _ntuple(n):
"""Make a function to create n-tuples.
Args:
n (int): tuple length
Returns:
function: function to create n-tuples
"""
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
"""Calculate the smallest number >= v that is divisible by divisor.
This function is primarily used to ensure that the output of a layer
is divisible by a certain number, typically to align with hardware
optimizations or memory layouts.
Args:
v (int): The input value.
divisor (int, optional): The divisor. Defaults to 8.
min_value (int, optional): The minimum value to return. If None, defaults to the divisor.
round_limit (float, optional): A threshold for rounding down. If the result of rounding down is less than round_limit * v, the next multiple of the divisor is returned instead. Defaults to 0.9.
Returns:
int: The smallest number >= v that is divisible by divisor.
"""
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < round_limit * v:
new_v += divisor
return new_v
def grad_cam_reshape_transform(tensor):
"""Transform the tensor for Grad-CAM calculation.
Args:
tensor (torch.Tensor): input tensor
Returns:
torch.Tensor: reshaped tensor without [CLS] token.
"""
n_squ = tensor.shape[1]
result = tensor[:, 1:] if int(sqrt(n_squ)) ** 2 != n_squ else tensor
bs, n, dim = result.shape
result = result.reshape(bs, int(sqrt(n)), int(sqrt(n)), dim)
# Bring the channels to the first dimension,
# like in CNNs.
return result.transpose(2, 3).transpose(1, 2)