towhee.models.utils.weight_init.trunc_normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0) Tensor[source]

Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution \(\mathcal{N}(\text{mean}, \text{std}^2)\) with values outside \([a, b]\) redrawn until they are within the bounds. The method used for generating the random values works best when \(a \leq \text{mean} \leq b\). :param tensor: an n-dimensional torch.Tensor :param mean: the mean of the normal distribution :param std: the standard deviation of the normal distribution :param a: the minimum cutoff value :param b: the maximum cutoff value


>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)