AAAI Version
This commit is contained in:
@@ -0,0 +1,111 @@
|
||||
from timm.models import register_model
|
||||
from timm.models.resnet import BasicBlock, Bottleneck
|
||||
from timm.models.resnet import ResNet as ResNetTimm
|
||||
from torch import nn
|
||||
|
||||
from resizing_interface import ResizingInterface
|
||||
|
||||
|
||||
class ResNet(ResNetTimm, ResizingInterface):
|
||||
"""The popular ResNet model with a ResizingInterface."""
|
||||
|
||||
def __init__(self, *args, global_pool="avg", **kwargs):
|
||||
"""Create ResNet model with resizing capabilities.
|
||||
|
||||
Args:
|
||||
*args: Arguments for the ResNet model.
|
||||
global_pool (str, optional): _description_. Defaults to "avg".
|
||||
**kwargs: Keyword arguments for the ResNet model (from Timm).
|
||||
|
||||
Keyword Args:
|
||||
block, layers, num_classes, in_chans, output_stride, cardinality, base_width, stem_width, stem_type, replace_stem_pool, block_reduce_first, down_kernel_size, avg_down, act_layer, norm_layer, aa_layer, drop_rate, drop_path_rate, drop_block_rate, zero_init_last, block_args
|
||||
|
||||
"""
|
||||
admissible_kwargs = [
|
||||
"block",
|
||||
"layers",
|
||||
"num_classes",
|
||||
"in_chans",
|
||||
"output_stride",
|
||||
"cardinality",
|
||||
"base_width",
|
||||
"stem_width",
|
||||
"stem_type",
|
||||
"replace_stem_pool",
|
||||
"block_reduce_first",
|
||||
"down_kernel_size",
|
||||
"avg_down",
|
||||
"act_layer",
|
||||
"norm_layer",
|
||||
"aa_layer",
|
||||
"drop_rate",
|
||||
"drop_path_rate",
|
||||
"drop_block_rate",
|
||||
"zero_init_last",
|
||||
"block_args",
|
||||
]
|
||||
for key in list(kwargs.keys()):
|
||||
if key not in admissible_kwargs:
|
||||
kwargs.pop(key)
|
||||
super().__init__(*args, global_pool=global_pool, **kwargs)
|
||||
self.global_pool_str = global_pool
|
||||
|
||||
def set_image_res(self, res):
|
||||
# resizing not needed for CNNs with pooling
|
||||
return
|
||||
|
||||
def set_num_classes(self, n_classes):
|
||||
if self.num_classes == n_classes:
|
||||
return
|
||||
if n_classes > 0:
|
||||
self.reset_classifier(num_classes=n_classes, global_pool=self.global_pool_str)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet18(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-18 model."""
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet34(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-34 model."""
|
||||
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet26(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-26 model."""
|
||||
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-50 model."""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet101(pretrained=False, **kwargs):
|
||||
"""Construct a ResNet-101 model."""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
|
||||
return ResNet(**model_args, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def wide_resnet50_2(pretrained=False, **kwargs):
|
||||
"""Construct a Wide ResNet-50-2 model.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
|
||||
return ResNet(**model_args, **kwargs)
|
||||
Reference in New Issue
Block a user