# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
import logging
import torch
from torch import nn
import torch.nn.functional as F
import math
logger = logging.getLogger(__name__)
[docs]def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
[docs]def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
[docs]class LayerNorm(nn.Module):
"""
LayerNorm for DRL.
"""
[docs] def __init__(self, hidden_size, eps=1e-12):
"""
Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
[docs] def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
[docs]class CrossEn(nn.Module):
[docs] def forward(self, sim_matrix):
logpt = F.log_softmax(sim_matrix, dim=-1)
logpt = torch.diag(logpt)
nce_loss = -logpt
sim_loss = nce_loss.mean()
return sim_loss
[docs]class AllGather(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor."""
[docs] @staticmethod
def forward(ctx, tensor, args):
if args.world_size == 1:
ctx.rank = args.local_rank
ctx.batch_size = tensor.shape[0]
return tensor
else:
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
torch.distributed.all_gather(output, tensor)
ctx.rank = args.local_rank
ctx.batch_size = tensor.shape[0]
return torch.cat(output, dim=0)
[docs] @staticmethod
def backward(ctx, grad_output):
return (
grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
None,
)
[docs]class AllGather2(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor."""
# https://github.com/PyTorchLightning/lightning-bolts/blob/8d3fbf7782e3d3937ab8a1775a7092d7567f2933/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
[docs] @staticmethod
def forward(ctx, tensor, args):
if args.world_size == 1:
ctx.rank = args.local_rank
ctx.batch_size = tensor.shape[0]
return tensor
else:
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
torch.distributed.all_gather(output, tensor)
ctx.rank = args.local_rank
ctx.batch_size = tensor.shape[0]
return torch.cat(output, dim=0)
[docs] @staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
return grad_input[ctx.rank * ctx.batch_size:(ctx.rank + 1) * ctx.batch_size], None
[docs]def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)