AAAI Version
This commit is contained in:
76
AAAI Supplementary Material/Model Training Code/.ruff.toml
Normal file
76
AAAI Supplementary Material/Model Training Code/.ruff.toml
Normal file
@@ -0,0 +1,76 @@
|
||||
# Exclude a variety of commonly ignored directories.
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".ipynb_checkpoints",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pyenv",
|
||||
".pytest_cache",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
".vscode",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"site-packages",
|
||||
"venv",
|
||||
]
|
||||
|
||||
# Same as Black.
|
||||
line-length = 120
|
||||
indent-width = 4
|
||||
|
||||
[lint]
|
||||
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
|
||||
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
|
||||
# McCabe complexity (`C901`) by default.
|
||||
select = ["E", "F", "D", "B", "NPY", "PD", "TD005", "TD006", "TD007", "SIM", "RET", "Q", "ICN", "I"]
|
||||
ignore = ["D203", "D213", "E501", "D100", "NPY002", "D102", "B008", "PD011", "D105", "SIM118", "D417"]
|
||||
|
||||
# Allow fix for all enabled rules (when `--fix`) is provided.
|
||||
unfixable = ["B"]
|
||||
|
||||
# Allow unused variables when underscore-prefixed.
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
[lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[format]
|
||||
# Like Black, use double quotes for strings.
|
||||
quote-style = "double"
|
||||
|
||||
# Like Black, indent with spaces, rather than tabs.
|
||||
indent-style = "space"
|
||||
|
||||
# Like Black, respect magic trailing commas.
|
||||
skip-magic-trailing-comma = true
|
||||
|
||||
# Like Black, automatically detect the appropriate line ending.
|
||||
line-ending = "auto"
|
||||
|
||||
# Enable auto-formatting of code examples in docstrings. Markdown,
|
||||
# reStructuredText code/literal blocks and doctests are all supported.
|
||||
#
|
||||
# This is currently disabled by default, but it is planned for this
|
||||
# to be opt-out in the future.
|
||||
docstring-code-format = true
|
||||
|
||||
# Set the line length limit used when formatting code snippets in
|
||||
# docstrings.
|
||||
#
|
||||
# This only has an effect when the `docstring-code-format` setting is
|
||||
# enabled.
|
||||
docstring-code-line-length = "dynamic"
|
||||
107
AAAI Supplementary Material/Model Training Code/README.md
Normal file
107
AAAI Supplementary Material/Model Training Code/README.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# ForNet
|
||||
|
||||
This is the training code for the ForNet paper.
|
||||
All our experiments and evaluations were run using this codebase.
|
||||
|
||||
## Requirements
|
||||
|
||||
This project heavily builds on [timm](https://github.com/huggingface/pytorch-image-models) and open source implementations of the models that are tested.
|
||||
All requirements are listed in [requirements.txt](./requirements.txt).
|
||||
To install those, run
|
||||
|
||||
```commandline
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
After **cloning this repository**, you can train and test a lot of different models.
|
||||
By default, a `srun` command is executed to run the code on a slurm cluster.
|
||||
To run on the local machine, append the `-local` flag to the command.
|
||||
|
||||
### General Preparation
|
||||
|
||||
After cloning the repository on a slurm cluster, make sure main.py is executable (by using "chmod a+x main.py").
|
||||
|
||||
To run the project on a slurm cluster, you need to create a docker image from the requirements file.
|
||||
You will also want to adapt the default slurm parameters in `config.py`.
|
||||
|
||||
Next, adjust the paths in paths_config.py for your system, specifically results_folder, slurm_output_folder and dataset folders.
|
||||
|
||||
Finally, if you want to use Weights and Biases for Tracking, create the file ".wandb.apikey" in this folder and paste your API Key into it.
|
||||
|
||||
### Training
|
||||
|
||||
#### Pretraining
|
||||
|
||||
To pretrain a `ViT-S` on a given dataset, run
|
||||
|
||||
```commandline
|
||||
./main.py --task pre-train --model ViT-S/16 --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
|
||||
```
|
||||
|
||||
This will save a checkpoint (`.pt` file) every `<save_epochs>` epochs (the default is 10), which contains all the model weights, along with the optimizer and scheduler state, and the current training stats.
|
||||
|
||||
#### Finetuning
|
||||
|
||||
A model (checkpoint) can be finetuned on another dataset using the following command:
|
||||
|
||||
```commandline
|
||||
./main.py --task fine-tune --model <model_checkpoint.pt> --epochs 300 --run_name <name_or_description_of_the_run> --experiment_name recombine_imagenet --lr 3e-3 (--local)
|
||||
```
|
||||
|
||||
This will also save new checkpoints during training.
|
||||
|
||||
### Evaluation
|
||||
|
||||
It is also possible to evaluate the models.
|
||||
To evaluate the model's accuracy on a specific dataset, run
|
||||
|
||||
```commandline
|
||||
./main.py -t eval -ds <dataset name> -m <model_checkpoint.pt> --ntasks 1 -bs 512 --num-workers 10 --cpus-per-task 10 --time 10:00 (--local)
|
||||
```
|
||||
|
||||
You can run our center-bias, size-bias, and foreground-focus evaluations using the `eval-attr`, `eval-center-bias`, and `eval-size-bias` tasks (`-t` or `--task` argument).
|
||||
|
||||
### Further Arguments
|
||||
|
||||
There can be multiple further arguments and flags given to the scripts.
|
||||
The most important ones are
|
||||
|
||||
| Arg | Description |
|
||||
| :------------------------------ | :----------------------------------------------------- |
|
||||
| `--model <model>` | Model name or checkpoint. |
|
||||
| `--run_name <name for the run>` | Name or description of this training run. |
|
||||
| `--dataset <dataset>` | Specifies a dataset to use. |
|
||||
| `--task <task>` | Specifies a task. The default is `pre-train`. |
|
||||
| `--local` | Run on the local machine, not on a slurm cluster. |
|
||||
| `--epochs <epochs>` | Epochs to train. |
|
||||
| `--lr <lr>` | Learning rate. Default is 3e-3. |
|
||||
| `--batch_size <bs>` | Batch size. Default is 2048. |
|
||||
| `--weight_decay <wd>` | Weight decay. Default is 0.02. |
|
||||
| `--imsize <image resolution>` | Resulution of the image to train with. Default is 224. |
|
||||
|
||||
For a list of all arguments, run
|
||||
|
||||
```commandline
|
||||
./main.py --help
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
These are the models we support. Links are to original code sources. If no link is provided, we implemented the architecture from scratch, following the specific paper.
|
||||
|
||||
| Architecture | Versions |
|
||||
| :----------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [DeiT](https://github.com/facebookresearch/deit) | `deit_tiny_patch16_LS`, `deit_small_patch16_LS`, `deit_medium_patch16_LS`, `deit_base_patch16_LS`, `deit_large_patch16_LS`, `deit_huge_patch14_LS`, `deit_huge_patch14_52_LS`, `deit_huge_patch14_26x2_LS`, `deit_Giant_48_patch14_LS`, `deit_giant_40_patch14_LS`, `deit_small_patch16_36_LS`, `deit_small_patch16_36`, `deit_small_patch16_18x2_LS`, `deit_small_patch16_18x2`, `deit_base_patch16_18x2_LS`, `deit_base_patch16_18x2`, `deit_base_patch16_36x1_LS`, `deit_base_patch16_36x1` |
|
||||
| [ResNet](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnet.py) | `resnet18`, `resnet34`, `resnet26`, `resnet50`, `resnet101`, `wide_resnet50_2` |
|
||||
| [Swin](https://github.com/microsoft/Swin-Transformer) | `swin_tiny_patch4_window7`, `swin_small_patch4_window7`, `swin_base_patch4_window7`, `swin_large_patch4_window7` |
|
||||
| [ViT](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py) | `ViT-{Ti,S,B,L}/<patch_size>` |
|
||||
|
||||
## License
|
||||
|
||||
We release this code under the [MIT license](./LICENSE).
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this codebase in your project, please cite:
|
||||
@@ -0,0 +1,3 @@
|
||||
albumentations==2.0.5
|
||||
datasets==3.5.0
|
||||
nvidia-dali-cuda120==1.47.0
|
||||
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
|
||||
from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
from architectures.vit import TimmViT
|
||||
|
||||
__all__ = [
|
||||
"deit_tiny_patch16_224",
|
||||
"deit_small_patch16_224",
|
||||
"deit_base_patch16_224",
|
||||
"deit_tiny_distilled_patch16_224",
|
||||
"deit_small_distilled_patch16_224",
|
||||
"deit_base_distilled_patch16_224",
|
||||
"deit_base_patch16_384",
|
||||
"deit_base_distilled_patch16_384",
|
||||
]
|
||||
|
||||
|
||||
class DistilledVisionTransformer(TimmViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.dist_token, std=0.02)
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
self.head_dist.apply(self._init_weights)
|
||||
|
||||
def forward_features(self, x):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add the dist_token
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
return x[:, 0], x[:, 1]
|
||||
|
||||
def forward(self, x):
|
||||
x, x_dist = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
x_dist = self.head_dist(x_dist)
|
||||
if self.training:
|
||||
return x, x_dist
|
||||
else:
|
||||
# during inference, return the average of both classifier predictions
|
||||
return (x + x_dist) / 2
|
||||
|
||||
|
||||
def _clean_kwargs(kwargs):
|
||||
allowed_keys = {key for key in kwargs.keys() if not key.startswith("pretrain")}
|
||||
allowed_keys = {key for key in allowed_keys if not key.startswith("cache")}
|
||||
return {key: kwargs[key] for key in allowed_keys}
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = TimmViT(
|
||||
patch_size=16,
|
||||
embed_dim=192,
|
||||
depth=12,
|
||||
num_heads=3,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = TimmViT(
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16(pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = TimmViT(
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_distilled_patch16(
|
||||
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
|
||||
):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = DistilledVisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=192,
|
||||
depth=12,
|
||||
num_heads=3,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_distilled_patch16(
|
||||
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
|
||||
):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = DistilledVisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_distilled_patch16(
|
||||
pretrained=False, img_size=224, drop_path_rate=0.1, num_classes=1000, drop_rate=0.0, **kwargs
|
||||
):
|
||||
kwargs = _clean_kwargs(kwargs)
|
||||
model = DistilledVisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_classes=num_classes,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
|
||||
map_location="cpu",
|
||||
check_hash=True,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
@@ -0,0 +1,850 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# Taken from https://github.com/facebookresearch/deit with slight modifications
|
||||
|
||||
from loguru import logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
from timm.models.vision_transformer import Mlp, PatchEmbed, _cfg
|
||||
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
from timm.models.registry import register_model
|
||||
|
||||
from resizing_interface import ResizingInterface
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
Attention_block=Attention,
|
||||
Mlp_block=Mlp,
|
||||
init_values=1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Layer_scale_init_Block(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
Attention_block=Attention,
|
||||
Mlp_block=Mlp,
|
||||
init_values=1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Layer_scale_init_Block_paralx2(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
Attention_block=Attention,
|
||||
Mlp_block=Mlp,
|
||||
init_values=1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.norm11 = norm_layer(dim)
|
||||
self.attn = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.attn1 = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.norm21 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.mlp1 = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = (
|
||||
x
|
||||
+ self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
+ self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
|
||||
)
|
||||
x = (
|
||||
x
|
||||
+ self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
+ self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class Block_paralx2(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
Attention_block=Attention,
|
||||
Mlp_block=Mlp,
|
||||
init_values=1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.norm11 = norm_layer(dim)
|
||||
self.attn = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.attn1 = Attention_block(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.norm21 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.mlp1 = Mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.attn1(self.norm11(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.mlp1(self.norm21(x)))
|
||||
return x
|
||||
|
||||
|
||||
class hMLP_stem(nn.Module):
|
||||
"""hMLP_stem: https://arxiv.org/pdf/2203.09795.pdf
|
||||
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
with slight modifications
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=nn.SyncBatchNorm,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
self.proj = torch.nn.Sequential(
|
||||
*[
|
||||
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=4, stride=4),
|
||||
norm_layer(embed_dim // 4),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(embed_dim // 4, embed_dim // 4, kernel_size=2, stride=2),
|
||||
norm_layer(embed_dim // 4),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=2, stride=2),
|
||||
norm_layer(embed_dim),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class vit_models(nn.Module, ResizingInterface):
|
||||
"""Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support
|
||||
taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
with slight modifications
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
global_pool=None,
|
||||
block_layers=Block,
|
||||
Patch_layer=PatchEmbed,
|
||||
act_layer=nn.GELU,
|
||||
Attention_block=Attention,
|
||||
Mlp_block=Mlp,
|
||||
dpr_constant=True,
|
||||
init_scale=1e-4,
|
||||
mlp_ratio_clstk=4.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dropout_rate = drop_rate
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_embed = Patch_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.patch_size = patch_size
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.embed_layer = Patch_layer
|
||||
self.pre_norm = False
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
self.no_embed_class = True
|
||||
|
||||
dpr = [drop_path_rate for i in range(depth)]
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
block_layers(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=0.0,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
Attention_block=Attention_block,
|
||||
Mlp_block=Mlp_block,
|
||||
init_values=init_scale,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")]
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
trunc_normal_(self.cls_token, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def set_num_classes(self, n_classes):
|
||||
super().set_num_classes(n_classes)
|
||||
self._init_weights(self.head)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {"pos_embed", "cls_token"}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def get_num_layers(self):
|
||||
return len(self.blocks)
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=""):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x, test=False):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
if test and x.isnan().any().item():
|
||||
logger.error("patch embedded input has nan value")
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
|
||||
x = x + self.pos_embed
|
||||
|
||||
if test and x.isnan().any().item():
|
||||
logger.error("position embedded input has a nan value")
|
||||
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
if test and x.isnan().any().item():
|
||||
logger.error("input with [CLS] has a nan value")
|
||||
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if test and x.isnan().any().item():
|
||||
logger.error(f"output of block {i} has a nan value")
|
||||
|
||||
x = self.norm(x)
|
||||
return x[:, 0]
|
||||
|
||||
def forward(self, x, test=False):
|
||||
|
||||
x = self.forward_features(x, test=test)
|
||||
|
||||
if self.dropout_rate:
|
||||
x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=192,
|
||||
depth=12,
|
||||
num_heads=3,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
name = "https://dl.fbaipublicfiles.com/deit/deit_3_small_" + str(img_size) + "_"
|
||||
if pretrained_21k:
|
||||
name += "21k.pth"
|
||||
else:
|
||||
name += "1k.pth"
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_medium_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
patch_size=16,
|
||||
embed_dim=512,
|
||||
depth=12,
|
||||
num_heads=8,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
name = "https://dl.fbaipublicfiles.com/deit/deit_3_medium_" + str(img_size) + "_"
|
||||
if pretrained_21k:
|
||||
name += "21k.pth"
|
||||
else:
|
||||
name += "1k.pth"
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
if pretrained:
|
||||
name = "https://dl.fbaipublicfiles.com/deit/deit_3_base_" + str(img_size) + "_"
|
||||
if pretrained_21k:
|
||||
name += "21k.pth"
|
||||
else:
|
||||
name += "1k.pth"
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_large_patch16_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
if pretrained:
|
||||
name = "https://dl.fbaipublicfiles.com/deit/deit_3_large_" + str(img_size) + "_"
|
||||
if pretrained_21k:
|
||||
name += "21k.pth"
|
||||
else:
|
||||
name += "1k.pth"
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_huge_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
embed_dim=1280,
|
||||
depth=32,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
if pretrained:
|
||||
name = "https://dl.fbaipublicfiles.com/deit/deit_3_huge_" + str(img_size) + "_"
|
||||
if pretrained_21k:
|
||||
name += "21k_v1.pth"
|
||||
else:
|
||||
name += "1k_v1.pth"
|
||||
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url=name, map_location="cpu", check_hash=True)
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_huge_patch14_52_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
embed_dim=1280,
|
||||
depth=52,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_huge_patch14_26x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
embed_dim=1280,
|
||||
depth=26,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block_paralx2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# @register_model
|
||||
# def deit_Giant_48x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
# model = vit_models(
|
||||
# img_size=img_size, patch_size=14, embed_dim=1664, depth=48, num_heads=16, mlp_ratio=4,
|
||||
# norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
|
||||
#
|
||||
# return model
|
||||
|
||||
|
||||
# @register_model
|
||||
# def deit_giant_40x2_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
# model = vit_models(
|
||||
# img_size=img_size, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4,
|
||||
# norm_layer=partial(nn.LayerNorm, eps=1e-6), block_layers=Block_paral_LS, **kwargs)
|
||||
# return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_Giant_48_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
embed_dim=1664,
|
||||
depth=48,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_giant_40_patch14_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
embed_dim=1408,
|
||||
depth=40,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
# model.default_cfg = _cfg()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# Models from Three things everyone should know about Vision Transformers (https://arxiv.org/pdf/2203.09795.pdf)
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_36_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=36,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_36(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=36,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=18,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block_paralx2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=18,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Block_paralx2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_18x2_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=18,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block_paralx2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_18x2(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=18,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Block_paralx2,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_36x1_LS(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=36,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_layers=Layer_scale_init_Block,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_36x1(pretrained=False, img_size=224, pretrained_21k=False, **kwargs):
|
||||
model = vit_models(
|
||||
img_size=img_size,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=36,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,111 @@
|
||||
from timm.models import register_model
|
||||
from timm.models.resnet import BasicBlock, Bottleneck
|
||||
from timm.models.resnet import ResNet as ResNetTimm
|
||||
from torch import nn
|
||||
|
||||
from resizing_interface import ResizingInterface
|
||||
|
||||
|
||||
class ResNet(ResNetTimm, ResizingInterface):
|
||||
"""The popular ResNet model with a ResizingInterface."""
|
||||
|
||||
def __init__(self, *args, global_pool="avg", **kwargs):
|
||||
"""Create ResNet model with resizing capabilities.
|
||||
|
||||
Args:
|
||||
*args: Arguments for the ResNet model.
|
||||
global_pool (str, optional): _description_. Defaults to "avg".
|
||||
**kwargs: Keyword arguments for the ResNet model (from Timm).
|
||||
|
||||
Keyword Args:
|
||||
block, layers, num_classes, in_chans, output_stride, cardinality, base_width, stem_width, stem_type, replace_stem_pool, block_reduce_first, down_kernel_size, avg_down, act_layer, norm_layer, aa_layer, drop_rate, drop_path_rate, drop_block_rate, zero_init_last, block_args
|
||||
|
||||
"""
|
||||
admissible_kwargs = [
|
||||
"block",
|
||||
"layers",
|
||||
"num_classes",
|
||||
"in_chans",
|
||||
"output_stride",
|
||||
"cardinality",
|
||||
"base_width",
|
||||
"stem_width",
|
||||
"stem_type",
|
||||
"replace_stem_pool",
|
||||
"block_reduce_first",
|
||||
"down_kernel_size",
|
||||
"avg_down",
|
||||
"act_layer",
|
||||
"norm_layer",
|
||||
"aa_layer",
|
||||
"drop_rate",
|
||||
"drop_path_rate",
|
||||
"drop_block_rate",
|
||||
"zero_init_last",
|
||||
"block_args",
|
||||
]
|
||||
for key in list(kwargs.keys()):
|
||||
if key not in admissible_kwargs:
|
||||
kwargs.pop(key)
|
||||
super().__init__(*args, global_pool=global_pool, **kwargs)
|
||||
self.global_pool_str = global_pool
|
||||
|
||||
def set_image_res(self, res):
|
||||
# resizing not needed for CNNs with pooling
|
||||
return
|
||||
|
||||
def set_num_classes(self, n_classes):
|
||||
if self.num_classes == n_classes:
|
||||
return
|
||||
if n_classes > 0:
|
||||
self.reset_classifier(num_classes=n_classes, global_pool=self.global_pool_str)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet18(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-18 model."""
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet34(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-34 model."""
|
||||
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet26(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-26 model."""
|
||||
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-50 model."""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet101(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-101 model."""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def wide_resnet50_2(pretrained=False, **kwargs):
|
||||
"""Construct a Wide ResNet-50-2 model.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
|
||||
return ResNet(**model_args, **kwargs)
|
||||
@@ -0,0 +1,893 @@
|
||||
# Taken from https://github.com/microsoft/Swin-Transformer with slight modifications
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Swin Transformer
|
||||
# Copyright (c) 2021 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Ze Liu
|
||||
# --------------------------------------------------------
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from loguru import logger
|
||||
from timm.models import register_model
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
from resizing_interface import ResizingInterface
|
||||
|
||||
try:
|
||||
import os
|
||||
import sys
|
||||
|
||||
kernel_path = os.path.abspath(os.path.join(".."))
|
||||
sys.path.append(kernel_path)
|
||||
from kernels.window_process.window_process import (
|
||||
WindowProcess,
|
||||
WindowProcessReverse,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
WindowProcess = None
|
||||
WindowProcessReverse = None
|
||||
logger.warning("Fused window process have not been installed. Please refer to get_started.md for installation.")
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = (
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
|
||||
|
||||
def flops(self, N):
|
||||
# calculate flops for 1 window with token length of N
|
||||
flops = 0
|
||||
# qkv = self.qkv(x)
|
||||
flops += N * self.dim * 3 * self.dim
|
||||
# attn = (q @ k.transpose(-2, -1))
|
||||
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
||||
# x = (attn @ v)
|
||||
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
||||
# x = self.proj(x)
|
||||
flops += N * self.dim * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
r"""Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resulotion.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
fused_window_process=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None),
|
||||
)
|
||||
w_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None),
|
||||
)
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
self.fused_window_process = fused_window_process
|
||||
|
||||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
if not self.fused_window_process:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
else:
|
||||
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
|
||||
else:
|
||||
shifted_x = x
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
if not self.fused_window_process:
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
|
||||
else:
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||
x = shifted_x
|
||||
x = x.view(B, H * W, C)
|
||||
x = shortcut + self.drop_path(x)
|
||||
|
||||
# FFN
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
||||
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
||||
)
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.input_resolution
|
||||
# norm1
|
||||
flops += self.dim * H * W
|
||||
# W-MSA/SW-MSA
|
||||
nW = H * W / self.window_size / self.window_size
|
||||
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
||||
# mlp
|
||||
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
||||
# norm2
|
||||
flops += self.dim * H * W
|
||||
return flops
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r"""Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = H * W * self.dim
|
||||
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
fused_window_process=False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=(drop_path[i] if isinstance(drop_path, list) else drop_path),
|
||||
norm_layer=norm_layer,
|
||||
fused_window_process=fused_window_process,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
for blk in self.blocks:
|
||||
flops += blk.flops()
|
||||
if self.downsample is not None:
|
||||
flops += self.downsample.flops()
|
||||
return flops
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
r"""Image to Patch Embedding
|
||||
|
||||
Args:
|
||||
img_size (int): Image size. Default: 224.
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [
|
||||
img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1],
|
||||
]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.patches_resolution = patches_resolution
|
||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
Ho, Wo = self.patches_resolution
|
||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||
if self.norm is not None:
|
||||
flops += Ho * Wo * self.embed_dim
|
||||
return flops
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module, ResizingInterface):
|
||||
r"""Swin Transformer
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
|
||||
Args:
|
||||
img_size (int | tuple(int)): Input image size. Default 224
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
num_classes (int): Number of classes for classification head. Default: 1000
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
depths (tuple(int)): Depth of each Swin Transformer layer.
|
||||
num_heads (tuple(int)): Number of attention heads in different layers.
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
||||
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm,
|
||||
ape=False,
|
||||
patch_norm=True,
|
||||
use_checkpoint=False,
|
||||
fused_window_process=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.patch_norm = patch_norm
|
||||
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.img_size = img_size
|
||||
self.fused_window_process = fused_window_process
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
self.embed_layer = PatchEmbed
|
||||
self.patch_size = patch_size
|
||||
self.in_chans = in_chans
|
||||
self.norm_layer = norm_layer
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
dim=int(embed_dim * 2**i_layer),
|
||||
input_resolution=(
|
||||
patches_resolution[0] // (2**i_layer),
|
||||
patches_resolution[1] // (2**i_layer),
|
||||
),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
fused_window_process=fused_window_process,
|
||||
)
|
||||
self.layers.append(layer)
|
||||
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def set_num_classes(self, n_classes):
|
||||
"""Reset the classification head with a new number of classes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_classes : int
|
||||
new number of classes
|
||||
"""
|
||||
if n_classes == self.num_classes:
|
||||
return
|
||||
self.head = nn.Linear(self.num_features, n_classes) if n_classes > 0 else nn.Identity()
|
||||
self.num_classes = n_classes
|
||||
|
||||
nn.init.trunc_normal_(self.head.weight, std=0.02)
|
||||
nn.init.constant_(self.head.bias, 0)
|
||||
|
||||
def set_image_res(self, res):
|
||||
if res == self.img_size:
|
||||
return
|
||||
|
||||
old_patch_embed_state = copy(self.patch_embed.state_dict())
|
||||
self.patch_embed = self.embed_layer(
|
||||
img_size=res,
|
||||
patch_size=self.patch_size,
|
||||
in_chans=self.in_chans,
|
||||
embed_dim=self.embed_dim,
|
||||
norm_layer=self.norm_layer if self.patch_norm else None,
|
||||
)
|
||||
self.patch_embed.load_state_dict(old_patch_embed_state)
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
for i_layer, layer in enumerate(self.layers):
|
||||
input_resolution = (
|
||||
patches_resolution[0] // (2**i_layer),
|
||||
patches_resolution[1] // (2**i_layer),
|
||||
)
|
||||
layer.input_resolution = input_resolution
|
||||
downsample = PatchMerging if (i_layer < self.num_layers - 1) else None
|
||||
if downsample is not None:
|
||||
layer.downsample = downsample(input_resolution, dim=layer.dim, norm_layer=self.norm_layer)
|
||||
|
||||
for block in layer.blocks:
|
||||
block.input_resolution = input_resolution
|
||||
|
||||
if min(input_resolution) <= block.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
block.shift_size = 0
|
||||
block.window_size = min(block.input_resolution)
|
||||
assert 0 <= block.shift_size < block.window_size, "shift_size must in 0-window_size"
|
||||
|
||||
if block.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = block.input_resolution
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (
|
||||
slice(0, -block.window_size),
|
||||
slice(-block.window_size, -block.shift_size),
|
||||
slice(-block.shift_size, None),
|
||||
)
|
||||
w_slices = (
|
||||
slice(0, -block.window_size),
|
||||
slice(-block.window_size, -block.shift_size),
|
||||
slice(-block.shift_size, None),
|
||||
)
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, block.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, block.window_size * block.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0)
|
||||
)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
block.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
if self.ape:
|
||||
orig_size = int((self.absolute_pos_embed.shape[-2]) ** 0.5)
|
||||
new_size = int(self.patch_embed.num_patches**0.5)
|
||||
pos_tokens = self.absolute_pos_embed[:, :]
|
||||
# make it shape rest x embed_dim x orig_size x orig_size
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
|
||||
pos_tokens = nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
# make it shape rest x new_size^2 x embed_dim
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
self.absolute_pos_embed = nn.Parameter(pos_tokens.contiguous())
|
||||
|
||||
self.img_size = res
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {"absolute_pos_embed"}
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
return {"relative_position_bias_table"}
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.ape:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
x = self.norm(x) # B L C
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
flops += self.patch_embed.flops()
|
||||
for i, layer in enumerate(self.layers):
|
||||
flops += layer.flops()
|
||||
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2**self.num_layers)
|
||||
flops += self.num_features * self.num_classes
|
||||
return flops
|
||||
|
||||
|
||||
swin_sizes = {
|
||||
"Ti": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]),
|
||||
"S": dict(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24]),
|
||||
"B": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
|
||||
"L": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48]),
|
||||
}
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_tiny_patch4_window7(pretrained=False, img_size=224, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
size = swin_sizes["Ti"]
|
||||
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_small_patch4_window7(pretrained=False, img_size=224, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
size = swin_sizes["S"]
|
||||
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window7(pretrained=False, img_size=224, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
size = swin_sizes["B"]
|
||||
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window7(pretrained=False, img_size=224, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
size = swin_sizes["L"]
|
||||
model = SwinTransformer(img_size=img_size, patch_size=4, window_size=7, **size, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_ashim(pretrained=False, img_size=112, **kwargs):
|
||||
if "pretrained_cfg" in kwargs:
|
||||
kwargs.pop("pretrained_cfg")
|
||||
size = dict(embed_dim=384, depths=[12], num_heads=[12])
|
||||
if "num_heads" in kwargs:
|
||||
kwargs["num_heads"] = [kwargs["num_heads"]]
|
||||
return SwinTransformer(img_size=img_size, in_chans=3, patch_size=2, window_size=7, **{**size, **kwargs})
|
||||
@@ -0,0 +1,225 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from timm.models.vision_transformer import (
|
||||
Attention,
|
||||
Block,
|
||||
PatchEmbed,
|
||||
VisionTransformer,
|
||||
)
|
||||
|
||||
from resizing_interface import ResizingInterface
|
||||
|
||||
|
||||
class _MatrixSaveAttn(Attention):
|
||||
attn_mat = None
|
||||
|
||||
@classmethod
|
||||
def cast(cls, attn: Attention):
|
||||
assert isinstance(attn, Attention), "Can only save attention from Timms attention class"
|
||||
attn.__class__ = cls
|
||||
assert isinstance(attn, _MatrixSaveAttn)
|
||||
return attn
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
self.attn_mat = attn.detach()
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
def _index_picker(tensor, idx=-1):
|
||||
"""Pick a specific index from a tensor.
|
||||
|
||||
tensor: B x N x D -> B x D, by picking idx from N.
|
||||
|
||||
Args:
|
||||
tensor (toch.tensor): tensor to pick from.
|
||||
idx (int, optional): index to pick. Defaults to -1.
|
||||
|
||||
Returns:
|
||||
torch.tensor: index from tensor
|
||||
|
||||
"""
|
||||
return tensor[..., idx, :] # B x N x D -> B x D
|
||||
|
||||
|
||||
class TimmViT(VisionTransformer, ResizingInterface):
|
||||
"""Wrapper for *VisionTransformer* from *timm* library (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool="token",
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
init_values=None,
|
||||
class_token=True,
|
||||
no_embed_class=True,
|
||||
pre_norm=False,
|
||||
fc_norm=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
weight_init="",
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
block_fn=Block,
|
||||
save_attention_maps=False,
|
||||
fused_attn=True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a Vision Transformer model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
img_size : int
|
||||
image dimensions -> img_size x img_size
|
||||
patch_size : int
|
||||
patch_size
|
||||
in_chans : int
|
||||
number of image channels
|
||||
num_classes : int
|
||||
number of classes for classification head
|
||||
global_pool : str,int
|
||||
type of global pooling for final sequence (default: 'token'), or index for token to be taken
|
||||
embed_dim : int
|
||||
embedding dimension
|
||||
depth : int
|
||||
number of transformer layers
|
||||
num_heads : int
|
||||
number of transformer heads
|
||||
mlp_ratio : float
|
||||
ratio of feed forward (mlp) hidden dimension to embedding dimension
|
||||
qkv_bias : bool
|
||||
enable bias for query, key, and value (qkv) embeddings
|
||||
qk_norm : bool
|
||||
normalize query and key embeddings
|
||||
init_values : float
|
||||
layer scale initial values
|
||||
class_token : bool
|
||||
use a class token [CLS]
|
||||
no_embed_class : bool
|
||||
no positional embedding for the class token
|
||||
pre_norm : bool
|
||||
use pre-norm architecture (norm before the blocks, not after)
|
||||
fc_norm : bool
|
||||
norm after pool (used when global_pool == 'avg')
|
||||
drop_rate : float
|
||||
dropout rate
|
||||
attn_drop_rate : float
|
||||
dropout rate in the attention module
|
||||
drop_path_rate : float
|
||||
drop path rate (stochastic depth)
|
||||
weight_init : str
|
||||
scheme for weight initialization
|
||||
embed_layer : nn.Module
|
||||
patch embedding layer
|
||||
norm_layer : nn.Module
|
||||
normalization layer
|
||||
act_layer : nn.Module
|
||||
activation function
|
||||
block_fn : nn.Module
|
||||
which block structure to use; for parallel attention layers, ...
|
||||
save_attention_maps : bool
|
||||
save attention maps for each block
|
||||
fused_attn : bool
|
||||
use fused attention
|
||||
kwargs : dict
|
||||
additional arguments (will be ignored)
|
||||
|
||||
"""
|
||||
init_kwargs = dict(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
num_classes=num_classes,
|
||||
global_pool=global_pool if isinstance(global_pool, str) else "avg",
|
||||
embed_dim=embed_dim,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
init_values=init_values,
|
||||
class_token=class_token,
|
||||
no_embed_class=no_embed_class,
|
||||
pre_norm=pre_norm,
|
||||
fc_norm=fc_norm,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
weight_init=weight_init,
|
||||
embed_layer=embed_layer,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
block_fn=block_fn,
|
||||
)
|
||||
|
||||
self.new_version = False # TODO: check based on the timm version
|
||||
self.global_pool = global_pool
|
||||
if isinstance(global_pool, int):
|
||||
self.attn_pool = partial(_index_picker, idx=global_pool)
|
||||
|
||||
if self.new_version:
|
||||
init_kwargs["qk_norm"] = qk_norm
|
||||
init_kwargs["proj_drop_rate"] = drop_rate
|
||||
super(TimmViT, self).__init__(**init_kwargs)
|
||||
self.embed_layer = embed_layer
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.in_chans = in_chans
|
||||
self.pre_norm = pre_norm
|
||||
self.class_token = class_token
|
||||
self.no_embed_class = no_embed_class
|
||||
self.num_classes = num_classes
|
||||
self.num_heads = num_heads
|
||||
self.depth = depth
|
||||
self.save_attention_maps = save_attention_maps
|
||||
if save_attention_maps:
|
||||
self.do_save_attention_maps()
|
||||
try:
|
||||
for block in self.blocks:
|
||||
block.attn.fused_attn = fused_attn and self.blocks[0].attn.fused_attn
|
||||
use_fused = self.blocks[0].attn.fused_attn
|
||||
logger.info(f"Use fused attention: {use_fused}")
|
||||
except: # I'm lazy for now # noqa: E722
|
||||
pass
|
||||
|
||||
def do_save_attention_maps(self):
|
||||
self.save_attention_maps = True
|
||||
for block in self.blocks:
|
||||
block.attn = _MatrixSaveAttn.cast(block.attn)
|
||||
|
||||
def attention_maps(self):
|
||||
assert self.save_attention_maps, "Have to save attention maps first"
|
||||
return [getattr(block.attn, "attn_mat", None) for block in self.blocks]
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
if self.new_version:
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
x = self.blocks(x)
|
||||
return self.norm(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
return self.forward_head(x)
|
||||
118
AAAI Supplementary Material/Model Training Code/config.py
Normal file
118
AAAI Supplementary Material/Model Training Code/config.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Default configuration parameters.
|
||||
|
||||
Attributes:
|
||||
default_kwargs (dict): Default hyperparameters for the training process.
|
||||
slurm_defaults (dict): Default values for SLURM batch job settings.
|
||||
|
||||
"""
|
||||
|
||||
from paths_config import user
|
||||
|
||||
default_kwargs = {
|
||||
"amp": True,
|
||||
"aug_color_jitter_factor": 0.3,
|
||||
"aug_crop": True,
|
||||
"aug_cutmix_alpha": 1.0,
|
||||
"aug_flip": True,
|
||||
"aug_gauss_blur": True,
|
||||
"aug_grayscale": True,
|
||||
"aug_mixup_alpha": 0.0,
|
||||
"aug_normalize": True,
|
||||
"aug_rand_rot": 0,
|
||||
"aug_random_erase_count": 1,
|
||||
"aug_random_erase_mode": "pixel",
|
||||
"aug_random_erase_prob": 0.0,
|
||||
"aug_repeated_augment_repeats": 1,
|
||||
"aug_resize": True,
|
||||
"aug_solarize": True,
|
||||
"augment_engine": "torchvision",
|
||||
"augment_strategy": "3-augment",
|
||||
"auto_augment_strategy": "rand-m9-mstd0.5-inc1",
|
||||
"batch_size": 2048,
|
||||
"compile_model": False,
|
||||
"cuda": True,
|
||||
"custom_dataset_path": None,
|
||||
"debug": False,
|
||||
"drop_path_rate": 0.05,
|
||||
"dropout": 0.0,
|
||||
"eval_amp": True,
|
||||
"experiment_name": "none",
|
||||
"fused_attn": True,
|
||||
"gather_stats_during_training": True,
|
||||
"imsize": 224,
|
||||
"input_dim": None,
|
||||
"keep_interm_states": 2,
|
||||
"label_smoothing": 0.1,
|
||||
"layer_scale": True,
|
||||
"layer_scale_init_values": 1e-4,
|
||||
"log_level": "info",
|
||||
"loss": "ce",
|
||||
"loss_weight": "none",
|
||||
"lr": 3e-3,
|
||||
"max_grad_norm": 1.0,
|
||||
"max_seq_len": None,
|
||||
"min_lr": 1e-5,
|
||||
"momentum": 0.0,
|
||||
"num_heads": None,
|
||||
"num_workers": 44,
|
||||
"opt": "fusedlamb",
|
||||
"opt_eps": 1e-7,
|
||||
"pin_memory": False,
|
||||
"pre_norm": False,
|
||||
"prefetch_factor": 2,
|
||||
"qkv_bias": True,
|
||||
"run_name": None,
|
||||
"save_epochs": 10,
|
||||
"sched": "cosine",
|
||||
"seed": None,
|
||||
"shuffle": True,
|
||||
"tqdm": True,
|
||||
"wandb": True,
|
||||
"warmup_epochs": 5,
|
||||
"warmup_lr": 1e-6,
|
||||
"warmup_sched": "linear",
|
||||
"weight_decay": 0.02,
|
||||
"weighted_sampler": False,
|
||||
}
|
||||
# , 'model_ema': True, 'model_ema_decay': 0.99996}
|
||||
|
||||
|
||||
deit_kwargs = {
|
||||
"aug_mixup_alpha": 0.8,
|
||||
"aug_repeated_augment_repeats": 3,
|
||||
"augment_strategy": "deit",
|
||||
"aug_random_erase_prob": 0.25,
|
||||
"batch_size": 1024,
|
||||
"lr": 1e-3,
|
||||
"max_grad_norm": 0.0,
|
||||
"num_workers": 10,
|
||||
"opt": "adamw",
|
||||
"opt_eps": 1e-8,
|
||||
"weight_decay": 0.05,
|
||||
}
|
||||
|
||||
|
||||
def get_default_kwargs(settings="deitiii"):
|
||||
if settings.lower() == "deitiii":
|
||||
return default_kwargs
|
||||
if settings.lower() == "deit":
|
||||
return {**default_kwargs, **deit_kwargs}
|
||||
raise NotImplementedError(f"No such defaults setting: {settings}")
|
||||
|
||||
|
||||
slurm_defaults = {
|
||||
"after_job": None,
|
||||
"container_image": f"PATH/TO/ENROOT/IMAGE",
|
||||
"container_mounts": f'MOUNT_ALL_IMPORTANT_STORAGE_SERVERS_HERE,"`pwd`":"`pwd`"',
|
||||
"container_workdir": '"`pwd`"',
|
||||
"cpus_per_task": 24,
|
||||
"exclude": None,
|
||||
"export": "ALL,TQDM_DISABLE=1",
|
||||
"job_name": None,
|
||||
"mem_per_gpu": 90,
|
||||
"nodes": 1,
|
||||
"ntasks": 4,
|
||||
"partition": ["A100", "H100", "H200"],
|
||||
"task_prolog": None,
|
||||
"time": "1-0",
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import numpy as np
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from datadings.torch import CompressedToPIL
|
||||
|
||||
|
||||
class AlbumTorchCompose(A.Compose):
|
||||
"""Compose albumentation augmentations in a way that works with PIL images and datadings."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Pass to A.Compose."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.to_pil = CompressedToPIL()
|
||||
|
||||
def __call__(self, image, mask=None, **kwargs):
|
||||
if isinstance(image, bytes):
|
||||
image = self.to_pil(image)
|
||||
if mask is not None and len(mask) == 0:
|
||||
mask = None
|
||||
if not isinstance(image, np.ndarray):
|
||||
image = np.array(image)
|
||||
if mask is not None and not isinstance(mask, np.ndarray):
|
||||
mask = np.array(mask)
|
||||
if mask is None:
|
||||
return super().__call__(image=image, **kwargs)["image"]
|
||||
return super().__call__(image=image, mask=mask, **kwargs)
|
||||
|
||||
|
||||
class PILToNP(A.DualTransform):
|
||||
"""Convert PIL image to numpy array."""
|
||||
|
||||
def apply(self, image, **params):
|
||||
return np.array(image)
|
||||
|
||||
def apply_to_mask(self, image, **params):
|
||||
return np.array(image)
|
||||
|
||||
def get_transform_init_args_names(self):
|
||||
return ()
|
||||
|
||||
|
||||
class AlbumCompressedToPIL(A.DualTransform):
|
||||
"""Convert compressed image to PIL image."""
|
||||
|
||||
def apply(self, img, **params):
|
||||
return self.to_pil(img)
|
||||
|
||||
def apply_to_mask(self, img, **params):
|
||||
return self.to_pil(img)
|
||||
|
||||
def get_transform_init_args_names(self):
|
||||
return ()
|
||||
|
||||
|
||||
def minimal_augment(args, test=False):
|
||||
"""Get minimal augmentations for training or testing.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): arguments
|
||||
test (bool, optional): if True, return test augmentations. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List: Augmentation list
|
||||
"""
|
||||
augs = []
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(A.CenterCrop(args.imsize, args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(A.RandomCrop(args.imsize, args.imsize))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(A.HorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
|
||||
|
||||
augs.append(ToTensorV2())
|
||||
return augs
|
||||
|
||||
|
||||
def three_augment(args, as_list=False, test=False):
|
||||
"""Create the data augmentation.
|
||||
|
||||
Args:
|
||||
args (Namespace): arguments
|
||||
as_list (bool): return list of transformations, not composed transformation
|
||||
test (bool): In eval mode? If False => train mode
|
||||
|
||||
Returns:
|
||||
torch.nn.Module | list[torch.nn.Module]: composed transformation or list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(A.SmallestMaxSize(args.imsize, interpolation=cv2.INTER_CUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(A.CenterCrop(args.imsize, args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(A.RandomCrop(args.imsize, args.imsize, pad_if_needed=True, border_mode=cv2.BORDER_REFLECT))
|
||||
|
||||
if not test:
|
||||
if args.aug_flip:
|
||||
augs.append(A.HorizontalFlip(p=0.5))
|
||||
|
||||
augs_choice = []
|
||||
if args.aug_grayscale:
|
||||
augs_choice.append(A.ToGray(p=1, num_output_channels=3))
|
||||
|
||||
if args.aug_solarize:
|
||||
augs_choice.append(A.Solarize(p=1, threshold_range=(0.5, 0.5)))
|
||||
|
||||
if args.aug_gauss_blur:
|
||||
augs_choice.append(A.GaussianBlur(p=1, sigma_limit=(0.2, 2.0), blur_limit=(7, 7)))
|
||||
|
||||
if len(augs_choice) > 0:
|
||||
augs.append(A.OneOf(augs_choice))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
A.ColorJitter(
|
||||
brightness=args.aug_color_jitter_factor,
|
||||
contrast=args.aug_color_jitter_factor,
|
||||
saturation=args.aug_color_jitter_factor,
|
||||
hue=0.0,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
|
||||
|
||||
augs.append(ToTensorV2())
|
||||
|
||||
if as_list:
|
||||
return augs
|
||||
return AlbumTorchCompose(augs)
|
||||
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class CounterAnimal(Dataset):
|
||||
"""Dataset to load the CounterAnimal dataset with ImageNet labels."""
|
||||
|
||||
def __init__(self, base_path, mode, transform=None, target_transform=None, train=False):
|
||||
"""Create the dataset.
|
||||
|
||||
Args:
|
||||
base_path (str): path to the base folder (the one where the class folders are in)
|
||||
mode (str): mode/variant of the dataset (common/counter)
|
||||
transform: Image augmentation
|
||||
target_transform: label augmentation
|
||||
train: train or test set. Train set is not supported
|
||||
"""
|
||||
super().__init__()
|
||||
self.base = base_path
|
||||
assert mode in ["counter", "common"], f"Supported modes are counter and common, but got '{mode}'"
|
||||
assert not train, "CounterAnimal only consists of test data, not training data."
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
self.index = []
|
||||
for class_folder in os.listdir(self.base):
|
||||
if not os.path.isdir(os.path.join(self.base, class_folder)):
|
||||
continue
|
||||
# print(f"looking in folder {class_folder}")
|
||||
class_idx = int(class_folder.split(" ")[0])
|
||||
for variant_folder in os.listdir(os.path.join(self.base, class_folder)):
|
||||
# print(f"\tlooking in variant {variant_folder}")
|
||||
if not variant_folder.startswith(mode):
|
||||
# print("\t\tskip")
|
||||
continue
|
||||
|
||||
_folder = os.path.join(self.base, class_folder, variant_folder)
|
||||
# print(f"\t\tadding {len(os.listdir(_folder))} files to index")
|
||||
for file in os.listdir(_folder):
|
||||
if file.lower().split(".")[-1] in ["jpg", "jpeg", "png"]:
|
||||
self.index.append((os.path.join(_folder, file), class_idx))
|
||||
|
||||
# print(f"loaded {len(self.index)} images into the index: {self.index[:5]}...")
|
||||
assert len(self.index) > 0, "did not find any images :("
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path, label = self.index[idx]
|
||||
|
||||
img = Image.open(path).convert("RGB")
|
||||
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform:
|
||||
label = self.target_transform(label)
|
||||
|
||||
return img, label
|
||||
@@ -0,0 +1,145 @@
|
||||
from nvidia.dali import fn, pipeline_def, types
|
||||
|
||||
# see https://docs.nvidia.com/deeplearning/dali/user-guide/docs/plugins/pytorch_dali_proxy.html
|
||||
|
||||
|
||||
@pipeline_def
|
||||
def minimal_augment(args, test=False):
|
||||
"""Minimal Augmentation set for images.
|
||||
|
||||
Contains only resize, crop, flip, to tensor and normalize.
|
||||
|
||||
Args:
|
||||
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
|
||||
test (bool, optional): On the test set? Defaults to False.
|
||||
|
||||
Returns:
|
||||
images: augmented images.
|
||||
|
||||
"""
|
||||
images = fn.external_source(name="images", no_copy=True)
|
||||
|
||||
if args.aug_resize:
|
||||
images = fn.resize(images, size=args.imsize, mode="not_smaller")
|
||||
|
||||
if test and args.aug_crop:
|
||||
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
|
||||
elif args.aug_crop:
|
||||
images = fn.crop(
|
||||
images,
|
||||
crop=(args.imsize, args.imsize),
|
||||
crop_pos_x=fn.random.uniform(range=(0, 1)),
|
||||
crop_pos_y=fn.random.uniform(range=(0, 1)),
|
||||
)
|
||||
|
||||
# if not test and args.aug_flip:
|
||||
# images = fn.flip(images, horizontal=fn.random.coin_flip())
|
||||
|
||||
# if args.aug_normalize:
|
||||
# images = fn.normalize(
|
||||
# images,
|
||||
# mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
||||
# std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
|
||||
# dtype=types.FLOAT,
|
||||
# )
|
||||
return fn.crop_mirror_normalize(
|
||||
images,
|
||||
dtype=types.FLOAT,
|
||||
output_layout="CHW",
|
||||
crop=(args.imsize, args.imsize),
|
||||
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
|
||||
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
|
||||
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
|
||||
)
|
||||
|
||||
|
||||
def dali_solarize(images, threshold=128):
|
||||
"""Solarize implementation for nvidia DALI.
|
||||
|
||||
Args:
|
||||
images (DALI Tensor): Images to solarize.
|
||||
threshold (int, optional): Threshold for solarization. Defaults to 128.
|
||||
|
||||
Returns:
|
||||
images: solarized images.
|
||||
|
||||
"""
|
||||
inv_images = types.Constant(255).uint8() - images
|
||||
mask = (images >= threshold) * types.Constant(1).uint8()
|
||||
return mask * inv_images + (types.Constant(1).uint8() ^ mask) * images
|
||||
|
||||
|
||||
@pipeline_def(enable_conditionals=True)
|
||||
def three_augment(args, test=False):
|
||||
"""3-augment data augmentation pipeline for nvidia DALI.
|
||||
|
||||
Args:
|
||||
args (namespace): augmentation arguments.
|
||||
test (bool, optional): Test (or train) split. Defaults to False.
|
||||
|
||||
Returns:
|
||||
images: augmented images.
|
||||
|
||||
"""
|
||||
images = fn.external_source(name="images", no_copy=True)
|
||||
|
||||
if args.aug_resize:
|
||||
images = fn.resize(images, size=args.imsize, mode="not_smaller")
|
||||
|
||||
if test and args.aug_crop:
|
||||
images = fn.crop(images, crop=(args.imsize, args.imsize), crop_pos_x=0.5, crop_pos_y=0.5)
|
||||
elif args.aug_crop:
|
||||
images = fn.crop(
|
||||
images,
|
||||
crop=(args.imsize, args.imsize),
|
||||
crop_pos_x=fn.random.uniform(range=(0, 1)),
|
||||
crop_pos_y=fn.random.uniform(range=(0, 1)),
|
||||
)
|
||||
|
||||
if not test:
|
||||
choices = []
|
||||
# choice = fn.random.choice(3)
|
||||
# print(images.layout())
|
||||
choice_ps = [1 * args.aug_grayscale, 1 * args.aug_solarize, 1 * args.aug_gauss_blur]
|
||||
choice_ps = [c / sum(choice_ps) for c in choice_ps]
|
||||
choice = fn.random.choice(
|
||||
[0, 1, 2],
|
||||
p=choice_ps,
|
||||
)
|
||||
|
||||
if choice == 0:
|
||||
images = fn.color_space_conversion(
|
||||
fn.color_space_conversion(images, image_type=types.RGB, output_type=types.GRAY),
|
||||
image_type=types.GRAY,
|
||||
output_type=types.RGB,
|
||||
)
|
||||
|
||||
elif choice == 1:
|
||||
images = dali_solarize(images, threshold=128)
|
||||
elif choice == 2:
|
||||
images = fn.gaussian_blur(images, window_size=7, sigma=fn.random.uniform(range=(0.2, 2.0)))
|
||||
|
||||
if len(choices) > 0:
|
||||
images = fn.random.choice(choices)
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
images = fn.color_twist(
|
||||
images,
|
||||
brightness=fn.random.uniform(
|
||||
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
|
||||
),
|
||||
contrast=fn.random.uniform(range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)),
|
||||
saturation=fn.random.uniform(
|
||||
range=(1 - args.aug_color_jitter_factor, 1 + args.aug_color_jitter_factor)
|
||||
),
|
||||
)
|
||||
|
||||
return fn.crop_mirror_normalize(
|
||||
images,
|
||||
dtype=types.FLOAT,
|
||||
output_layout="CHW",
|
||||
crop=(args.imsize, args.imsize),
|
||||
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255] if args.aug_normalize else [0, 0, 0],
|
||||
std=[0.229 * 255, 0.224 * 255, 0.225 * 255] if args.aug_normalize else [1, 1, 1],
|
||||
mirror=fn.random.coin_flip(probability=0.5) if args.aug_flip and not test else False,
|
||||
)
|
||||
@@ -0,0 +1,407 @@
|
||||
from random import uniform
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import torchvision
|
||||
from datadings.torch import CompressedToPIL
|
||||
from datadings.torch import Dataset as DDDataset
|
||||
from PIL import ImageFilter
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
ColorJitter,
|
||||
Compose,
|
||||
GaussianBlur,
|
||||
Grayscale,
|
||||
InterpolationMode,
|
||||
Normalize,
|
||||
RandomChoice,
|
||||
RandomCrop,
|
||||
RandomHorizontalFlip,
|
||||
RandomResizedCrop,
|
||||
RandomSolarize,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
_image_and_target_transforms = [
|
||||
torchvision.transforms.RandomCrop,
|
||||
torchvision.transforms.RandomHorizontalFlip,
|
||||
torchvision.transforms.CenterCrop,
|
||||
torchvision.transforms.RandomRotation,
|
||||
torchvision.transforms.RandomAffine,
|
||||
torchvision.transforms.RandomResizedCrop,
|
||||
torchvision.transforms.RandomRotation,
|
||||
]
|
||||
|
||||
|
||||
def apply_dense_transforms(x, y, transforms: torchvision.transforms.transforms.Compose):
|
||||
"""Apply some transfomations to both image and target.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): image
|
||||
y (torch.Tensor): target (image)
|
||||
transforms (torchvision.transforms.transforms.Compose): transformations to apply
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: (x, y) with applyed transformations
|
||||
|
||||
"""
|
||||
for trans in transforms.transforms:
|
||||
if isinstance(trans, torchvision.transforms.RandomResizedCrop):
|
||||
params = trans.get_params(x, trans.scale, trans.ratio)
|
||||
x = F.resized_crop(x, *params, trans.size, trans.interpolation, antialias=trans.antialias)
|
||||
y = F.resized_crop(y.unsqueeze(0), *params, trans.size, 0).squeeze(0) # nearest neighbor interpolation
|
||||
elif isinstance(trans, Resize):
|
||||
pre_shape = x.shape
|
||||
x = trans(x)
|
||||
if x.shape != pre_shape:
|
||||
y = F.resize(y.unsqueeze(0), trans.size, 0, trans.max_size, trans.antialias).squeeze(
|
||||
0
|
||||
) # nearest neighbor interpolation
|
||||
elif any(isinstance(trans, simul_transform) for simul_transform in _image_and_target_transforms):
|
||||
xy = torch.cat([x, y.unsqueeze(0).float()], dim=0)
|
||||
xy = trans(xy)
|
||||
x, y = xy[:-1], xy[-1].long()
|
||||
elif isinstance(trans, torchvision.transforms.ToTensor):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = trans(x)
|
||||
else:
|
||||
x = trans(x)
|
||||
|
||||
return x, y
|
||||
|
||||
|
||||
def get_hf_transform(transform_f, trgt_transform_f=None, image_key="image"):
|
||||
"""Convert the transform function to a huggingface compatible transform function.
|
||||
|
||||
Args:
|
||||
transform_f (callable): Image transform.
|
||||
trgt_transform (callable, optional): Target transform. Defaults to None.
|
||||
image_key (str, optional): Key for the image in the hf ds return dict. Defaults to "image".
|
||||
"""
|
||||
|
||||
def _transform(samples):
|
||||
try:
|
||||
samples[image_key] = [transform_f(im) for im in samples[image_key]]
|
||||
if trgt_transform_f is not None:
|
||||
samples["label"] = [trgt_transform_f(tgt) for tgt in samples["label"]]
|
||||
except TypeError as e:
|
||||
print(f"Type error when transforming samples: {samples}")
|
||||
raise e
|
||||
return samples
|
||||
|
||||
return _transform
|
||||
|
||||
|
||||
class DDDecodeDataset(DDDataset):
|
||||
"""Datadings dataset with image decoding before transform."""
|
||||
|
||||
def __init__(self, *args, transform=None, target_transform=None, transforms=None, **kwargs):
|
||||
"""Create datadings dataset.
|
||||
|
||||
Args:
|
||||
transform (callable, optional): Image transform. Overrides transforms['image']. Defaults to None.
|
||||
target_transform (callable, optional): Label transform. Overrides transforms['label']. Defaults to None.
|
||||
transforms (dict[str, callable], optional): Dict of transforms for each key. Defaults to None.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
if transforms is None:
|
||||
transforms = {}
|
||||
self._decode_transform = transform if transform is not None else transforms.get("image", None)
|
||||
self._decode_target_transform = (
|
||||
target_transform if target_transform is not None else transforms.get("label", None)
|
||||
)
|
||||
|
||||
self.ctp = CompressedToPIL()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = super().__getitem__(idx)
|
||||
img, lbl = sample["image"], sample["label"]
|
||||
if isinstance(img, bytes):
|
||||
img = self.ctp(img)
|
||||
if self._decode_transform is not None:
|
||||
img = self._decode_transform(img)
|
||||
if self._decode_target_transform is not None:
|
||||
lbl = self._decode_target_transform(lbl)
|
||||
return img, lbl
|
||||
|
||||
|
||||
def minimal_augment(args, test=False):
|
||||
"""Minimal Augmentation set for images.
|
||||
|
||||
Contains only resize, crop, flip, to tensor and normalize.
|
||||
|
||||
Args:
|
||||
args (DotDict): Arguments: aug_resize, aug_crop, aug_flip, aug_normalize to turn on/off the respective augmentation.
|
||||
test (bool, optional): On the test set? Defaults to False.
|
||||
|
||||
Returns:
|
||||
List: Augmentation list
|
||||
|
||||
"""
|
||||
augs = []
|
||||
augs.append(ToTensor())
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
def three_augment(args, as_list=False, test=False):
|
||||
"""Create the data augmentation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Args:
|
||||
arguments
|
||||
as_list : bool
|
||||
return list of transformations, not composed transformation
|
||||
test : bool
|
||||
In eval mode? If False => train mode
|
||||
|
||||
Returns:
|
||||
-------
|
||||
torch.nn.Module | list[torch.nn.Module]
|
||||
composed transformation of list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
augs.append(ToTensor())
|
||||
|
||||
if args.aug_resize:
|
||||
augs.append(Resize(args.imsize, interpolation=InterpolationMode.BICUBIC))
|
||||
|
||||
if test and args.aug_crop:
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
elif args.aug_crop:
|
||||
augs.append(RandomCrop(args.imsize, padding=4, padding_mode="reflect"))
|
||||
|
||||
if not test:
|
||||
if args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
augs_choice = []
|
||||
if args.aug_grayscale:
|
||||
augs_choice.append(Grayscale(num_output_channels=3))
|
||||
if args.aug_solarize:
|
||||
augs_choice.append(RandomSolarize(threshold=0.5, p=1.0))
|
||||
if args.aug_gauss_blur:
|
||||
# TODO: check kernel size?
|
||||
augs_choice.append(GaussianBlur(kernel_size=7, sigma=(0.2, 2.0)))
|
||||
# augs_choice.append(QuickGaussBlur())
|
||||
|
||||
if len(augs_choice) > 0:
|
||||
augs.append(RandomChoice(augs_choice))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
ColorJitter(
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
if as_list:
|
||||
return augs
|
||||
return Compose(augs)
|
||||
|
||||
|
||||
def segment_augment(args, test=False):
|
||||
"""Create the data augmentation for segmentation.
|
||||
|
||||
No cropping in this part, as cropping has to be done for the image and labels simultaneously.
|
||||
|
||||
Args:
|
||||
args (DotDict): arguments
|
||||
test (bool, optional): In eval mode? If False => train mode. Defaults to False.
|
||||
|
||||
Returns:
|
||||
list[torch.nn.Module]: list of transformations
|
||||
|
||||
"""
|
||||
augs = []
|
||||
|
||||
if test:
|
||||
augs.append(ResizeUp(args.imsize))
|
||||
augs.append(CenterCrop(args.imsize))
|
||||
else:
|
||||
augs.append(RandomResizedCrop(args.imsize, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)))
|
||||
|
||||
if not test and args.aug_flip:
|
||||
augs.append(RandomHorizontalFlip(p=0.5))
|
||||
|
||||
if args.aug_color_jitter_factor > 0.0:
|
||||
augs.append(
|
||||
ColorJitter(
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
args.aug_color_jitter_factor,
|
||||
)
|
||||
)
|
||||
|
||||
if args.aug_normalize:
|
||||
augs.append(
|
||||
Normalize(
|
||||
mean=torch.tensor([0.485, 0.456, 0.406]),
|
||||
std=torch.tensor([0.229, 0.224, 0.225]),
|
||||
)
|
||||
)
|
||||
|
||||
return augs
|
||||
|
||||
|
||||
class QuickGaussBlur:
|
||||
"""Gaussian blur transformation using PIL ImageFilter."""
|
||||
|
||||
def __init__(self, sigma=(0.2, 2.0)):
|
||||
"""Create Gaussian blur operator.
|
||||
|
||||
Args:
|
||||
-----
|
||||
sigma : tuple[float, float]
|
||||
range of sigma for blur
|
||||
|
||||
"""
|
||||
self.sigma_min, self.sigma_max = sigma
|
||||
|
||||
def __call__(self, img):
|
||||
return img.filter(ImageFilter.GaussianBlur(radius=uniform(self.sigma_min, self.sigma_max)))
|
||||
|
||||
|
||||
class RemoveTransform:
|
||||
"""Remove data from transformation.
|
||||
|
||||
To use with default collate function.
|
||||
"""
|
||||
|
||||
def __call__(self, x, y=None):
|
||||
if y is None:
|
||||
return [1]
|
||||
return [1], y
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
def collate_imnet(data, image_key="image"):
|
||||
"""Collate function for imagenet(1k / 21k) with datadings.
|
||||
|
||||
Args:
|
||||
----
|
||||
data : list[dict[str, Any]]
|
||||
images for a batch
|
||||
|
||||
Returns:
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
images, labels
|
||||
|
||||
"""
|
||||
if isinstance(data[0][image_key], torch.Tensor):
|
||||
ims = torch.stack([d[image_key] for d in data], dim=0)
|
||||
else:
|
||||
ims = [d[image_key] for d in data]
|
||||
labels = torch.tensor([d["label"] for d in data])
|
||||
# keys = [d['key'] for d in data]
|
||||
return ims, labels # , keys
|
||||
|
||||
|
||||
def collate_listops(data):
|
||||
"""Collate function for ListOps with datadings.
|
||||
|
||||
Args:
|
||||
----
|
||||
data : list[tuple[torch.Tensor, torch.Tensor]]
|
||||
list of samples
|
||||
|
||||
Returns:
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
images, labels
|
||||
|
||||
"""
|
||||
return data[0][0], data[0][1]
|
||||
|
||||
|
||||
def no_param_transf(self, sample):
|
||||
"""Call transformation without extra parameter.
|
||||
|
||||
To use with datadings QuasiShuffler.
|
||||
|
||||
Args:
|
||||
----
|
||||
self : object
|
||||
use this as a method ( <obj>.<method_name> = MethodType(no_param_transf, <obj>) )
|
||||
sample : Any
|
||||
sample to transform
|
||||
|
||||
Returns:
|
||||
-------
|
||||
Any
|
||||
transformed sample
|
||||
|
||||
"""
|
||||
if isinstance(sample, tuple):
|
||||
# sample of type (name (str), data (bytes encoded))
|
||||
sample = sample[1]
|
||||
if isinstance(sample, bytes):
|
||||
# decode msgpack bytes
|
||||
sample = msgpack.loads(sample)
|
||||
params = self._rng(sample)
|
||||
for k, f in self._transforms.items():
|
||||
sample[k] = f(sample[k], params)
|
||||
return sample
|
||||
|
||||
|
||||
class ToOneHotSequence:
|
||||
"""Convert a sequence of grayscale values (range 0 to 1) to a one-hot encoded sequence based on 8-bit rounded values."""
|
||||
|
||||
def __call__(self, x, y=None):
|
||||
# x is 1 x 32 x 32
|
||||
x = (255 * x).round().to(torch.int64).view(-1)
|
||||
assert x.max() < 256, f"Found max value {x.max()} in {x}."
|
||||
x = torch.nn.functional.one_hot(x, num_classes=256).float()
|
||||
if y is None:
|
||||
return x
|
||||
return x, y
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
class ResizeUp(Resize):
|
||||
"""Resize up if image is smaller than target size."""
|
||||
|
||||
def forward(self, img):
|
||||
w, h = img.shape[-2], img.shape[-1]
|
||||
if w < self.size or h < self.size:
|
||||
img = super().forward(img)
|
||||
return img
|
||||
484
AAAI Supplementary Material/Model Training Code/data/fornet.py
Normal file
484
AAAI Supplementary Material/Model Training Code/data/fornet.py
Normal file
@@ -0,0 +1,484 @@
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from math import floor
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from datadings.torch import Compose
|
||||
from loguru import logger
|
||||
from PIL import Image, ImageFilter
|
||||
from torch.utils.data import Dataset, get_worker_info
|
||||
from torchvision import transforms as T
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from data.data_utils import apply_dense_transforms
|
||||
|
||||
|
||||
class ForNet(Dataset):
|
||||
"""Recombine ImageNet forgrounds and backgrounds.
|
||||
|
||||
Note:
|
||||
This dataset has exactly the ImageNet classes.
|
||||
|
||||
"""
|
||||
|
||||
_back_combs = ["same", "all", "original"]
|
||||
_bg_transforms = {T.RandomCrop, T.CenterCrop, T.Resize, T.RandomResizedCrop}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
transform=None,
|
||||
train=True,
|
||||
target_transform=None,
|
||||
background_combination="all",
|
||||
fg_scale_jitter=0.3,
|
||||
fg_transform=None,
|
||||
pruning_ratio=0.8,
|
||||
return_fg_masks=False,
|
||||
fg_size_mode="range",
|
||||
fg_bates_n=1,
|
||||
paste_pre_transform=True,
|
||||
mask_smoothing_sigma=4.0,
|
||||
rel_jut_out=0.0,
|
||||
fg_in_nonant=None,
|
||||
size_fact=1.0,
|
||||
orig_img_prob=0.0,
|
||||
orig_ds=None,
|
||||
_orig_ds_file_type="JPEG",
|
||||
epochs=0,
|
||||
_album_compose=False,
|
||||
):
|
||||
"""Create RecombinationNet dataset.
|
||||
|
||||
Args:
|
||||
root (str): Root folder for the dataset.
|
||||
transform (T.Collate | list, optional): Transform to apply to the image. Defaults to None.
|
||||
train (bool, optional): On the train set (False -> val set). Defaults to True.
|
||||
target_transform (T.Collate | list, optional): Transform to apply to the target values. Defaults to None.
|
||||
background_combination (str, optional): Which backgrounds to combine with foregrounds. Defaults to "same".
|
||||
fg_scale_jitter (tuple, optional): How much should the size of the foreground be changed (random ratio). Defaults to (0.1, 0.8).
|
||||
fg_transform (_type_, optional): Transform to apply to the foreground before applying to the background. This is supposed to be a random rotation, mainly. Defaults to None.
|
||||
pruning_ratio (float, optional): For pruning backgrounds, with (foreground size/background size) >= <pruning_ratio>. Backgrounds from images that contain very large foreground objects are mostly computer generated and therefore relatively unnatural. Defaults to full dataset.
|
||||
return_fg_masks (bool, optional): Return the foreground masks. Defaults to False.
|
||||
fg_size_mode (str, optional): How to determine the size of the foreground, based on the foreground sizes of the foreground and background images. Defaults to "max".
|
||||
fg_bates_n (int, optional): Bates parameter for the distribution of the object position in the foreground. Defaults to 1 (uniform distribution). The higher the value, the more likely the object is in the center. For fg_bates_n = 0, the object is always in the center.
|
||||
paste_pre_transform (bool, optional): Paste the foreground onto the background before applying the transform. If false, the background will be cropped and resized before pasting the foreground. Defaults to False.
|
||||
mask_smoothing_sigma (float, optional): Sigma for the Gaussian blur of the mask edge. Defaults to 0.0. Try 2.0 or 4.0?
|
||||
rel_jut_out (float, optional): How much is the foreground allowed to stand/jut out of the background (and then cut off). Defaults to 0.0.
|
||||
fg_in_nonant (int, optional): If not None, the foreground will be placed in a specific nonant (0-8) of the image. Defaults to None.
|
||||
size_fact (float, optional): Factor to multiply the size of the foreground with. Defaults to 1.0.
|
||||
orig_img_prob (float | str, optional): Probability to use the original image, instead of the fg-bg recombinations. Defaults to 0.0.
|
||||
orig_ds (Dataset, optional): Original dataset (without transforms) to use for the original images. Defaults to None.
|
||||
_orig_ds_file_type (str, optional): File type of the original dataset. Defaults to "JPEG".
|
||||
epochs (int, optional): Number of epochs to train on. Used for linear increase of orig_img_prob.
|
||||
|
||||
Note:
|
||||
For more information on the bates distribution, see https://en.wikipedia.org/wiki/Bates_distribution.
|
||||
For fg_bats_n < 0, we take extend the bates dirstribution to focus more and more on the edges. This is done by sampling B ~ Bates(|fg_bates_n|) and then passing through f(x) = x + 0.5 - floor(x + 0.5).
|
||||
|
||||
For the list of transformations that will be applied to the background only (if paste_pre_transform=False), see RecombinationNet._bg_transforms.
|
||||
|
||||
A nonant in this case refers to a square in a 3x3 grid dividing the image.
|
||||
|
||||
"""
|
||||
assert (
|
||||
background_combination in self._back_combs
|
||||
), f"background_combination={background_combination} is not supported. Use one of {self._back_combs}"
|
||||
|
||||
if (not os.path.exists(f"{root}/backgrounds_{'train' if train else 'val'}.zip")) and os.path.exists(
|
||||
os.path.join(root, "train" if train else "val", "backgrounds")
|
||||
):
|
||||
self._mode = "folder"
|
||||
else:
|
||||
self._mode = "zip"
|
||||
|
||||
if self._mode == "zip":
|
||||
try:
|
||||
with zipfile.ZipFile(f"{root}/backgrounds_{'train' if train else 'val'}.zip", "r") as bg_zip:
|
||||
self.backgrounds = [f for f in bg_zip.namelist() if f.endswith(".JPEG")]
|
||||
with zipfile.ZipFile(f"{root}/foregrounds_{'train' if train else 'val'}.zip", "r") as fg_zip:
|
||||
self.foregrounds = [f for f in fg_zip.namelist() if f.endswith(".WEBP")]
|
||||
except FileNotFoundError as e:
|
||||
logger.error(
|
||||
f"RecombinationNet: {e}. Make sure to have the background and foreground zips in the root"
|
||||
f" directory: found {os.listdir(root)}"
|
||||
)
|
||||
raise e
|
||||
classes = set([f.split("/")[-2] for f in self.foregrounds])
|
||||
else:
|
||||
logger.info("ForNet folder mode: loading classes")
|
||||
classes = set(os.listdir(os.path.join(root, "train" if train else "val", "foregrounds")))
|
||||
foregrounds = []
|
||||
backgrounds = []
|
||||
for cls in tqdm(classes, desc="Loading files"):
|
||||
foregrounds.extend(
|
||||
[
|
||||
f"{cls}/{f}"
|
||||
for f in os.listdir(os.path.join(root, "train" if train else "val", "foregrounds", cls))
|
||||
]
|
||||
)
|
||||
backgrounds.extend(
|
||||
[
|
||||
f"{cls}/{f}"
|
||||
for f in os.listdir(os.path.join(root, "train" if train else "val", "backgrounds", cls))
|
||||
]
|
||||
)
|
||||
self.foregrounds = foregrounds
|
||||
self.backgrounds = backgrounds
|
||||
|
||||
self.classes = sorted(list(classes), key=lambda x: int(x[1:]))
|
||||
|
||||
assert os.path.exists(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json"), (
|
||||
f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json not found, provide the information or set"
|
||||
" pruning_ratio=1.0"
|
||||
)
|
||||
with open(f"{root}/fg_bg_ratios_{'train' if train else 'val'}.json", "r") as f:
|
||||
self.fg_bg_ratios = json.load(f)
|
||||
if self._mode == "folder":
|
||||
self.fg_bg_ratios = {"/".join(key.split("/")[-2:]): val for key, val in self.fg_bg_ratios.items()}
|
||||
logger.info(f"Renamed fg_bg_ratios keys to {list(self.fg_bg_ratios.keys())[:3]}...")
|
||||
|
||||
if pruning_ratio <= 1.0:
|
||||
backup_backgrounds = {}
|
||||
for bg_file in self.backgrounds:
|
||||
bg_cls = bg_file.split("/")[-2]
|
||||
backup_backgrounds[bg_cls] = bg_file
|
||||
self.backgrounds = [
|
||||
bg for bg in self.backgrounds if bg in self.fg_bg_ratios and self.fg_bg_ratios[bg] < pruning_ratio
|
||||
]
|
||||
# logger.info(
|
||||
# f"RecombinationNet: keep {len(self.backgrounds)} of {len(self.fg_bg_ratios)} backgrounds with pr {pruning_ratio}"
|
||||
# )
|
||||
|
||||
self.root = root
|
||||
self.train = train
|
||||
self.background_combination = background_combination
|
||||
self.fg_scale_jitter = fg_scale_jitter
|
||||
self.fg_transform = fg_transform
|
||||
self.return_fg_masks = return_fg_masks
|
||||
self.paste_pre_transform = paste_pre_transform
|
||||
self.mask_smoothing_sigma = mask_smoothing_sigma
|
||||
self.rel_jut_out = rel_jut_out
|
||||
self.size_fact = size_fact
|
||||
self.fg_in_nonant = fg_in_nonant
|
||||
assert fg_in_nonant is None or -1 <= fg_in_nonant < 9, f"fg_in_nonant={fg_in_nonant} not in [0, 8] or None"
|
||||
|
||||
self.orig_img_prob = orig_img_prob
|
||||
if orig_img_prob != 0.0:
|
||||
assert (isinstance(orig_img_prob, float) and orig_img_prob > 0.0) or orig_img_prob in [
|
||||
"linear",
|
||||
"cos",
|
||||
"revlinear",
|
||||
]
|
||||
assert orig_ds is not None, "orig_ds must be provided if orig_img_prob > 0.0"
|
||||
assert not return_fg_masks, "can't provide fg masks for original images (yet)"
|
||||
assert os.path.exists(os.path.join(root, f"{'train' if train else 'val'}_indices.json")) or isinstance(
|
||||
orig_ds, str
|
||||
), f"{root}/{'train' if train else 'val'}_indices.json must be provided if orig_ds is a dataset"
|
||||
if not isinstance(orig_ds, str):
|
||||
with open(os.path.join(root, f"{'train' if train else 'val'}_indices.json"), "r") as f:
|
||||
self.key_to_orig_idx = json.load(f)
|
||||
else:
|
||||
if not (
|
||||
orig_ds.endswith("train" if train else "val") or orig_ds.endswith("train/" if train else "val/")
|
||||
):
|
||||
orig_ds = f"{orig_ds}/{'train' if train else 'val'}"
|
||||
self.key_to_orig_idx = None
|
||||
self._orig_ds_file_type = _orig_ds_file_type
|
||||
self.orig_ds = orig_ds
|
||||
self.epochs = epochs
|
||||
self._epoch = 0
|
||||
|
||||
assert fg_size_mode in [
|
||||
"max",
|
||||
"min",
|
||||
"mean",
|
||||
"range",
|
||||
], f"fg_size_mode={fg_size_mode} not supported; use one of ['max', 'min', 'mean', 'range']"
|
||||
self.fg_size_mode = fg_size_mode
|
||||
self.fg_bates_n = fg_bates_n
|
||||
|
||||
if not paste_pre_transform:
|
||||
if isinstance(transform, (T.Compose, Compose)):
|
||||
transform = transform.transforms
|
||||
|
||||
# do cropping and resizing mainly on background; paste foreground on top later
|
||||
self.bg_transform = []
|
||||
self.join_transform = []
|
||||
for tf in transform:
|
||||
if isinstance(tf, tuple(self._bg_transforms)):
|
||||
self.bg_transform.append(tf)
|
||||
else:
|
||||
self.join_transform.append(tf)
|
||||
|
||||
if _album_compose:
|
||||
from data.album_transf import AlbumTorchCompose
|
||||
|
||||
self.bg_transform = AlbumTorchCompose(self.bg_transform)
|
||||
self.join_transform = AlbumTorchCompose(self.join_transform)
|
||||
else:
|
||||
self.bg_transform = T.Compose(self.bg_transform)
|
||||
self.join_transform = T.Compose(self.join_transform)
|
||||
|
||||
else:
|
||||
if isinstance(transform, list):
|
||||
if _album_compose:
|
||||
from data.album_transf import AlbumTorchCompose
|
||||
|
||||
self.join_transform = AlbumTorchCompose(transform)
|
||||
else:
|
||||
self.join_transform = T.Compose(transform)
|
||||
else:
|
||||
self.join_transform = transform
|
||||
self.bg_transform = None
|
||||
|
||||
self.trgt_map = {cls: i for i, cls in enumerate(self.classes)}
|
||||
|
||||
self.target_transform = target_transform
|
||||
|
||||
self.cls_to_allowed_bg = {}
|
||||
for bg_file in self.backgrounds:
|
||||
if background_combination == "same":
|
||||
bg_cls = bg_file.split("/")[-2]
|
||||
if bg_cls not in self.cls_to_allowed_bg:
|
||||
self.cls_to_allowed_bg[bg_cls] = []
|
||||
self.cls_to_allowed_bg[bg_cls].append(bg_file)
|
||||
|
||||
if background_combination == "same":
|
||||
for cls_code in classes:
|
||||
if cls_code not in self.cls_to_allowed_bg or len(self.cls_to_allowed_bg[cls_code]) == 0:
|
||||
self.cls_to_allowed_bg[cls_code] = [backup_backgrounds[cls_code]]
|
||||
logger.warning(f"No background for class {cls_code}, using {backup_backgrounds[cls_code]}")
|
||||
|
||||
self._zf = {}
|
||||
|
||||
@property
|
||||
def epoch(self):
|
||||
return self._epoch
|
||||
|
||||
@epoch.setter
|
||||
def epoch(self, value):
|
||||
self._epoch = value
|
||||
|
||||
def __len__(self):
|
||||
"""Size of the dataset.
|
||||
|
||||
Returns:
|
||||
int: number of foregrounds
|
||||
|
||||
"""
|
||||
return len(self.foregrounds)
|
||||
|
||||
def num_classes(self):
|
||||
return len(self.classes)
|
||||
|
||||
def _get_fg(self, idx):
|
||||
worker_id = self._wrkr_info()
|
||||
|
||||
fg_file = self.foregrounds[idx]
|
||||
with self._zf[worker_id]["fg"].open(fg_file) as f:
|
||||
fg_data = BytesIO(f.read())
|
||||
return Image.open(fg_data)
|
||||
|
||||
def _wrkr_info(self):
|
||||
worker_id = get_worker_info().id if get_worker_info() else 0
|
||||
|
||||
if worker_id not in self._zf and self._mode == "zip":
|
||||
self._zf[worker_id] = {
|
||||
"bg": zipfile.ZipFile(f"{self.root}/backgrounds_{'train' if self.train else 'val'}.zip", "r"),
|
||||
"fg": zipfile.ZipFile(f"{self.root}/foregrounds_{'train' if self.train else 'val'}.zip", "r"),
|
||||
}
|
||||
return worker_id
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get the foreground at index idx and combine it with a (random) background.
|
||||
|
||||
Args:
|
||||
idx (int): foreground index
|
||||
|
||||
Returns:
|
||||
torch.Tensor, torch.Tensor: image, target
|
||||
|
||||
"""
|
||||
worker_id = self._wrkr_info()
|
||||
fg_file = self.foregrounds[idx]
|
||||
trgt_cls = fg_file.split("/")[-2]
|
||||
|
||||
if (
|
||||
(self.orig_img_prob == "linear" and np.random.rand() < self._epoch / self.epochs)
|
||||
or (self.orig_img_prob == "revlinear" and np.random.rand() < (self._epoch - self.epochs) / self.epochs)
|
||||
or (self.orig_img_prob == "cos" and np.random.rand() > np.cos(np.pi * self._epoch / (2 * self.epochs)))
|
||||
or (
|
||||
isinstance(self.orig_img_prob, float)
|
||||
and self.orig_img_prob > 0.0
|
||||
and np.random.rand() < self.orig_img_prob
|
||||
)
|
||||
):
|
||||
data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
|
||||
if isinstance(self.orig_ds, str):
|
||||
image_file = os.path.join(self.orig_ds, f"{data_key}.{self._orig_ds_file_type}")
|
||||
orig_img = Image.open(image_file).convert("RGB")
|
||||
else:
|
||||
orig_data = self.orig_ds[self.key_to_orig_idx[data_key]]
|
||||
orig_img = orig_data["image"] if isinstance(orig_data, dict) else orig_data[0]
|
||||
|
||||
if self.bg_transform:
|
||||
orig_img = self.bg_transform(orig_img)
|
||||
if self.join_transform:
|
||||
orig_img = self.join_transform(orig_img)
|
||||
trgt = self.trgt_map[trgt_cls]
|
||||
if self.target_transform:
|
||||
trgt = self.target_transform(trgt)
|
||||
return orig_img, trgt
|
||||
|
||||
if self._mode == "zip":
|
||||
with self._zf[worker_id]["fg"].open(fg_file) as f:
|
||||
fg_data = BytesIO(f.read())
|
||||
try:
|
||||
fg_img = Image.open(fg_data).convert("RGBA")
|
||||
except PIL.UnidentifiedImageError as e:
|
||||
logger.error(f"Error with idx={idx}, file={self.foregrounds[idx]}")
|
||||
raise e
|
||||
else:
|
||||
# data_key = f"{trgt_cls}/{fg_file.split('/')[-1].split('.')[0]}"
|
||||
fg_img = Image.open(
|
||||
os.path.join(self.root, "train" if self.train else "val", "foregrounds", fg_file)
|
||||
).convert("RGBA")
|
||||
|
||||
if self.fg_transform:
|
||||
fg_img = self.fg_transform(fg_img)
|
||||
fg_size_factor = T.ToTensor()(fg_img.split()[-1]).mean().item()
|
||||
|
||||
if self.background_combination == "all":
|
||||
bg_idx = np.random.randint(len(self.backgrounds))
|
||||
bg_file = self.backgrounds[bg_idx]
|
||||
elif self.background_combination == "original":
|
||||
bg_file = fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")
|
||||
else:
|
||||
bg_idx = np.random.randint(len(self.cls_to_allowed_bg[trgt_cls]))
|
||||
bg_file = self.cls_to_allowed_bg[trgt_cls][bg_idx]
|
||||
|
||||
if self._mode == "zip":
|
||||
with self._zf[worker_id]["bg"].open(bg_file) as f:
|
||||
bg_data = BytesIO(f.read())
|
||||
bg_img = Image.open(bg_data).convert("RGB")
|
||||
else:
|
||||
bg_img = Image.open(
|
||||
os.path.join(self.root, "train" if self.train else "val", "backgrounds", bg_file)
|
||||
).convert("RGB")
|
||||
|
||||
if not self.paste_pre_transform:
|
||||
bg_img = self.bg_transform(bg_img)
|
||||
|
||||
bg_size = bg_img.size
|
||||
|
||||
# choose scale factor, such that relative area is in fg_scale
|
||||
bg_area = bg_size[0] * bg_size[1]
|
||||
if self.fg_in_nonant is not None:
|
||||
bg_area = bg_area / 9
|
||||
|
||||
# logger.info(f"background: size={bg_size} area={bg_area}")
|
||||
# logger.info(f"fg_file={fg_file}, fg_bg_ratio_keys={list(self.fg_bg_ratios.keys())[:3]}...")
|
||||
orig_fg_ratio = self.fg_bg_ratios[fg_file.replace("foregrounds", "backgrounds").replace("WEBP", "JPEG")]
|
||||
bg_fg_ratio = self.fg_bg_ratios[bg_file]
|
||||
|
||||
if self.fg_size_mode == "max":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
|
||||
elif self.fg_size_mode == "min":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = min(orig_fg_ratio, bg_fg_ratio)
|
||||
elif self.fg_size_mode == "mean":
|
||||
goal_fg_ratio_lower = goal_fg_ratio_upper = (orig_fg_ratio + bg_fg_ratio) / 2
|
||||
else:
|
||||
# range
|
||||
goal_fg_ratio_lower = min(orig_fg_ratio, bg_fg_ratio)
|
||||
goal_fg_ratio_upper = max(orig_fg_ratio, bg_fg_ratio)
|
||||
|
||||
# logger.info(f"fg_bg_ratio={orig_fg_ratio}")
|
||||
fg_scale = (
|
||||
np.random.uniform(
|
||||
goal_fg_ratio_lower * (1 - self.fg_scale_jitter), goal_fg_ratio_upper * (1 + self.fg_scale_jitter)
|
||||
)
|
||||
/ fg_size_factor
|
||||
* self.size_fact
|
||||
)
|
||||
|
||||
goal_shape_y = round(np.sqrt(bg_area * fg_scale * fg_img.size[1] / fg_img.size[0]))
|
||||
goal_shape_x = round(np.sqrt(bg_area * fg_scale * fg_img.size[0] / fg_img.size[1]))
|
||||
|
||||
fg_img = fg_img.resize((goal_shape_x, goal_shape_y))
|
||||
|
||||
if fg_img.size[0] > bg_size[0] or fg_img.size[1] > bg_size[1]:
|
||||
# random crop to fit
|
||||
goal_w, goal_h = (min(fg_img.size[0], bg_size[0]), min(fg_img.size[1], bg_size[1]))
|
||||
fg_img = T.RandomCrop((goal_h, goal_w))(fg_img) if self.train else T.CenterCrop((goal_h, goal_w))(fg_img)
|
||||
|
||||
# paste fg on bg
|
||||
z1, z2 = (
|
||||
(
|
||||
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(), # bates distribution n=1 => uniform
|
||||
np.random.uniform(0, 1, abs(self.fg_bates_n)).mean(),
|
||||
)
|
||||
if self.fg_bates_n != 0
|
||||
else (0.5, 0.5)
|
||||
)
|
||||
if self.fg_bates_n < 0:
|
||||
z1 = z1 + 0.5 - floor(z1 + 0.5)
|
||||
z2 = z2 + 0.5 - floor(z2 + 0.5)
|
||||
|
||||
x_min = -self.rel_jut_out * fg_img.size[0]
|
||||
x_max = bg_size[0] - fg_img.size[0] * (1 - self.rel_jut_out)
|
||||
y_min = -self.rel_jut_out * fg_img.size[1]
|
||||
y_max = bg_size[1] - fg_img.size[1] * (1 - self.rel_jut_out)
|
||||
|
||||
if self.fg_in_nonant is not None and self.fg_in_nonant >= 0:
|
||||
x_min = (self.fg_in_nonant % 3) * bg_size[0] / 3
|
||||
x_max = ((self.fg_in_nonant % 3) + 1) * bg_size[0] / 3 - fg_img.size[0]
|
||||
y_min = (self.fg_in_nonant // 3) * bg_size[1] / 3
|
||||
y_max = ((self.fg_in_nonant // 3) + 1) * bg_size[1] / 3 - fg_img.size[1]
|
||||
|
||||
if x_min > x_max:
|
||||
x_min = x_max = (x_min + x_max) / 2
|
||||
if y_min > y_max:
|
||||
y_min = y_max = (y_min + y_max) / 2
|
||||
|
||||
offs_x = round(z1 * (x_max - x_min) + x_min)
|
||||
offs_y = round(z2 * (y_max - y_min) + y_min)
|
||||
|
||||
paste_mask = fg_img.split()[-1]
|
||||
if self.mask_smoothing_sigma > 0.0:
|
||||
sigma = (np.random.rand() * 0.9 + 0.1) * self.mask_smoothing_sigma
|
||||
paste_mask = paste_mask.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
paste_mask = paste_mask.point(lambda p: 2 * p - 255 if p > 128 else 0)
|
||||
|
||||
bg_img.paste(fg_img.convert("RGB"), (offs_x, offs_y), paste_mask)
|
||||
bg_img = bg_img.convert("RGB")
|
||||
|
||||
if self.return_fg_masks:
|
||||
fg_mask = Image.new("L", bg_size, 0)
|
||||
fg_mask.paste(paste_mask, (offs_x, offs_y))
|
||||
|
||||
fg_mask = T.ToTensor()(fg_mask)[0]
|
||||
|
||||
bg_img = T.ToTensor()(bg_img)
|
||||
|
||||
if self.join_transform:
|
||||
# img_mask_stack = torch.cat([bg_img, fg_mask.unsqueeze(0)], dim=0)
|
||||
# img_mask_stack = self.join_transform(img_mask_stack)
|
||||
# bg_img, fg_mask = img_mask_stack[:-1], img_mask_stack[-1]
|
||||
bg_img, fg_mask = apply_dense_transforms(bg_img, fg_mask, self.join_transform)
|
||||
else:
|
||||
bg_img = self.join_transform(bg_img)
|
||||
|
||||
if trgt_cls not in self.trgt_map:
|
||||
raise ValueError(f"trgt_cls={trgt_cls} not in trgt_map: {self.trgt_map}")
|
||||
trgt = self.trgt_map[trgt_cls]
|
||||
if self.target_transform:
|
||||
trgt = self.target_transform(trgt)
|
||||
|
||||
if self.return_fg_masks:
|
||||
return bg_img, trgt, fg_mask
|
||||
|
||||
return bg_img, trgt
|
||||
@@ -0,0 +1,151 @@
|
||||
import argparse
|
||||
import shutil
|
||||
import zipfile
|
||||
from os import listdir, makedirs, path
|
||||
from random import choice
|
||||
|
||||
from datadings.reader import MsgpackReader
|
||||
from datadings.writer import FileWriter
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-tiny_imagenet_zip", type=str, required=True, help="Path to the Tiny ImageNet zip file")
|
||||
parser.add_argument("-output_dir", type=str, required=True, help="Directory to extract the image names to")
|
||||
parser.add_argument("-in_segment_dir", type=str, required=True, help="Directory that holds the segmented ImageNet")
|
||||
parser.add_argument(
|
||||
"-imagenet_path", type=str, nargs="?", required=True, help="Path to the original ImageNet dataset (datadings)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
images = {"train": set(), "val": set()}
|
||||
|
||||
with zipfile.ZipFile(args.tiny_imagenet_zip, "r") as zip_ref:
|
||||
for info in tqdm(zip_ref.infolist(), desc="Gathering Images"):
|
||||
if info.filename.endswith(".JPEG"):
|
||||
if "/val/" in info.filename:
|
||||
images["val"].add(info.filename.split("/")[-1])
|
||||
elif "/train/" in info.filename:
|
||||
images["train"].add(info.filename.split("/")[-1])
|
||||
|
||||
with open(path.join(args.output_dir, "tiny_imagenet_train_images.txt"), "w+") as f:
|
||||
f.write("\n".join(images["train"]))
|
||||
with open(path.join(args.output_dir, "tiny_imagenet_val_images.txt"), "w+") as f:
|
||||
f.write("\n".join(images["val"]))
|
||||
|
||||
print(f"Found {len(images['train'])} training images and {len(images['val'])} validation images")
|
||||
classes = {img_name.split("_")[0] for img_name in images["train"]}
|
||||
|
||||
classes = sorted(list(classes), key=lambda x: int(x[1:]))
|
||||
assert len(classes) == 200, f"Expected 200 classes, found {len(classes)}"
|
||||
assert len(images["train"]) == len(classes) * 500, f"Expected 100000 training images, found {len(images['train'])}"
|
||||
assert len(images["val"]) == len(classes) * 50, f"Expected 10000 validation images, found {len(images['val'])}"
|
||||
with open(path.join(args.output_dir, "tiny_imagenet_classes.txt"), "w+") as f:
|
||||
f.write("\n".join(classes))
|
||||
|
||||
# copy over the relevant images
|
||||
for split in ["train", "val"]:
|
||||
ipc = 500 if split == "train" else 50
|
||||
part = "foregrounds_WEBP"
|
||||
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
|
||||
for synset in classes:
|
||||
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
|
||||
tqdm.write(
|
||||
f"skip {split} > {part} > {synset} with"
|
||||
f" {len(listdir(path.join(args.output_dir, split, part, synset)))} ims"
|
||||
)
|
||||
pbar.update(ipc)
|
||||
continue
|
||||
for img in listdir(path.join(args.in_segment_dir, split, part, synset)):
|
||||
orig_name = (
|
||||
img.split(".")[0] + ".JPEG"
|
||||
if split == "train"
|
||||
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
|
||||
)
|
||||
if orig_name in images[split]:
|
||||
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
shutil.copy(
|
||||
path.join(args.in_segment_dir, split, part, synset, img),
|
||||
path.join(args.output_dir, split, part, synset, img),
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
while len(listdir(path.join(args.output_dir, split, part, synset))) < min(
|
||||
ipc, len(listdir(path.join(args.in_segment_dir, split, part, synset)))
|
||||
):
|
||||
# copy over more random images
|
||||
image_names = [
|
||||
(
|
||||
img,
|
||||
(
|
||||
img.split(".")[0] + ".JPEG"
|
||||
if split == "train"
|
||||
else f"val_{int(img.split('_')[-1].split('.')[0])}.JPEG"
|
||||
),
|
||||
)
|
||||
for img in listdir(path.join(args.in_segment_dir, split, part, synset))
|
||||
]
|
||||
image_names = [
|
||||
img for img in image_names if img[1] not in listdir(path.join(args.output_dir, split, part, synset))
|
||||
]
|
||||
img = choice(image_names)[0]
|
||||
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
shutil.copy(
|
||||
path.join(args.in_segment_dir, split, part, synset, img),
|
||||
path.join(args.output_dir, split, part, synset, img),
|
||||
)
|
||||
pbar.update(1)
|
||||
tqdm.write(f"Extra image: {orig_name} to {split}/{part}/{synset}")
|
||||
|
||||
# copy over the background images corresponding to those foregrounds
|
||||
part = "backgrounds_JPEG"
|
||||
with tqdm(total=len(images[split]), desc=f"Copying images for {split} > {part}") as pbar:
|
||||
for synset in classes:
|
||||
makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
if len(listdir(path.join(args.output_dir, split, part, synset))) >= ipc:
|
||||
tqdm.write(f"skip {split} > {part} > {synset}")
|
||||
pbar.update(ipc)
|
||||
continue
|
||||
for img in listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset)):
|
||||
bg_name = img.replace(".WEBP", ".JPEG")
|
||||
# makedirs(path.join(args.output_dir, split, part, synset), exist_ok=True)
|
||||
shutil.copy(
|
||||
path.join(args.in_segment_dir, split, part, synset, bg_name),
|
||||
path.join(args.output_dir, split, part, synset, bg_name),
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
assert len(listdir(path.join(args.output_dir, split, part, synset))) == len(
|
||||
listdir(path.join(args.output_dir, split, "foregrounds_WEBP", synset))
|
||||
), (
|
||||
f"Expected {len(listdir(path.join(args.output_dir, split, 'foregrounds_WEBP', synset)))} background"
|
||||
f" images, found {len(listdir(path.join(args.output_dir, split, part, synset)))}"
|
||||
)
|
||||
|
||||
# write the original dataset to datadings
|
||||
for part in ["train", "val"]:
|
||||
reader = MsgpackReader(path.join(args.imagenet_path, f"{part}.msgpack"))
|
||||
with FileWriter(path.join(args.output_dir, f"TinyIN_{part}.msgpack")) as writer:
|
||||
for data in tqdm(reader, desc=f"Writing {part} to datadings"):
|
||||
key = data["key"].split("/")[-1]
|
||||
allowed_synsets = [key.split("_")[0]] if part == "train" else classes
|
||||
|
||||
if part == "train" and allowed_synsets[0] not in classes:
|
||||
continue
|
||||
|
||||
keep_image = False
|
||||
label_synset = None
|
||||
for synset in allowed_synsets:
|
||||
for img in listdir(path.join(args.output_dir, part, "foregrounds_WEBP", synset)):
|
||||
if img.split(".")[0] == key.split(".")[0]:
|
||||
keep_image = True
|
||||
label_synset = synset
|
||||
break
|
||||
|
||||
if not keep_image:
|
||||
continue
|
||||
|
||||
data["label"] = classes.index(label_synset)
|
||||
|
||||
writer.write(data)
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,48 @@
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
parser = argparse.ArgumentParser("Script to convert ImageNet trained models to ImageNet-9")
|
||||
parser.add_argument("-m", "--model", type=str, required=True, help="Model weights (.pt file).")
|
||||
parser.add_argument(
|
||||
"--in_to_in9", type=str, default="/ds-sds/images/ImageNet-9/in_to_in9.json", help="Path to in_to_in9.json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
checkpoint = torch.load(args.model, map_location="cpu")
|
||||
|
||||
model_state = checkpoint["model_state"]
|
||||
head_keys = [k for k in model_state.keys() if ".head." in k or ".fc." in k]
|
||||
print("weights that will be modified:", head_keys)
|
||||
assert len(head_keys) > 0, "no head keys found :("
|
||||
|
||||
with open(args.in_to_in9, "r") as f:
|
||||
in_to_in9_classes = json.load(f)
|
||||
print(f"{len([k for k, v in in_to_in9_classes.items() if v == -1])} classes get mapped to -1")
|
||||
|
||||
print("map", len(in_to_in9_classes), " classes to", set(in_to_in9_classes.values()))
|
||||
|
||||
print("Building conversion matrix")
|
||||
conversion_matrix = torch.zeros((9, 1000))
|
||||
for in_idx, in9_idx in in_to_in9_classes.items():
|
||||
if in9_idx == -1:
|
||||
continue
|
||||
in_idx = int(in_idx)
|
||||
conversion_matrix[in9_idx, in_idx] = 1
|
||||
print(f"Conversion matrix ({conversion_matrix.shape}) has {int(conversion_matrix.sum().item())} non-zero values")
|
||||
|
||||
for head_key in head_keys:
|
||||
print(f"converting {head_key} of shape {model_state[head_key].shape}", end=" ")
|
||||
model_state[head_key] = conversion_matrix @ model_state[head_key]
|
||||
print(f"\tto {model_state[head_key].shape}")
|
||||
|
||||
checkpoint["model_state"] = model_state
|
||||
checkpoint["args"]["n_classes"] = 9
|
||||
save_folder = os.path.dirname(args.model)
|
||||
orig_model_name = args.model.split(os.sep)[-1]
|
||||
new_model_name = ".".join(orig_model_name.split(".")[:-1]) + "_to_in9." + orig_model_name.split(".")[-1]
|
||||
print(f"saving model as {new_model_name} in {save_folder}")
|
||||
torch.save(checkpoint, os.path.join(save_folder, new_model_name))
|
||||
@@ -0,0 +1,200 @@
|
||||
n07695742: pretzel
|
||||
n03902125: pay-phone, pay-station
|
||||
n03980874: poncho
|
||||
n01644900: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
|
||||
n02730930: apron
|
||||
n02699494: altar
|
||||
n03201208: dining table, board
|
||||
n02056570: king penguin, Aptenodytes patagonica
|
||||
n04099969: rocking chair, rocker
|
||||
n04366367: suspension bridge
|
||||
n04067472: reel
|
||||
n02808440: bathtub, bathing tub, bath, tub
|
||||
n04540053: volleyball
|
||||
n02403003: ox
|
||||
n03100240: convertible
|
||||
n04562935: water tower
|
||||
n02788148: bannister, banister, balustrade, balusters, handrail
|
||||
n02988304: CD player
|
||||
n02423022: gazelle
|
||||
n03637318: lampshade, lamp shade
|
||||
n01774384: black widow, Latrodectus mactans
|
||||
n01768244: trilobite
|
||||
n07614500: ice cream, icecream
|
||||
n04254777: sock
|
||||
n02085620: Chihuahua
|
||||
n01443537: goldfish, Carassius auratus
|
||||
n01629819: European fire salamander, Salamandra salamandra
|
||||
n02099601: golden retriever
|
||||
n02321529: sea cucumber, holothurian
|
||||
n03837869: obelisk
|
||||
n02002724: black stork, Ciconia nigra
|
||||
n02841315: binoculars, field glasses, opera glasses
|
||||
n04560804: water jug
|
||||
n02364673: guinea pig, Cavia cobaya
|
||||
n03706229: magnetic compass
|
||||
n09256479: coral reef
|
||||
n09332890: lakeside, lakeshore
|
||||
n03544143: hourglass
|
||||
n02124075: Egyptian cat
|
||||
n02948072: candle, taper, wax light
|
||||
n01950731: sea slug, nudibranch
|
||||
n02791270: barbershop
|
||||
n03179701: desk
|
||||
n02190166: fly
|
||||
n04275548: spider web, spider's web
|
||||
n04417672: thatch, thatched roof
|
||||
n03930313: picket fence, paling
|
||||
n02236044: mantis, mantid
|
||||
n03976657: pole
|
||||
n01774750: tarantula
|
||||
n04376876: syringe
|
||||
n04133789: sandal
|
||||
n02099712: Labrador retriever
|
||||
n04532670: viaduct
|
||||
n04487081: trolleybus, trolley coach, trackless trolley
|
||||
n09428293: seashore, coast, seacoast, sea-coast
|
||||
n03160309: dam, dike, dyke
|
||||
n03250847: drumstick
|
||||
n02843684: birdhouse
|
||||
n07768694: pomegranate
|
||||
n03670208: limousine, limo
|
||||
n03085013: computer keyboard, keypad
|
||||
n02892201: brass, memorial tablet, plaque
|
||||
n02233338: cockroach, roach
|
||||
n03649909: lawn mower, mower
|
||||
n03388043: fountain
|
||||
n02917067: bullet train, bullet
|
||||
n02486410: baboon
|
||||
n04596742: wok
|
||||
n03255030: dumbbell
|
||||
n03937543: pill bottle
|
||||
n02113799: standard poodle
|
||||
n03977966: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
|
||||
n02906734: broom
|
||||
n07920052: espresso
|
||||
n01698640: American alligator, Alligator mississipiensis
|
||||
n02123394: Persian cat
|
||||
n03424325: gasmask, respirator, gas helmet
|
||||
n02129165: lion, king of beasts, Panthera leo
|
||||
n04008634: projectile, missile
|
||||
n03042490: cliff dwelling
|
||||
n02415577: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
|
||||
n02815834: beaker
|
||||
n02395406: hog, pig, grunter, squealer, Sus scrofa
|
||||
n01784675: centipede
|
||||
n03126707: crane
|
||||
n04399382: teddy, teddy bear
|
||||
n07875152: potpie
|
||||
n03733131: maypole
|
||||
n02802426: basketball
|
||||
n03891332: parking meter
|
||||
n01910747: jellyfish
|
||||
n03838899: oboe, hautboy, hautbois
|
||||
n03770439: miniskirt, mini
|
||||
n02281406: sulphur butterfly, sulfur butterfly
|
||||
n03970156: plunger, plumber's helper
|
||||
n09246464: cliff, drop, drop-off
|
||||
n02206856: bee
|
||||
n02074367: dugong, Dugong dugon
|
||||
n03584254: iPod
|
||||
n04179913: sewing machine
|
||||
n04328186: stopwatch, stop watch
|
||||
n07583066: guacamole
|
||||
n01917289: brain coral
|
||||
n03447447: gondola
|
||||
n02823428: beer bottle
|
||||
n03854065: organ, pipe organ
|
||||
n02793495: barn
|
||||
n04285008: sports car, sport car
|
||||
n02231487: walking stick, walkingstick, stick insect
|
||||
n04465501: tractor
|
||||
n02814860: beacon, lighthouse, beacon light, pharos
|
||||
n02883205: bow tie, bow-tie, bowtie
|
||||
n02165456: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
|
||||
n04149813: scoreboard
|
||||
n04023962: punching bag, punch bag, punching ball, punchball
|
||||
n02226429: grasshopper, hopper
|
||||
n02279972: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
|
||||
n02669723: academic gown, academic robe, judge's robe
|
||||
n04486054: triumphal arch
|
||||
n04070727: refrigerator, icebox
|
||||
n03444034: go-kart
|
||||
n02666196: abacus
|
||||
n01945685: slug
|
||||
n04251144: snorkel
|
||||
n03617480: kimono
|
||||
n03599486: jinrikisha, ricksha, rickshaw
|
||||
n02437312: Arabian camel, dromedary, Camelus dromedarius
|
||||
n01984695: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
|
||||
n04118538: rugby ball
|
||||
n01770393: scorpion
|
||||
n04356056: sunglasses, dark glasses, shades
|
||||
n03804744: nail
|
||||
n02132136: brown bear, bruin, Ursus arctos
|
||||
n03400231: frying pan, frypan, skillet
|
||||
n03983396: pop bottle, soda bottle
|
||||
n07734744: mushroom
|
||||
n02480495: orangutan, orang, orangutang, Pongo pygmaeus
|
||||
n02410509: bison
|
||||
n03404251: fur coat
|
||||
n04456115: torch
|
||||
n02123045: tabby, tabby cat
|
||||
n03026506: Christmas stocking
|
||||
n07715103: cauliflower
|
||||
n04398044: teapot
|
||||
n02927161: butcher shop, meat market
|
||||
n07749582: lemon
|
||||
n07615774: ice lolly, lolly, lollipop, popsicle
|
||||
n02795169: barrel, cask
|
||||
n04532106: vestment
|
||||
n02837789: bikini, two-piece
|
||||
n02814533: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
|
||||
n04265275: space heater
|
||||
n02481823: chimpanzee, chimp, Pan troglodytes
|
||||
n02509815: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
|
||||
n06596364: comic book
|
||||
n01983481: American lobster, Northern lobster, Maine lobster, Homarus americanus
|
||||
n02504458: African elephant, Loxodonta africana
|
||||
n03014705: chest
|
||||
n01944390: snail
|
||||
n04146614: school bus
|
||||
n01641577: bullfrog, Rana catesbeiana
|
||||
n07720875: bell pepper
|
||||
n02999410: chain
|
||||
n01855672: goose
|
||||
n02125311: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
|
||||
n07753592: banana
|
||||
n07871810: meat loaf, meatloaf
|
||||
n04501370: turnstile
|
||||
n04311004: steel arch bridge
|
||||
n02977058: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
|
||||
n04074963: remote control, remote
|
||||
n03662601: lifeboat
|
||||
n02106662: German shepherd, German shepherd dog, German police dog, alsatian
|
||||
n03089624: confectionery, confectionary, candy store
|
||||
n04259630: sombrero
|
||||
n03393912: freight car
|
||||
n04597913: wooden spoon
|
||||
n07711569: mashed potato
|
||||
n03355925: flagpole, flagstaff
|
||||
n02963159: cardigan
|
||||
n07579787: plate
|
||||
n02950826: cannon
|
||||
n01882714: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
|
||||
n02094433: Yorkshire terrier
|
||||
n02909870: bucket, pail
|
||||
n02058221: albatross, mollymawk
|
||||
n01742172: boa constrictor, Constrictor constrictor
|
||||
n09193705: alp
|
||||
n04371430: swimming trunks, bathing trunks
|
||||
n07747607: orange
|
||||
n03814639: neck brace
|
||||
n04507155: umbrella
|
||||
n02268443: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
|
||||
n02769748: backpack, back pack, knapsack, packsack, rucksack, haversack
|
||||
n03763968: military uniform
|
||||
n07873807: pizza, pizza pie
|
||||
n03992509: potter's wheel
|
||||
n03796401: moving van
|
||||
n12267677: acorn
|
||||
@@ -0,0 +1,73 @@
|
||||
# Repeat Augment sampler taken from DeiT: https://github.com/facebookresearch/deit/blob/main/samplers.py
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class RASampler(torch.utils.data.Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation.
|
||||
|
||||
It ensures that different each augmented version of a sample will be visible to a
|
||||
different process (GPU)
|
||||
Heavily based on torch.utils.data.DistributedSampler
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
if num_repeats < 1:
|
||||
raise ValueError("num_repeats should be greater than 0")
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.num_repeats = num_repeats
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
|
||||
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(self.dataset), generator=g)
|
||||
else:
|
||||
indices = torch.arange(start=0, end=len(self.dataset))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
|
||||
padding_size: int = self.total_size - len(indices)
|
||||
if padding_size > 0:
|
||||
indices += indices[:padding_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices[: self.num_selected_samples])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_selected_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"{type(self).__name__}(num_replicas: {self.num_replicas}, rank: {self.rank}, num_repeats:"
|
||||
f" {self.num_repeats}, epoch: {self.epoch}, num_samples: {self.num_samples}, total_size: {self.total_size},"
|
||||
f" num_selected_samples: {self.num_selected_samples}, shuffle: {self.shuffle})"
|
||||
)
|
||||
@@ -0,0 +1,381 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from copy import copy
|
||||
|
||||
from loguru import logger
|
||||
from nltk.corpus import wordnet as wn
|
||||
|
||||
|
||||
class bcolors:
|
||||
HEADER = "\033[95m"
|
||||
OKBLUE = "\033[94m"
|
||||
OKCYAN = "\033[96m"
|
||||
OKGREEN = "\033[92m"
|
||||
WARNING = "\033[93m"
|
||||
FAIL = "\033[91m"
|
||||
ENDC = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
UNDERLINE = "\033[4m"
|
||||
|
||||
|
||||
def _lemmas_str(synset):
|
||||
return ", ".join([lemma.name() for lemma in synset.lemmas()])
|
||||
|
||||
|
||||
class WNEntry:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
id: int,
|
||||
lemmas: str,
|
||||
parent_id: int,
|
||||
depth: int = None,
|
||||
in_image_net: bool = False,
|
||||
child_ids: list = None,
|
||||
in_main_tree: bool = True,
|
||||
_n_images: int = 0,
|
||||
_description: str = None,
|
||||
_name: str = None,
|
||||
_pruned: bool = False,
|
||||
):
|
||||
self.name = name
|
||||
self.id = id
|
||||
self.lemmas = lemmas
|
||||
self.parent_id = parent_id
|
||||
self.depth = depth
|
||||
self.in_image_net = in_image_net
|
||||
self.child_ids = child_ids
|
||||
self.in_main_tree = in_main_tree
|
||||
self._n_images = _n_images
|
||||
self._description = _description
|
||||
self._name = _name
|
||||
self._pruned = _pruned
|
||||
|
||||
def __str__(self, tree=None, accumulate=True):
|
||||
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}" if self.in_image_net else f"{bcolors.FAIL}-{bcolors.ENDC}"
|
||||
n_ims = f"{self._n_images} of Σ {self.n_images(tree)}" if accumulate and tree is not None else self._n_images
|
||||
if self.child_ids is None or tree is None:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
||||
else:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
||||
["\n ".join(tree.nodes[child_id].__str__(tree).split("\n")) for child_id in self.child_ids]
|
||||
)
|
||||
|
||||
def tree_diff(self, tree_1, tree_2):
|
||||
if tree_2[self.id]._n_images > tree_1[self.id]._n_images:
|
||||
start_symb = f"{bcolors.OKGREEN}+{bcolors.ENDC}"
|
||||
elif tree_2[self.id]._n_images < tree_1[self.id]._n_images:
|
||||
start_symb = f"{bcolors.FAIL}-{bcolors.ENDC}"
|
||||
else:
|
||||
start_symb = f"{bcolors.OKBLUE}={bcolors.ENDC}"
|
||||
n_ims = (
|
||||
f"{tree_1[self.id]._n_images} + {tree_2[self.id]._n_images - tree_1[self.id]._n_images} of Σ"
|
||||
f" {tree_1[self.id].n_images(tree_2)}/{tree_2[self.id].n_images(tree_2)}"
|
||||
)
|
||||
|
||||
if self.child_ids is None:
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}"
|
||||
|
||||
return f"{start_symb}{self.name} ({self.id}) > {n_ims}\n " + "\n ".join(
|
||||
["\n ".join(tree_1.nodes[child_id].tree_diff(tree_1, tree_2).split("\n")) for child_id in self.child_ids]
|
||||
)
|
||||
|
||||
def prune(self, tree):
|
||||
if self._pruned:
|
||||
return
|
||||
|
||||
if self.child_ids is not None:
|
||||
for child_id in self.child_ids:
|
||||
tree[child_id].prune(tree)
|
||||
|
||||
self._pruned = True
|
||||
parent_node = tree.nodes[self.parent_id]
|
||||
try:
|
||||
parent_node.child_ids.remove(self.id)
|
||||
except ValueError as e:
|
||||
print(
|
||||
f"Error removing {self.name} from"
|
||||
f" {parent_node.name} ({[tree[cid].name for cid in parent_node.child_ids]}): {e}"
|
||||
)
|
||||
while parent_node._pruned:
|
||||
parent_node = tree.nodes[parent_node.parent_id]
|
||||
parent_node._n_images += self._n_images
|
||||
self._n_images = 0
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
if not self._description:
|
||||
self._description = wn.synset_from_pos_and_offset("n", self.id).definition()
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def print_name(self):
|
||||
return self.name.split(".")[0]
|
||||
|
||||
def get_branch(self, tree=None):
|
||||
if self.parent_id is None or tree is None:
|
||||
return self.print_name
|
||||
|
||||
parent = tree.nodes[self.parent_id]
|
||||
return parent.get_branch(tree) + " > " + self.print_name
|
||||
|
||||
def get_branch_list(self, tree):
|
||||
if self.parent_id is None:
|
||||
return [self]
|
||||
parent = tree.nodes[self.parent_id]
|
||||
return parent.get_branch_list(tree) + [self]
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"id": self.id,
|
||||
"lemmas": self.lemmas,
|
||||
"parent_id": self.parent_id,
|
||||
"depth": self.depth,
|
||||
"in_image_net": self.in_image_net,
|
||||
"child_ids": self.child_ids,
|
||||
"in_main_tree": self.in_main_tree,
|
||||
"_n_images": self._n_images,
|
||||
"_description": self._description,
|
||||
"_name": self._name,
|
||||
"_pruned": self._pruned,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(**d)
|
||||
|
||||
def n_images(self, tree=None):
|
||||
if tree is None or self.child_ids is None or len(self.child_ids) == 0:
|
||||
return self._n_images
|
||||
return sum([tree.nodes[child_id].n_images(tree) for child_id in self.child_ids]) + self._n_images
|
||||
|
||||
def n_children(self, tree=None):
|
||||
if self.child_ids is None:
|
||||
return 0
|
||||
if tree is None or len(self.child_ids) == 0:
|
||||
return len(self.child_ids)
|
||||
return len(self.child_ids) + sum([tree.nodes[child_id].n_children(tree) for child_id in self.child_ids])
|
||||
|
||||
def get_examples(self, tree, n_examples=3):
|
||||
if self.child_ids is None or len(self.child_ids) == 0:
|
||||
return ""
|
||||
child_images = {child_id: tree.nodes[child_id].n_images(tree) for child_id in self.child_ids}
|
||||
max_images = max(child_images.values())
|
||||
if max_images == 0:
|
||||
# go on number of child nodes
|
||||
child_images = {child_id: tree.nodes[child_id].n_children(tree) for child_id in self.child_ids}
|
||||
# sorted childids by number of images
|
||||
top_children = [
|
||||
child_id for child_id, n_images in sorted(child_images.items(), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
top_children = top_children[: min(n_examples, len(top_children))]
|
||||
return ", ".join(
|
||||
[f"{tree.nodes[child_id].print_name} ({tree.nodes[child_id].description})" for child_id in top_children]
|
||||
)
|
||||
|
||||
|
||||
class WNTree:
|
||||
def __init__(self, root=1740, nodes=None):
|
||||
if isinstance(root, int):
|
||||
root_id = root
|
||||
root_synset = wn.synset_from_pos_and_offset("n", root)
|
||||
root_node = WNEntry(
|
||||
root_synset.name(),
|
||||
root_id,
|
||||
_lemmas_str(root_synset),
|
||||
parent_id=None,
|
||||
depth=0,
|
||||
)
|
||||
else:
|
||||
assert isinstance(root, WNEntry)
|
||||
root_id = root.id
|
||||
root_node = root
|
||||
|
||||
self.root = root_node
|
||||
self.nodes = {root_id: self.root} if nodes is None else nodes
|
||||
self.parentless = []
|
||||
self.label_index = None
|
||||
self.pruned = set()
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"root": self.root.to_dict(),
|
||||
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
||||
"parentless": self.parentless,
|
||||
"pruned": list(self.pruned),
|
||||
}
|
||||
|
||||
def prune(self, min_images):
|
||||
pruned_nodes = set()
|
||||
|
||||
# prune all nodes that have fewer than min_images below them
|
||||
for node_id, node in self.nodes.items():
|
||||
if node.n_images(self) < min_images:
|
||||
pruned_nodes.add(node_id)
|
||||
node.prune(self)
|
||||
|
||||
# prune all nodes that have fewer than min_images inside them, after all nodes below have been pruned
|
||||
node_stack = [self.root]
|
||||
node_idx = 0
|
||||
while node_idx < len(node_stack):
|
||||
node = node_stack[node_idx]
|
||||
if node.child_ids is not None:
|
||||
for child_id in node.child_ids:
|
||||
child = self.nodes[child_id]
|
||||
node_stack.append(child)
|
||||
node_idx += 1
|
||||
|
||||
# now prune the stack from the bottom up
|
||||
for node in node_stack[::-1]:
|
||||
# only look at images of that class, not of additional children
|
||||
if node.n_images() < min_images:
|
||||
pruned_nodes.add(node.id)
|
||||
node.prune(self)
|
||||
|
||||
self.pruned = pruned_nodes
|
||||
return pruned_nodes
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
tree = cls()
|
||||
tree.root = WNEntry.from_dict(d["root"])
|
||||
tree.nodes = {int(node_id): WNEntry.from_dict(node) for node_id, node in d["nodes"].items()}
|
||||
tree.parentless = d["parentless"]
|
||||
if "pruned" in d:
|
||||
tree.pruned = set(d["pruned"])
|
||||
return tree
|
||||
|
||||
def add_node(self, node_id, in_in=True):
|
||||
if node_id in self.nodes:
|
||||
self.nodes[node_id].in_image_net = in_in or self.nodes[node_id].in_image_net
|
||||
return
|
||||
|
||||
synset = wn.synset_from_pos_and_offset("n", node_id)
|
||||
|
||||
# print(f"adding node {synset.name()} with id {node_id}")
|
||||
|
||||
hypernyms = synset.hypernyms()
|
||||
if len(hypernyms) == 0:
|
||||
parent_id = None
|
||||
self.parentless.append(node_id)
|
||||
main_tree = False
|
||||
print(f"--------- no hypernyms for {synset.name()} ({synset.offset()}) ------------")
|
||||
else:
|
||||
parent_id = synset.hypernyms()[0].offset()
|
||||
if parent_id not in self.nodes:
|
||||
self.add_node(parent_id, in_in=False)
|
||||
parent = self.nodes[parent_id]
|
||||
|
||||
if parent.child_ids is None:
|
||||
parent.child_ids = []
|
||||
parent.child_ids.append(node_id)
|
||||
main_tree = parent.in_main_tree
|
||||
|
||||
depth = self.nodes[parent_id].depth + 1 if parent_id is not None else 0
|
||||
node = WNEntry(
|
||||
synset.name(),
|
||||
node_id,
|
||||
_lemmas_str(synset),
|
||||
parent_id=parent_id,
|
||||
in_image_net=in_in,
|
||||
depth=depth,
|
||||
in_main_tree=main_tree,
|
||||
)
|
||||
|
||||
self.nodes[node_id] = node
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes)
|
||||
|
||||
def image_net_len(self, only_main_tree=False):
|
||||
return sum([node.in_image_net for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
||||
|
||||
def max_depth(self, only_main_tree=False):
|
||||
return max([node.depth for node in self.nodes.values() if node.in_main_tree or not only_main_tree])
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"WordNet Tree with {len(self)} nodes, {self.image_net_len()} in ImageNet21k;"
|
||||
f" {len(self.parentless)} parentless nodes:\n{self.root.__str__(tree=self)}\nParentless:\n"
|
||||
+ "\n".join([self.nodes[node_id].__str__(tree=self) for node_id in self.parentless])
|
||||
)
|
||||
|
||||
def save(self, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(self.to_dict(), f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path):
|
||||
with open(path, "r") as f:
|
||||
tree_dict = json.load(f)
|
||||
return cls.from_dict(tree_dict)
|
||||
|
||||
def subtree(self, node_id):
|
||||
if node_id not in self.nodes:
|
||||
return None
|
||||
node_queue = [self.nodes[node_id]]
|
||||
subtree_ids = set()
|
||||
while len(node_queue) > 0:
|
||||
node = node_queue.pop(0)
|
||||
subtree_ids.add(node.id)
|
||||
if node.child_ids is not None:
|
||||
node_queue += [self.nodes[child_id] for child_id in node.child_ids]
|
||||
subtree_nodes = {node_id: copy(self.nodes[node_id]) for node_id in subtree_ids}
|
||||
subtree_root = subtree_nodes[node_id]
|
||||
subtree_root.parent_id = None
|
||||
depth_diff = subtree_root.depth
|
||||
for node in subtree_nodes.values():
|
||||
node.depth -= depth_diff
|
||||
return WNTree(root=subtree_root, nodes=subtree_nodes)
|
||||
|
||||
def _make_label_index(self, include_merged=False):
|
||||
self.label_index = sorted(
|
||||
[
|
||||
node_id
|
||||
for node_id, node in self.nodes.items()
|
||||
if node.n_images(self if include_merged else None) > 0 and not node._pruned
|
||||
]
|
||||
)
|
||||
|
||||
def get_label(self, node_id):
|
||||
if self.label_index is None:
|
||||
self._make_label_index()
|
||||
while self.nodes[node_id]._pruned:
|
||||
node_id = self.nodes[node_id].parent_id
|
||||
return self.label_index.index(node_id)
|
||||
|
||||
def n_labels(self):
|
||||
if self.label_index is None:
|
||||
self._make_label_index()
|
||||
return len(self.label_index)
|
||||
|
||||
def __contains__(self, item):
|
||||
if isinstance(item, str):
|
||||
if item[0] == "n":
|
||||
item = int(item[1:])
|
||||
else:
|
||||
return False
|
||||
if isinstance(item, int):
|
||||
return item in self.nodes
|
||||
if isinstance(item, WNEntry):
|
||||
return item.id in self.nodes
|
||||
return False
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, str) and item[0].startswith("n"):
|
||||
try:
|
||||
item = int(item[1:])
|
||||
except ValueError:
|
||||
pass
|
||||
if isinstance(item, str) and ".n." in item:
|
||||
for node in self.nodes.values():
|
||||
if item == node.name:
|
||||
return node
|
||||
raise KeyError(f"Item {item} not found in tree")
|
||||
if isinstance(item, int):
|
||||
return self.nodes[item]
|
||||
if isinstance(item, WNEntry):
|
||||
return self.nodes[item.id]
|
||||
raise KeyError(f"Item {item} not found in tree")
|
||||
834
AAAI Supplementary Material/Model Training Code/engine.py
Normal file
834
AAAI Supplementary Material/Model Training Code/engine.py
Normal file
@@ -0,0 +1,834 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from math import isfinite
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from loguru import logger
|
||||
from timm.data import Mixup
|
||||
from timm.optim import create_optimizer
|
||||
from timm.scheduler import create_scheduler
|
||||
from torch import distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from metrics import calculate_metrics, per_class_counts
|
||||
from utils import (
|
||||
NoScaler,
|
||||
ScalerGradNormReturn,
|
||||
SchedulerArgs,
|
||||
log_formatter,
|
||||
save_model_state,
|
||||
)
|
||||
|
||||
try:
|
||||
from apex.optimizers import FusedLAMB # noqa: F401
|
||||
|
||||
apex_available = True
|
||||
except ImportError:
|
||||
logger.error("Nvidia apex not available")
|
||||
apex_available = False
|
||||
try:
|
||||
from lion_pytorch import Lion
|
||||
|
||||
lion_available = True
|
||||
except ImportError:
|
||||
logger.error("Lion not available")
|
||||
lion_available = False
|
||||
|
||||
|
||||
WANDB_AVAILABLE = False
|
||||
try:
|
||||
import wandb
|
||||
|
||||
WANDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.error("wandb not available")
|
||||
|
||||
|
||||
def wandb_available(turn_off=False):
|
||||
"""If wandb is available.
|
||||
|
||||
Args:
|
||||
turn_off (bool, optional): set wandb to be unavailble manually.
|
||||
|
||||
Returns:
|
||||
bool: wandb is available
|
||||
"""
|
||||
global WANDB_AVAILABLE
|
||||
if turn_off:
|
||||
WANDB_AVAILABLE = False
|
||||
return WANDB_AVAILABLE
|
||||
|
||||
|
||||
tqdm = partial(tqdm, leave=True, position=0) # noqa: F405
|
||||
|
||||
|
||||
def setup_tracking_and_logging(args, rank, append_model_path=None, log_wandb=True):
|
||||
"""Set up logging and tracking for an experiment.
|
||||
|
||||
Args:
|
||||
args (DotDict): Parsed command-line arguments
|
||||
rank (int): The rank of the current process
|
||||
append_model_path (str, optional): Path of an existing model, by default None
|
||||
log_wandb (bool, optional): Whether to log to wandb, by default True
|
||||
|
||||
Returns:
|
||||
str: folder, where all the run data is saved.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `dataset` or `model` is `None`.
|
||||
|
||||
Notes:
|
||||
This function sets up logger to stdout and file, as well as MLflow tracking for an experiment.
|
||||
For wandb logger, provide .wandb.apikey in the current directory.
|
||||
"""
|
||||
dataset, model, epochs = args.dataset.replace(os.sep, "_").lower(), args.model.replace(os.sep, "_"), args.epochs
|
||||
_base_folder = (
|
||||
os.path.join(args.results_folder, args.experiment_name, args.task.replace("-", ""), dataset)
|
||||
if args.out_dir is None
|
||||
else args.out_dir
|
||||
)
|
||||
run_folder = os.path.join(
|
||||
_base_folder,
|
||||
f"{args.run_name.replace(os.sep, '_')}_{model}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}",
|
||||
)
|
||||
assert dataset is not None and model is not None
|
||||
|
||||
if os.name == "nt":
|
||||
run_folder = run_folder.replace("@", "_").replace(" ", "_").replace(":", ".")
|
||||
|
||||
if append_model_path is not None:
|
||||
run_folder = os.path.dirname(append_model_path)
|
||||
if "run_name" not in args or args.run_name is None:
|
||||
args.run_name = run_folder.split(os.sep)[-1].split("_")[0]
|
||||
elif args.distributed:
|
||||
obj_list = [None]
|
||||
if rank == 0:
|
||||
obj_list[0] = run_folder
|
||||
dist.broadcast_object_list(obj_list, src=0)
|
||||
run_folder = obj_list[0]
|
||||
if rank == 0:
|
||||
os.makedirs(run_folder, exist_ok=True)
|
||||
dist.barrier()
|
||||
elif rank == 0:
|
||||
os.makedirs(run_folder, exist_ok=True)
|
||||
|
||||
assert "%" not in args.run_name, f"found '%' in run_name '{args.run_name}'. This messes with string formatting..."
|
||||
|
||||
if args.debug:
|
||||
args.log_level = "debug"
|
||||
|
||||
# logger to stdout & file
|
||||
log_name = args.task.replace("-", "")
|
||||
if args.task not in ["pre-train", "fine-tune", "fine-tune-head"]:
|
||||
log_name += f"_{dataset}_{datetime.now().strftime('%d.%m.%Y_%H:%M:%S')}"
|
||||
log_file = os.path.join(run_folder, f"{log_name}.log")
|
||||
logger.remove()
|
||||
logger.configure(extra=dict(run_name=args.run_name, rank=rank, world_size=args.world_size))
|
||||
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.info(f"Run folder '{run_folder}'")
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"{args.task.replace('-', '').capitalize()} {model} on {dataset} for {epochs} epochs")
|
||||
|
||||
global WANDB_AVAILABLE
|
||||
WANDB_AVAILABLE = WANDB_AVAILABLE and log_wandb and os.path.isfile(".wandb.apikey") and args.wandb
|
||||
if WANDB_AVAILABLE:
|
||||
with open(".wandb.apikey", "r") as f:
|
||||
__wandb_api_key = f.read().strip()
|
||||
wandb.login(key=__wandb_api_key)
|
||||
if args.wandb_run_id is not None:
|
||||
wandb_args = dict(project=args.experiment_name, resume="must", id=args.wandb_run_id)
|
||||
else:
|
||||
wandb_args = dict(
|
||||
project=args.experiment_name,
|
||||
name=args.run_name.replace("_", "-").replace(" ", "-"),
|
||||
config={"logfile": log_file, **dict(args)},
|
||||
job_type=args.task,
|
||||
tags=[model, dataset],
|
||||
resume="allow",
|
||||
id=args.wandb_run_id,
|
||||
)
|
||||
wandb.init(**wandb_args)
|
||||
args["wandb_run_id"] = wandb.run.id
|
||||
if rank == 0:
|
||||
logger.info(f"wandb run initialized with id {args['wandb_run_id']}.")
|
||||
else:
|
||||
logger.info(
|
||||
f"Not using wandb. (args.wandb={args.wandb}, .wandb.apikey exists={os.path.isfile('.wandb.apikey')},"
|
||||
f" function declaration log_wandb={log_wandb})"
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
|
||||
if args.debug:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
logger.warning("torch.autograd anomaly detection enabled. Will slow down model.")
|
||||
|
||||
return run_folder
|
||||
|
||||
|
||||
def setup_model_optim_sched_scaler(model, device, epochs, args, head_only=False):
|
||||
"""Set up model, optimizer, and scheduler with automatic mixed precision (amp) and distributed data parallel (DDP).
|
||||
|
||||
Args:
|
||||
model (nn.Module): the loaded model
|
||||
device (torch.device): the current device
|
||||
epochs (int): total number of epochs to learn for (for scheduler)
|
||||
args: further arguments
|
||||
head_only (bool, optional): train only the linear head. Default: False
|
||||
|
||||
Returns:
|
||||
tuple[nn.Module, optim.Optimizer, optim.lr_scheduler._LRScheduler, ScalerGradNormReturn]: model, optimizer, scheduler, scaler
|
||||
|
||||
"""
|
||||
model = model.to(device)
|
||||
|
||||
if head_only:
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
for param in model.head.parameters():
|
||||
param.requires_grad = True
|
||||
for name, param in model.named_parameters():
|
||||
if "head" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
params = model.head.parameters()
|
||||
else:
|
||||
params = model # model.named_parameters() use model itself for now and let timm do the work...
|
||||
|
||||
if args.opt == "lion" and not lion_available:
|
||||
args.opt = "fusedlamb"
|
||||
logger.warning("Falling back from lion to fusedlamb")
|
||||
if args.opt == "fusedlamb" and not apex_available:
|
||||
args.opt = "adamw"
|
||||
logger.warning("Falling back from fusedlamb to adamw")
|
||||
if args.opt == "lion":
|
||||
optimizer = Lion(params, lr=args["lr"], weight_decay=args["weight_decay"])
|
||||
else:
|
||||
optimizer = create_optimizer(args, params)
|
||||
|
||||
scaler = ScalerGradNormReturn() if args.amp else NoScaler()
|
||||
|
||||
# if args.model_ema:
|
||||
# # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
||||
# ema_model = ModelEma(model, decay=args.model_ema_decay, resume='')
|
||||
|
||||
if args.distributed:
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = DDP(model, device_ids=[device])
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# scheduler = optim.lr_scheduler.LambdaLR(optimizer,
|
||||
# lr_lambda=scheduler_function_factory(**args))
|
||||
sched_args = SchedulerArgs(args.sched, args.epochs, args.min_lr, args.warmup_lr, args.warmup_epochs)
|
||||
scheduler, _ = create_scheduler(sched_args, optimizer)
|
||||
|
||||
return model, optimizer, scheduler, scaler
|
||||
|
||||
|
||||
def setup_criteria_mixup(args, dataset=None, **criterion_kwargs):
|
||||
"""Set up further objects that are needed for training.
|
||||
|
||||
Args:
|
||||
args: arguments
|
||||
dataset (torch.data.Dataset, optional): dataset that implements images_per_class, for class weights (Default value = None)
|
||||
criterion_kwargs: further arguments for the criterion
|
||||
**criterion_kwargs:
|
||||
|
||||
Returns:
|
||||
tuple[nn.Module, nn.Module, Mixup]: criterion, val_criterion, mixup
|
||||
|
||||
"""
|
||||
weight = None
|
||||
if args.loss_weight != "none":
|
||||
if dataset is not None and hasattr(dataset, "images_per_class"):
|
||||
ipc = dataset.images_per_class
|
||||
total_ims = sum(ipc)
|
||||
|
||||
if args.loss_weight == "linear":
|
||||
weight = torch.tensor([total_ims / (ims * args.n_classes) for ims in ipc])
|
||||
elif args.loss_weight == "log":
|
||||
p_c = torch.tensor([ims / total_ims for ims in ipc])
|
||||
log_p_c = torch.where(p_c > 0, p_c.log(), torch.zeros_like(p_c))
|
||||
entr = -(p_c * log_p_c).sum()
|
||||
weight = -log_p_c / entr
|
||||
elif args.loss_weight == "sqrt":
|
||||
p_c = torch.tensor([ims / total_ims for ims in ipc])
|
||||
weight = 1 / (p_c.sqrt() * p_c.sqrt().sum())
|
||||
|
||||
else:
|
||||
logger.warning("Could not find images_per_class in dataset. Using uniform weights.")
|
||||
|
||||
if args.aug_cutmix or args.multi_label:
|
||||
# criterion = SoftTargetCrossEntropy()
|
||||
if args.ignore_index >= 0:
|
||||
if weight is None:
|
||||
weight = torch.ones(args.n_classes)
|
||||
weight[args.ignore_index] = 0
|
||||
if args.multi_label:
|
||||
if args.loss == "ce":
|
||||
criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
|
||||
val_criterion = nn.BCEWithLogitsLoss(pos_weight=weight, **criterion_kwargs)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Only BCEWithLogitsLoss (ce) is implemented for multi-label classification, not {args.loss}."
|
||||
)
|
||||
else:
|
||||
if args.loss == "ce":
|
||||
loss_cls = nn.CrossEntropyLoss
|
||||
elif args.loss == "baikal":
|
||||
loss_cls = BaikalLoss
|
||||
else:
|
||||
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
|
||||
criterion = loss_cls(weight=weight, **criterion_kwargs)
|
||||
val_criterion = loss_cls(weight=weight, **criterion_kwargs)
|
||||
else:
|
||||
if args.loss == "ce":
|
||||
loss_cls = nn.CrossEntropyLoss
|
||||
elif args.loss == "baikal":
|
||||
loss_cls = BaikalLoss
|
||||
else:
|
||||
raise NotImplementedError(f"'{args.loss}'-loss is not implemented.")
|
||||
criterion = loss_cls(
|
||||
label_smoothing=args.label_smoothing,
|
||||
ignore_index=args.ignore_index if weight is None else -100,
|
||||
weight=weight,
|
||||
**criterion_kwargs,
|
||||
) # LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
# criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
val_criterion = loss_cls(
|
||||
label_smoothing=args.label_smoothing,
|
||||
ignore_index=args.ignore_index if weight is None else -100,
|
||||
weight=weight,
|
||||
**criterion_kwargs,
|
||||
) # LabelSmoothingCrossEntropy(smoothing=0.)
|
||||
|
||||
mixup_kwargs = dict(
|
||||
mixup_alpha=args.aug_mixup_alpha,
|
||||
cutmix_alpha=args.aug_cutmix_alpha,
|
||||
label_smoothing=args.label_smoothing,
|
||||
num_classes=args.n_classes,
|
||||
)
|
||||
mixup = Mixup(**mixup_kwargs) if abs(args.aug_cutmix_alpha) + abs(args.aug_mixup_alpha) > 0.0 else None
|
||||
|
||||
return criterion, val_criterion, mixup
|
||||
|
||||
|
||||
def _train(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epochs,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
world_size,
|
||||
scheduler,
|
||||
args,
|
||||
val_loader,
|
||||
val_criterion,
|
||||
model_folder,
|
||||
scaler,
|
||||
do_metrics_calculation=True,
|
||||
start_epoch=0,
|
||||
show_tqdm=True,
|
||||
topk=(1, 5),
|
||||
acc_dict_key=None,
|
||||
train_dali_server=None,
|
||||
val_dali_server=None,
|
||||
):
|
||||
"""Train the model.
|
||||
|
||||
Args:
|
||||
model:
|
||||
train_loader:
|
||||
optimizer:
|
||||
rank:
|
||||
epochs:
|
||||
device:
|
||||
mixup:
|
||||
criterion:
|
||||
world_size:
|
||||
scheduler:
|
||||
args:
|
||||
val_loader:
|
||||
val_criterion:
|
||||
model_folder:
|
||||
scaler:
|
||||
do_metrics_calculation: (Default value = True)
|
||||
start_epoch: (Default value = 0)
|
||||
show_tqdm: (Default value = True)
|
||||
topk: (Default value = (1)
|
||||
5):
|
||||
acc_dict_key: (Default value = None)
|
||||
|
||||
Returns:
|
||||
dict: evaluation metrics at the end of training
|
||||
|
||||
"""
|
||||
if acc_dict_key is None:
|
||||
acc_dict_key = "acc{}"
|
||||
training_start = time()
|
||||
topk = tuple(k for k in topk if k <= args.n_classes)
|
||||
time_spend_training = time_spend_validating = 0
|
||||
current_best_acc = 0.0
|
||||
if rank == 0:
|
||||
logger.info(f"Dataloader has {len(train_loader)} batches")
|
||||
|
||||
logger.debug("Starting training with the following settings:")
|
||||
logger.debug(f"criterion: {criterion}")
|
||||
logger.debug(f"train_loader: {train_loader}, sampler: {train_loader.sampler}")
|
||||
logger.debug(f"dataset: {train_loader.dataset}")
|
||||
logger.debug(f"optimizer: {optimizer}")
|
||||
logger.debug(f"device: {device}")
|
||||
logger.debug(f"start epoch: {start_epoch}, epochs: {epochs}")
|
||||
logger.debug(f"scaler: {scaler}")
|
||||
logger.debug(f"max_grad_norm: {args.max_grad_norm}")
|
||||
# logger.debug(f"model_ema:\n{model_ema}\n{model_ema.decay}\n{model_ema.device}")
|
||||
if mixup:
|
||||
logger.debug(
|
||||
f"mixup: {mixup}; mixup_alpha: {mixup.mixup_alpha}, cutmix_alpha: {mixup.cutmix_alpha},"
|
||||
f" cutmix_minmax: {mixup.cutmix_minmax}, prob: {mixup.mix_prob}, switch_prob: {mixup.switch_prob},"
|
||||
f" label_smoothing: {mixup.label_smoothing}, num_classes: {mixup.num_classes}, correct_lam:"
|
||||
f" {mixup.correct_lam}, mixup_enabled: {mixup.mixup_enabled}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"mixup: {mixup}")
|
||||
|
||||
for epoch in range(start_epoch, epochs):
|
||||
with logger.contextualize(epoch=str(epoch + 1)):
|
||||
if args.distributed:
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
|
||||
set_ep_func = getattr(train_loader.dataset, "set_epoch", None)
|
||||
if callable(set_ep_func):
|
||||
train_loader.dataset.set_epoch(epoch)
|
||||
val_loader.dataset.set_epoch(epoch)
|
||||
|
||||
if train_dali_server:
|
||||
train_dali_server.start_thread()
|
||||
logger.info("started train dali server")
|
||||
|
||||
epoch_time, epoch_stats = _train_one_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epoch,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
scheduler,
|
||||
scaler,
|
||||
args,
|
||||
topk,
|
||||
"train/" + acc_dict_key,
|
||||
show_tqdm,
|
||||
)
|
||||
time_spend_training += epoch_time
|
||||
|
||||
if train_dali_server:
|
||||
train_dali_server.stop_thread()
|
||||
|
||||
val_time, val_stats = _evaluate(
|
||||
model,
|
||||
val_loader,
|
||||
epoch,
|
||||
rank,
|
||||
device,
|
||||
val_criterion,
|
||||
args,
|
||||
topk,
|
||||
"val/" + acc_dict_key,
|
||||
dali_server=val_dali_server,
|
||||
)
|
||||
time_spend_validating += val_time
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"total_time={time() - training_start}s")
|
||||
|
||||
if rank == 0:
|
||||
top1_val_acc = val_stats["val/" + acc_dict_key.format(1)]
|
||||
# print metadata for grafana
|
||||
metadata = {
|
||||
"epoch": epoch + 1,
|
||||
"progress": (epoch + 1) / args.epochs,
|
||||
**val_stats,
|
||||
**epoch_stats,
|
||||
}
|
||||
# filter out Nan and infinity values
|
||||
metadata = {k: v for k, v in metadata.items() if isfinite(v)}
|
||||
print(json.dumps(metadata), flush=True)
|
||||
logger.debug(f"printed metadata: {json.dumps(metadata)}")
|
||||
if WANDB_AVAILABLE:
|
||||
wandb.log(metadata, step=epoch + 1)
|
||||
|
||||
# saving current state
|
||||
if top1_val_acc > current_best_acc or (epoch + 1) % args.save_epochs == 0:
|
||||
reason = "top" if top1_val_acc > current_best_acc else "" # min(...) will be the top-1 accuracy
|
||||
if reason == "top":
|
||||
current_best_acc = top1_val_acc
|
||||
logger.info(f"found a new best model with acc: {current_best_acc}")
|
||||
kwargs = dict(
|
||||
model_state=model.state_dict(),
|
||||
stats=metadata,
|
||||
optimizer_state=optimizer.state_dict(),
|
||||
additional_reason=reason,
|
||||
regular_save=(epoch + 1) % args.save_epochs == 0,
|
||||
)
|
||||
if scheduler:
|
||||
kwargs["scheduler_state"] = scheduler.state_dict()
|
||||
save_model_state(
|
||||
model_folder, epoch + 1, args, **kwargs, max_interm_ep_states=args.keep_interm_states
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
end_time = time()
|
||||
logger.info(
|
||||
f"training done: total time={end_time - training_start}, "
|
||||
f"time spend training={time_spend_training}, "
|
||||
f"time spend validating={time_spend_validating}"
|
||||
)
|
||||
|
||||
results = {**val_stats, **epoch_stats, f"val/best_{acc_dict_key.format(1)}": current_best_acc}
|
||||
|
||||
if rank == 0:
|
||||
save_model_state(
|
||||
model_folder,
|
||||
epoch + 1,
|
||||
args,
|
||||
model_state=model.state_dict(),
|
||||
stats=results,
|
||||
additional_reason="final",
|
||||
regular_save=False,
|
||||
max_interm_ep_states=args.keep_interm_states,
|
||||
)
|
||||
|
||||
if do_metrics_calculation:
|
||||
# Calculate efficiency metrics
|
||||
inp = next(iter(train_loader))[0].to(device)
|
||||
metrics = calculate_metrics(
|
||||
args,
|
||||
model,
|
||||
rank=rank,
|
||||
input=inp,
|
||||
device=device,
|
||||
did_training=True,
|
||||
all_metrics=False,
|
||||
world_size=world_size,
|
||||
key_start="eval/",
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"Efficiency metrics: {json.dumps(metrics)}")
|
||||
return results
|
||||
|
||||
|
||||
def _mask_preds(preds, cls_masks, mask_val=-100):
|
||||
"""Mask the predictions by the mask.
|
||||
|
||||
Args:
|
||||
preds: model predictions
|
||||
cls_masks: class masks
|
||||
mask_val: (Default value = -100)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: masked predictions
|
||||
|
||||
"""
|
||||
if cls_masks is None:
|
||||
return preds
|
||||
return torch.where(cls_masks.bool(), mask_val, preds)
|
||||
|
||||
|
||||
def _evaluate(
|
||||
model,
|
||||
val_loader,
|
||||
epoch,
|
||||
rank,
|
||||
device,
|
||||
val_criterion,
|
||||
args,
|
||||
topk=(1, 5),
|
||||
acc_dict_key=None,
|
||||
dali_server=None,
|
||||
):
|
||||
"""Evaluate the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): the model to evaluate
|
||||
val_loader (DataLoader): loader for evaluation data
|
||||
epoch (int): the current epoch (for logger & tracking)
|
||||
rank (int): this processes rank (don't log n times)
|
||||
device (torch.device): device to evaluate on
|
||||
val_criterion (nn.Module): validation loss
|
||||
args (DotDict): further arguments
|
||||
topk (tuple[int], optional, optional): top-k accuracy, by default (1, 5)
|
||||
acc_dict_key (str, optional, optional): key for the accuracy dictionary, by default name of the performance metric. 'val_' will be prepended.
|
||||
|
||||
Returns:
|
||||
tuple[float, float, dict, dict]: validation time, validation loss, validation accuracies, additional information
|
||||
|
||||
"""
|
||||
if not acc_dict_key:
|
||||
acc_dict_key = "acc{}"
|
||||
|
||||
if dali_server:
|
||||
dali_server.start_thread()
|
||||
topk = tuple(k for k in topk if k <= args.n_classes)
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
val_accs = {acc_dict_key.format(k): 0.0 for k in topk}
|
||||
val_start = time()
|
||||
n_iters = 0
|
||||
iterator = (
|
||||
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch + 1}")
|
||||
if rank == 0 and args.tqdm
|
||||
else val_loader
|
||||
)
|
||||
class_counts = torch.zeros(1 if isinstance(topk, int) else len(topk), args.n_classes, 2)
|
||||
for batch_data in iterator:
|
||||
xs, ys = batch_data[:2]
|
||||
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
|
||||
|
||||
if args.debug:
|
||||
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
|
||||
|
||||
xs, ys = xs.to(device, non_blocking=True), ys.to(device, non_blocking=True)
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", enabled=args.eval_amp):
|
||||
preds = model(xs)
|
||||
preds = _mask_preds(preds, cls_masks)
|
||||
|
||||
if args.multi_label:
|
||||
# labels are float for BCELoss
|
||||
ys = ys.float()
|
||||
val_loss += val_criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys).item()
|
||||
class_counts += per_class_counts(preds, ys, args.n_classes, topk=topk, ignore_index=args.ignore_index)
|
||||
n_iters += 1
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
|
||||
if dali_server:
|
||||
dali_server.stop_thread()
|
||||
val_end = time()
|
||||
iterations = n_iters
|
||||
|
||||
if args.distributed:
|
||||
gather_tensor = torch.Tensor([val_loss]).to(device)
|
||||
dist.barrier()
|
||||
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
|
||||
gather_tensor = gather_tensor.tolist()
|
||||
val_loss = gather_tensor[0]
|
||||
class_counts = class_counts.to(device)
|
||||
dist.all_reduce(class_counts, op=dist.ReduceOp.SUM)
|
||||
class_counts = class_counts.cpu()
|
||||
|
||||
for i, k in enumerate(topk):
|
||||
key = acc_dict_key.format(k)
|
||||
mkey = key.replace("acc", "m-acc")
|
||||
val_accs[key] = class_counts[i].sum(dim=0)[0].item() / class_counts[i].sum(dim=0).sum(dim=-1).item()
|
||||
val_accs[mkey] = (class_counts[i, :, 0] / class_counts[i].sum(dim=-1)).mean().item()
|
||||
|
||||
val_accs["val/loss"] = val_loss
|
||||
|
||||
if rank == 0:
|
||||
log_s = f"val/time={val_end - val_start}s"
|
||||
for key, val in val_accs.items():
|
||||
log_s += f", {key}={val:.4f}"
|
||||
logger.info(log_s)
|
||||
|
||||
return val_end - val_start, val_accs
|
||||
|
||||
|
||||
def _train_one_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epoch,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
scheduler,
|
||||
scaler,
|
||||
args,
|
||||
topk=(1, 5),
|
||||
acc_dict_key=None,
|
||||
show_tqdm=True,
|
||||
):
|
||||
"""Train the model for one epoch.
|
||||
|
||||
Args:
|
||||
model:
|
||||
train_loader:
|
||||
optimizer:
|
||||
rank:
|
||||
epoch:
|
||||
device:
|
||||
mixup:
|
||||
criterion:
|
||||
scheduler:
|
||||
scaler:
|
||||
args:
|
||||
topk: (Default value = (1, 5)
|
||||
acc_dict_key: (Default value = None)
|
||||
show_tqdm: (Default value = True)
|
||||
|
||||
Returns:
|
||||
tuple[float, float, dict]: time spend in training, epoch loss, epoch accuracies
|
||||
|
||||
"""
|
||||
if not acc_dict_key:
|
||||
acc_dict_key = "acc{}"
|
||||
|
||||
model.train()
|
||||
iterator = (
|
||||
tqdm(train_loader, total=len(train_loader), desc=f"Training epoch {epoch + 1}")
|
||||
if rank == 0 and show_tqdm
|
||||
else train_loader
|
||||
)
|
||||
|
||||
if not args.amp:
|
||||
scaler = NoScaler()
|
||||
|
||||
epoch_loss = 0
|
||||
epoch_accs = {}
|
||||
epoch_start = time()
|
||||
grad_norms = []
|
||||
n_iters = 0
|
||||
if hasattr(train_loader.dataset, "epoch"):
|
||||
train_loader.dataset.epoch = epoch
|
||||
for i, batch_data in enumerate(iterator):
|
||||
xs, ys = batch_data[:2]
|
||||
cls_masks = batch_data[2].to(device, non_blocking=True) if len(batch_data) == 3 else None
|
||||
optimizer.zero_grad()
|
||||
n_iters += 1
|
||||
xs = xs.to(device, non_blocking=True)
|
||||
ys = ys.to(device, non_blocking=True)
|
||||
|
||||
if args.debug and i == 0:
|
||||
logger.debug(f"y_max = {ys.max()}, y_min = {ys.min()}, num_classes={args.n_classes}")
|
||||
|
||||
if mixup:
|
||||
if args.multi_label:
|
||||
xs, ys = mixup(xs, ys, cls_masks)
|
||||
else:
|
||||
xs, ys = mixup(xs, ys)
|
||||
|
||||
if args.debug and i == 0:
|
||||
logger.debug(f"input x: {type(xs)}; {xs.shape}, y: {type(ys)}; {ys.shape}")
|
||||
|
||||
with torch.amp.autocast("cuda", enabled=args.amp):
|
||||
preds = model(xs)
|
||||
preds = _mask_preds(preds, cls_masks)
|
||||
if args.multi_label:
|
||||
# labels are float for BCELoss
|
||||
ys = ys.float()
|
||||
loss = criterion(preds.transpose(1, -1), ys.transpose(1, -1) if len(ys.shape) > 1 else ys) + (
|
||||
model.get_internal_loss() if hasattr(model, "get_internal_loss") else model.module.get_internal_loss()
|
||||
)
|
||||
|
||||
if not isfinite(loss.item()):
|
||||
logger.error(f"Got loss value {loss.item()}. Stopping training.")
|
||||
logger.info(f"input has nan: {xs.isnan().any().item()}")
|
||||
logger.info(f"target has nan: {ys.isnan().any().item()}")
|
||||
logger.info(f"output has nan: {preds.isnan().any().item()}")
|
||||
for name, param in model.named_parameters():
|
||||
if param.isnan().any().item():
|
||||
logger.error(f"parameter {name} has a nan value")
|
||||
if len(grad_norms) > 0:
|
||||
grad_norms = torch.Tensor(grad_norms)
|
||||
logger.info(
|
||||
f"Gradient norms until now: min={grad_norms.min().item()}, 20th"
|
||||
f" %tile={torch.quantile(grad_norms, .2).item()}, mean={torch.mean(grad_norms)}, 80th"
|
||||
f" %tile={torch.quantile(grad_norms, .8).item()}, max={grad_norms.max()}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
iter_grad_norm = scaler(
|
||||
loss,
|
||||
optimizer,
|
||||
parameters=model.parameters(),
|
||||
clip_grad=args.max_grad_norm if args.max_grad_norm > 0.0 else None,
|
||||
).cpu()
|
||||
|
||||
if args.gather_stats_during_training and isfinite(iter_grad_norm):
|
||||
grad_norms.append(iter_grad_norm)
|
||||
|
||||
# if args.aug_cutmix:
|
||||
# ys = ys.argmax(dim=-1) # for accuracy with CutMix, just use the argmax for both
|
||||
#
|
||||
epoch_loss += loss.item()
|
||||
# accuracies = accuracy(preds, ys, topk=topk, dict_key=acc_dict_key, ignore_index=args.ignore_index)
|
||||
# for key in accuracies:
|
||||
# epoch_accs[key] += accuracies[key]
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
epoch_end = time()
|
||||
|
||||
iterations = n_iters
|
||||
# epoch_accs = {key: val / iterations for key, val in epoch_accs.items()}
|
||||
epoch_loss = epoch_loss / iterations
|
||||
grad_norm_avrg = -1
|
||||
inf_grads = iterations - len(grad_norms)
|
||||
if len(grad_norms) > 0 and args.gather_stats_during_training:
|
||||
grad_norm_max = max(grad_norms)
|
||||
grad_norms = torch.Tensor(grad_norms)
|
||||
grad_norm_20 = torch.quantile(grad_norms, 0.2).item()
|
||||
grad_norm_80 = torch.quantile(grad_norms, 0.8).item()
|
||||
grad_norm_avrg = torch.mean(grad_norms)
|
||||
|
||||
if args.distributed:
|
||||
# grad norm is already synchronized
|
||||
# gather_tensor = torch.Tensor([epoch_loss, *[epoch_accs[acc_dict_key.format(k)] for k in topk]]).to(device)
|
||||
gather_tensor = torch.Tensor([epoch_loss]).to(device)
|
||||
dist.barrier()
|
||||
dist.all_reduce(gather_tensor, op=dist.ReduceOp.AVG)
|
||||
# gather_tensor = (gather_tensor / world_size).tolist()
|
||||
epoch_loss = gather_tensor.item()
|
||||
# for i, k in enumerate(topk):
|
||||
# epoch_accs[acc_dict_key.format(k)] = gather_tensor[i + 1]
|
||||
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
epoch_accs["train/lr"] = lr
|
||||
epoch_accs["train/loss"] = epoch_loss
|
||||
|
||||
if rank == 0:
|
||||
if args.gather_stats_during_training:
|
||||
print_s = f"train/time={epoch_end - epoch_start}s"
|
||||
logger.info(print_s)
|
||||
if len(grad_norms) > 0:
|
||||
logger.info(
|
||||
f"grad norm avrg={grad_norm_avrg}, grad norm max={grad_norm_max}, "
|
||||
f"inf grad norm={inf_grads}, grad norm 20%={grad_norm_20}, grad norm 80%={grad_norm_80}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"inf grad norm={inf_grads}")
|
||||
logger.error("100% of update steps with infinite grad norms!")
|
||||
else:
|
||||
logger.info(f"train/time={epoch_end - epoch_start}s")
|
||||
|
||||
if scheduler:
|
||||
if isinstance(scheduler, optim.lr_scheduler.LambdaLR):
|
||||
scheduler.step()
|
||||
else:
|
||||
scheduler.step(epoch)
|
||||
|
||||
if args.gather_stats_during_training:
|
||||
return epoch_end - epoch_start, epoch_accs
|
||||
return epoch_end - epoch_start, {}
|
||||
710
AAAI Supplementary Material/Model Training Code/evaluate.py
Normal file
710
AAAI Supplementary Material/Model Training Code/evaluate.py
Normal file
@@ -0,0 +1,710 @@
|
||||
"""Module to evaluate trained models."""
|
||||
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from math import sqrt
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from timm.loss import LabelSmoothingCrossEntropy
|
||||
from timm.models.resnet import ResNet as TimmResNet
|
||||
from torch import distributed as dist
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from engine import (
|
||||
_evaluate,
|
||||
setup_criteria_mixup,
|
||||
setup_model_optim_sched_scaler,
|
||||
setup_tracking_and_logging,
|
||||
wandb_available,
|
||||
)
|
||||
from load_dataset import prepare_dataset
|
||||
from metrics import calculate_metrics
|
||||
from models import load_pretrained
|
||||
from utils import (
|
||||
RepeatedDataset,
|
||||
ddp_cleanup,
|
||||
ddp_setup,
|
||||
denormalize,
|
||||
get_cpu_name,
|
||||
grad_cam_reshape_transform,
|
||||
prep_kwargs,
|
||||
set_filter_warnings,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_metrics(model, dataset, **kwargs):
|
||||
"""Evaluate efficiency metrics for a given model.
|
||||
|
||||
Args:
|
||||
model (str): path to model state .tar
|
||||
dataset (str): name of the dataset to evaluate on
|
||||
**kwargs: further arguments
|
||||
|
||||
"""
|
||||
set_filter_warnings()
|
||||
model_path = model
|
||||
args = prep_kwargs(kwargs)
|
||||
if args.cuda:
|
||||
args.distributed, device, world_size, rank, _ = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
args.distributed = False
|
||||
device = torch.device("cpu")
|
||||
rank = 0
|
||||
args.compile_model = False
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
args.dataset = dataset
|
||||
args.run_name = old_args.run_name
|
||||
args.experiment_name = old_args.experiment_name
|
||||
args.wandb_run_id = old_args.wandb_run_id
|
||||
setup_tracking_and_logging(args, rank=rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None)
|
||||
|
||||
train_loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
old_args["eval_imsize"] = args.imsize
|
||||
args.model = model_name = old_args.model
|
||||
args.dataset = dataset
|
||||
args.epochs = 5
|
||||
|
||||
model, optim, _, scaler = setup_model_optim_sched_scaler(model, device, epochs=10, args=args, head_only=False)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
f"Evaluate metrics for model {model_name} on {dataset}. "
|
||||
f"It was {old_args.task.replace('-','')}d on {old_args.dataset} for {save_state['epoch']} "
|
||||
"epochs."
|
||||
)
|
||||
# logger.info(f"full set of arguments: {args}")
|
||||
logger.info(f"full set of training arguments: {old_args}")
|
||||
logger.info(f"full set of eval-metrics arguments: {args}")
|
||||
|
||||
logger.info(
|
||||
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
||||
)
|
||||
metrics = calculate_metrics(
|
||||
args, model, rank=rank, device=device, optim=optim, scaler=scaler, train_loader=train_loader, key_start="eval/"
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"Metrics: {metrics}")
|
||||
if wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb.log(metrics)
|
||||
|
||||
|
||||
def evaluate(model, dataset=None, val_dataset=None, **kwargs):
|
||||
"""Evaluate model accuracy.
|
||||
|
||||
Args:
|
||||
model (str): path to model state .tar
|
||||
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
**kwargs: further arguments
|
||||
Note:
|
||||
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
||||
|
||||
"""
|
||||
set_filter_warnings()
|
||||
model_path = model
|
||||
args = prep_kwargs(kwargs)
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
args.dataset = dataset
|
||||
args.val_dataset = val_dataset
|
||||
if args.cuda:
|
||||
args.distributed, device, world_size, rank, _ = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
args.distributed = False
|
||||
device = torch.device("cpu")
|
||||
world_size = 1
|
||||
rank = 0
|
||||
args.compile_model = False
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
args.dataset = dataset
|
||||
args.run_name = old_args.run_name
|
||||
args.experiment_name = old_args.experiment_name
|
||||
args.wandb_run_id = old_args.wandb_run_id
|
||||
run_folder = setup_tracking_and_logging(
|
||||
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
|
||||
)
|
||||
|
||||
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
|
||||
val_dataset, args, train=False
|
||||
)
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
model = model.to(device)
|
||||
args.model = model_name = old_args.model
|
||||
args.dataset = dataset
|
||||
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
f"Evaluate model {model_name} on {val_dataset}. "
|
||||
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
model = DDP(model)
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# log all devices
|
||||
logger.info(
|
||||
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
logger.info(f"full set of arguments: {args}")
|
||||
logger.info(f"full set of old arguments: {old_args}")
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
if rank == 0:
|
||||
logger.info("start evaluation")
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
if rank == 0:
|
||||
val_time, val_stats = _evaluate(
|
||||
model.to(device),
|
||||
val_loader,
|
||||
epoch=save_state["epoch"] - 1,
|
||||
rank=rank,
|
||||
device=device,
|
||||
val_criterion=val_criterion,
|
||||
args=args,
|
||||
dali_server=dali_server,
|
||||
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
|
||||
)
|
||||
log_s = f"Evaluation done in {val_time}s"
|
||||
for key, val in val_stats.items():
|
||||
log_s += f", {key}={val:.4f}"
|
||||
logger.info(log_s)
|
||||
if wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb.log(val_stats)
|
||||
else:
|
||||
_evaluate(
|
||||
model.to(device),
|
||||
val_loader,
|
||||
epoch=save_state["epoch"] - 1,
|
||||
rank=rank,
|
||||
device=device,
|
||||
val_criterion=val_criterion,
|
||||
args=args,
|
||||
dali_server=dali_server,
|
||||
acc_dict_key=f"eval_{val_dataset}/acc{{}}",
|
||||
)
|
||||
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
|
||||
|
||||
def evaluate_center_bias(model, dataset=None, val_dataset=None, **kwargs):
|
||||
"""Evaluate model accuracy in different nonants.
|
||||
|
||||
Args:
|
||||
model (str): path to model state .tar
|
||||
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
**kwargs: further arguments
|
||||
Note:
|
||||
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
||||
|
||||
"""
|
||||
set_filter_warnings()
|
||||
model_path = model
|
||||
args = prep_kwargs(kwargs)
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
if dataset is None:
|
||||
dataset = val_dataset
|
||||
assert dataset is not None, "Specify validation dataset (-valds) or dataset (-ds)."
|
||||
args.dataset = dataset
|
||||
args.val_dataset = val_dataset
|
||||
if args.cuda:
|
||||
args.distributed, device, world_size, rank, _ = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
args.distributed = False
|
||||
device = torch.device("cpu")
|
||||
world_size = 1
|
||||
rank = 0
|
||||
args.compile_model = False
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
args.dataset = dataset
|
||||
args.run_name = old_args.run_name
|
||||
args.experiment_name = old_args.experiment_name
|
||||
args.wandb_run_id = old_args.wandb_run_id
|
||||
run_folder = setup_tracking_and_logging(
|
||||
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
|
||||
)
|
||||
|
||||
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
|
||||
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
model = model.to(device)
|
||||
args.model = model_name = old_args.model
|
||||
args.dataset = dataset
|
||||
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
f"Evaluate model {model_name} on {val_dataset}. "
|
||||
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
model = DDP(model)
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# log all devices
|
||||
logger.info(
|
||||
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
logger.info(f"full set of arguments: {args}")
|
||||
logger.info(f"full set of old arguments: {old_args}")
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
if rank == 0:
|
||||
logger.info("start evaluation")
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
if rank == 0:
|
||||
nonant_accs = []
|
||||
for nonant in range(-1, 9):
|
||||
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
|
||||
val_loader.dataset.fg_in_nonant = nonant
|
||||
logger.info(f"Evaluate nonant {nonant} for 5 rounds.")
|
||||
round_accs = []
|
||||
for _ in range(5):
|
||||
val_time, val_stats = _evaluate(
|
||||
model.to(device),
|
||||
val_loader,
|
||||
epoch=save_state["epoch"] - 1,
|
||||
rank=rank,
|
||||
device=device,
|
||||
val_criterion=val_criterion,
|
||||
args=args,
|
||||
dali_server=dali_server,
|
||||
)
|
||||
round_accs.append(val_stats["acc1"])
|
||||
nonant_accs.append(sum(round_accs) / len(round_accs))
|
||||
log_s = f"Evaluation done in {val_time}s: "
|
||||
for nonant, val in enumerate(nonant_accs[1:]):
|
||||
log_s += f", nonant {nonant}={val}% acc ({val / nonant_accs[0]} rel acc)"
|
||||
center_bias_val = 1 - (
|
||||
min([nonant_accs[1], nonant_accs[3], nonant_accs[7], nonant_accs[9]])
|
||||
+ min([nonant_accs[2], nonant_accs[4], nonant_accs[6], nonant_accs[8]])
|
||||
) / (2 * nonant_accs[5])
|
||||
log_s += f", center_bias={center_bias_val:.4f}"
|
||||
logger.info(log_s)
|
||||
if wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb.log({f"eval_{args.val_dataset}/center_bias": center_bias_val})
|
||||
else:
|
||||
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
|
||||
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
|
||||
|
||||
def evaluate_size_bias(model, dataset=None, val_dataset=None, **kwargs):
|
||||
"""Evaluate model accuracy for differently scaled foregrounds.
|
||||
|
||||
Args:
|
||||
model (str): path to model state .tar
|
||||
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
**kwargs: further arguments
|
||||
Note:
|
||||
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
||||
|
||||
"""
|
||||
set_filter_warnings()
|
||||
model_path = model
|
||||
args = prep_kwargs(kwargs)
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
if dataset is None:
|
||||
dataset = val_dataset
|
||||
assert val_dataset is not None and dataset is not None
|
||||
args.dataset = dataset
|
||||
args.val_dataset = val_dataset
|
||||
if args.cuda:
|
||||
args.distributed, device, world_size, rank, _ = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
args.distributed = False
|
||||
device = torch.device("cpu")
|
||||
world_size = 1
|
||||
rank = 0
|
||||
args.compile_model = False
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
args.dataset = dataset
|
||||
args.run_name = old_args.run_name
|
||||
args.experiment_name = old_args.experiment_name
|
||||
args.wandb_run_id = old_args.wandb_run_id
|
||||
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path, log_wandb=False)
|
||||
|
||||
assert "fornet" in val_dataset.lower(), "Only ForNet supported for center bias evaluation."
|
||||
_, args.n_classes, args.ignore_index, args.multi_label, __ = prepare_dataset(val_dataset, args, train=False)
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
model = model.to(device)
|
||||
args.model = model_name = old_args.model
|
||||
args.dataset = dataset
|
||||
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
f"Evaluate model {model_name} on {val_dataset}. "
|
||||
f"It was pretrained on {old_args.dataset} for {save_state['epoch']} epochs."
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
model = DDP(model)
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# log all devices
|
||||
logger.info(
|
||||
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
logger.info(f"full set of arguments: {args}")
|
||||
logger.info(f"full set of old arguments: {old_args}")
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
val_criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
|
||||
if rank == 0:
|
||||
logger.info("start evaluation")
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
sizes = [0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 1.75, 2.0]
|
||||
if rank == 0:
|
||||
size_accs = []
|
||||
val_times = 0
|
||||
for size in sizes:
|
||||
val_loader, _, __, ___, dali_server = prepare_dataset(val_dataset, args, train=False)
|
||||
val_loader.dataset.size_fact = size
|
||||
val_loader.dataset.fg_scale_jitter = 0.0
|
||||
logger.info(f"Evaluate size factor {size} for 5 rounds.")
|
||||
round_accs = []
|
||||
for _ in range(5):
|
||||
val_time, val_stats = _evaluate(
|
||||
model.to(device),
|
||||
val_loader,
|
||||
epoch=save_state["epoch"] - 1,
|
||||
rank=rank,
|
||||
device=device,
|
||||
val_criterion=val_criterion,
|
||||
args=args,
|
||||
dali_server=dali_server,
|
||||
)
|
||||
round_accs.append(val_stats["acc1"])
|
||||
val_times += val_time
|
||||
size_accs.append(sum(round_accs) / len(round_accs))
|
||||
log_s = f"Evaluation done in {val_times}s: "
|
||||
for size, val in zip(sizes, size_accs):
|
||||
log_s += f", rel_size {size}={val}% acc ({val / size_accs[sizes.index(1.0)]} rel acc)"
|
||||
logger.info(log_s)
|
||||
else:
|
||||
raise NotImplementedError("Center bias evaluation not supported in distributed mode.")
|
||||
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
|
||||
|
||||
def evaluate_attributions(model, dataset=None, val_dataset=None, **kwargs):
|
||||
"""Evaluate model attributions using captum.
|
||||
|
||||
Args:
|
||||
model (str): path to model state .tar
|
||||
dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
val_dataset (str, optional): name of the dataset to evaluate on (Default value = None)
|
||||
**kwargs: further arguments
|
||||
Note:
|
||||
If `val_dataset` is not provided, the model will be evaluated on `dataset`.
|
||||
The `captum` package is required.
|
||||
|
||||
"""
|
||||
from captum.attr import IntegratedGradients
|
||||
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
|
||||
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
||||
|
||||
set_filter_warnings()
|
||||
model_path = model
|
||||
args = prep_kwargs(kwargs)
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
assert val_dataset is not None, "Please set dataset (-ds) or validation dataset (-valds)"
|
||||
args.dataset = val_dataset
|
||||
args.val_dataset = val_dataset
|
||||
if args.cuda:
|
||||
args.distributed, device, world_size, rank, _ = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
args.distributed = False
|
||||
device = torch.device("cpu")
|
||||
world_size = 1
|
||||
rank = 0
|
||||
args.compile_model = False
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
args.dataset = val_dataset
|
||||
args.run_name = old_args.run_name
|
||||
args.experiment_name = old_args.experiment_name
|
||||
args.wandb_run_id = old_args.wandb_run_id
|
||||
run_folder = setup_tracking_and_logging(
|
||||
args, rank, append_model_path=model_path, log_wandb=args.wandb_run_id is not None
|
||||
)
|
||||
|
||||
assert "fornet" in val_dataset.lower(), "Only ForNet supported for attribution evaluation."
|
||||
val_loader, args.n_classes, args.ignore_index, args.multi_label, dali_server = prepare_dataset(
|
||||
val_dataset, args, train=False
|
||||
)
|
||||
val_loader.dataset.return_fg_masks = True
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
model = model.to(device)
|
||||
args.model = model_name = old_args.model
|
||||
args.dataset = dataset
|
||||
# assert (
|
||||
# args.imsize == old_args.imsize
|
||||
# ), f"Model was trained on {old_args.imsize}x{old_args.imsize} images. Not {args.imsize}x{args.imsize}."
|
||||
epoch = save_state["epoch"]
|
||||
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
f"Evaluate attributions of model {model_name} on {dataset}. "
|
||||
f"It was pretrained on {old_args.dataset} for {epoch} epochs."
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
model = DDP(model)
|
||||
|
||||
if args.compile_model:
|
||||
model = torch.compile(model)
|
||||
|
||||
# log all devices
|
||||
logger.info(
|
||||
f"evaluating on {device} -> {torch.cuda.get_device_name(device) if device.type != 'cpu' else get_cpu_name()}"
|
||||
)
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
if args.new_log:
|
||||
logger.info(f"full set of arguments: {args}")
|
||||
logger.info(f"full set of old arguments: {old_args}")
|
||||
else:
|
||||
logger.info(f"full set of attribution evaluation arguments: {old_args}")
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
iterator = (
|
||||
tqdm(val_loader, total=len(val_loader), desc=f"Validating epoch {epoch}")
|
||||
if rank == 0 and args.tqdm
|
||||
else val_loader
|
||||
)
|
||||
|
||||
if args.debug:
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
eval_attn_importance = False
|
||||
if isinstance(model, TimmResNet):
|
||||
reshape_transform = None
|
||||
target_layers = [model.layer4[-1]]
|
||||
elif model_name.lower().startswith("vit-"):
|
||||
reshape_transform = grad_cam_reshape_transform
|
||||
target_layers = [model.blocks[-1].norm1]
|
||||
eval_attn_importance = True
|
||||
from architectures.vit import _MatrixSaveAttn
|
||||
|
||||
model.blocks[-1].attn = _MatrixSaveAttn.cast(model.blocks[-1].attn)
|
||||
elif model_name.lower().startswith("swin_"):
|
||||
reshape_transform = grad_cam_reshape_transform
|
||||
target_layers = [model.layers[-1].blocks[-1].norm1]
|
||||
else:
|
||||
raise NotImplementedError(f"Model {model_name} not supported for attribution evaluation.")
|
||||
|
||||
model.eval()
|
||||
val_start = time()
|
||||
rel_ig_weights = 0.0
|
||||
rel_attn_weights = 0.0
|
||||
rel_cam_weights = {"GradCAM": 0.0, "GradCAM++": 0.0}
|
||||
if rank == 0:
|
||||
logger.info("Start attribution evaluation")
|
||||
if dali_server:
|
||||
dali_server.start_thread()
|
||||
for batch_data in iterator:
|
||||
xs, ys, fg_masks = batch_data
|
||||
|
||||
xs, ys, fg_masks = (
|
||||
xs.to(device, non_blocking=True),
|
||||
ys.to(device, non_blocking=True),
|
||||
fg_masks.float().to(device, non_blocking=True),
|
||||
)
|
||||
|
||||
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
|
||||
model.zero_grad()
|
||||
ig = IntegratedGradients(model)
|
||||
# we use attention temperature of 10 to make differences more apparent after exp
|
||||
attr_ig = (
|
||||
ig.attribute(xs, target=ys, baselines=0.0, internal_batch_size=args.batch_size * 4).sum(dim=1) * 10
|
||||
) # B x W x H
|
||||
attr_probs = attr_ig.view(xs.shape[0], -1).softmax(dim=-1).view(xs.shape[0], *xs.shape[2:])
|
||||
fg_masks = fg_masks.view(attr_probs.shape)
|
||||
|
||||
fg_attrs = (attr_probs * fg_masks).sum(dim=(-1, -2))
|
||||
rel_attr_weight = fg_attrs / fg_masks.mean(dim=(-1, -2))
|
||||
rel_attr_weight = torch.where(fg_masks.mean(dim=(-1, -2)) > 0, rel_attr_weight, 1.0)
|
||||
if rel_attr_weight.isnan().any():
|
||||
logger.error(f"NaNs in rel_attr_weight: {rel_attr_weight}, fg_mask_weights: {fg_masks.mean(dim=(-1, -2))}")
|
||||
break
|
||||
rel_ig_weights += rel_attr_weight.mean().item()
|
||||
|
||||
cam_targets = [ClassifierOutputTarget(int(trgt)) for trgt in ys.tolist()]
|
||||
for method, name in zip([GradCAM, GradCAMPlusPlus], ["GradCAM", "GradCAM++"]):
|
||||
with method(model=model, target_layers=target_layers, reshape_transform=reshape_transform) as cam, (
|
||||
torch.amp.autocast("cuda") if args.eval_amp else nullcontext()
|
||||
):
|
||||
cam_attr = cam(input_tensor=xs, targets=cam_targets)
|
||||
|
||||
cam_attr = torch.from_numpy(cam_attr).to(device)
|
||||
rel_cam_attr = (cam_attr * fg_masks).sum(dim=(-1, -2)) / cam_attr.sum(dim=(-1, -2))
|
||||
cam_attr_weight = rel_cam_attr / fg_masks.mean(dim=(-1, -2))
|
||||
cam_attr_weight = torch.where(
|
||||
(fg_masks.mean(dim=(-1, -2)) > 0) & (cam_attr.sum(dim=(-1, -2)) > 0), cam_attr_weight, 1.0
|
||||
)
|
||||
rel_cam_weights[name] += cam_attr_weight.mean().item()
|
||||
if cam_attr_weight.isnan().any():
|
||||
logger.error(
|
||||
f"NaNs in cam_attr_weight ({name}): {cam_attr_weight}, fg_mask_weights:"
|
||||
f" {fg_masks.mean(dim=(-1, -2))}"
|
||||
)
|
||||
break
|
||||
|
||||
if eval_attn_importance:
|
||||
with torch.amp.autocast("cuda") if args.eval_amp else nullcontext():
|
||||
pred = model(xs) # noqa: F841
|
||||
last_attn_mat = model.blocks[-1].attn.attn_mat
|
||||
cls_tkn_attn = last_attn_mat[:, :, 0, 1:].mean(dim=1).squeeze(dim=1) # B x H x 1(CLS Token) X N -> B x N
|
||||
B, N = cls_tkn_attn.shape
|
||||
att_HW = int(sqrt(N))
|
||||
cls_tkn_attn = cls_tkn_attn.view(B, 1, att_HW, att_HW)
|
||||
attn_attr = F.interpolate(
|
||||
cls_tkn_attn, size=(xs.shape[-2], xs.shape[-1]), mode="bilinear", align_corners=False
|
||||
).view(B, xs.shape[-2], xs.shape[-1])
|
||||
rel_attn_attr = (attn_attr * fg_masks).sum(dim=(-1, -2)) / attn_attr.sum(dim=(-1, -2))
|
||||
attn_attr_weight = rel_attn_attr / fg_masks.mean(dim=(-1, -2))
|
||||
attn_attr_weight = torch.where(
|
||||
(fg_masks.mean(dim=(-1, -2)) > 0) & (attn_attr.sum(dim=(-1, -2)) > 0), attn_attr_weight, 1.0
|
||||
)
|
||||
rel_attn_weights += attn_attr_weight.mean().item()
|
||||
|
||||
if args.debug:
|
||||
logger.debug(f"Attribution scores: IG: {rel_attr_weight[:5]}, GradCAM(++): {cam_attr_weight[:5]}")
|
||||
num_subplots = 5 if eval_attn_importance else 4
|
||||
fig, axs = plt.subplots(num_subplots, 4)
|
||||
for plt_i in range(4):
|
||||
axs[0][plt_i].imshow(denormalize(xs[plt_i]).permute(1, 2, 0).cpu().numpy())
|
||||
axs[1][plt_i].imshow(fg_masks[plt_i].cpu().numpy())
|
||||
axs[2][plt_i].imshow(attr_probs[plt_i].cpu().numpy())
|
||||
axs[3][plt_i].imshow(cam_attr[plt_i].cpu().numpy())
|
||||
if eval_attn_importance:
|
||||
axs[4][plt_i].imshow(attn_attr[plt_i].cpu().numpy())
|
||||
plt.show()
|
||||
|
||||
iterator_desc = (
|
||||
f"IG weights: {rel_ig_weights / (iterator.n + 1):.4f}, GradCAM weights:"
|
||||
f" {rel_cam_weights['GradCAM'] / (iterator.n + 1):.4f}, GradCAM++ weights:"
|
||||
f" {rel_cam_weights['GradCAM++'] / (iterator.n + 1):.4f}"
|
||||
)
|
||||
if eval_attn_importance:
|
||||
iterator_desc += f", Attn weights: {rel_attn_weights / (iterator.n + 1):.4f}"
|
||||
|
||||
iterator.set_description(iterator_desc)
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
|
||||
val_end = time()
|
||||
rel_ig_weights /= len(iterator)
|
||||
rel_grad_cam = rel_cam_weights["GradCAM"] / len(iterator)
|
||||
rel_grad_cam_pp = rel_cam_weights["GradCAM++"] / len(iterator)
|
||||
rel_attn_weights /= len(iterator)
|
||||
|
||||
if dali_server:
|
||||
dali_server.stop_thread()
|
||||
|
||||
if args.distributed:
|
||||
gather_tensor = torch.Tensor([rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights]).to(device)
|
||||
dist.barrier()
|
||||
dist.all_reduce(gather_tensor)
|
||||
gather_tensor = (gather_tensor / world_size).tolist()
|
||||
rel_ig_weights, rel_grad_cam, rel_grad_cam_pp, rel_attn_weights = gather_tensor
|
||||
|
||||
if rank == 0:
|
||||
output_text = (
|
||||
f"epoch {epoch}: eval_{args.val_dataset}/rel_ig_weights={rel_ig_weights},"
|
||||
f" eval_{args.val_dataset}/rel_grad_cam={rel_grad_cam},"
|
||||
f" eval_{args.val_dataset}/rel_grad_cam_pp={rel_grad_cam_pp}"
|
||||
)
|
||||
if eval_attn_importance:
|
||||
output_text += f", eval_{args.val_dataset}/rel_attn_weights={rel_attn_weights}"
|
||||
output_text += f", eval_{args.val_dataset}/attribution_eval_time={val_end - val_start}s"
|
||||
logger.info(output_text)
|
||||
if wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb_data = {
|
||||
f"eval_{args.val_dataset}/importance_ig": rel_ig_weights,
|
||||
f"eval_{args.val_dataset}/importance_grad_cam": rel_grad_cam,
|
||||
f"eval_{args.val_dataset}/importance_grad_cam_pp": rel_grad_cam_pp,
|
||||
}
|
||||
if eval_attn_importance:
|
||||
wandb_data[f"eval_{args.val_dataset}/importance_attn"] = rel_attn_weights
|
||||
wandb.log(wandb_data)
|
||||
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
@@ -0,0 +1,211 @@
|
||||
import argparse
|
||||
import os
|
||||
from functools import partial
|
||||
from math import log
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, RandomCrop, Resize, ToTensor
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
try:
|
||||
from models import load_pretrained
|
||||
except ModuleNotFoundError:
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from models import load_pretrained
|
||||
|
||||
from utils import prep_kwargs
|
||||
|
||||
|
||||
def score_5( # noqa: D103
|
||||
idx,
|
||||
bg_probs,
|
||||
mean_probs,
|
||||
fg_ratio,
|
||||
max_idx,
|
||||
fg_ratio_max=0.9,
|
||||
fg_ratio_min=0.002,
|
||||
fg_ratio_exp=0.4, # learned: 0.487
|
||||
idx_exp=0.01, # learned: 0.043
|
||||
bg_probs_exp=0.2, # learned: 0.24
|
||||
opt_fg_ratio=0.1, # learned: 0.1
|
||||
mean_probs_exp=0.2, # learned: 0.2
|
||||
fg_ratio_penalty=1, # learned: -0.446 ???
|
||||
):
|
||||
return (
|
||||
log(mean_probs) * mean_probs_exp
|
||||
+ log(1 - bg_probs) * bg_probs_exp
|
||||
+ log(1 - abs(fg_ratio - opt_fg_ratio)) * fg_ratio_exp
|
||||
+ log(1 - idx / (max_idx + 1)) * idx_exp
|
||||
+ (fg_ratio_min < fg_ratio < fg_ratio_max) * fg_ratio_penalty
|
||||
) # TODO: KEEP IT LIKE IT IS
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Inspect image versions")
|
||||
parser.add_argument("-f", "--base_folder", type=str, required=True, help="Base folder to inspect")
|
||||
parser.add_argument(
|
||||
"-batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for model inspection. Will be 1 background and batch_size - 1 foregrounds",
|
||||
)
|
||||
parser.add_argument("-imsize", type=int, default=224, help="Image size")
|
||||
parser.add_argument(
|
||||
"-score_f_weights", choices=["manual", "automatic"], default="manual", help="Score function hyperparameters"
|
||||
)
|
||||
parser.add_argument("-auto_fg_pen_val", type=float, default=-0.446, help="Automatic foreground penalty value")
|
||||
parser.add_argument("-d", "--dataset", choices=["tinyimagenet", "imagenet"], default="tinyimagenet", help="Dataset")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
score_f = (
|
||||
score_5
|
||||
if args.score_f_weights == "manual"
|
||||
else partial(
|
||||
score_5, **dict(fg_ratio_exp=0.487, idx_exp=0.043, bg_probs_exp=0.24, fg_ratio_penalty=args.auto_fg_pen_val)
|
||||
)
|
||||
)
|
||||
|
||||
bg_folder = os.path.join(args.base_folder, "backgrounds")
|
||||
fg_folder = os.path.join(args.base_folder, "foregrounds")
|
||||
|
||||
classes = os.listdir(fg_folder)
|
||||
classes = sorted(classes, key=lambda x: int(x[1:]))
|
||||
assert len(classes) in [200, 1_000], f"Expected 200 or 1_000 classes, got {len(classes)}"
|
||||
|
||||
total_images = set()
|
||||
for in_cls in classes:
|
||||
cls_images = {
|
||||
os.path.join(in_cls, "_".join(img.split(".")[0].split("_")[:-1]))
|
||||
for img in os.listdir(os.path.join(fg_folder, in_cls))
|
||||
if img.split(".")[0].split("_")[-1].startswith("v")
|
||||
}
|
||||
total_images.update(cls_images)
|
||||
total_images = list(total_images)
|
||||
|
||||
# base_folder = os.path.join(*(os.path.dirname(__file__).split("/")[:-1]))
|
||||
|
||||
# in_cls to print name/lemma
|
||||
with open(os.path.join("data", "misc_dataset_files", "tinyimagenet_synset_names.txt"), "r") as f:
|
||||
in_cls_to_name = {line.split(":")[0].strip(): line.split(":")[1].strip() for line in f.readlines() if len(line) > 2}
|
||||
|
||||
if args.dataset == "tinyimagenet":
|
||||
inspection_model_paths = [] # PATHS TO MODEL WEIGHTS (.pt) PRETRAINED ON TinyImageNet
|
||||
elif args.dataset == "imagenet":
|
||||
inspection_model_paths = [] # PATHS TO MODEL WEIGHTS (.pt) PRETRAINED ON ImageNet
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {args.dataset}")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
inspection_models = [
|
||||
load_pretrained(path, prep_kwargs({}), new_dataset_params=False)[0].to(device) for path in inspection_model_paths
|
||||
]
|
||||
img_transform = Compose([Resize((args.imsize, args.imsize)), RandomCrop(args.imsize), ToTensor()])
|
||||
|
||||
total_versions = []
|
||||
|
||||
for img_name in tqdm(total_images, desc="Image version computation"):
|
||||
in_cls, img_name = img_name.split("/")
|
||||
versions = set()
|
||||
for img in os.listdir(os.path.join(fg_folder, in_cls)):
|
||||
if "_".join(img.split("_")[: len(img_name.split("_"))]) == img_name:
|
||||
versions.add(img)
|
||||
if len(versions) == 1:
|
||||
version = list(versions)[0]
|
||||
if version.split(".")[0].split("_")[-1].startswith("v"):
|
||||
tqdm.write(f"renaming single version image {version} to {img_name}.WEBP")
|
||||
os.rename(os.path.join(fg_folder, in_cls, version), os.path.join(fg_folder, in_cls, f"{img_name}.WEBP"))
|
||||
os.rename(
|
||||
os.path.join(bg_folder, in_cls, version.replace(".WEBP", ".JPEG")),
|
||||
os.path.join(bg_folder, in_cls, f"{img_name}.JPEG"),
|
||||
)
|
||||
continue
|
||||
elif len(versions) == 0:
|
||||
tqdm.write(f"Image {img_name} has no versions")
|
||||
continue
|
||||
versions = sorted(list(versions))
|
||||
assert all(
|
||||
[version.split(".")[0].split("_")[-1].startswith("v") for version in versions]
|
||||
), f"Weird Versions: {versions} for image {img_name}"
|
||||
assert len(versions) <= 3, f"Too many versions for image {img_name}: {versions}"
|
||||
|
||||
version_scores = []
|
||||
for v_idx, version in enumerate(versions):
|
||||
img = Image.open(os.path.join(fg_folder, in_cls, version))
|
||||
bg_img = Image.open(os.path.join(bg_folder, in_cls, f"{version.split('.')[0]}.JPEG"))
|
||||
img_mask = np.array(img.convert("RGBA").split()[-1])
|
||||
|
||||
fg_ratio = np.sum(img_mask) / (255 * bg_img.size[0] * bg_img.size[1])
|
||||
|
||||
fg_size = img.size
|
||||
monochrome_backgrounds = [
|
||||
Image.new(
|
||||
"RGB",
|
||||
(max(args.imsize, fg_size[0]), max(args.imsize, fg_size[1])),
|
||||
(255 * i // (args.batch_size - 2), 255 * i // (args.batch_size - 2), 255 * i // (args.batch_size - 2)),
|
||||
)
|
||||
for i in range(args.batch_size - 1)
|
||||
]
|
||||
pasting_error = False
|
||||
for mc_bg in monochrome_backgrounds:
|
||||
try:
|
||||
mc_bg.paste(img, ((args.imsize - fg_size[0]) // 2, (args.imsize - fg_size[1]) // 2), img)
|
||||
except ValueError as e:
|
||||
tqdm.write(f"Image {img_name} could not be pasted into background: {e}")
|
||||
pasting_error = True
|
||||
break
|
||||
|
||||
inp_batch = torch.stack(
|
||||
[img_transform(bg_img)] + [img_transform(mc_bg) for mc_bg in monochrome_backgrounds], dim=0
|
||||
).to(device)
|
||||
|
||||
cls_idx = classes.index(in_cls)
|
||||
bg_probs = []
|
||||
mean_probs = []
|
||||
for model in inspection_models:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
out_probs = model(inp_batch).softmax(dim=-1)[:, cls_idx].cpu().numpy()
|
||||
bg_probs.append(out_probs[0])
|
||||
mean_probs.append(np.mean(out_probs[1:]))
|
||||
|
||||
# average the lists
|
||||
bg_probs = np.mean(bg_probs)
|
||||
mean_probs = np.mean(mean_probs)
|
||||
|
||||
version_score = (
|
||||
score_f(
|
||||
idx=v_idx,
|
||||
bg_probs=float(bg_probs),
|
||||
mean_probs=float(mean_probs),
|
||||
fg_ratio=float(fg_ratio),
|
||||
max_idx=len(versions) - 1,
|
||||
)
|
||||
if not pasting_error
|
||||
else -100
|
||||
)
|
||||
version_scores.append(version_score)
|
||||
|
||||
assert len(versions) == len(version_scores), f"Expected {len(versions)} scores, got {len(version_scores)}"
|
||||
|
||||
if max(version_scores) > min(version_scores):
|
||||
# find best version
|
||||
best_version_idx = int(np.argmax(version_scores))
|
||||
best_version = versions[best_version_idx]
|
||||
|
||||
# delete all other versions
|
||||
for version in versions:
|
||||
if version != best_version:
|
||||
os.remove(os.path.join(fg_folder, in_cls, version))
|
||||
os.remove(os.path.join(bg_folder, in_cls, f"{version.split('.')[0]}.JPEG"))
|
||||
# remove version tag in name
|
||||
new_version_name = "_".join(best_version.split("_")[:-1]) + "." + best_version.split(".")[-1]
|
||||
os.rename(os.path.join(fg_folder, in_cls, best_version), os.path.join(fg_folder, in_cls, new_version_name))
|
||||
os.rename(
|
||||
os.path.join(bg_folder, in_cls, f"{best_version.split('.')[0]}.JPEG"),
|
||||
os.path.join(bg_folder, in_cls, f"{new_version_name.split('.')[0]}.JPEG"),
|
||||
)
|
||||
else:
|
||||
tqdm.write(f"All versions have the same score for image {img_name}")
|
||||
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
srun -K \
|
||||
--container-image=PATH/TO/SLURM/IMAGE \
|
||||
--container-workdir="$(pwd)" \
|
||||
--container-mounts=/ALL/IMPORTANT/MOUNTS,"$(pwd)":"$(pwd)" \
|
||||
--partition=RTXA6000,RTX3090,A100-40GB,A100-80GB,H100,H200 \
|
||||
--job-name="python" \
|
||||
--nodes=1 \
|
||||
--gpus=1 \
|
||||
--ntasks=1 \
|
||||
--cpus-per-task=24 \
|
||||
--mem=64G \
|
||||
--time=1-0 \
|
||||
--export="NLTK_DATA=/PATH/TO/NLTK_DATA" \
|
||||
python3 "$@"
|
||||
346
AAAI Supplementary Material/Model Training Code/load_dataset.py
Normal file
346
AAAI Supplementary Material/Model Training Code/load_dataset.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Module to load the datasets, using torch and datadings."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import torchvision.transforms as tv_transforms
|
||||
from datadings.reader import MsgpackReader
|
||||
from timm.data import create_transform
|
||||
from torch.utils.data import DataLoader, DistributedSampler, WeightedRandomSampler
|
||||
from torchvision.datasets import (
|
||||
CIFAR10,
|
||||
CIFAR100,
|
||||
FGVCAircraft,
|
||||
Flowers102,
|
||||
Food101,
|
||||
ImageFolder,
|
||||
OxfordIIITPet,
|
||||
StanfordCars,
|
||||
)
|
||||
|
||||
from data.counter_animal import CounterAnimal
|
||||
from data.data_utils import (
|
||||
DDDecodeDataset,
|
||||
ToOneHotSequence,
|
||||
collate_imnet,
|
||||
collate_listops,
|
||||
get_hf_transform,
|
||||
minimal_augment,
|
||||
segment_augment,
|
||||
three_augment,
|
||||
)
|
||||
from data.fornet import ForNet
|
||||
from data.samplers import RASampler
|
||||
from paths_config import ds_path
|
||||
|
||||
|
||||
def prepare_dataset(dataset_name, args, transform=None, train=True, rank=None):
|
||||
"""Load a dataset from disk, different formats are used for different datasets.
|
||||
|
||||
Supported datasets: CIFAR10, ImageNet, ImageNet21k
|
||||
|
||||
Args:
|
||||
dataset_name (str): name of the dataset
|
||||
args: further arguments
|
||||
transform (list[Module] | str, optional): transformations to use on the data; the list gets composed, or give args.augment_strategy (Default value = None)
|
||||
train (bool, optional): use the training split (or test/validation split) (Default value = True)
|
||||
rank (int, optional): global rank of this process in distributed training (Default value = None)
|
||||
|
||||
Returns:
|
||||
DataLoader: data loader for the dataset
|
||||
int: number of classes in the dataset
|
||||
int: ignore index for the dataset
|
||||
bool: whether the dataset is multi-label
|
||||
|
||||
"""
|
||||
compose = tv_transforms.Compose
|
||||
dali_server = None
|
||||
if transform is None:
|
||||
if args.augment_engine == "torchvision":
|
||||
if args.augment_strategy == "3-augment":
|
||||
transform = three_augment(args, as_list=False, test=not train)
|
||||
elif args.augment_strategy == "differentiable-transform":
|
||||
from data.distilled_dataset import differentiable_augment
|
||||
|
||||
transform = differentiable_augment(args, as_list=False, test=not train)
|
||||
elif args.augment_strategy == "none":
|
||||
transform = []
|
||||
elif args.augment_strategy == "lm_one_hot":
|
||||
transform = [
|
||||
tv_transforms.Grayscale(num_output_channels=1),
|
||||
tv_transforms.ToTensor(),
|
||||
ToOneHotSequence(),
|
||||
]
|
||||
elif args.augment_strategy == "segment-augment":
|
||||
transform = segment_augment(args, test=not train)
|
||||
elif args.augment_strategy == "minimal":
|
||||
transform = minimal_augment(args, test=not train)
|
||||
elif args.augment_strategy == "deit":
|
||||
if train:
|
||||
transform = create_transform(
|
||||
input_size=args.imsize,
|
||||
is_training=True,
|
||||
color_jitter=args.aug_color_jitter_factor,
|
||||
auto_augment=args.auto_augment_strategy,
|
||||
interpolation="bicubic",
|
||||
re_prob=args.aug_random_erase_prob,
|
||||
re_mode=args.aug_random_erase_mode,
|
||||
re_count=args.aug_random_erase_count,
|
||||
)
|
||||
else:
|
||||
transform = three_augment(args, test=True) # only do resize, centercrop, and normalize
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
||||
)
|
||||
elif args.augment_engine == "albumentations":
|
||||
from data import album_transf as ATf
|
||||
|
||||
compose = ATf.AlbumTorchCompose
|
||||
|
||||
if args.augment_strategy == "3-augment":
|
||||
transform = ATf.three_augment(args, as_list=False, test=not train)
|
||||
elif args.augment_strategy == "minimal":
|
||||
transform = ATf.minimal_augment(args, test=not train)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
||||
)
|
||||
elif args.augment_engine == "dali":
|
||||
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
|
||||
|
||||
from data import dali_transf as DTf
|
||||
|
||||
dev_id = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if args.augment_strategy == "3-augment":
|
||||
pipe = DTf.three_augment(
|
||||
args,
|
||||
test=not train,
|
||||
batch_size=args.batch_size,
|
||||
num_threads=args.num_workers,
|
||||
device_id=dev_id,
|
||||
)
|
||||
elif args.augment_strategy == "minimal":
|
||||
pipe = DTf.minimal_augment(
|
||||
args,
|
||||
test=not train,
|
||||
batch_size=args.batch_size,
|
||||
num_threads=args.num_workers,
|
||||
device_id=dev_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Augmentation strategy {args.augment_strategy} is not implemented for {args.augment_engine} (yet)."
|
||||
)
|
||||
|
||||
dali_server = dali_proxy.DALIServer(pipe)
|
||||
transform = dali_server.proxy
|
||||
|
||||
dataset_name_case_sensitive = dataset_name # keep the original name for AnimalNet folder
|
||||
dataset_name = dataset_name.lower()
|
||||
ignore_index = -100
|
||||
multi_label = False
|
||||
|
||||
if isinstance(transform, list):
|
||||
transform = compose(transform)
|
||||
|
||||
if dataset_name == "cifar10":
|
||||
dataset = CIFAR10(root=ds_path("cifar"), train=train, download=False, transform=transform)
|
||||
n_classes, collate = 10, None
|
||||
|
||||
elif dataset_name == "stanford-cars":
|
||||
dataset = StanfordCars(
|
||||
root=ds_path("stanford_cars"),
|
||||
split="train" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 196, None
|
||||
|
||||
elif dataset_name == "oxford-pet":
|
||||
dataset = OxfordIIITPet(
|
||||
root=ds_path("oxford_pet"),
|
||||
split="trainval" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 37, None
|
||||
|
||||
elif dataset_name == "flowers102":
|
||||
dataset = Flowers102(
|
||||
root=ds_path("flowers102"),
|
||||
split="train" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 102, None
|
||||
|
||||
elif dataset_name == "food-101":
|
||||
dataset = Food101(
|
||||
root=ds_path("food101"),
|
||||
split="train" if train else "test",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 101, None
|
||||
|
||||
elif dataset_name == "fgvc-aircraft":
|
||||
dataset = FGVCAircraft(
|
||||
root=ds_path("aircraft"),
|
||||
split="train" if train else "test",
|
||||
annotation_level="variant",
|
||||
download=False,
|
||||
transform=transform,
|
||||
)
|
||||
n_classes, collate = 100, None
|
||||
|
||||
elif dataset_name == "imagenet":
|
||||
dataset = ImageFolder(os.path.join(ds_path("imagenet1k"), "train" if train else "val"), transform=transform)
|
||||
n_classes, collate = 1000, None
|
||||
|
||||
elif dataset_name == "tinyimagenet":
|
||||
dataset = ImageFolder(os.path.join(ds_path("tinyimagenet"), "train" if train else "val"), transform=transform)
|
||||
n_classes, collate = 200, None
|
||||
|
||||
elif dataset_name.startswith("fornet"):
|
||||
ds_def = dataset_name.split("/")
|
||||
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
|
||||
pruning_ratio = 0.8 if len(ds_def) < 3 else float(ds_def[2])
|
||||
fg_size_mode = ("range" if train else "max") if len(ds_def) < 4 else ds_def[3]
|
||||
paste_pre_transform = True if len(ds_def) < 5 else ds_def[4] in ["y", "t"]
|
||||
orig_img_prob = (
|
||||
0.0 if len(ds_def) < 6 else (ds_def[5] if ds_def[5] in ["linear", "revlinear", "cos"] else float(ds_def[5]))
|
||||
)
|
||||
mask_smoothing_sigma = 0.0 if len(ds_def) < 7 else float(ds_def[6])
|
||||
assert len(ds_def) < 5 or ds_def[4] in [
|
||||
"y",
|
||||
"t",
|
||||
"n",
|
||||
"f",
|
||||
], f"Invalid dataset definition: {ds_def[4]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
|
||||
|
||||
orig_ds = ds_path("imagenet1k")
|
||||
|
||||
dataset = ForNet(
|
||||
ds_path("fornet"),
|
||||
train=train,
|
||||
background_combination=comb_scheme,
|
||||
pruning_ratio=pruning_ratio,
|
||||
transform=transform,
|
||||
fg_transform=(
|
||||
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
|
||||
),
|
||||
fg_size_mode=fg_size_mode,
|
||||
paste_pre_transform=paste_pre_transform,
|
||||
orig_img_prob=orig_img_prob,
|
||||
orig_ds=orig_ds,
|
||||
mask_smoothing_sigma=mask_smoothing_sigma,
|
||||
epochs=args.epochs,
|
||||
_album_compose=args.augment_engine == "albumentations",
|
||||
)
|
||||
n_classes, collate = 1000, None
|
||||
|
||||
elif dataset_name.startswith("tinyfornet"):
|
||||
ds_def = dataset_name.split("/")
|
||||
comb_scheme = ds_def[1] if len(ds_def) > 1 else "same"
|
||||
pruning_ratio = 1.1 if len(ds_def) < 3 else float(ds_def[2])
|
||||
fg_size_mode = "range" if len(ds_def) < 4 else ds_def[3]
|
||||
fg_bates_n = 1 if len(ds_def) < 5 else int(ds_def[4])
|
||||
paste_pre_transform = False if len(ds_def) < 6 else ds_def[5] in ["y", "t"]
|
||||
orig_img_prob = (
|
||||
0.0 if len(ds_def) < 7 else (ds_def[6] if ds_def[6] in ["linear", "revlinear", "cos"] else float(ds_def[6]))
|
||||
)
|
||||
mask_smoothing_sigma = 0.0 if len(ds_def) < 8 else float(ds_def[7])
|
||||
assert len(ds_def) < 6 or ds_def[5] in [
|
||||
"y",
|
||||
"t",
|
||||
"n",
|
||||
"f",
|
||||
], f"Invalid dataset definition: {ds_def[5]}; paste pre transform must be 'y'/'t' or 'n'/'f'"
|
||||
assert "-" not in ds_def[0] or len(ds_def[0].split("-")) == 3, f"Invalid dataset definition: {ds_def[0]}"
|
||||
version = "" if "-" not in ds_def[0] else f"_v{ds_def[0].split('-')[1]}_f{ds_def[0].split('-')[2]}"
|
||||
|
||||
orig_ds = ds_path("tinyimagenet")
|
||||
|
||||
dataset = ForNet(
|
||||
f"{ds_path('tinyimagenet')}{version}",
|
||||
train=train,
|
||||
background_combination=comb_scheme,
|
||||
pruning_ratio=pruning_ratio,
|
||||
transform=transform,
|
||||
fg_transform=(
|
||||
None if args.aug_rand_rot == 0 else tv_transforms.RandomRotation(args.aug_rand_rot, expand=True)
|
||||
),
|
||||
fg_size_mode=fg_size_mode,
|
||||
fg_bates_n=fg_bates_n,
|
||||
paste_pre_transform=paste_pre_transform,
|
||||
orig_img_prob=orig_img_prob,
|
||||
orig_ds=orig_ds,
|
||||
mask_smoothing_sigma=mask_smoothing_sigma,
|
||||
epochs=args.epochs,
|
||||
_album_compose=args.augment_engine == "albumentations",
|
||||
)
|
||||
n_classes, collate = 200, None
|
||||
|
||||
elif dataset_name.startswith("counteranimal/"):
|
||||
mode = dataset_name.split("/")[1]
|
||||
|
||||
dataset = CounterAnimal(ds_path("counteranimal"), mode=mode, transform=transform, train=train)
|
||||
n_classes, collate = 1000, None
|
||||
|
||||
elif dataset_name.startswith("imagenet9/"):
|
||||
variant = dataset_name.split("/")[1]
|
||||
assert variant in [
|
||||
"next",
|
||||
"same",
|
||||
"rand",
|
||||
], f"ImageNet-9 has possible variants next, same, and rand, but not '{variant}'."
|
||||
|
||||
dataset = ImageFolder(os.path.join(ds_path("imagenet9"), f"mixed_{variant}", "val"), transform=transform)
|
||||
n_classes, collate = 9, None
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Dataset {dataset_name} is not implemented (yet).")
|
||||
|
||||
if args.aug_repeated_augment_repeats > 1 and train:
|
||||
# use repeated augment sampler from DeiT
|
||||
sampler = RASampler(
|
||||
dataset,
|
||||
num_replicas=args.world_size,
|
||||
rank=rank,
|
||||
shuffle=args.shuffle,
|
||||
num_repeats=args.aug_repeated_augment_repeats,
|
||||
)
|
||||
elif args.weighted_sampler:
|
||||
assert hasattr(
|
||||
dataset, "per_sample_weights"
|
||||
), f"Dataset {type(dataset)} should implement per_sample_weights function, but does not."
|
||||
|
||||
sampler = WeightedRandomSampler(dataset.per_sample_weights(), num_samples=len(dataset) // args.world_size)
|
||||
elif args.distributed:
|
||||
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=rank, shuffle=train and args.shuffle)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
loader_batch_size = 1 if dataset_name.startswith("listops") else args.batch_size
|
||||
|
||||
loader_kwargs = dict(
|
||||
batch_size=loader_batch_size,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
drop_last=train,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
persistent_workers=False,
|
||||
collate_fn=collate,
|
||||
shuffle=None if sampler else train and args.shuffle,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
if args.augment_engine == "dali":
|
||||
data_loader = dali_proxy.DataLoader(dali_server, dataset, **loader_kwargs)
|
||||
else:
|
||||
data_loader = DataLoader(dataset, **loader_kwargs)
|
||||
|
||||
return data_loader, n_classes, ignore_index, multi_label, dali_server
|
||||
679
AAAI Supplementary Material/Model Training Code/main.py
Normal file
679
AAAI Supplementary Material/Model Training Code/main.py
Normal file
@@ -0,0 +1,679 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Parse args and call the correct script inside slurm container.
|
||||
|
||||
Outside the container, on the head-node, create and call the correct srun command.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
from config import default_kwargs, slurm_defaults
|
||||
from paths_config import results_folder, slurm_output_folder
|
||||
|
||||
_EXPNAMES = ["EfficientCVBench", "test", "recombine_imagenet"]
|
||||
|
||||
|
||||
def base_parser():
|
||||
"""Create the argument parser with all the choices for the training / evaluation scripts."""
|
||||
parser = argparse.ArgumentParser("Transformer training and evaluation.")
|
||||
|
||||
# Main
|
||||
group = parser.add_argument_group("Main")
|
||||
group.add_argument(
|
||||
"-t",
|
||||
"--task",
|
||||
nargs="?",
|
||||
choices=[
|
||||
"pre-train",
|
||||
"fine-tune",
|
||||
"fine-tune-head",
|
||||
"eval",
|
||||
"parser-test",
|
||||
"eval-metrics",
|
||||
"eval-attr",
|
||||
"continue",
|
||||
"eval-center-bias",
|
||||
"eval-size-bias",
|
||||
"load-images",
|
||||
"save-images",
|
||||
],
|
||||
required=True,
|
||||
help="Task to perform.",
|
||||
)
|
||||
group.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
nargs="?",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model to use. Either model name for a new model or weights and dicts to load for fine-tuning.",
|
||||
)
|
||||
group.add_argument("-ds", "--dataset", nargs="?", type=str, help="Dataset to train on.")
|
||||
group.add_argument(
|
||||
"-valds", "--val-dataset", nargs="?", type=str, help="Validation dataset. Defaults to same as training."
|
||||
)
|
||||
group.add_argument("-ep", "--epochs", nargs="?", type=int, help="Number of epochs to train.")
|
||||
group.add_argument(
|
||||
"-run",
|
||||
"--run-name",
|
||||
nargs="?",
|
||||
type=str,
|
||||
help="A name for the run. If not give, the model name is used instead.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--defaults", nargs="?", choices=["DeiT", "DeiTIII"], default="DeiTIII", help="Default settings to use."
|
||||
)
|
||||
|
||||
# Further model parameters
|
||||
group = parser.add_argument_group("Further model parameters")
|
||||
group.add_argument("--drop-path-rate", nargs="?", type=float, help="Drop path rate for ViT models.")
|
||||
group.add_argument("--layer-scale-init-values", nargs="?", type=float, help="LayerScale initial values.")
|
||||
group.add_argument("--layer-scale", action=argparse.BooleanOptionalAction, help="Use layer scale?")
|
||||
group.add_argument(
|
||||
"--qkv-bias",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use bias in linear transformation to queries, keys, and values?",
|
||||
)
|
||||
group.add_argument("--pre-norm", action=argparse.BooleanOptionalAction, help="Use norm first architecture?")
|
||||
group.add_argument("--dropout", nargs="?", type=float, help="Model dropout.")
|
||||
group.add_argument("-heads", "--num-heads", nargs="?", type=int, help="Number of parallel attention heads.")
|
||||
group.add_argument("--input-dim", nargs="?", type=int, help="Dimensionality of text encoding.")
|
||||
group.add_argument("--max-seq-len", nargs="?", type=int, help="Maximum sequence length for text data.")
|
||||
group.add_argument(
|
||||
"--fused-attn",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use fused attention (for ViT with Timm's attention only)?",
|
||||
)
|
||||
# group.add_argument(
|
||||
# "--perf-metric", nargs="?", choices=["acc", "mIoU"], help="Performance metric to use for evaluation."
|
||||
# )
|
||||
# group.add_argument("-no_model_ema", action="store_true",
|
||||
# help="Don't use an exponential moving average for model parameters")
|
||||
# group.add_argument("-model_ema_decay", nargs='?', type=float, default=default_kwargs["model_ema_decay"],
|
||||
# help="Decay rate for exponential moving average of model parameters")
|
||||
|
||||
# Experiment management
|
||||
group = parser.add_argument_group("Experiment management")
|
||||
group.add_argument("--seed", nargs="?", type=int, help="Manual RNG seed.")
|
||||
group.add_argument(
|
||||
"-exp",
|
||||
"--experiment-name",
|
||||
nargs="?",
|
||||
choices=_EXPNAMES,
|
||||
help="Name for the experiment. Is used for grouping of runs.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--save-epochs", nargs="?", type=int, help="Number of epochs after which to save the full training state."
|
||||
)
|
||||
group.add_argument(
|
||||
"--keep-interm-states",
|
||||
nargs="?",
|
||||
type=int,
|
||||
help="Number of intermediate states to keep. All others (earlier ones) will be deleted automatically.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--custom-dataset-path", nargs="?", type=str, help="Overwrite the path to any dataset to this path."
|
||||
)
|
||||
group.add_argument(
|
||||
"--results-folder",
|
||||
nargs="?",
|
||||
default=results_folder,
|
||||
type=str,
|
||||
help="Folder to put script results (mlflow data, models, etc.).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gather-stats-during-training",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Gather training statistics from all GPUs?",
|
||||
)
|
||||
group.add_argument("--tqdm", action=argparse.BooleanOptionalAction, help="Show tqdm for every epoch?")
|
||||
group.add_argument(
|
||||
"--debug", action=argparse.BooleanOptionalAction, help="Debug mode: lots of intermediate prints."
|
||||
)
|
||||
group.add_argument("--wandb", action=argparse.BooleanOptionalAction, help="Use external logging via Wandb?")
|
||||
group.add_argument("--log-level", choices=["info", "debug"], help="Log level", metavar="LEVEL")
|
||||
group.add_argument("-out", "--out-dir", type=str, help="Output directory for additional outputs.")
|
||||
|
||||
# Speedup
|
||||
group = parser.add_argument_group("Speedup")
|
||||
group.add_argument("--amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision?")
|
||||
group.add_argument(
|
||||
"--eval-amp", action=argparse.BooleanOptionalAction, help="Use automatic mixed precision during evaluation?"
|
||||
)
|
||||
group.add_argument("--compile-model", action=argparse.BooleanOptionalAction, help="Use torch.compile?")
|
||||
group.add_argument("--cuda", action=argparse.BooleanOptionalAction, help="Use cuda?")
|
||||
|
||||
# Data loading
|
||||
group = parser.add_argument_group("Data loading")
|
||||
group.add_argument("-bs", "--batch-size", nargs="?", type=int, help="Batch size over all graphics cards (togeter).")
|
||||
group.add_argument("--num-workers", nargs="?", type=int, help="Number of dataloader worker threads. Should be >0.")
|
||||
group.add_argument(
|
||||
"--pin-memory", action=argparse.BooleanOptionalAction, help="Use pin_memory of torch Dataloader?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--prefetch-factor",
|
||||
nargs="?",
|
||||
type=int,
|
||||
help="Prefetch factor for dataloader workers (how many batches to fetch)",
|
||||
)
|
||||
group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, help="Shuffle the training data?")
|
||||
group.add_argument(
|
||||
"--weighted-sampler",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use a class-weighted sampler to sample evenly from all classes (train and val)?",
|
||||
)
|
||||
group.add_argument("--ipc", type=int, help="How many images per class to load and save.")
|
||||
|
||||
# Optimizer
|
||||
group = parser.add_argument_group("Optimizer")
|
||||
group.add_argument("--opt", nargs="?", type=str, help="Optimizer to use.")
|
||||
group.add_argument("--weight-decay", nargs="?", type=float, help="Weight decay factor for use in optimizer.")
|
||||
group.add_argument("-lr", "--lr", nargs="?", type=float, help="Initial learning rate.")
|
||||
group.add_argument(
|
||||
"--max-grad-norm", nargs="?", type=float, help="Maximum norm for the gradients (used for cutoff)."
|
||||
)
|
||||
group.add_argument("--warmup-epochs", nargs="?", type=int, help="Number of epochs of linear warmup.")
|
||||
group.add_argument("--label-smoothing", nargs="?", type=float, help="Label smoothing factor.")
|
||||
group.add_argument("--loss", nargs="?", choices=["ce", "baikal"], type=str, help="Loss function to use.")
|
||||
group.add_argument(
|
||||
"--loss-weight", nargs="?", type=str, choices=["none", "linear", "log", "sqrt"], help="Per class loss weight."
|
||||
)
|
||||
group.add_argument("--sched", nargs="?", choices=["cosine", "const"], help="Learning rate schedule.")
|
||||
group.add_argument("--min-lr", nargs="?", type=float, help="Minimum learning rate to be hit by scheduler.")
|
||||
group.add_argument("--warmup-lr", nargs="?", type=float, help="Warmup learning rate.")
|
||||
group.add_argument("--warmup-sched", nargs="?", choices=["linear", "const"], help="Schedule for warmup")
|
||||
group.add_argument(
|
||||
"--opt-eps", nargs="?", type=float, help="Epsilon value added in the optimizer to stabilize training."
|
||||
)
|
||||
group.add_argument("--momentum", nargs="?", type=float, help="Optimizer momentum.")
|
||||
|
||||
# Data augmentation
|
||||
group = parser.add_argument_group("Data augmentation")
|
||||
group.add_argument("--augment-strategy", nargs="?", type=str, help="Data augmentation strategy.")
|
||||
group.add_argument("--aug-rand-rot", nargs="?", type=int, help="Random rotation limit.")
|
||||
group.add_argument(
|
||||
"--aug-flip", action=argparse.BooleanOptionalAction, help="Use data augmentation: horizontal flip?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-crop",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Use data augmentation: cropping. This may break the skript?",
|
||||
)
|
||||
group.add_argument("--aug-resize", action=argparse.BooleanOptionalAction, help="Use data augmentation: resize?")
|
||||
group.add_argument(
|
||||
"--aug-grayscale", action=argparse.BooleanOptionalAction, help="Use data augmentation: grayscale?"
|
||||
)
|
||||
group.add_argument("--aug-solarize", action=argparse.BooleanOptionalAction, help="Use data augmentation: solarize?")
|
||||
group.add_argument(
|
||||
"--aug-gauss-blur", action=argparse.BooleanOptionalAction, help="Use data augmentation: gaussian blur?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-cutmix-alpha",
|
||||
type=float,
|
||||
help="Alpha value for using CutMix. CutMix is active when aug_cutmix_alpha > 0.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-mixup-alpha", type=float, help="Alpha value for using Mixup. Mixup is active when aug_mixup_alpha > 0."
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-color-jitter-factor",
|
||||
nargs="?",
|
||||
type=float,
|
||||
help="Factor to use for the data augmentation: color jitter.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-normalize", action=argparse.BooleanOptionalAction, help="Use data augmentation: Normalization?"
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-repeated-augment-repeats",
|
||||
type=int,
|
||||
help="Number of image repeats with repeat-augment from DeiT. 1 is not using repeat-augment.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aug-random-erase-prob", type=float, help="For DeiT augment: Probabiliy of RandomErase augmentation."
|
||||
)
|
||||
group.add_argument("--auto-augment-strategy", type=str, help="For DeiT augment: AutoAugment Policy to use.")
|
||||
group.add_argument("--imsize", nargs="?", type=int, help="Image size given to the model -> imsize x imsize.")
|
||||
group.add_argument(
|
||||
"--augment-engine",
|
||||
nargs="?",
|
||||
choices=["torchvision", "albumentations", "dali"],
|
||||
help="Which data augmentation engine to use.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def partition_choices():
|
||||
"""Automatically create a list of all possible slurm partitions."""
|
||||
potential = list(set([l.split(" ")[0] for l in os.popen("sinfo")])) # noqa: E741
|
||||
if len(potential) <= 2:
|
||||
return slurm_defaults["partition"]
|
||||
return [p[:-1] if "*" in p else p for p in potential if p != "PARTITION"]
|
||||
|
||||
|
||||
def slurm_parser(parser=None):
|
||||
"""Add srun arguments to the given parser.
|
||||
|
||||
Args:
|
||||
parser (argparse.ArgumentParser, optional): base parser to extend; default is parser from *base_parser*
|
||||
|
||||
Returns:
|
||||
parser (argparse.ArgumentParser): extended parser
|
||||
|
||||
"""
|
||||
if parser is None:
|
||||
parser = base_parser()
|
||||
group = parser.add_argument_group("Slurm arguments")
|
||||
group.add_argument(
|
||||
"--partition",
|
||||
nargs="*",
|
||||
default=slurm_defaults["partition"],
|
||||
choices=partition_choices(),
|
||||
help="Slurm partition to use",
|
||||
)
|
||||
group.add_argument(
|
||||
"--container-image",
|
||||
nargs="?",
|
||||
default=slurm_defaults["container_image"],
|
||||
type=str,
|
||||
help="Path to slurm container image (.sqsh)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--container-workdir",
|
||||
nargs="?",
|
||||
default=slurm_defaults["container_workdir"],
|
||||
type=str,
|
||||
help="Working directory in container",
|
||||
)
|
||||
group.add_argument(
|
||||
"--container-mounts",
|
||||
nargs="?",
|
||||
default=slurm_defaults["container_mounts"],
|
||||
type=str,
|
||||
help="All slurm mounts separated by ','.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--job-name",
|
||||
nargs="?",
|
||||
default=slurm_defaults["job_name"],
|
||||
type=str,
|
||||
help="Slurm job name. Will default to '<model> <task> <dataset>'.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--nodes", nargs="?", default=slurm_defaults["nodes"], type=int, help="Number of cluster nodes to use."
|
||||
)
|
||||
group.add_argument(
|
||||
"--ntasks", nargs="?", default=slurm_defaults["ntasks"], type=int, help="Number of GPUs to use for the job."
|
||||
)
|
||||
group.add_argument("--gpus", action=argparse.BooleanOptionalAction, default=True, help="Use gpus for this job?")
|
||||
group.add_argument(
|
||||
"-cpus",
|
||||
"--cpus-per-task",
|
||||
"--cpus-per-gpu",
|
||||
nargs="?",
|
||||
default=slurm_defaults["cpus_per_task"],
|
||||
type=int,
|
||||
help="Number of CPUs per task/GPU.",
|
||||
)
|
||||
group.add_argument(
|
||||
"-mem",
|
||||
"--mem-per-gpu",
|
||||
"--mem-per-task",
|
||||
nargs="?",
|
||||
default=slurm_defaults["mem_per_gpu"],
|
||||
type=int,
|
||||
help="Ram per GPU (in Gb) to use. Will be given as total mem in srun command.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--task-prolog",
|
||||
nargs="?",
|
||||
default=slurm_defaults["task_prolog"],
|
||||
type=str,
|
||||
help="Shell script for task prolog (installing packages, etc.).",
|
||||
)
|
||||
group.add_argument("--time", nargs="?", default=slurm_defaults["time"], type=str, help="Slurm time limit.")
|
||||
group.add_argument(
|
||||
"--export",
|
||||
nargs="?",
|
||||
default=slurm_defaults["export"],
|
||||
type=str,
|
||||
help="Additional environment variables to export.",
|
||||
)
|
||||
group.add_argument("--exclude", nargs="?", default=slurm_defaults["exclude"], type=str, help="Nodes to exclude.")
|
||||
group.add_argument(
|
||||
"--after-job", nargs="?", default=slurm_defaults["after_job"], type=int, help="Job ID to wait for."
|
||||
)
|
||||
group.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Run using srun instead of sbatch. This will print the output into the terminal, not the slurm output file."
|
||||
" The logfile will still be created as usual."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Run locally")
|
||||
group.add_argument("--local", action="store_true", help="Run locally; not in slurm", default=False)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args(args=None, parser=None):
|
||||
"""Parse args from *base_parser* and insert defaults.
|
||||
|
||||
Args:
|
||||
args: (Default value = None)
|
||||
parser: (Default value = None)
|
||||
|
||||
Returns:
|
||||
dict: parsed arguments
|
||||
|
||||
"""
|
||||
if args is None:
|
||||
parser = base_parser()
|
||||
args = parser.parse_args()
|
||||
args = dict(vars(args))
|
||||
|
||||
check_arg_completeness(args, parser)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def check_arg_completeness(args, parser):
|
||||
"""Check completeness of arguments.
|
||||
|
||||
Args:
|
||||
args (dict): arguments to check
|
||||
parser (argparse.ArgumentParser): for raising the parser error
|
||||
Note:
|
||||
will raise a parser error if the arguments are not complete.
|
||||
"""
|
||||
if args["task"] in ["pre-train", "fine-tune", "fine-tune-head"]:
|
||||
if "run_name" not in args or args["run_name"] is None or len(args["run_name"]) == 0:
|
||||
parser.error(f"-run_name is required for task {args['task']}")
|
||||
|
||||
if "experiment_name" not in args or args["experiment_name"] is None or len(args["experiment_name"]) == 0:
|
||||
parser.error(f"-experiment_name is required for task {args['task']}. Choose from {_EXPNAMES}")
|
||||
|
||||
if "epochs" not in args or args["epochs"] is None:
|
||||
parser.error(f"-epochs is required for task {args['task']}")
|
||||
|
||||
if ("dataset" not in args or args["dataset"] is None) and args["task"] in [
|
||||
"pre-train",
|
||||
"fine-tune",
|
||||
"fine-tune-head",
|
||||
"eval-metrics",
|
||||
]:
|
||||
parser.error(f"-dataset is required for task {args['task']}")
|
||||
|
||||
if (
|
||||
("val_dataset" not in args or args["val_dataset"] is None)
|
||||
and ("dataset" not in args or args["dataset"] is None)
|
||||
and args["task"] in ["eval"]
|
||||
):
|
||||
parser.error(f"-dataset or -val_dataset is required for task {args['task']}")
|
||||
|
||||
if args["aug_repeated_augment_repeats"] is not None and args["aug_repeated_augment_repeats"] < 1:
|
||||
parser.error(
|
||||
"number of repeats for repeated augment has to be >= 1, but got -aug_repeated_augment_repeats ="
|
||||
f" {args['aug_repeated_augment_repeats']}"
|
||||
)
|
||||
|
||||
if args["task"] == "save-images" and ("out_dir" not in args or args["out_dir"] is None):
|
||||
parser.error("Need to set save directory (--out-dir) to save the images in.")
|
||||
|
||||
|
||||
def inside_slurm():
|
||||
"""Test for being inside a slurm container.
|
||||
|
||||
Works by testing for environment variable 'RANK'.
|
||||
"""
|
||||
return "RANK" in os.environ
|
||||
|
||||
|
||||
# TODO: fix ./runscript.tmp: 18: Syntax error: Unterminated quoted string
|
||||
def create_runscript(args, file_name=None):
|
||||
"""Create a run script for a distributed training job using SLURM.
|
||||
|
||||
Args:
|
||||
args (dict): A dictionary containing various arguments for the job, including parameters for SLURM and for training.
|
||||
file_name (str, optional, optional): The name of the file to create. Defaults to "runscript.tmp".
|
||||
|
||||
Returns:
|
||||
str: The name of the created file.
|
||||
str: Additional command line arguments for sbatch.
|
||||
|
||||
Example:
|
||||
>>> args = {"model": "vit_large_patch16_384", "task": "pre-train", "batch_size": 256, ...}
|
||||
|
||||
>>> file_name = "my_run_script.sh"
|
||||
|
||||
>>> create_runscript(args, file_name)
|
||||
|
||||
"""
|
||||
for key, val in slurm_defaults.items():
|
||||
if key not in args and val is not None:
|
||||
args[key] = val
|
||||
|
||||
if "run_name" not in args or args["run_name"] is None:
|
||||
model_str = args["model"]
|
||||
if model_str.endswith(".pt"):
|
||||
model_str = os.path.dirname(model_str)
|
||||
run_name = args["task"] + " " + model_str.split(os.sep)[-1].split("_")[0]
|
||||
else:
|
||||
run_name = args["run_name"]
|
||||
job_name = run_name.replace(" ", "_").replace("/", "_").replace(">", "_").replace("<", "_")
|
||||
if file_name is None:
|
||||
file_name = (
|
||||
f"experiments/sbatch/run_{args['task']}_{job_name}_at_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.sbatch"
|
||||
)
|
||||
|
||||
task_args = ""
|
||||
# slurm_command = "echo run distributed:\necho python3 main.py {0}\n\nsrun -K \\\n" # " --gpus-per-task=1 \\\n --gpu-bind=none \\\n"
|
||||
srun_command = "\nsrun -K \\\n"
|
||||
sbatch_commands = ( # outfile name is job name, date, job id, node name
|
||||
"#!/bin/bash\n\n#SBATCH"
|
||||
f" --output={slurm_output_folder}/%x-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-%j-%N.out\n"
|
||||
)
|
||||
sbatch_cmd_args = "" # additional command line arguments for sbatch
|
||||
python_command = " python3 main.py {0}\n"
|
||||
for key, val in args.items():
|
||||
if key == "local":
|
||||
continue
|
||||
if key == "interactive":
|
||||
continue
|
||||
if key == "gpus":
|
||||
continue
|
||||
if key in slurm_defaults:
|
||||
# it's a parameter for srun
|
||||
# slurm has - instead of _
|
||||
key = key.replace("_", "-")
|
||||
if key == "mem-per-gpu":
|
||||
# convert mem-per-gpu to mem
|
||||
# slurm_command += f" --mem={val * args['ntasks'] // args['nodes']}G \\\n" # that amount of memory is assigned on each node
|
||||
key = "mem"
|
||||
val = f"{val * args['ntasks'] // args['nodes']}G" # that amount of memory is assigned on each node
|
||||
# continue
|
||||
if key == "job-name" and val is None:
|
||||
# # default jobname is '<task> <model> <dataset>'
|
||||
# model_str = args["model"]
|
||||
# task = args["task"]
|
||||
# if task == "pre-train":
|
||||
# # it's just the model name...
|
||||
# model = model_str.split("_")[0]
|
||||
# else:
|
||||
# # it's a path to the tar file
|
||||
# if not model_str.startswith(res_folder):
|
||||
# model = "<vit model>"
|
||||
# else:
|
||||
# model = model_str[len(res_folder) :].split("_")[1].split(" ")[0]
|
||||
# if "dataset" in args and args["dataset"] is not None:
|
||||
# dataset = args["dataset"]
|
||||
# else:
|
||||
# dataset = ""
|
||||
val = run_name
|
||||
if key == "job-name" and not val.startswith('"'):
|
||||
val = f'"{val}"'
|
||||
if key in ["task-prolog", "nodes", "exclude", "after-job"] and val is None:
|
||||
continue
|
||||
if key == "task-prolog":
|
||||
srun_command += f' --{key}="{val}" \\\n'
|
||||
continue
|
||||
if key == "after-job":
|
||||
sbatch_cmd_args += f"--dependency=afterany:{val} "
|
||||
continue
|
||||
if key == "partition" and isinstance(val, list):
|
||||
val = ",".join(val)
|
||||
if key == "ntasks":
|
||||
if args["nodes"] == 1:
|
||||
gpus = val if args["gpus"] else 0
|
||||
# slurm_command += f" --gpus={val} \\\n"
|
||||
sbatch_commands += f"#SBATCH --gpus={gpus}\n"
|
||||
else:
|
||||
assert (
|
||||
val % args["nodes"] == 0
|
||||
), f"Number of tasks ({val}) must be a multiple of the number of nodes ({args['nodes']})."
|
||||
# slurm_command += f" --gpus-per-node={val // args['nodes']} \\\n"
|
||||
sbatch_commands += f"#SBATCH --gpus-per-node={val // args['nodes']}\n"
|
||||
sbatch_commands += "#SBATCH --ntasks-per-node=8\n"
|
||||
if "container" in key:
|
||||
srun_command += f" --{key}={val} \\\n"
|
||||
else:
|
||||
sbatch_commands += f"#SBATCH --{key}={val}\n"
|
||||
# slurm_command += f" --{key}={val} \\\n"
|
||||
else:
|
||||
# it's a parameter for the training
|
||||
if val is None:
|
||||
continue
|
||||
if key in ["results_folder"] and val == globals()[key]:
|
||||
continue
|
||||
key = key.replace("_", "-")
|
||||
if isinstance(val, bool):
|
||||
if val:
|
||||
task_args += f"--{key} "
|
||||
else:
|
||||
task_args += f"--no-{key} "
|
||||
continue
|
||||
if isinstance(val, str):
|
||||
task_args += f'--{key} "{val}" '
|
||||
else:
|
||||
task_args += f"--{key} {val} "
|
||||
|
||||
# slurm_command += "python3 main.py {0}\n"
|
||||
# os.umask(0) # make it possible to create an executable file
|
||||
# with open(file_name, "w+", opener=lambda pth, flgs: os.open(pth, flgs, 0o777)) as f:
|
||||
# f.write(slurm_command.format(task_args))
|
||||
with open(file_name, "w+") as f:
|
||||
f.write(sbatch_commands + srun_command + python_command.format(task_args))
|
||||
|
||||
# delete all runscripts older than a month
|
||||
n_old_files = int(os.popen("find experiments/sbatch/ -type f -mtime +30 | wc -l").read())
|
||||
if n_old_files > 0:
|
||||
print(f"Deleting {n_old_files} old runscripts.")
|
||||
os.system("find experiments/sbatch/ -type f -mtime +30 -delete")
|
||||
|
||||
return file_name, sbatch_cmd_args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not inside_slurm():
|
||||
# Make execution script and execute it
|
||||
parser = slurm_parser()
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
if not args["local"]:
|
||||
script_name, cmd_args = create_runscript(args)
|
||||
# os.system("./" + script_name) # run srun to execute this script in slurm cluster
|
||||
# -> the following lines will be executed there
|
||||
if args["interactive"]:
|
||||
os.system(f"python3 srun-sbatch.py {script_name}")
|
||||
else:
|
||||
os.system(f"sbatch {cmd_args} {script_name}") # sbatch to queue the job on the cluster
|
||||
exit(0)
|
||||
|
||||
# local execution is wanted
|
||||
for key in list(args.keys()):
|
||||
if key.replace("_", "-") in slurm_defaults:
|
||||
args.pop(key)
|
||||
args = parse_args(args, parser)
|
||||
|
||||
else:
|
||||
args = parse_args()
|
||||
|
||||
args["branch"] = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode("utf-8")
|
||||
args["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
|
||||
|
||||
if args["task"] == "pre-train":
|
||||
from train import pretrain
|
||||
|
||||
pretrain(**args)
|
||||
|
||||
elif args["task"] == "fine-tune":
|
||||
from train import finetune
|
||||
|
||||
finetune(**args)
|
||||
|
||||
elif args["task"] == "fine-tune-head":
|
||||
from train import finetune
|
||||
|
||||
finetune(**args, head_only=True)
|
||||
|
||||
elif args["task"] == "parser-test":
|
||||
from copy import copy
|
||||
|
||||
from utils import prep_kwargs, log_args
|
||||
|
||||
kwargs = prep_kwargs(copy(args))
|
||||
log_args(kwargs)
|
||||
# keys = sorted(list(args.keys()))
|
||||
# fill_len = max(len(k) for k in keys)
|
||||
# for key in keys:
|
||||
# print(f"{key + ' ' * (fill_len - len(key))} = {args[key]} -> {kwargs[key]}")
|
||||
|
||||
elif args["task"] == "eval-metrics":
|
||||
from evaluate import evaluate_metrics
|
||||
|
||||
evaluate_metrics(**args)
|
||||
|
||||
elif args["task"] == "eval":
|
||||
from evaluate import evaluate
|
||||
|
||||
evaluate(**args)
|
||||
|
||||
elif args["task"] == "eval-attr":
|
||||
from evaluate import evaluate_attributions
|
||||
|
||||
evaluate_attributions(**args)
|
||||
|
||||
elif args["task"] == "continue":
|
||||
from recover import continue_training
|
||||
|
||||
continue_training(**args)
|
||||
|
||||
elif args["task"] == "eval-center-bias":
|
||||
from evaluate import evaluate_center_bias
|
||||
|
||||
evaluate_center_bias(**args)
|
||||
|
||||
elif args["task"] == "eval-size-bias":
|
||||
from evaluate import evaluate_size_bias
|
||||
|
||||
evaluate_size_bias(**args)
|
||||
|
||||
elif args["task"] == "load-images":
|
||||
from test import load_images
|
||||
|
||||
load_images(**args)
|
||||
|
||||
elif args["task"] == "save-images":
|
||||
from test import save_images
|
||||
|
||||
save_images(**args)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Task {args['task']} is not implemented.")
|
||||
1032
AAAI Supplementary Material/Model Training Code/metrics.py
Normal file
1032
AAAI Supplementary Material/Model Training Code/metrics.py
Normal file
File diff suppressed because it is too large
Load Diff
154
AAAI Supplementary Material/Model Training Code/models.py
Normal file
154
AAAI Supplementary Material/Model Training Code/models.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Model loading and preparation."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
|
||||
import utils
|
||||
from architectures.vit import TimmViT
|
||||
from resizing_interface import vit_sizes
|
||||
|
||||
_ARCHITECTURES_IMPORTED = False
|
||||
|
||||
|
||||
def _import_architectures():
|
||||
global _ARCHITECTURES_IMPORTED
|
||||
if not _ARCHITECTURES_IMPORTED:
|
||||
model_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
for file in os.listdir(os.path.join(model_file_path, "architectures")):
|
||||
if not file.endswith(".py"):
|
||||
continue
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
importlib.import_module(f"architectures.{file[:-3]}")
|
||||
logger.debug(f"Imported architectures.{file[:-3]}")
|
||||
except Exception as e:
|
||||
logger.error(f"\033[93mCould not import \033[0m\033[91m{file}\033[0m")
|
||||
logger.error(e)
|
||||
_ARCHITECTURES_IMPORTED = True
|
||||
|
||||
|
||||
def prepare_model(model_str, args):
|
||||
"""Prepare a new model.
|
||||
|
||||
If the name is of the format ViT-<size>/<patch_size>, use a *TimmViT*, else fall back to timm model loading.
|
||||
|
||||
Args:
|
||||
model_str (str): model name
|
||||
args (utils.DotDict): further arguments, needs to have keys n_classes, drop_path_rate; key imsize or '_<imsize>' at the end of ViT specification
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: model
|
||||
|
||||
"""
|
||||
_import_architectures()
|
||||
|
||||
kwargs = dict(args)
|
||||
for key in list([key for key, val in kwargs.items() if val is None]):
|
||||
kwargs.pop(key)
|
||||
|
||||
if args.layer_scale_init_values:
|
||||
kwargs["init_values"] = kwargs["init_scale"] = args.layer_scale_init_values
|
||||
if args.dropout and args.dropout > 0.0:
|
||||
kwargs["drop"] = kwargs["drop_rate"] = args.dropout
|
||||
if args.drop_path_rate and args.drop_path_rate > 0.0:
|
||||
kwargs["drop_block_rate"] = args.drop_path_rate
|
||||
kwargs["num_classes"] = args.n_classes
|
||||
kwargs["img_size"] = args.imsize
|
||||
if model_str.startswith("ViT"):
|
||||
# Format: ViT-{Ti,S,B,L}/<patch_size>[_<image_res>]
|
||||
h1, h2 = model_str.split("/")
|
||||
_, model_size = h1.split("-")
|
||||
if "_" in h2:
|
||||
patch_size, image_res = h2.split("_")
|
||||
assert args.imsize is None or args.imsize == int(
|
||||
image_res
|
||||
), f"Got two different image sizes: {args.imsize} vs {image_res}"
|
||||
else:
|
||||
patch_size = h2
|
||||
|
||||
kwargs = {**vit_sizes[model_size], **kwargs}
|
||||
model = TimmViT(patch_size=int(patch_size), in_chans=3, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
|
||||
else:
|
||||
logger.debug(f"Loading model via timm api {model_str} with args {kwargs}")
|
||||
model = timm.create_model(model_str, pretrained=False, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def load_pretrained(model_path, args, new_dataset_params=False):
|
||||
"""Load a pretrained model from .tar file.
|
||||
|
||||
Args:
|
||||
new_dataset_params (bool, optional): change model parameters (imsize, n_classes) to the ones from args. (Default value = False)
|
||||
model_path (str): path to .tar file
|
||||
args: new model parameters
|
||||
|
||||
Returns:
|
||||
tuple: model, args, old_args, save_state
|
||||
|
||||
"""
|
||||
_import_architectures()
|
||||
|
||||
save_state = torch.load(model_path, map_location="cpu")
|
||||
old_args = utils.prep_kwargs(save_state["args"])
|
||||
args.model = old_args.model
|
||||
old_args.cuda = args.cuda
|
||||
|
||||
if old_args.model.startswith("flash_vit"):
|
||||
args.pop("layer_scale_init_values", None)
|
||||
old_args.pop("layer_scale_init_values", None)
|
||||
|
||||
# load the model (the old one first)
|
||||
model = prepare_model(old_args.model, old_args)
|
||||
logger.debug(f"loading model {old_args.model} from {model_path} with args {old_args}")
|
||||
file_save_state = utils.remove_prefix(save_state["model_state"], prefix="_orig_mod.")
|
||||
file_save_state = utils.remove_prefix(file_save_state)
|
||||
try:
|
||||
model.load_state_dict(file_save_state)
|
||||
except (UnboundLocalError, RuntimeError) as e:
|
||||
model_keys = set(model.state_dict().keys())
|
||||
file_keys = set(file_save_state.keys())
|
||||
logger.warning(f"Error loading state dict: {e}")
|
||||
model_minus_file = model_keys.difference(file_keys)
|
||||
file_minus_model = file_keys.difference(model_keys)
|
||||
logger.warning(f"model-file: {model_minus_file}\nfile-model: {file_minus_model}")
|
||||
if len(file_minus_model) == 0 and all([".ls" in key and key.endswith(".gamma") for key in model_minus_file]):
|
||||
logger.info("Old model was without LayerScale -> replicating")
|
||||
try:
|
||||
args.pop("layer_scale_init_values")
|
||||
old_args.pop("layer_scale_init_values")
|
||||
model = prepare_model(old_args.model, old_args)
|
||||
model.load_state_dict(file_save_state)
|
||||
except (UnboundLocalError, RuntimeError) as e:
|
||||
logger.error("Could not resolve conflict")
|
||||
logger.error(f"Still got error {e}")
|
||||
exit(-1)
|
||||
elif any("head.0." in key for key in file_minus_model):
|
||||
logger.info("Old model used nn.Seqeuntial for head. Trying to fix -> nn.Linear")
|
||||
file_save_state = {key.replace("head.0.", "head."): val for key, val in file_save_state.items()}
|
||||
try:
|
||||
model.load_state_dict(file_save_state)
|
||||
except (UnboundLocalError, RuntimeError) as e:
|
||||
logger.error("Could not resolve conflict")
|
||||
logger.error(f"Still got error {e}")
|
||||
exit(-1)
|
||||
else:
|
||||
exit(-1)
|
||||
|
||||
if new_dataset_params:
|
||||
# setup for finetuning parameters
|
||||
model.set_image_res(args.imsize)
|
||||
model.set_num_classes(args.n_classes)
|
||||
|
||||
if args.max_seq_len is not None:
|
||||
model.set_max_seq_len(args.max_seq_len)
|
||||
|
||||
return model, args, old_args, save_state
|
||||
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
|
||||
user = os.environ.get("USER")
|
||||
results_folder = os.path.join("/BASE/FOLDER/TO/STORE/WEIGHTS/AND/LOGS", "EfficientCVBench")
|
||||
# PATH: /netscratch/<user>/slurm
|
||||
slurm_output_folder = os.path.join("/FOLDER/FOR/SLURM/TO/WRITE/LOGS/TO", "slurm")
|
||||
|
||||
|
||||
_ds_paths = {
|
||||
"cifar": "/PATH/TO/CIFAT",
|
||||
"tinyimagenet": "/PATH/TO/TINYIMAGENET",
|
||||
"stanford_cars": "/PATH/TO/CARS/SUPERFOLDER",
|
||||
"oxford_pet": "/PATH/TO/PET/SUPERFOLDER",
|
||||
"flowers102": "/PATH/TO/FLOWERS/SUPERFOLDER",
|
||||
"food101": "/PATH/TO/FOOD/SUPERFOLDER",
|
||||
"aircraft": "/PATH/TO/AIRCRAFT/SUPERFOLDER",
|
||||
"fornet": "/PATH/TO/FORNET",
|
||||
"counteranimal": "/PATH/TO/CounterAnimal/LAION-final",
|
||||
}
|
||||
|
||||
|
||||
def ds_path(dataset, args=None):
|
||||
"""Get the (base) path for any dataset.
|
||||
|
||||
Args:
|
||||
-----
|
||||
dataset (str): The dataset I'm looking for.
|
||||
args (DotDict, optional): Run args. If args.custom_dataset_path is set, this one is always returned.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
str: Path to the dataset root folder.
|
||||
"""
|
||||
if args is not None and "custom_dataset_path" in args and args.custom_dataset_path is not None:
|
||||
return args.custom_dataset_path
|
||||
return _ds_paths[dataset]
|
||||
136
AAAI Supplementary Material/Model Training Code/recover.py
Normal file
136
AAAI Supplementary Material/Model Training Code/recover.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Continue pretraining / finetuning after something went wrong."""
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from engine import _train, setup_criteria_mixup, setup_model_optim_sched_scaler, setup_tracking_and_logging
|
||||
from load_dataset import prepare_dataset
|
||||
from models import load_pretrained
|
||||
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs
|
||||
|
||||
|
||||
def continue_training(model, **kwargs):
|
||||
"""Continue training a model from a saved state.
|
||||
|
||||
Args:
|
||||
model (str): path to saved state.
|
||||
**kwargs: additional keyword arguments.
|
||||
|
||||
"""
|
||||
model_path = model
|
||||
save_state = torch.load(model, map_location="cpu")
|
||||
|
||||
# state is of the form
|
||||
#
|
||||
# state = {'epoch': epochs,
|
||||
# 'model_state': model.state_dict(),
|
||||
# 'optimizer_state': optimizer.state_dict(),
|
||||
# 'scheduler_state': scheduler.state_dict(),
|
||||
# 'args': dict(args),
|
||||
# 'run_name': run_name,
|
||||
# 'stats': metrics}
|
||||
|
||||
args = prep_kwargs(save_state["args"])
|
||||
|
||||
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
if "world_size" in args and args.world_size is not None:
|
||||
global_bs = args.batch_size * args.world_size
|
||||
else:
|
||||
# assume global bs is given in kwargs
|
||||
global_bs = kwargs["batch_size"]
|
||||
args.batch_size = int(global_bs / world_size)
|
||||
args.world_size = world_size
|
||||
|
||||
if "dataset" in args and args.dataset is not None:
|
||||
dataset = args.dataset
|
||||
else:
|
||||
# get default dataset for the task
|
||||
dataset = "ImageNet21k" if args.task == "pre-train" else "ImageNet"
|
||||
args.dataset = dataset
|
||||
|
||||
if "val_dataset" in args and args.val_dataset is not None:
|
||||
val_dataset = args.val_dataset
|
||||
else:
|
||||
val_dataset = dataset
|
||||
args.val_dataset = val_dataset
|
||||
|
||||
start_epoch = save_state["epoch"]
|
||||
if "epochs" in args and args.epochs is not None and args.epochs != start_epoch:
|
||||
epochs = args.epochs
|
||||
else:
|
||||
epochs = kwargs["epochs"]
|
||||
|
||||
run_folder = setup_tracking_and_logging(args, rank, append_model_path=model_path)
|
||||
logger.info(f"Logging run information to '{run_folder}'")
|
||||
|
||||
# get the datasets & dataloaders
|
||||
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
|
||||
dataset, args, rank=rank
|
||||
)
|
||||
val_loader, _, __, ___, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
|
||||
|
||||
# model_name = args.model
|
||||
|
||||
model, args, _, __ = load_pretrained(model_path, args)
|
||||
|
||||
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
|
||||
|
||||
try:
|
||||
optimizer.load_state_dict(save_state["optimizer_state"])
|
||||
except ValueError as e:
|
||||
logger.error(f"Could not load optimizer state: {e}")
|
||||
logger.error(
|
||||
f"optimizer state: {optimizer.state_dict().keys()}, param groups: {optimizer.state_dict()['param_groups']}"
|
||||
)
|
||||
logger.error(
|
||||
f"saved state: {save_state['optimizer_state'].keys()}, param groups:"
|
||||
f" {save_state['optimizer_state']['param_groups']}"
|
||||
)
|
||||
raise e
|
||||
|
||||
scheduler.load_state_dict(save_state["scheduler_state"])
|
||||
|
||||
# log all devices
|
||||
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
log_args(args)
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
criterion, val_criterion, mixup = setup_criteria_mixup(args)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"start training at epoch {start_epoch}")
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
res = _train(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epochs,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
world_size,
|
||||
scheduler,
|
||||
args,
|
||||
val_loader,
|
||||
val_criterion,
|
||||
run_folder,
|
||||
scaler=scaler,
|
||||
do_metrics_calculation=True,
|
||||
start_epoch=start_epoch,
|
||||
show_tqdm=args.tqdm,
|
||||
train_dali_server=train_dali_server,
|
||||
val_dali_server=val_dali_server,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
|
||||
logger.info(f"Run '{args.run_name}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%")
|
||||
ddp_cleanup(args=args, rank=rank)
|
||||
@@ -0,0 +1,25 @@
|
||||
captum==0.8.0
|
||||
# datadings==3.4.6
|
||||
einops==0.8.1
|
||||
fvcore==0.1.5.post20221221
|
||||
grad-cam==1.5.4
|
||||
halonet-pytorch==0.0.4
|
||||
numpy==1.26.4
|
||||
opencv-python==4.11.0.86
|
||||
pytorch_wavelets==1.3.0
|
||||
pywavelets==1.8.0
|
||||
reformer-pytorch==1.4.4
|
||||
routing-transformer==1.6.1
|
||||
sinkhorn-transformer==0.11.4
|
||||
timm==1.0.15
|
||||
torch==2.6.0
|
||||
torcheval==0.0.7
|
||||
torchprofile==0.0.4
|
||||
torchvision==0.21.0
|
||||
tqdm==4.67.1
|
||||
nltk==3.9.1
|
||||
numpy==1.26.4
|
||||
Pillow==11.1.0
|
||||
psutils==3.3.9
|
||||
wandb==0.19.9
|
||||
psutil==7.0.0
|
||||
@@ -0,0 +1,108 @@
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch import nn
|
||||
|
||||
vit_sizes = {
|
||||
"na": dict(embed_dim=96, depth=2, num_heads=2),
|
||||
"mu": dict(embed_dim=144, depth=6, num_heads=3),
|
||||
"Ti": dict(embed_dim=192, depth=12, num_heads=3),
|
||||
"S": dict(embed_dim=384, depth=12, num_heads=6),
|
||||
"B": dict(embed_dim=768, depth=12, num_heads=12),
|
||||
"L": dict(embed_dim=1024, depth=24, num_heads=16),
|
||||
"LRA_CIFAR": dict(embed_dim=256, depth=1, num_heads=4, mlp_ratio=1.0),
|
||||
"LRA_IMDB": dict(embed_dim=256, depth=4, num_heads=4, mlp_ratio=4.0),
|
||||
"LRA_ListOps": dict(embed_dim=512, depth=4, num_heads=8, mlp_ratio=4.0),
|
||||
}
|
||||
|
||||
|
||||
class ResizingInterface:
|
||||
"""Interface for resizing parts of a Vision Transformer model."""
|
||||
|
||||
def get_internal_loss(self):
|
||||
"""Add a term to the loss."""
|
||||
return 0.0
|
||||
|
||||
def set_image_res(self, res):
|
||||
"""Set a new image resolution.
|
||||
|
||||
Resets the (learned) patch embedding.
|
||||
|
||||
Args:
|
||||
res (int): new image resolution
|
||||
|
||||
"""
|
||||
self._set_input_strand(res=res)
|
||||
|
||||
def _set_input_strand(self, res=None, patch_size=None):
|
||||
"""Set a new image resolution and patch size.
|
||||
|
||||
Args:
|
||||
res (int): (Default value = None)
|
||||
patch_size (int): (Default value = None)
|
||||
|
||||
"""
|
||||
if res is None:
|
||||
res = self.img_size
|
||||
|
||||
if patch_size is None:
|
||||
patch_size = self.patch_size
|
||||
else:
|
||||
# TODO: implement interpolation of patch_embed weights to new patch size/input shape
|
||||
raise NotImplementedError("Interpolation of patch_embed weights to new patch size not implemented yet.")
|
||||
|
||||
if res == self.img_size and patch_size == self.patch_size:
|
||||
return # nothing to do here
|
||||
|
||||
logger.info(f"Resizing input from {self.img_size} to {res} with patch size {self.patch_size} to {patch_size}.")
|
||||
|
||||
old_patch_embed_state = copy(self.patch_embed.state_dict())
|
||||
self.patch_embed = self.embed_layer(
|
||||
img_size=res,
|
||||
patch_size=patch_size,
|
||||
in_chans=self.in_chans,
|
||||
embed_dim=self.embed_dim,
|
||||
bias=not self.pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
)
|
||||
|
||||
self.patch_embed.load_state_dict(old_patch_embed_state)
|
||||
|
||||
num_extra_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
|
||||
orig_size = int((self.pos_embed.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
new_size = int(self.patch_embed.num_patches**0.5)
|
||||
extra_tokens = self.pos_embed[:, :num_extra_tokens]
|
||||
pos_tokens = self.pos_embed[:, num_extra_tokens:]
|
||||
# make it shape rest x embed_dim x orig_size x orig_size
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, self.embed_dim).permute(0, 3, 1, 2)
|
||||
pos_tokens = nn.functional.interpolate(
|
||||
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
|
||||
)
|
||||
# make it shape rest x new_size^2 x embed_dim
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
if num_extra_tokens > 0:
|
||||
pos_tokens = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
self.pos_embed = nn.Parameter(pos_tokens.contiguous())
|
||||
|
||||
self.img_size = res
|
||||
self.patch_size = patch_size
|
||||
|
||||
def set_num_classes(self, n_classes):
|
||||
"""Reset the classification head with a new number of classes.
|
||||
|
||||
Args:
|
||||
n_classes (int): new number of classes
|
||||
|
||||
"""
|
||||
if n_classes == self.num_classes:
|
||||
return
|
||||
logger.info(f"Resizing classification head from {self.num_classes} to {n_classes}.")
|
||||
self.head = nn.Linear(self.embed_dim, n_classes) if n_classes > 0 else nn.Identity()
|
||||
self.num_classes = n_classes
|
||||
|
||||
# init weight + bias
|
||||
# nn.init.zeros_(self.head.weight)
|
||||
# nn.init.constant_(self.head.bias, -log(self.num_classes))
|
||||
|
||||
nn.init.trunc_normal_(self.head.weight, std=0.02)
|
||||
nn.init.constant_(self.head.bias, 0)
|
||||
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
assert len(sys.argv) >= 2, f"Add a sbatch script as the first argument"
|
||||
assert os.path.isfile(
|
||||
sys.argv[1]
|
||||
), f"First argument has to be an executable script (a file that exists), but got '{sys.argv[1]}'"
|
||||
|
||||
with open(sys.argv[1], "r") as f:
|
||||
script = f.readlines()
|
||||
|
||||
script = [l.strip() for l in script if len(l.strip()) > 0]
|
||||
|
||||
# join lines ending with \
|
||||
joined_script = []
|
||||
current_line = ""
|
||||
for line in script:
|
||||
current_line += line
|
||||
if current_line.endswith("\\"):
|
||||
current_line = current_line[:-1]
|
||||
else:
|
||||
joined_script.append(current_line)
|
||||
current_line = ""
|
||||
script = joined_script
|
||||
|
||||
additional_srun_params = []
|
||||
srun_lines = []
|
||||
for line in script:
|
||||
if line.upper().startswith("#SBATCH "):
|
||||
param = line[len("#SBATCH ") :]
|
||||
if param.startswith("--output="):
|
||||
continue
|
||||
additional_srun_params.append(param)
|
||||
elif line.startswith("srun "):
|
||||
srun_lines.append(line)
|
||||
|
||||
further_args = " ".join(sys.argv[2:])
|
||||
|
||||
sruns = [
|
||||
line.replace("srun ", "srun " + " ".join(additional_srun_params) + " ").replace('"$@"', further_args)
|
||||
for line in srun_lines
|
||||
]
|
||||
|
||||
for srun_line in sruns:
|
||||
print(f"I will run:\n{srun_line}", flush=True)
|
||||
os.system(srun_line)
|
||||
97
AAAI Supplementary Material/Model Training Code/test.py
Normal file
97
AAAI Supplementary Material/Model Training Code/test.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from load_dataset import prepare_dataset
|
||||
from utils import log_args, log_formatter, prep_kwargs
|
||||
|
||||
|
||||
def load_images(dataset, **kwargs):
|
||||
args = prep_kwargs(kwargs)
|
||||
args.dataset = dataset
|
||||
|
||||
args.aug_normalize = False
|
||||
|
||||
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
||||
images = next(iter(loader))[0]
|
||||
|
||||
images = images.permute(0, 2, 3, 1).numpy()
|
||||
images = [images[i] for i in range(images.shape[0])]
|
||||
|
||||
rows = math.ceil(math.sqrt(len(images) / 2))
|
||||
ims_per_row = len(images) // rows
|
||||
|
||||
fig, axs = plt.subplots(rows, ims_per_row)
|
||||
axs = [ax for row in axs for ax in row]
|
||||
for img, ax in zip(images, axs):
|
||||
ax.imshow(img)
|
||||
fig.suptitle(f"Examples from {dataset}")
|
||||
fig.tight_layout(pad=0)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def save_images(dataset, out_dir, ipc=None, **kwargs):
|
||||
args = prep_kwargs(kwargs)
|
||||
args.dataset = dataset
|
||||
args.out_dir = out_dir
|
||||
args.ipc = ipc
|
||||
args.aug_normalize = False
|
||||
|
||||
log_file = os.path.join(out_dir, "save_images.log")
|
||||
logger.remove()
|
||||
logger.configure(extra=dict(run_name=f"Save images of {dataset}", rank=0, world_size=-1))
|
||||
logger.add(sys.stderr, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.add(log_file, format=log_formatter, enqueue=True, colorize=True, level=args.log_level.upper())
|
||||
logger.info(f"Out dir '{out_dir}'")
|
||||
log_args(args)
|
||||
|
||||
loader, args.n_classes, args.ignore_index, args.multi_label, _ = prepare_dataset(dataset, args)
|
||||
n_ims = [0 for i in range(args.n_classes)]
|
||||
|
||||
if args.n_classes == 1000:
|
||||
# assume its ImageNet classes
|
||||
logger.info("1000 classes => assuming ImageNet class names")
|
||||
with open("data/misc_dataset_files/imagenet_labels.txt", "r") as f:
|
||||
lines = f.readlines()
|
||||
labels = [l.strip().split(" ")[0] for l in lines]
|
||||
lbl_to_cls_name = sorted(labels)
|
||||
else:
|
||||
lbl_to_cls_name = [i for i in range(args.n_classes)]
|
||||
|
||||
for cls_name in lbl_to_cls_name:
|
||||
os.makedirs(os.path.join(args.out_dir, cls_name), exist_ok=True)
|
||||
|
||||
skipped_ims = 0
|
||||
tqdm_is_disabled = (not args.tqdm) or os.environ.get("TQDM_DISABLE", 0) != 0
|
||||
for i, (images, labels) in (
|
||||
pbar := tqdm(enumerate(loader), desc="Loading and saving images", disable=tqdm_is_disabled, total=len(loader))
|
||||
):
|
||||
images = (images.permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
|
||||
images = [images[i] for i in range(images.shape[0])]
|
||||
labels = labels.tolist()
|
||||
|
||||
for img, lbl in zip(images, labels):
|
||||
if ipc is not None and n_ims[lbl] >= ipc:
|
||||
skipped_ims += 1
|
||||
continue
|
||||
|
||||
img = Image.fromarray(img).save(
|
||||
os.path.join(args.out_dir, lbl_to_cls_name[lbl], f"{lbl_to_cls_name[lbl]}_{n_ims[lbl]}.JPEG")
|
||||
)
|
||||
n_ims[lbl] += 1
|
||||
|
||||
if ipc is not None and sum(n_ims) >= args.n_classes * ipc:
|
||||
break
|
||||
if tqdm_is_disabled:
|
||||
if i % 1000 == 0:
|
||||
logger.info(f"Batch [{i+1}/{len(loader)}]: Saved {sum(n_ims)}, skipped {skipped_ims}")
|
||||
else:
|
||||
pbar.set_description(f"Loading and saving (saved {sum(n_ims)}, skipped {skipped_ims})")
|
||||
logger.success(f"Extracted {sum(n_ims)} images.")
|
||||
303
AAAI Supplementary Material/Model Training Code/train.py
Normal file
303
AAAI Supplementary Material/Model Training Code/train.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import timm
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from engine import (
|
||||
_train,
|
||||
setup_criteria_mixup,
|
||||
setup_model_optim_sched_scaler,
|
||||
setup_tracking_and_logging,
|
||||
wandb_available,
|
||||
)
|
||||
from load_dataset import prepare_dataset
|
||||
from models import load_pretrained, prepare_model
|
||||
from utils import ddp_cleanup, ddp_setup, log_args, prep_kwargs, set_filter_warnings
|
||||
|
||||
|
||||
def finetune(model, dataset, epochs, val_dataset=None, head_only=False, **kwargs):
|
||||
"""Finetune a pretrained model on a given dataset for a specified number of epochs.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pretrained model state file (in .tar format).
|
||||
dataset (str): Name of the dataset to finetune on.
|
||||
val_dataset (str, optional): Name of the validation dataset. (Default value = None)
|
||||
epochs (int): Number of epochs to train for.
|
||||
head_only (bool, optional): Whether to train only the head of the model. Default: False.
|
||||
**kwargs (dict): Further arguments for model setup, training, evaluation,...
|
||||
|
||||
Notes:
|
||||
This function assumes that the model was pretrained on a different dataset using a different set of hyperparameters.
|
||||
It fine-tunes the model on a new dataset by loading the pretrained weights and training for the specified number of
|
||||
epochs. The function supports distributed training using the PyTorch DistributedDataParallel module.
|
||||
"""
|
||||
set_filter_warnings()
|
||||
|
||||
# Add defaults & make keys properties
|
||||
args = prep_kwargs(kwargs)
|
||||
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
|
||||
args.val_dataset = val_dataset
|
||||
args.dataset = dataset
|
||||
args.epochs = epochs
|
||||
|
||||
args.distributed, device, world_size, rank, gpu_id = ddp_setup()
|
||||
args.world_size = world_size
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
except RuntimeError as e:
|
||||
logger.error(
|
||||
f"Could not set device {device} as current device; "
|
||||
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
||||
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
||||
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
||||
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
||||
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
||||
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
||||
)
|
||||
raise e
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
if args.seed is not None:
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed + rank
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# get the datasets & dataloaders
|
||||
# transform only contains resize & crop here; everything else is handled on the GPU / in the training loop
|
||||
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
|
||||
dataset, args, rank=rank
|
||||
)
|
||||
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
|
||||
assert (
|
||||
args.n_classes == _val_classes
|
||||
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
|
||||
|
||||
save_state = torch.load(model, map_location="cpu")
|
||||
old_args = prep_kwargs(save_state["args"])
|
||||
parent_folder = os.path.dirname(model)
|
||||
args.model = old_args.model
|
||||
run_folder = setup_tracking_and_logging(args, rank)
|
||||
if rank == 0:
|
||||
if not os.path.exists(os.path.join(run_folder, "parent_run")):
|
||||
os.symlink(parent_folder, os.path.join(run_folder, "parent_run"), target_is_directory=True)
|
||||
logger.info(
|
||||
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
||||
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
||||
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
||||
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
||||
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
||||
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
||||
)
|
||||
|
||||
if args.seed:
|
||||
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
|
||||
|
||||
model, args, old_args, save_state = load_pretrained(model, args, new_dataset_params=True)
|
||||
# model_name = old_args.model
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"The model was pretrained on {old_args.dataset} for {save_state['epoch']} epochs.")
|
||||
|
||||
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(
|
||||
model, device, epochs, args, head_only=head_only
|
||||
)
|
||||
|
||||
# log all devices
|
||||
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if args.device != 'cpu' else ''}")
|
||||
if rank == 0:
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
logger.info(f"timm version {timm.__version__}")
|
||||
logger.info(f"full set of old arguments: {old_args}")
|
||||
log_args(args)
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
criterion, val_criterion, mixup = setup_criteria_mixup(args)
|
||||
if rank == 0:
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
res = _train(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epochs,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
world_size,
|
||||
scheduler,
|
||||
args,
|
||||
val_loader,
|
||||
val_criterion,
|
||||
run_folder,
|
||||
scaler,
|
||||
do_metrics_calculation=True,
|
||||
show_tqdm=args.tqdm,
|
||||
train_dali_server=train_dali_server,
|
||||
val_dali_server=val_dali_server,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
best_acc_key = sorted([key for key in res.keys() if key.startswith("val/best_")])[0]
|
||||
logger.info(
|
||||
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
|
||||
)
|
||||
|
||||
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)
|
||||
|
||||
|
||||
def pretrain(model, dataset, epochs, val_dataset=None, **kwargs):
|
||||
"""Train or pretrain a model.
|
||||
|
||||
Args:
|
||||
model (str): Name of the model to train.
|
||||
dataset (str): Name of the dataset to train the model on.
|
||||
epochs (int): Number of training epochs.
|
||||
val_dataset (str, optional, optional): Name of the validation dataset, by default None
|
||||
**kwargs (dict): Additional keyword arguments.
|
||||
|
||||
Notes:
|
||||
This function sets up logger, prepares the model, and trains the model on the given dataset.
|
||||
"""
|
||||
set_filter_warnings()
|
||||
|
||||
# Add defaults & make args properties
|
||||
args = prep_kwargs(kwargs)
|
||||
|
||||
if val_dataset is None:
|
||||
val_dataset = dataset
|
||||
|
||||
args.val_dataset = val_dataset
|
||||
args.dataset = dataset
|
||||
args.model = model
|
||||
args.epochs = epochs
|
||||
|
||||
args.distributed, device, world_size, rank, gpu_id = ddp_setup(args.cuda)
|
||||
args.world_size = world_size
|
||||
|
||||
# sleep(rank * 5)
|
||||
# logger.debug(f'running environment commands for rank {rank}')
|
||||
# os.system('env')
|
||||
# os.system('nvidia-smi')
|
||||
# sleep((world_size - rank) * 5)
|
||||
|
||||
logger.debug(
|
||||
f"rank params: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
||||
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; gpu params: "
|
||||
f"SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
||||
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
||||
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}"
|
||||
)
|
||||
|
||||
if args.cuda:
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Could not set device {device} as current device: {e}")
|
||||
logger.error(
|
||||
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
||||
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
||||
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
||||
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
||||
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
||||
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
||||
)
|
||||
raise e
|
||||
|
||||
args.batch_size = int(args.batch_size / world_size)
|
||||
|
||||
run_folder = setup_tracking_and_logging(args, rank)
|
||||
if rank % world_size == 0:
|
||||
logger.info(
|
||||
f"environment parameters: RANK={os.environ.get('RANK')}, LOCAL_RANK={os.environ.get('LOCAL_RANK')}, "
|
||||
f"WORLD_SIZE={os.environ.get('WORLD_SIZE')}; SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')}, "
|
||||
f"GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')}, "
|
||||
f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}, "
|
||||
f"SLURM_STEP_NODELIST={os.environ.get('SLURM_STEP_NODELIST')}, "
|
||||
f"SLURMD_NODENAME={os.environ.get('SLURMD_NODENAME')}"
|
||||
)
|
||||
|
||||
if args.seed is not None:
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed + rank
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
logger.info(f"setting manual seed '{seed}' (arg: {args.seed} + rank: {rank})")
|
||||
|
||||
# get the datasets & dataloaders
|
||||
train_loader, args.n_classes, args.ignore_index, args.multi_label, train_dali_server = prepare_dataset(
|
||||
dataset, args, rank=rank
|
||||
)
|
||||
val_loader, _val_classes, _, __, val_dali_server = prepare_dataset(val_dataset, args, train=False, rank=rank)
|
||||
assert (
|
||||
args.n_classes == _val_classes
|
||||
), f"Training and validation datasets have different numbers of classes: {args.n_classes} vs {_val_classes}"
|
||||
|
||||
# setup model with amp & DDP
|
||||
if isinstance(model, str):
|
||||
if model.startswith("ViT") and "_" not in model:
|
||||
model += f"_{args.imsize}"
|
||||
model_name = model
|
||||
model = prepare_model(model, args)
|
||||
if not model_name:
|
||||
model_name = type(model).__name__
|
||||
|
||||
model, optimizer, scheduler, scaler = setup_model_optim_sched_scaler(model, device, epochs, args)
|
||||
|
||||
# log all devices
|
||||
logger.info(f"training on {device} -> {torch.cuda.get_device_name(device) if device != 'cpu' else ''}")
|
||||
if rank == 0:
|
||||
logger.info(f"python version {sys.version}")
|
||||
logger.info(f"torch version {torch.__version__}")
|
||||
logger.info(f"timm version {timm.__version__}")
|
||||
log_args(args)
|
||||
|
||||
if args.seed:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
criterion, val_criterion, mixup = setup_criteria_mixup(args)
|
||||
|
||||
if rank == 0:
|
||||
logger.info(f"Run info at: '{run_folder}'")
|
||||
|
||||
res = _train(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
rank,
|
||||
epochs,
|
||||
device,
|
||||
mixup,
|
||||
criterion,
|
||||
world_size,
|
||||
scheduler,
|
||||
args,
|
||||
val_loader,
|
||||
val_criterion,
|
||||
run_folder,
|
||||
scaler,
|
||||
do_metrics_calculation=True,
|
||||
show_tqdm=args.tqdm,
|
||||
train_dali_server=train_dali_server,
|
||||
val_dali_server=val_dali_server,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
best_acc_key = [key for key in res.keys() if key.startswith("val/best_")][0]
|
||||
logger.info(
|
||||
f"Run '{run_folder.split(os.sep)[-1]}' is done. Top-1 validation accuracy: {res[best_acc_key] * 100:.2f}%"
|
||||
)
|
||||
|
||||
ddp_cleanup(args=args, sync_old_wandb=wandb_available(), rank=rank)
|
||||
594
AAAI Supplementary Material/Model Training Code/utils.py
Normal file
594
AAAI Supplementary Material/Model Training Code/utils.py
Normal file
@@ -0,0 +1,594 @@
|
||||
"""Utils and small helper functions."""
|
||||
|
||||
import collections.abc
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from itertools import repeat
|
||||
from math import cos, pi, sqrt
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from loguru import logger
|
||||
from timm.data import Mixup
|
||||
from timm.utils import NativeScaler, dispatch_clip_grad
|
||||
from torch.nn.modules.loss import _WeightedLoss
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
import paths_config
|
||||
from config import default_kwargs, get_default_kwargs # noqa: F401 # is used in prep_kwargs
|
||||
|
||||
|
||||
class RepeatedDataset(Dataset):
|
||||
"""Dataset that repeats the given dataset a number of times."""
|
||||
|
||||
def __init__(self, dataset, num_repeats):
|
||||
"""Create repeated dataset.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): dataset to repeat.
|
||||
num_repeats (int): number of repeats.
|
||||
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.num_repeats = num_repeats
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset[idx // self.num_repeats]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset) * self.num_repeats
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerArgs:
|
||||
"""Class for scheduler arguments."""
|
||||
|
||||
sched: str
|
||||
epochs: int
|
||||
min_lr: float
|
||||
warmup_lr: float
|
||||
warmup_epochs: int
|
||||
cooldown_epochs: int = 0
|
||||
|
||||
|
||||
def scheduler_function_factory(
|
||||
epochs, sched, warmup_epochs=0, lr=None, min_lr=0.0, warmup_sched=None, warmup_lr=None, offset=0, **kwargs
|
||||
):
|
||||
"""Create a scheduler factor function.
|
||||
|
||||
Args:
|
||||
sched (str): the learning rate schedule type
|
||||
epochs (int): length of the full schedule
|
||||
warmup_epochs (int, optional): number of epochs reserved for warmup (Default value = 0)
|
||||
lr (float, optional): learning rate (has to be given, when warmup or min_lr are set) (Default value = None)
|
||||
min_lr (float, optional): minimum learning rate (Default value = 0.0)
|
||||
warmup_sched (str, optional): the type of schedule during warmup (Default value = None)
|
||||
warmup_lr (float, optional): (starting) learning rate during warmup (Default value = None)
|
||||
offset (int, optional): offset for the schedule (to be the same as the timm scheduler) (Default value = 0)
|
||||
**kwargs: unused
|
||||
|
||||
Returns:
|
||||
function: scheduler function
|
||||
|
||||
"""
|
||||
sched = sched.lower()
|
||||
|
||||
def warmup_f(ep):
|
||||
return 1.0
|
||||
|
||||
if warmup_epochs > 0:
|
||||
assert warmup_lr is not None, "Need warmup_lr, but got None"
|
||||
warmup_lr_factor = warmup_lr / lr
|
||||
if warmup_sched == "linear":
|
||||
|
||||
def warmup_f(ep):
|
||||
return warmup_lr_factor + (1 - warmup_lr_factor) * max(ep, 0.0) / warmup_epochs
|
||||
|
||||
elif warmup_sched == "const":
|
||||
|
||||
def warmup_f(ep):
|
||||
return warmup_lr_factor
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Warmup schedule {warmup_sched} not implemented")
|
||||
|
||||
epochs = epochs - warmup_epochs + offset
|
||||
if sched == "cosine":
|
||||
# cos from 0 to pi
|
||||
def main_f(ep):
|
||||
return cos(pi * ep / epochs) / 2 + 0.5
|
||||
|
||||
elif sched == "const":
|
||||
|
||||
def main_f(ep):
|
||||
return 1.0
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Schedule {sched} is not implemented.")
|
||||
|
||||
# rescale and add min_lr
|
||||
min_lr_fact = min_lr / lr
|
||||
|
||||
def main_f_with_min_lr(ep):
|
||||
return (1 - min_lr_fact) * main_f(ep) + min_lr_fact
|
||||
|
||||
return lambda ep: (
|
||||
warmup_f(ep + offset) if ep + offset < warmup_epochs else main_f_with_min_lr(ep + offset - warmup_epochs)
|
||||
)
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
"""Extension of a Python dictionary to access its keys using dot notation."""
|
||||
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
def __getattr__(self, item, default=None):
|
||||
"""Get item from.
|
||||
|
||||
Args:
|
||||
item: key
|
||||
default (optional): default value. Defaults to None.
|
||||
|
||||
Returns:
|
||||
value
|
||||
|
||||
"""
|
||||
if item not in self:
|
||||
return default
|
||||
return self.get(item)
|
||||
|
||||
|
||||
def prep_kwargs(kwargs):
|
||||
"""Prepare the arguments and add defaults.
|
||||
|
||||
Args:
|
||||
kwargs (dict[str, Any]): dict of kwargs
|
||||
|
||||
Returns:
|
||||
DotDict: prepared kwargs
|
||||
|
||||
"""
|
||||
if "defaults" not in kwargs:
|
||||
kwargs["defaults"] = "DeiTIII"
|
||||
defaults = get_default_kwargs(kwargs["defaults"])
|
||||
for k, v in defaults.items():
|
||||
if k not in kwargs or kwargs[k] is None:
|
||||
kwargs[k] = v
|
||||
|
||||
if "results_folder" not in kwargs:
|
||||
kwargs[var_name] = paths_config.results_folder # globals()[var_name]
|
||||
|
||||
if kwargs["results_folder"].endswith("/"):
|
||||
kwargs["results_folder"] = kwargs["results_folder"][:-1]
|
||||
|
||||
if "val_dataset" not in kwargs and "dataset" in kwargs:
|
||||
kwargs["val_dataset"] = kwargs["dataset"]
|
||||
|
||||
return DotDict(kwargs)
|
||||
|
||||
|
||||
def denormalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
||||
"""Invert the normlize operation.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): images to de-normalize
|
||||
mean (tuple, optional): normalization mean. Defaults to (0.485, 0.456, 0.406).
|
||||
std (tuple, optional): normalization std. Defaults to (0.229, 0.224, 0.225).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: de-normalized images
|
||||
|
||||
"""
|
||||
operation = transforms.Normalize(
|
||||
mean=[-mu / sigma for mu, sigma in zip(mean, std)], std=[1 / sigma for sigma in std]
|
||||
)
|
||||
return operation(x)
|
||||
|
||||
|
||||
def log_formatter(record):
|
||||
if "run_name" not in record["extra"]:
|
||||
return (
|
||||
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <y>name TBD</y> > <y>?</y>/<y>?</y> <c>|</c> <level>{level:"
|
||||
" <8}</level> <c>|</c> {message}\n"
|
||||
)
|
||||
|
||||
epoch_str = "@ epoch <y>{extra[epoch]: >3}</y> " if "epoch" in record["extra"] else ""
|
||||
code_loc_str = "<r>{name}</r>.<r>{function}</r>:<r>{line}</r> - " if record["level"].no >= 30 else ""
|
||||
|
||||
return (
|
||||
"<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> <c>|</c> <m>{extra[run_name]}</m> >"
|
||||
" <m>{extra[rank]}</m>/<m>{extra[world_size]}</m> "
|
||||
+ epoch_str
|
||||
+ "<c>|</c> <level>{level: <8}</level> <c>|</c> "
|
||||
+ code_loc_str
|
||||
+ "{message}\n"
|
||||
)
|
||||
|
||||
|
||||
def ddp_setup(use_cuda=True):
|
||||
"""Set up the distributed environment.
|
||||
|
||||
Args:
|
||||
use_cuda: (Default value = True)
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the following elements:
|
||||
* bool: Whether the training is distributed.
|
||||
* torch.device: The device to use for distributed training.
|
||||
* int: The total number of processes in the distributed setup.
|
||||
* int: The global rank of the current process in the distributed setup.
|
||||
* int: The local rank of the current process on its node.
|
||||
|
||||
Notes:
|
||||
The 'nccl' backend is used.
|
||||
|
||||
"""
|
||||
logger.remove()
|
||||
rank = int(os.getenv("RANK", 0))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||||
num_gpus = int(os.getenv("WORLD_SIZE", 1))
|
||||
distributed = "RANK" in os.environ and num_gpus > 1
|
||||
logger.add(sys.stderr, format=log_formatter, colorize=True, enqueue=True)
|
||||
if distributed:
|
||||
assert use_cuda, "Only use distributed mode with cuda."
|
||||
try:
|
||||
dist.init_process_group("nccl")
|
||||
except ValueError as e:
|
||||
logger.critical(f"Value error while setting up nccl process group: {e}")
|
||||
logger.info(
|
||||
f" CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
|
||||
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
|
||||
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
|
||||
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}. Shutting down now."
|
||||
)
|
||||
raise e
|
||||
|
||||
assert torch.cuda.is_available() or not use_cuda, "CUDA is not available"
|
||||
assert (
|
||||
len(str(os.environ.get("SLURM_STEP_GPUS")).split(","))
|
||||
== len(str(os.environ.get("CUDA_VISIBLE_DEVICES")).split(","))
|
||||
== len(str(os.environ.get("GPU_DEVICE_ORDINAL")).split(","))
|
||||
== num_gpus
|
||||
) or not use_cuda, (
|
||||
f"SLURM GPU setup is incorrect: CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')},"
|
||||
f" SLURM_STEP_GPUS={os.environ.get('SLURM_STEP_GPUS')},"
|
||||
f" GPU_DEVICE_ORDINAL={os.environ.get('GPU_DEVICE_ORDINAL')} for process:"
|
||||
f" RANK={rank} (LOCAL_RANK={local_rank}) of WORLD_SIZE={num_gpus}"
|
||||
)
|
||||
return distributed, torch.device(f"cuda:{local_rank}") if use_cuda else "cpu", num_gpus, rank, local_rank
|
||||
|
||||
|
||||
def ddp_cleanup(args, sync_old_wandb=False, rank=0):
|
||||
"""Clean the distributed setup after use.
|
||||
|
||||
Args:
|
||||
args (DotDict): arguments
|
||||
sync_old_wandb (bool, optional): Whether to sync and remove wandb runs older than 100 hours (>3 days). Defaults to False.
|
||||
rank (int, optional): The rank of the current process, so only one process syncs wandb. Defaults to 0.
|
||||
|
||||
"""
|
||||
if sync_old_wandb and rank == 0:
|
||||
os.system("wandb sync --clean --clean-old-hours 100 --clean-force")
|
||||
|
||||
if args.distributed:
|
||||
logger.info("waiting for all processes to finish")
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
logger.info("exiting now")
|
||||
exit(0)
|
||||
|
||||
|
||||
def set_filter_warnings():
|
||||
"""Filter out some warnings to reduce spam."""
|
||||
# filter DataLoader number of workers warning
|
||||
warnings.filterwarnings(
|
||||
"ignore", ".*worker processes in total. Our suggested max number of worker in current system is.*"
|
||||
)
|
||||
|
||||
# Filter datadings only varargs warning
|
||||
warnings.filterwarnings("ignore", ".*only accepts varargs so.*")
|
||||
|
||||
# Filter warnings from calculation of MACs & FLOPs
|
||||
# warnings.filterwarnings("ignore", ".*No handlers found:.*")
|
||||
|
||||
# Filter warnings from gather
|
||||
warnings.filterwarnings("ignore", ".*is_namedtuple is deprecated, please use the python checks instead.*")
|
||||
|
||||
# Filter warnings from meshgrid
|
||||
warnings.filterwarnings("ignore", ".*in an upcoming release, it will be required to pass the indexing.*")
|
||||
|
||||
# Filter warnings from timm when overwriting models
|
||||
warnings.filterwarnings("ignore", ".*UserWarning: Overwriting .*")
|
||||
|
||||
|
||||
def remove_prefix(state_dict, prefix="module."):
|
||||
"""Remove a prefix from the keys in a state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict[str, Any]): The state dictionary to remove the prefix from.
|
||||
prefix (str, optional): The prefix to remove from the keys. Default is 'module.'.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A new dictionary with the prefix removed from the keys.
|
||||
|
||||
Examples:
|
||||
>>> state_dict = {'module.layer1.weight': 1, 'module.layer1.bias': 2}
|
||||
|
||||
>>> remove_prefix(state_dict)
|
||||
|
||||
{'layer1.weight': 1, 'layer1.bias': 2}
|
||||
|
||||
"""
|
||||
return {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
|
||||
|
||||
|
||||
def prime_factors(n):
|
||||
"""Calculate the prime factors of a given integer.
|
||||
|
||||
Args:
|
||||
n (int): The integer to find the prime factors of.
|
||||
|
||||
Returns:
|
||||
list[int]: The prime factors of n.
|
||||
|
||||
"""
|
||||
i = 2
|
||||
factors = []
|
||||
while i * i <= n:
|
||||
if n % i:
|
||||
i += 1
|
||||
else:
|
||||
n //= i
|
||||
factors.append(i)
|
||||
if n > 1:
|
||||
factors.append(n)
|
||||
return factors
|
||||
|
||||
|
||||
def linear_regession(points):
|
||||
"""Calculate a linear interpolation of the points.
|
||||
|
||||
Args:
|
||||
points (dict[float, float]): points to interpolate in the format points[x] = y
|
||||
|
||||
Returns:
|
||||
function: A function that interpolates the points.
|
||||
|
||||
"""
|
||||
N = len(points)
|
||||
x = []
|
||||
y = []
|
||||
for x_i, y_i in points.items():
|
||||
x.append(x_i)
|
||||
y.append(y_i)
|
||||
x = np.array(x)
|
||||
y = np.array(y)
|
||||
|
||||
a = (N * (x * y).sum() - x.sum() * y.sum()) / (N * (x * x).sum() - x.sum() ** 2)
|
||||
b = (y.sum() - a * x.sum()) / N
|
||||
return lambda z: a * z + b
|
||||
|
||||
|
||||
def save_model_state(
|
||||
model_folder,
|
||||
epoch,
|
||||
args,
|
||||
model_state,
|
||||
regular_save=True,
|
||||
stats=None,
|
||||
val_accs=None,
|
||||
epoch_accs=None,
|
||||
additional_reason="",
|
||||
max_interm_ep_states=2,
|
||||
**kwargs,
|
||||
):
|
||||
"""Save the model state.
|
||||
|
||||
Args:
|
||||
model_folder: Folder to save model in
|
||||
epoch: current epoch
|
||||
args: arguments to guide and save
|
||||
model_state: state of the model
|
||||
regular_save: Is this a regular or a special save? (Default value = True)
|
||||
stats: model stats to save (Default value = None)
|
||||
val_accs: model accuracy to save (Default value = None)
|
||||
epoch_accs: training accuracy to save (Default value = None)
|
||||
additional_reason: save reason; in case it's not just a regular save interval. Would be "top" or "final", for example. (Default value = "")
|
||||
max_interm_ep_states: Number of regular epoch states to keep (Default value = 2)
|
||||
**kwargs: Further arguments to save
|
||||
|
||||
"""
|
||||
# make args dict, not DotDict to be able to save it
|
||||
state = {"epoch": epoch, "model_state": model_state, "run_name": args.run_name, "args": dict(args)}
|
||||
if stats is None:
|
||||
stats = {}
|
||||
if val_accs is not None:
|
||||
stats = {**stats, **val_accs}
|
||||
if epoch_accs is not None:
|
||||
stats = {**stats, **epoch_accs}
|
||||
state["stats"] = stats
|
||||
state = {**state, **kwargs}
|
||||
logger.info(f"saving model state at epoch {epoch} ({additional_reason})")
|
||||
regular_file_name = f"ep_{epoch}.pt"
|
||||
save_name = additional_reason + ".pt" if len(additional_reason) > 0 else regular_file_name
|
||||
outfile = os.path.join(model_folder, save_name)
|
||||
torch.save(state, outfile)
|
||||
if len(additional_reason) > 0 and regular_save:
|
||||
shutil.copyfile(outfile, os.path.join(model_folder, regular_file_name))
|
||||
|
||||
# remove intermediate epoch states (all but the last max_interm_ep_states)
|
||||
if max_interm_ep_states > 0:
|
||||
epoch_states = [f for f in os.listdir(model_folder) if f.startswith("ep_") and f.endswith(".pt")]
|
||||
epoch_states = sorted(epoch_states, key=lambda x: int(x.split("_")[1].split(".")[0]))
|
||||
if len(epoch_states) > max_interm_ep_states:
|
||||
for f in epoch_states[:-max_interm_ep_states]:
|
||||
os.remove(os.path.join(model_folder, f))
|
||||
logger.debug(f"removed intermediate epoch state {f}")
|
||||
|
||||
|
||||
def log_args(args, rank=0):
|
||||
if rank == 0:
|
||||
logger.info("full set of arguments: " + json.dumps(dict(args), sort_keys=True))
|
||||
# keys = sorted(list(args.keys()))
|
||||
# for key in keys:
|
||||
# logger.info(f"arg: {key} = {args[key]}")
|
||||
|
||||
|
||||
class ScalerGradNormReturn(NativeScaler):
|
||||
"""A wrapper around PyTorch's NativeScaler that returns the gradient norm."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{type(self).__name__}(_scaler: {self._scaler})"
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode="norm", parameters=None, create_graph=False):
|
||||
"""Scale and backpropagate through the loss tensor, and return the gradient norm of the selected parameters.
|
||||
|
||||
Does an optimizer step.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss tensor to scale and backpropagate through.
|
||||
optimizer (torch.optim.Optimizer): The optimizer to use for the optimization step.
|
||||
clip_grad (float, optional): The maximum allowed norm of the gradients. If None, no clipping is performed.
|
||||
clip_mode (str, optional): The mode used for clipping the gradients. Only used if `clip_grad` is not None. Possible values are 'norm'
|
||||
(clipping the norm of the gradients) and 'value' (clipping the value of the gradients). (default='norm')
|
||||
parameters (iterable[torch.nn.Parameter], optional): The parameters to compute the gradient norm for. If None, the gradient norm is not computed.
|
||||
create_graph (bool, optional): Whether to create a computation graph for computing second-order gradients. (default=False)
|
||||
|
||||
Returns:
|
||||
float: The gradient norm of the selected parameters.
|
||||
|
||||
"""
|
||||
self._scaler.scale(loss).backward(create_graph=create_graph)
|
||||
|
||||
# always unscale the gradients, since it's being done anyway
|
||||
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
if parameters is not None:
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
device = grads[0].device
|
||||
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
||||
else:
|
||||
grad_norm = -1
|
||||
if clip_grad is not None:
|
||||
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
|
||||
self._scaler.step(optimizer)
|
||||
self._scaler.update()
|
||||
return grad_norm
|
||||
|
||||
|
||||
class NoScaler:
|
||||
"""Dummy gradient scaler that doesn't scale gradients.
|
||||
|
||||
This scaler performs a simple backward pass with the given loss, and then updates the model's parameters
|
||||
with the given optimizer. The resulting gradient norm is computed and returned.
|
||||
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{type(self).__name__}()"
|
||||
|
||||
def __call__(self, loss, optimizer, parameters=None, **kwargs):
|
||||
"""Perform backward pass with the given loss, updates the model's parameters with the given optimizer, and computes the resulting gradient norm.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss tensor that the gradients will be computed from.
|
||||
optimizer (torch.optim.Optimizer): The optimizer that will be used to update the model's parameters.
|
||||
parameters (iterable[torch.Tensor], optional): An iterable of model parameters to compute gradients. If None, returns -1.
|
||||
**kwargs: Additional keyword arguments; nothing will be done with these.
|
||||
|
||||
Returns:
|
||||
float: The gradient norm computed after the optimizer step, if parameters is not None.
|
||||
|
||||
"""
|
||||
loss.backward()
|
||||
if parameters is not None:
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
device = grads[0].device
|
||||
grad_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
||||
else:
|
||||
grad_norm = -1
|
||||
optimizer.step()
|
||||
return grad_norm
|
||||
|
||||
|
||||
def get_cpu_name():
|
||||
"""Get the name of the CPU."""
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("model name"):
|
||||
return line.split(":")[1].strip()
|
||||
return "unknown"
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
"""Make a function to create n-tuples.
|
||||
|
||||
Args:
|
||||
n (int): tuple length
|
||||
|
||||
Returns:
|
||||
function: function to create n-tuples
|
||||
|
||||
"""
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
|
||||
"""Calculate the smallest number >= v that is divisible by divisor.
|
||||
|
||||
This function is primarily used to ensure that the output of a layer
|
||||
is divisible by a certain number, typically to align with hardware
|
||||
optimizations or memory layouts.
|
||||
|
||||
Args:
|
||||
v (int): The input value.
|
||||
divisor (int, optional): The divisor. Defaults to 8.
|
||||
min_value (int, optional): The minimum value to return. If None, defaults to the divisor.
|
||||
round_limit (float, optional): A threshold for rounding down. If the result of rounding down is less than round_limit * v, the next multiple of the divisor is returned instead. Defaults to 0.9.
|
||||
|
||||
Returns:
|
||||
int: The smallest number >= v that is divisible by divisor.
|
||||
|
||||
"""
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < round_limit * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def grad_cam_reshape_transform(tensor):
|
||||
"""Transform the tensor for Grad-CAM calculation.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): input tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: reshaped tensor without [CLS] token.
|
||||
|
||||
"""
|
||||
n_squ = tensor.shape[1]
|
||||
result = tensor[:, 1:] if int(sqrt(n_squ)) ** 2 != n_squ else tensor
|
||||
bs, n, dim = result.shape
|
||||
result = result.reshape(bs, int(sqrt(n)), int(sqrt(n)), dim)
|
||||
|
||||
# Bring the channels to the first dimension,
|
||||
# like in CNNs.
|
||||
return result.transpose(2, 3).transpose(1, 2)
|
||||
Reference in New Issue
Block a user