Source code for towhee.models.layers.patch_merging3d

# original code from
# modified by Zilliz.

import torch
from torch import nn
import torch.nn.functional as F

[docs]class PatchMerging3D(nn.Module): """ 3D Patch Merging Layer. Args: dim (`int`): Number of input channels. norm_layer (`nn.Module`): Normalization layer. Default: nn.LayerNorm """
[docs] def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim)
[docs] def forward(self, x): """ Forward function of 3D Patch Merging Layer. Args: x (`tensor`): Input tensor with size (B, D, H, W, C) """ _, _, h, w, _ = x.shape # padding pad_input = (h % 2 == 1) or (w % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2)) x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C x =[x0, x1, x2, x3], -1) # B D H/2 W/2 4*C x = self.norm(x) x = self.reduction(x) return x
[docs]class PatchMerging(nn.Module): """ Patch Merging Layer Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """
[docs] def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim)
[docs] def forward(self, x): """ Forward function. Args: x: Input feature, tensor size (B, D, H, W, C). """ _, _, h, w, _ = x.shape # padding pad_input = (h % 2 == 1) or (w % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2)) x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C x =[x0, x1, x2, x3], -1) # B D H/2 W/2 4*C x = self.norm(x) x = self.reduction(x) return x