Source code for towhee.models.layers.mixed_conv2d

# Copyright 2021 Ross Wightman . All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is modified by Zilliz.
from typing import List

import torch
from torch import nn

from towhee.models.utils.create_conv2d_pad import create_conv2d_pad

def _split_channels(num_chan: int, num_groups: int) -> List[int]:
    split = [num_chan // num_groups for _ in range(num_groups)]
    split[0] += num_chan - sum(split)
    return split

[docs]class MixedConv2d(nn.ModuleDict): """ Mixed Grouped Convolution Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py Args: in_channels (`int`): Number of channels in the input image out_channels (`int`): Number of channels produced by the convolution. kernel_size (`int`): Size of the convolving kernel. stride (`int`): Stride of the convolution. padding (`str`): Padding added to all four sides of the input. dilation (`int`): Spacing between kernel elements. depthwise (`bool`): If True, use depthwise convolution. """
[docs] def __init__(self, in_channels: int, out_channels: int, kernel_size: int=3, stride: int=1, padding: str='', dilation: int=1, depthwise: bool=False, **kwargs) -> None: super().__init__() kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] num_groups = len(kernel_size) in_splits = _split_channels(in_channels, num_groups) out_splits = _split_channels(out_channels, num_groups) self.in_channels = sum(in_splits) self.out_channels = sum(out_splits) for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): conv_groups = in_ch if depthwise else 1 # use add_module to keep key space clean self.add_module( str(idx), create_conv2d_pad( in_ch, out_ch, k, stride=stride, padding=padding, dilation=dilation, groups=conv_groups, **kwargs) ) self.splits = in_splits
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x_split = torch.split(x, self.splits, 1) x_out = [c(x_split[i]) for i, c in enumerate(self.values())] x = torch.cat(x_out, 1) return x