Source code for torch.nn.modules.normalization

import torch
import numbers
from torch.nn.parameter import Parameter
from .module import Module
from ._functions import CrossMapLRN2d as _cross_map_lrn2d
from .. import functional as F
from .. import init

from torch import Tensor, Size
from typing import Union, List, Tuple


class LocalResponseNorm(Module):
    r"""Applies local response normalization over an input signal composed
    of several input planes, where channels occupy the second dimension.
    Applies normalization across channels.

    .. math::
        b_{c} = a_{c}\left(k + \frac{\alpha}{n}
        \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}

    Args:
        size: amount of neighbouring channels used for normalization
        alpha: multiplicative factor. Default: 0.0001
        beta: exponent. Default: 0.75
        k: additive factor. Default: 1

    Shape:
        - Input: :math:`(N, C, *)`
        - Output: :math:`(N, C, *)` (same shape as input)

    Examples::

        >>> lrn = nn.LocalResponseNorm(2)
        >>> signal_2d = torch.randn(32, 5, 24, 24)
        >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7)
        >>> output_2d = lrn(signal_2d)
        >>> output_4d = lrn(signal_4d)

    """
    __constants__ = ['size', 'alpha', 'beta', 'k']
    size: int
    alpha: float
    beta: float
    k: float

    def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.) -> None:
        super(LocalResponseNorm, self).__init__()
        self.size = size
        self.alpha = alpha
        self.beta = beta
        self.k = k

    def forward(self, input: Tensor) -> Tensor:
        return F.local_response_norm(input, self.size, self.alpha, self.beta,
                                     self.k)

    def extra_repr(self):
        return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)


class CrossMapLRN2d(Module):
    size: int
    alpha: float
    beta: float
    k: float

    def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1) -> None:
        super(CrossMapLRN2d, self).__init__()
        self.size = size
        self.alpha = alpha
        self.beta = beta
        self.k = k

    def forward(self, input: Tensor) -> Tensor:
        return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta,
                                      self.k)

    def extra_repr(self) -> str:
        return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)


_shape_t = Union[int, List[int], Size]


class LayerNorm(Module):
    r"""Applies Layer Normalization over a mini-batch of inputs as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
    is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
    is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
    the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
    The standard-deviation is calculated via the biased estimator, equivalent to
    `torch.var(input, unbiased=False)`.

    .. note::
        Unlike Batch Normalization and Instance Normalization, which applies
        scalar scale and bias for each entire channel/plane with the
        :attr:`affine` option, Layer Normalization applies per-element scale and
        bias with :attr:`elementwise_affine`.

    This layer uses statistics computed from input data in both training and
    evaluation modes.

    Args:
        normalized_shape (int or list or torch.Size): input shape from an expected input
            of size

            .. math::
                [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
                    \times \ldots \times \text{normalized\_shape}[-1]]

            If a single integer is used, it is treated as a singleton list, and this module will
            normalize over the last dimension which is expected to be of that specific size.
        eps: a value added to the denominator for numerical stability. Default: 1e-5
        elementwise_affine: a boolean value that when set to ``True``, this module
            has learnable per-element affine parameters initialized to ones (for weights)
            and zeros (for biases). Default: ``True``.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
            The values are initialized to 1.
        bias:   the learnable bias of the module of shape
                :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
                The values are initialized to 0.

    Shape:
        - Input: :math:`(N, *)`
        - Output: :math:`(N, *)` (same shape as input)

    Examples::

        >>> # NLP Example
        >>> batch, sentence_length, embedding_dim = 20, 5, 10
        >>> embedding = torch.randn(batch, sentence_length, embedding_dim)
        >>> layer_norm = nn.LayerNorm(embedding_dim)
        >>> # Activate module
        >>> layer_norm(embedding)
        >>>
        >>> # Image Example
        >>> N, C, H, W = 20, 5, 10, 10
        >>> input = torch.randn(N, C, H, W)
        >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
        >>> # as shown in the image below
        >>> layer_norm = nn.LayerNorm([C, H, W])
        >>> output = layer_norm(input)

    .. image:: ../_static/img/nn/layer_norm.jpg
        :scale: 50 %

    """
    __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
    normalized_shape: Tuple[int, ...]
    eps: float
    elementwise_affine: bool

[docs] def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super(LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): # mypy error: incompatible types in assignment normalized_shape = (normalized_shape,) # type: ignore[assignment] self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.reset_parameters()
def reset_parameters(self) -> None: if self.elementwise_affine: init.ones_(self.weight) init.zeros_(self.bias)
[docs] def forward(self, input: Tensor) -> Tensor: return F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps)
[docs] def extra_repr(self) -> str: return '{normalized_shape}, eps={eps}, ' \ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
class GroupNorm(Module): r"""Applies Group Normalization over a mini-batch of inputs as described in the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__ .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta The input channels are separated into :attr:`num_groups` groups, each containing ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by :attr:`num_groups`. The mean and standard-deviation are calculated separately over the each group. :math:`\gamma` and :math:`\beta` are learnable per-channel affine transform parameter vectors of size :attr:`num_channels` if :attr:`affine` is ``True``. The standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. This layer uses statistics computed from input data in both training and evaluation modes. Args: num_groups (int): number of groups to separate the channels into num_channels (int): number of channels expected in input eps: a value added to the denominator for numerical stability. Default: 1e-5 affine: a boolean value that when set to ``True``, this module has learnable per-channel affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``. Shape: - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}` - Output: :math:`(N, C, *)` (same shape as input) Examples:: >>> input = torch.randn(20, 6, 10, 10) >>> # Separate 6 channels into 3 groups >>> m = nn.GroupNorm(3, 6) >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) >>> m = nn.GroupNorm(6, 6) >>> # Put all 6 channels into a single group (equivalent with LayerNorm) >>> m = nn.GroupNorm(1, 6) >>> # Activating the module >>> output = m(input) """ __constants__ = ['num_groups', 'num_channels', 'eps', 'affine'] num_groups: int num_channels: int eps: float affine: bool def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super(GroupNorm, self).__init__() if num_channels % num_groups != 0: raise ValueError('num_channels must be divisible by num_groups') self.num_groups = num_groups self.num_channels = num_channels self.eps = eps self.affine = affine if self.affine: self.weight = Parameter(torch.empty(num_channels, **factory_kwargs)) self.bias = Parameter(torch.empty(num_channels, **factory_kwargs)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self) -> None: if self.affine: init.ones_(self.weight) init.zeros_(self.bias) def forward(self, input: Tensor) -> Tensor: return F.group_norm( input, self.num_groups, self.weight, self.bias, self.eps) def extra_repr(self) -> str: return '{num_groups}, {num_channels}, eps={eps}, ' \ 'affine={affine}'.format(**self.__dict__) # TODO: ContrastiveNorm2d # TODO: DivisiveNorm2d # TODO: SubtractiveNorm2d