Source code for towhee.models.loss.focal_loss

# Copyright clcarwin. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is modified by Zilliz.

import torch
import torch.nn.functional as F

from torch import nn
from torch.autograd import Variable


[docs]class FocalLoss(nn.Module): """ Focal loss for object detection. Args: gamma (`float`): hyper-parameter of focal loss alpha (`float`): hyper-parameter of focal loss size_average (`bool`): whether to average the loss """
[docs] def __init__(self, gamma: float = 0, alpha: float = None, size_average: bool = True): super().__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha]) if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) self.size_average = size_average
[docs] def forward(self, input_s: torch.Tensor, target: torch.Tensor): """ Args: input_s (`float`): predicted result target (`float`): ground truth """ if input_s.dim() > 2: input_s = input_s.view(input_s.size(0), input_s.size(1), -1) # N,C,H,W => N,C,H*W input_s = input_s.transpose(1, 2) # N,C,H*W => N,H*W,C input_s = input_s.contiguous().view(-1, input_s.size(2)) # N,H*W,C => N*H*W,C target = target.view(-1, 1) logpt = F.log_softmax(input_s) logpt = logpt.gather(1, target) logpt = logpt.view(-1) pt = Variable(logpt.data.exp()) if self.alpha is not None: if self.alpha.type() != input_s.data.type(): self.alpha = self.alpha.type_as(input_s.data) at = self.alpha.gather(0, target.data.view(-1)) logpt = logpt * Variable(at) loss = -1 * (1 - pt) ** self.gamma * logpt if self.size_average: return loss.mean() else: return loss.sum()