# Pytorch implementation of models in [A ConvNet for 2020s](https://arxiv.org/abs/2201.03545).
# Inspired by https://github.com/facebookresearch/ConvNeXt
#
# Modifications & additions by Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from towhee.models.convnext.utils import LayerNorm, Block
from towhee.models.convnext.configs import get_configs
from towhee.models.utils.weight_init import trunc_normal_
[docs]class ConvNeXt(nn.Module):
"""
ConvNeXt model
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 9, 3)
dims (int): Feature dimension at each stage. Default: (96, 192, 384, 768)
drop_path_rate (float): Stochastic depth rate. Default: 0.
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
[docs] def __init__(self,
in_chans=3,
num_classes=1000,
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 768),
drop_path_rate=0.,
layer_scale_init_value=1e-6,
head_init_scale=1.,
):
super().__init__()
self.depths = depths
self.dims = dims
self.num_classes = num_classes
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format='channels_first')
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format='channels_first'),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
self.head = nn.Linear(dims[-1], num_classes)
self.apply(self._init_weights)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
[docs] def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
[docs]def create_model(
model_name: str = None,
pretrained: bool = False,
weights_path: str = None,
device: str = None,
**kwargs,
):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
configs = get_configs(model_name)
configs.update(**kwargs)
if 'url' in configs:
url = configs['url']
del configs['url']
else:
url = None
model = ConvNeXt(**configs).to(device)
if pretrained:
if weights_path is None:
assert url, 'No default url or weights path is provided for the pretrained model.'
state_dict = torch.hub.load_state_dict_from_url(url=url, map_location=device)
else:
state_dict = torch.load(weights_path, map_location=device)
if 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict)
model.eval()
return model