Source code for towhee.models.layers.patch_merging

# Copyright 2021 Microsoft . All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is modified by Zilliz.
import torch
from torch import nn


[docs]class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """
[docs] def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, is_v2=False): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.is_v2 = is_v2 if self.is_v2: self.norm = norm_layer(2 * dim) else: self.norm = norm_layer(4 * dim)
[docs] def forward(self, x): """ x: B, H*W, C """ h, w = self.input_resolution b, l, c = x.shape assert l == h * w, 'input feature has wrong size' assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.' x = x.view(b, h, w, c) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(b, -1, 4 * c) # B H/2*W/2 4*C if self.is_v2: x = self.reduction(x) x = self.norm(x) else: x = self.norm(x) x = self.reduction(x) return x
[docs] def extra_repr(self) -> str: return f'input_resolution={self.input_resolution}, dim={self.dim}'
def flops(self): h, w = self.input_resolution flops = h * w * self.dim flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim return flops