Source code for towhee.models.drl.until_module

# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""

import logging
import torch
from torch import nn
import torch.nn.functional as F
import math

logger = logging.getLogger(__name__)


[docs]def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
[docs]def swish(x): return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
[docs]class LayerNorm(nn.Module): """ LayerNorm for DRL. """
[docs] def __init__(self, hidden_size, eps=1e-12): """ Construct a layernorm module in the TF style (epsilon inside the square root). """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps
[docs] def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias
[docs]class CrossEn(nn.Module):
[docs] def forward(self, sim_matrix): logpt = F.log_softmax(sim_matrix, dim=-1) logpt = torch.diag(logpt) nce_loss = -logpt sim_loss = nce_loss.mean() return sim_loss
[docs]class AllGather(torch.autograd.Function): """An autograd function that performs allgather on a tensor."""
[docs] @staticmethod def forward(ctx, tensor, args): if args.world_size == 1: ctx.rank = args.local_rank ctx.batch_size = tensor.shape[0] return tensor else: output = [torch.empty_like(tensor) for _ in range(args.world_size)] torch.distributed.all_gather(output, tensor) ctx.rank = args.local_rank ctx.batch_size = tensor.shape[0] return torch.cat(output, dim=0)
[docs] @staticmethod def backward(ctx, grad_output): return ( grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)], None, )
[docs]class AllGather2(torch.autograd.Function): """An autograd function that performs allgather on a tensor.""" # https://github.com/PyTorchLightning/lightning-bolts/blob/8d3fbf7782e3d3937ab8a1775a7092d7567f2933/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
[docs] @staticmethod def forward(ctx, tensor, args): if args.world_size == 1: ctx.rank = args.local_rank ctx.batch_size = tensor.shape[0] return tensor else: output = [torch.empty_like(tensor) for _ in range(args.world_size)] torch.distributed.all_gather(output, tensor) ctx.rank = args.local_rank ctx.batch_size = tensor.shape[0] return torch.cat(output, dim=0)
[docs] @staticmethod def backward(ctx, grad_output): grad_input = grad_output.clone() torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) return grad_input[ctx.rank * ctx.batch_size:(ctx.rank + 1) * ctx.batch_size], None
[docs]def convert_weights(model: nn.Module): """Convert applicable model parameters to fp16""" def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() if isinstance(l, nn.MultiheadAttention): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.half() for name in ["text_projection", "proj"]: if hasattr(l, name): attr = getattr(l, name) if attr is not None: attr.data = attr.data.half() model.apply(_convert_weights_to_fp16)