Source code for towhee.models.cvnet.resnet

# Original pytorch implementation by:
# 'Correlation Verification for Image Retrieval'
#       - https://arxiv.org/abs/2204.01458
# Original code by / Copyright 2022, Seongwon Lee.
# Modifications & additions by / Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn

# Stage depths for ImageNet models
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
TRANS_FUN = "bottleneck_transform"
NUM_GROUPS = 1
WIDTH_PER_GROUP = 64
STRIDE_1X1 = False
BN_EPS = 1e-5
BN_MOM = 0.1
RELU_INPLACE = True


[docs]def get_trans_fun(name): """Retrieves the transformation function by name.""" trans_funs = { "basic_transform": BasicTransform, "bottleneck_transform": BottleneckTransform, } err_str = "Transformation function '{}' not supported" assert name in trans_funs.keys(), err_str.format(name) return trans_funs[name]
[docs]class GeneralizedMeanPooling(nn.Module): r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` - At p = infinity, one gets Max Pooling - At p = 1, one gets Average Pooling The output is of size H x W, for any input size. The number of output features is equal to the number of input planes. Args: output_size: the target output size of the image of the form H x W. Can be a tuple (H, W) or a single H for a square image H x H H and W can be either a ``int``, or ``None`` which means the size will be the same as that of the input. """
[docs] def __init__(self, norm, output_size=1, eps=1e-6): super().__init__() assert norm > 0 self.p = float(norm) self.output_size = output_size self.eps = eps
[docs] def forward(self, x): x = x.clamp(min=self.eps).pow(self.p) return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
[docs] def __repr__(self): return self.__class__.__name__ + "(" \ + str(self.p) + ", " \ + "output_size=" + str(self.output_size) + ")"
[docs]class GeneralizedMeanPoolingP(GeneralizedMeanPooling): """ Same, but norm is trainable """
[docs] def __init__(self, norm=3, output_size=1, eps=1e-6): super().__init__(norm, output_size, eps) self.p = nn.Parameter(torch.ones(1) * norm)
[docs]class BasicTransform(nn.Module): """Basic transformation: 3x3, BN, ReLU, 3x3, BN."""
[docs] def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1): err_str = "Basic transform does not support w_b and num_gs options" assert w_b is None and num_gs == 1, err_str super().__init__() self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False) self.a_bn = nn.BatchNorm2d(w_out, eps=BN_EPS, momentum=BN_MOM) self.a_relu = nn.ReLU(inplace=RELU_INPLACE) self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False) self.b_bn = nn.BatchNorm2d(w_out, eps=BN_EPS, momentum=BN_MOM) self.b_bn.final_bn = True
[docs] def forward(self, x): for layer in self.children(): x = layer(x) return x
[docs]class BottleneckTransform(nn.Module): """Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN."""
[docs] def __init__(self, w_in, w_out, stride, w_b, num_gs): super().__init__() # MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3 (s1, s3) = (stride, 1) if STRIDE_1X1 else (1, stride) self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False) self.a_bn = nn.BatchNorm2d(w_b, eps=BN_EPS, momentum=BN_MOM) self.a_relu = nn.ReLU(inplace=RELU_INPLACE) self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False) self.b_bn = nn.BatchNorm2d(w_b, eps=BN_EPS, momentum=BN_MOM) self.b_relu = nn.ReLU(inplace=RELU_INPLACE) self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False) self.c_bn = nn.BatchNorm2d(w_out, eps=BN_EPS, momentum=BN_MOM) self.c_bn.final_bn = True
[docs] def forward(self, x): for layer in self.children(): x = layer(x) return x
[docs]class ResBlock(nn.Module): """Residual block: x + F(x)."""
[docs] def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1): super().__init__() # Use skip connection with projection if shape changes self.proj_block = (w_in != w_out) or (stride != 1) if self.proj_block: self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False) self.bn = nn.BatchNorm2d(w_out, eps=BN_EPS, momentum=BN_MOM) self.f = trans_fun(w_in, w_out, stride, w_b, num_gs) self.relu = nn.ReLU(RELU_INPLACE)
[docs] def forward(self, x): if self.proj_block: x = self.bn(self.proj(x)) + self.f(x) else: x = x + self.f(x) x = self.relu(x) return x
[docs]class ResStage(nn.Module): """Stage of ResNet."""
[docs] def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1): super().__init__() for i in range(d): b_stride = stride if i == 0 else 1 b_w_in = w_in if i == 0 else w_out trans_fun = get_trans_fun(TRANS_FUN) res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs) self.add_module("b{}".format(i + 1), res_block)
[docs] def forward(self, x): for block in self.children(): x = block(x) return x
[docs]class ResStemIN(nn.Module): """ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
[docs] def __init__(self, w_in, w_out): super().__init__() self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False) self.bn = nn.BatchNorm2d(w_out, eps=BN_EPS, momentum=BN_MOM) self.relu = nn.ReLU(RELU_INPLACE) self.pool = nn.MaxPool2d(3, stride=2, padding=1)
[docs] def forward(self, x): for layer in self.children(): x = layer(x) return x
[docs]class ResNet(nn.Module): """ResNet model."""
[docs] def __init__(self, reset_depth, reduction_dim): super().__init__() self.reset_depth = reset_depth self.reduction_dim = reduction_dim self._construct()
def _construct(self): g, gw = NUM_GROUPS, WIDTH_PER_GROUP (d1, d2, d3, d4) = _IN_STAGE_DS[self.reset_depth] w_b = gw * g self.stem = ResStemIN(3, 64) self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g) self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g) self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g) self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g) self.head = GlobalHead(2048, nc=self.reduction_dim)
[docs] def forward(self, x): x = self.stem(x) x1 = self.s1(x) x2 = self.s2(x1) x3 = self.s3(x2) x4 = self.s4(x3) x4_p = self.head.pool(x4) x4_p = x4_p.view(x4_p.size(0), -1) x = self.head.fc(x4_p) return x, x3