# Original pytorch implementation by:
# 'Correlation Verification for Image Retrieval'
# - https://arxiv.org/abs/2204.01458
# Original code by / Copyright 2022, Seongwon Lee.
# Modifications & additions by / Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
# Stage depths for ImageNet models
_IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
TRANS_FUN = "bottleneck_transform"
NUM_GROUPS = 1
WIDTH_PER_GROUP = 64
STRIDE_1X1 = False
BN_EPS = 1e-5
BN_MOM = 0.1
RELU_INPLACE = True
[docs]def get_trans_fun(name):
"""Retrieves the transformation function by name."""
trans_funs = {
"basic_transform": BasicTransform,
"bottleneck_transform": BottleneckTransform,
}
err_str = "Transformation function '{}' not supported"
assert name in trans_funs.keys(), err_str.format(name)
return trans_funs[name]
[docs]class GlobalHead(nn.Module):
"""
GlobalHead
"""
[docs] def __init__(self, w_in, nc, pp=3):
super().__init__()
self.pool = GeneralizedMeanPoolingP(norm=pp)
self.fc = nn.Linear(w_in, nc, bias=True)
[docs] def forward(self, x):
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
[docs]class GeneralizedMeanPooling(nn.Module):
r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
- At p = infinity, one gets Max Pooling
- At p = 1, one gets Average Pooling
The output is of size H x W, for any input size.
The number of output features is equal to the number of input planes.
Args:
output_size: the target output size of the image of the form H x W.
Can be a tuple (H, W) or a single H for a square image H x H
H and W can be either a ``int``, or ``None`` which means the size will
be the same as that of the input.
"""
[docs] def __init__(self, norm, output_size=1, eps=1e-6):
super().__init__()
assert norm > 0
self.p = float(norm)
self.output_size = output_size
self.eps = eps
[docs] def forward(self, x):
x = x.clamp(min=self.eps).pow(self.p)
return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
[docs] def __repr__(self):
return self.__class__.__name__ + "(" \
+ str(self.p) + ", " \
+ "output_size=" + str(self.output_size) + ")"
[docs]class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
""" Same, but norm is trainable
"""
[docs] def __init__(self, norm=3, output_size=1, eps=1e-6):
super().__init__(norm, output_size, eps)
self.p = nn.Parameter(torch.ones(1) * norm)
[docs]class ResBlock(nn.Module):
"""Residual block: x + F(x)."""
[docs] def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
super().__init__()
# Use skip connection with projection if shape changes
self.proj_block = (w_in != w_out) or (stride != 1)
if self.proj_block:
self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=BN_EPS, momentum=BN_MOM)
self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
self.relu = nn.ReLU(RELU_INPLACE)
[docs] def forward(self, x):
if self.proj_block:
x = self.bn(self.proj(x)) + self.f(x)
else:
x = x + self.f(x)
x = self.relu(x)
return x
[docs]class ResStage(nn.Module):
"""Stage of ResNet."""
[docs] def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
super().__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
trans_fun = get_trans_fun(TRANS_FUN)
res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
self.add_module("b{}".format(i + 1), res_block)
[docs] def forward(self, x):
for block in self.children():
x = block(x)
return x
[docs]class ResStemIN(nn.Module):
"""ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
[docs] def __init__(self, w_in, w_out):
super().__init__()
self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
self.bn = nn.BatchNorm2d(w_out, eps=BN_EPS, momentum=BN_MOM)
self.relu = nn.ReLU(RELU_INPLACE)
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
[docs] def forward(self, x):
for layer in self.children():
x = layer(x)
return x
[docs]class ResNet(nn.Module):
"""ResNet model."""
[docs] def __init__(self, reset_depth, reduction_dim):
super().__init__()
self.reset_depth = reset_depth
self.reduction_dim = reduction_dim
self._construct()
def _construct(self):
g, gw = NUM_GROUPS, WIDTH_PER_GROUP
(d1, d2, d3, d4) = _IN_STAGE_DS[self.reset_depth]
w_b = gw * g
self.stem = ResStemIN(3, 64)
self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g)
self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g)
self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g)
self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g)
self.head = GlobalHead(2048, nc=self.reduction_dim)
[docs] def forward(self, x):
x = self.stem(x)
x1 = self.s1(x)
x2 = self.s2(x1)
x3 = self.s3(x2)
x4 = self.s4(x3)
x4_p = self.head.pool(x4)
x4_p = x4_p.view(x4_p.size(0), -1)
x = self.head.fc(x4_p)
return x, x3