Source code for towhee.models.drl.module_cross

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging

import torch
from torch import nn
from towhee.models.drl.until_module import LayerNorm
from collections import OrderedDict

logger = logging.getLogger(__name__)

PRETRAINED_MODEL_ARCHIVE_MAP = {}
CONFIG_NAME = "cross_config.json"
WEIGHTS_NAME = "cross_pytorch_model.bin"


[docs]class QuickGELU(nn.Module): """ QuickGELU for DRL. """
[docs] def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x)
[docs]class ResidualAttentionBlock(nn.Module): """ Residual attention block for DRL. """
[docs] def __init__(self, d_model: int, n_head: int): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.n_head = n_head
def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0) return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
[docs] def forward(self, para_tuple: tuple): # x: torch.Tensor, attn_mask: torch.Tensor # print(para_tuple) x, attn_mask = para_tuple x = x + self.attention(self.ln_1(x), attn_mask) x = x + self.mlp(self.ln_2(x)) return x, attn_mask
[docs]class Transformer(nn.Module): """ Transformer for DRL. """
[docs] def __init__(self, width: int, layers: int, heads: int): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)])
[docs] def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): return self.resblocks((x, attn_mask))[0]
[docs]class CrossEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """
[docs] def __init__(self, config): super().__init__() self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob)
[docs] def forward(self, concat_embeddings, concat_type=None): _ = concat_type _, seq_length = concat_embeddings.size(0), concat_embeddings.size(1) # if concat_type is None: # concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device) position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device) position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1) # token_type_embeddings = self.token_type_embeddings(concat_type) position_embeddings = self.position_embeddings(position_ids) embeddings = concat_embeddings + position_embeddings # + token_type_embeddings # embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings
[docs]class CrossPooler(nn.Module): """ CrossPooler for DRL. """
[docs] def __init__(self, config): super().__init__() self.ln_pool = LayerNorm(config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = QuickGELU()
[docs] def forward(self, hidden_states, hidden_mask): # We "pool" the model by simply taking the hidden state corresponding # to the first token. _ = hidden_mask hidden_states = self.ln_pool(hidden_states) pooled_output = hidden_states[:, 0] pooled_output = self.dense(pooled_output) pooled_output = self.activation(pooled_output) return pooled_output
[docs]class CrossModel(nn.Module): """ CrossModel for DRL. """ def initialize_parameters(self): proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
[docs] def __init__(self, config): super().__init__() self.config = config self.embeddings = CrossEmbeddings(config) transformer_width = config.hidden_size transformer_layers = config.num_hidden_layers transformer_heads = config.num_attention_heads self.transformer = Transformer(width=transformer_width, layers=transformer_layers, heads=transformer_heads, ) self.pooler = CrossPooler(config) self.apply(self.init_weights)
def build_attention_mask(self, attention_mask): extended_attention_mask = attention_mask.unsqueeze(1) extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0 extended_attention_mask = extended_attention_mask.expand(-1, attention_mask.size(1), -1) return extended_attention_mask
[docs] def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True): _ = output_all_encoded_layers if attention_mask is None: attention_mask = torch.ones(concat_input.size(0), concat_input.size(1)) if concat_type is None: concat_type = torch.zeros_like(attention_mask) extended_attention_mask = self.build_attention_mask(attention_mask) embedding_output = self.embeddings(concat_input, concat_type) embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND embedding_output = self.transformer(embedding_output, extended_attention_mask) embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD pooled_output = self.pooler(embedding_output, hidden_mask=attention_mask) return embedding_output, pooled_output
@property def dtype(self): """ :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ try: return next(self.parameters()).dtype except StopIteration: # For nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: nn.Module): tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = self._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].dtype
[docs] def init_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, LayerNorm): if "beta" in dir(module) and "gamma" in dir(module): module.beta.data.zero_() module.gamma.data.fill_(1.0) else: module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_()