```# Copyright 2021 Facebook and Zilliz. All rights reserved.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import math
from typing import Callable, Iterable, Tuple
import torch
from torch import nn
from torch.optim import Optimizer

"""
Implements Adam algorithm with weight decay fix as introduced in `Decoupled Weight Decay Regularization
<https://arxiv.org/abs/1711.05101>`.
Parameters:
params (Iterable[nn.parameter.Parameter]):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (float, optional):
The learning rate to use.
betas (Tuple[float,float], optional):
eps (float, optional):
weight_decay (float, optional):
Decoupled weight decay to apply.
correct_bias (bool, optional):
Whether or not to correct bias in Adam.
"""

def __init__(
self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0[")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0[")
if 0.0 > eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
super().__init__(params, defaults)

[docs]    def step(self, closure: Callable = None):
"""
Performs a single optimization step.
Arguments:
closure(Callable, optional):
A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
continue

state = self.state[p]

# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]

state["step"] += 1

# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time

step_size = group["lr"]
if group["correct_bias"]:  # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0: