Source code for towhee.models.layers.conv2d_separable

# Inspired by https://github.com/stdio2016/pfann
#
# Modifications by Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import nn

from typing import Union, Tuple


[docs]class SeparableConv2d(nn.Module): """ Separable Conv2D in pytorch Args: in_c (`int`): Number of channels in the input image. out_c (`int`): Number of channels produced by the convolution. kernel_size (`Union[int, Tuple]`): Size of the convolving kernel. stride (`Union[int, Tuple]`): Stride of the convolution. in_f (`int`): Padding parameter. in_t (`int`): Padding parameter. fuller (`bool=False`): Whether to use group in conv2 layer. activation (`nn.Module=nn.ReLU()`): Activation layer. relu_after_bn (`bool=False`): Whether to add ReLU after batch norm. """
[docs] def __init__( self, in_c: int, out_c: int, kernel_size: Union[int, Tuple], stride: Union[int, Tuple], in_f: int, in_t: int, fuller: bool = False, activation: nn.Module = nn.ReLU(), relu_after_bn: bool = False ): super().__init__() # Same padding padding = (in_t - 1) // stride[0] * stride[0] + kernel_size - in_t self.pad1 = nn.ZeroPad2d((padding // 2, padding - padding // 2, 0, 0)) self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=(1, kernel_size), stride=(1, stride[0])) self.ln1 = nn.LayerNorm((out_c, in_f, (in_t - 1) // stride[0] + 1)) self.relu1 = activation # Same padding padding = (in_f - 1) // stride[1] * stride[1] + kernel_size - in_f self.pad2 = nn.ZeroPad2d((0, 0, padding // 2, padding - padding // 2)) if fuller: self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=(kernel_size, 1), stride=(stride[1], 1)) else: self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=(kernel_size, 1), stride=(stride[1], 1), groups=out_c) self.ln2 = nn.LayerNorm((out_c, (in_f - 1) // stride[1] + 1, (in_t - 1) // stride[0] + 1)) self.relu2 = activation self.relu_after_bn = relu_after_bn self.hacked = False
# Equivalent to Keras Same Padding with stride=2 def hack(self): self.hacked = not self.hacked with torch.no_grad(): self.conv1.weight.set_(self.conv1.weight.flip([2, 3])) self.ln1.weight.set_(self.ln1.weight.flip([1, 2])) self.ln1.bias.set_(self.ln1.bias.flip([1, 2])) self.conv2.weight.set_(self.conv2.weight.flip([2, 3])) self.ln2.weight.set_(self.ln2.weight.flip([1, 2])) self.ln2.bias.set_(self.ln2.bias.flip([1, 2])) if self.hacked: self.conv1.padding = self.pad1.padding[3::-2] self.conv2.padding = self.pad2.padding[3::-2] else: self.conv1.padding = (0, 0) self.conv2.padding = (0, 0)
[docs] def forward(self, x): if not self.hacked: x = self.pad1(x) x = self.conv1(x) if self.relu_after_bn: x = self.ln1(x) x = self.relu1(x) else: x = self.relu1(x) x = self.ln1(x) if not self.hacked: x = self.pad2(x) x = self.conv2(x) if self.relu_after_bn: x = self.ln2(x) x = self.relu2(x) else: x = self.relu2(x) x = self.ln2(x) return x