Source code for towhee.models.layers.pool_attention

# Copyright 2021  Facebook. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICEnSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" bASIS,
# wIthOUt wARRAntIES OR COnDItIOnS OF AnY KInD, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is modified by Zilliz.


import torch
from torch import nn
from typing import List, Tuple


[docs]class AttentionPool(nn.Module): """ A MLP block that contains two linear layers with a normalization layer. the MLP block is used in a transformer model after the attention block. :: Input Reshape Pool Reshape norm Args: thw_shape(List): the shape of the input tensor (before flattening). pool(Callable): Pool operation that is applied to the input tensor. If pool is None, return the input tensor. has_cls_embed(bool): whether the input tensor contains cls token. Pool operation excludes cls token. norm(Callable): Optional normalization operation applied to tensor after pool. Returns: tensor(torch.Tensor): Input tensor after pool. thw_shape(List[int]): Output tensor shape (before flattening). """
[docs] def __init__( self, thw_shape, pool=None, has_cls_embed=True, norm=None ) -> None: super().__init__() self.pool = pool self.thw_shape = thw_shape self.has_cls_embed = has_cls_embed self.norm = norm
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[int]]: if self.pool is None: return x, self.thw_shape tensor_dim = x.ndim if tensor_dim == 4: pass elif tensor_dim == 3: x = x.unsqueeze(1) else: raise NotImplementedError(f"Unsupported input dimension {x.shape}") if self.has_cls_embed: cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] b, n, _, c = x.shape t, h, w = self.thw_shape x = x.reshape(b * n, t, h, w, c).permute(0, 4, 1, 2, 3).contiguous() x = self.pool(x) thw_shape = [x.shape[2], x.shape[3], x.shape[4]] l_pooled = x.shape[2] * x.shape[3] * x.shape[4] x = x.reshape(b, n, c, l_pooled).transpose(2, 3) if self.has_cls_embed: x = torch.cat((cls_tok, x), dim=2) if self.norm is not None: x = self.norm(x) if tensor_dim == 4: pass else: # For the case tensor_dim == 3. x = x.squeeze(1) return x, thw_shape