Source code for towhee.models.convnext.convnext

# Pytorch implementation of models in [A ConvNet for 2020s](https://arxiv.org/abs/2201.03545).
# Inspired by https://github.com/facebookresearch/ConvNeXt
#
# Modifications & additions by Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn

from towhee.models.convnext.utils import LayerNorm, Block
from towhee.models.convnext.configs import get_configs
from towhee.models.utils.weight_init import trunc_normal_


[docs]class ConvNeXt(nn.Module): """ ConvNeXt model Args: in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 9, 3) dims (int): Feature dimension at each stage. Default: (96, 192, 384, 768) drop_path_rate (float): Stochastic depth rate. Default: 0. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. """
[docs] def __init__(self, in_chans=3, num_classes=1000, depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1., ): super().__init__() self.depths = depths self.dims = dims self.num_classes = num_classes self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm(dims[0], eps=1e-6, data_format='channels_first') ) self.downsample_layers.append(stem) for i in range(3): downsample_layer = nn.Sequential( LayerNorm(dims[i], eps=1e-6, data_format='channels_first'), nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), ) self.downsample_layers.append(downsample_layer) self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur = 0 for i in range(4): stage = nn.Sequential( *[Block(dim=dims[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] ) self.stages.append(stage) cur += depths[i] self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer self.head = nn.Linear(dims[-1], num_classes) self.apply(self._init_weights) self.head.weight.data.mul_(head_init_scale) self.head.bias.data.mul_(head_init_scale)
def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) nn.init.constant_(m.bias, 0) def forward_features(self, x): for i in range(4): x = self.downsample_layers[i](x) x = self.stages[i](x) return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
[docs] def forward(self, x): x = self.forward_features(x) x = self.head(x) return x
[docs]def create_model( model_name: str = None, pretrained: bool = False, weights_path: str = None, device: str = None, **kwargs, ): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' configs = get_configs(model_name) configs.update(**kwargs) if 'url' in configs: url = configs['url'] del configs['url'] else: url = None model = ConvNeXt(**configs).to(device) if pretrained: if weights_path is None: assert url, 'No default url or weights path is provided for the pretrained model.' state_dict = torch.hub.load_state_dict_from_url(url=url, map_location=device) else: state_dict = torch.load(weights_path, map_location=device) if 'model' in state_dict: state_dict = state_dict['model'] model.load_state_dict(state_dict) model.eval() return model