AAAI Version

This commit is contained in:
Tobias Christian Nauen
2026-02-24 12:22:44 +01:00
parent 5c08f9d31a
commit ff34712155
378 changed files with 19844 additions and 4780 deletions

View File

@@ -0,0 +1,111 @@
from timm.models import register_model
from timm.models.resnet import BasicBlock, Bottleneck
from timm.models.resnet import ResNet as ResNetTimm
from torch import nn
from resizing_interface import ResizingInterface
class ResNet(ResNetTimm, ResizingInterface):
"""The popular ResNet model with a ResizingInterface."""
def __init__(self, *args, global_pool="avg", **kwargs):
"""Create ResNet model with resizing capabilities.
Args:
*args: Arguments for the ResNet model.
global_pool (str, optional): _description_. Defaults to "avg".
**kwargs: Keyword arguments for the ResNet model (from Timm).
Keyword Args:
block, layers, num_classes, in_chans, output_stride, cardinality, base_width, stem_width, stem_type, replace_stem_pool, block_reduce_first, down_kernel_size, avg_down, act_layer, norm_layer, aa_layer, drop_rate, drop_path_rate, drop_block_rate, zero_init_last, block_args
"""
admissible_kwargs = [
"block",
"layers",
"num_classes",
"in_chans",
"output_stride",
"cardinality",
"base_width",
"stem_width",
"stem_type",
"replace_stem_pool",
"block_reduce_first",
"down_kernel_size",
"avg_down",
"act_layer",
"norm_layer",
"aa_layer",
"drop_rate",
"drop_path_rate",
"drop_block_rate",
"zero_init_last",
"block_args",
]
for key in list(kwargs.keys()):
if key not in admissible_kwargs:
kwargs.pop(key)
super().__init__(*args, global_pool=global_pool, **kwargs)
self.global_pool_str = global_pool
def set_image_res(self, res):
# resizing not needed for CNNs with pooling
return
def set_num_classes(self, n_classes):
if self.num_classes == n_classes:
return
if n_classes > 0:
self.reset_classifier(num_classes=n_classes, global_pool=self.global_pool_str)
else:
self.fc = nn.Identity()
@register_model
def resnet18(pretrained=False, **kwargs):
"""Construct a ResNet-18 model."""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
return ResNet(**model_args, **kwargs)
@register_model
def resnet34(pretrained=False, **kwargs):
"""Construct a ResNet-34 model."""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
return ResNet(**model_args, **kwargs)
@register_model
def resnet26(pretrained=False, **kwargs):
"""Construct a ResNet-26 model."""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
return ResNet(**model_args, **kwargs)
@register_model
def resnet50(pretrained=False, **kwargs):
"""Construct a ResNet-50 model."""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3])
return ResNet(**model_args, **kwargs)
@register_model
def resnet101(pretrained=False, **kwargs):
"""Construct a ResNet-101 model."""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
return ResNet(**model_args, **kwargs)
@register_model
def wide_resnet50_2(pretrained=False, **kwargs):
"""Construct a Wide ResNet-50-2 model.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
return ResNet(**model_args, **kwargs)