# Pytorch implementation of HorNet from
# [HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions]
# (https://arxiv.org/abs/2207.14284).
#
# Inspired by https://github.com/raoyongming/HorNet
#
# Modifications & additions by Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from towhee.models.layers.droppath import DropPath
from towhee.models.utils.weight_init import trunc_normal_
from towhee.models.convnext.utils import LayerNorm
[docs]class GatedConv(nn.Module):
"""
gnConv: Recursive Gated Convolution
Args:
dim (`int`): Feature dimension.
order (`int`): Order to expand dimensions.
gflayer (`nn.Module`): Global filter.
h (`int`): Height.
w (`int`): Width.
s (`float`): Scale ratio.
"""
[docs] def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
super().__init__()
self.order = order
self.dims = [dim // 2 ** i for i in range(order)]
self.dims.reverse()
self.proj_in = nn.Conv2d(dim, 2 * dim, 1)
if gflayer is None:
self.dwconv = nn.Conv2d(sum(self.dims), sum(self.dims),
kernel_size=7, padding=(7 - 1) // 2,
bias=True, groups=sum(self.dims))
else:
self.dwconv = gflayer(sum(self.dims), h=h, w=w)
self.proj_out = nn.Conv2d(dim, dim, 1)
self.pws = nn.ModuleList(
[nn.Conv2d(self.dims[i], self.dims[i + 1], 1) for i in range(order - 1)]
)
self.scale = s
# print('[gnconv]', order, 'order with dims=', self.dims, 'scale=%.4f' % self.scale)
[docs] def forward(self, x):
# x: BCHW
fused_x = self.proj_in(x)
pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
dw_abc = self.dwconv(abc) * self.scale
dw_list = torch.split(dw_abc, self.dims, dim=1)
x = pwa * dw_list[0]
for i in range(self.order - 1):
x = self.pws[i](x) * dw_list[i + 1]
x = self.proj_out(x)
return x
[docs]class GlobalLocalFilter(nn.Module):
"""
Global local filter.
Args:
dim (`int`): Feature dimension.
h (`int`): Height.
w (`int`): Width.
"""
[docs] def __init__(self, dim, h=14, w=8):
super().__init__()
self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2)
self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
trunc_normal_(self.complex_weight, std=.02)
self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
[docs] def forward(self, x):
x = self.pre_norm(x)
x1, x2 = torch.chunk(x, 2, dim=1)
x1 = self.dw(x1)
x2 = x2.to(torch.float32)
num_b, num_c, a, b = x2.shape
x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
weight = self.complex_weight
if not weight.shape[1:3] == x2.shape[2:4]:
weight = nn.functional.interpolate(
weight.permute(3, 0, 1, 2), size=x2.shape[2:4],
mode='bilinear', align_corners=True).permute(1, 2, 3, 0)
weight = torch.view_as_complex(weight.contiguous())
x2 = x2 * weight
x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')
x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(num_b, 2 * num_c, a, b)
x = self.post_norm(x)
return x
[docs]class Block(nn.Module):
"""
HorNet block
Args:
dim (`int`): Feature dimension.
drop_rate (`float`): Drop ratio.
layer_scale_init_value (`float`): initial value to scale layer.
gnconv (`nn.Module`): gnConv layer.
"""
[docs] def __init__(self, dim, drop_rate=0., layer_scale_init_value=1e-6, gnconv=GatedConv):
super().__init__()
self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first')
self.gnconv = gnconv(dim) # depthwise conv
self.norm2 = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim),
requires_grad=True) if layer_scale_init_value > 0 else None
self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_rate) if drop_rate > 0. else nn.Identity()
[docs] def forward(self, x):
# x: BCHW
_, num_c, _, _ = x.shape
if self.gamma1 is not None:
gamma1 = self.gamma1.view(num_c, 1, 1)
else:
gamma1 = 1
x = x + self.drop_path(gamma1 * self.gnconv(self.norm1(x)))
x_input = x
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm2(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma2 is not None:
x = self.gamma2 * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = x_input + self.drop_path(x)
return x