| import math |
| import numpy as np |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| __all__ = [ |
| "focal_loss_with_logits", |
| "softmax_focal_loss_with_logits", |
| "soft_jaccard_score", |
| "soft_dice_score", |
| "wing_loss", |
| ] |
|
|
|
|
| def to_tensor(x, dtype=None) -> torch.Tensor: |
| if isinstance(x, torch.Tensor): |
| if dtype is not None: |
| x = x.type(dtype) |
| return x |
| if isinstance(x, np.ndarray): |
| x = torch.from_numpy(x) |
| if dtype is not None: |
| x = x.type(dtype) |
| return x |
| if isinstance(x, (list, tuple)): |
| x = np.array(x) |
| x = torch.from_numpy(x) |
| if dtype is not None: |
| x = x.type(dtype) |
| return x |
|
|
|
|
| def focal_loss_with_logits( |
| output: torch.Tensor, |
| target: torch.Tensor, |
| gamma: float = 2.0, |
| alpha: Optional[float] = 0.25, |
| reduction: str = "mean", |
| normalized: bool = False, |
| reduced_threshold: Optional[float] = None, |
| eps: float = 1e-6, |
| ) -> torch.Tensor: |
| """Compute binary focal loss between target and output logits. |
| See :class:`~pytorch_toolbelt.losses.FocalLoss` for details. |
| |
| Args: |
| output: Tensor of arbitrary shape (predictions of the model) |
| target: Tensor of the same shape as input |
| gamma: Focal loss power factor |
| alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, |
| high values will give more weight to positive class. |
| reduction (string, optional): Specifies the reduction to apply to the output: |
| 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied, |
| 'mean': the sum of the output will be divided by the number of |
| elements in the output, 'sum': the output will be summed. Note: :attr:`size_average` |
| and :attr:`reduce` are in the process of being deprecated, and in the meantime, |
| specifying either of those two args will override :attr:`reduction`. |
| 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' |
| normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). |
| reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). |
| |
| References: |
| https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py |
| """ |
| target = target.type(output.type()) |
|
|
| logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none") |
| pt = torch.exp(-logpt) |
|
|
| |
| if reduced_threshold is None: |
| focal_term = (1.0 - pt).pow(gamma) |
| else: |
| focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) |
| focal_term[pt < reduced_threshold] = 1 |
|
|
| loss = focal_term * logpt |
|
|
| if alpha is not None: |
| loss *= alpha * target + (1 - alpha) * (1 - target) |
|
|
| if normalized: |
| norm_factor = focal_term.sum().clamp_min(eps) |
| loss /= norm_factor |
|
|
| if reduction == "mean": |
| loss = loss.mean() |
| if reduction == "sum": |
| loss = loss.sum() |
| if reduction == "batchwise_mean": |
| loss = loss.sum(0) |
|
|
| return loss |
|
|
|
|
| def softmax_focal_loss_with_logits( |
| output: torch.Tensor, |
| target: torch.Tensor, |
| gamma: float = 2.0, |
| reduction="mean", |
| normalized=False, |
| reduced_threshold: Optional[float] = None, |
| eps: float = 1e-6, |
| ) -> torch.Tensor: |
| """Softmax version of focal loss between target and output logits. |
| See :class:`~pytorch_toolbelt.losses.FocalLoss` for details. |
| |
| Args: |
| output: Tensor of shape [B, C, *] (Similar to nn.CrossEntropyLoss) |
| target: Tensor of shape [B, *] (Similar to nn.CrossEntropyLoss) |
| reduction (string, optional): Specifies the reduction to apply to the output: |
| 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied, |
| 'mean': the sum of the output will be divided by the number of |
| elements in the output, 'sum': the output will be summed. Note: :attr:`size_average` |
| and :attr:`reduce` are in the process of being deprecated, and in the meantime, |
| specifying either of those two args will override :attr:`reduction`. |
| 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' |
| normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). |
| reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). |
| """ |
| log_softmax = F.log_softmax(output, dim=1) |
|
|
| loss = F.nll_loss(log_softmax, target, reduction="none") |
| pt = torch.exp(-loss) |
|
|
| |
| if reduced_threshold is None: |
| focal_term = (1.0 - pt).pow(gamma) |
| else: |
| focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) |
| focal_term[pt < reduced_threshold] = 1 |
|
|
| loss = focal_term * loss |
|
|
| if normalized: |
| norm_factor = focal_term.sum().clamp_min(eps) |
| loss = loss / norm_factor |
|
|
| if reduction == "mean": |
| loss = loss.mean() |
| if reduction == "sum": |
| loss = loss.sum() |
| if reduction == "batchwise_mean": |
| loss = loss.sum(0) |
|
|
| return loss |
|
|
|
|
| def soft_jaccard_score( |
| output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None |
| ) -> torch.Tensor: |
| assert output.size() == target.size() |
| if dims is not None: |
| intersection = torch.sum(output * target, dim=dims) |
| cardinality = torch.sum(output + target, dim=dims) |
| else: |
| intersection = torch.sum(output * target) |
| cardinality = torch.sum(output + target) |
|
|
| union = cardinality - intersection |
| jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps) |
| return jaccard_score |
|
|
|
|
| def soft_dice_score( |
| output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None |
| ) -> torch.Tensor: |
| assert output.size() == target.size() |
| if dims is not None: |
| intersection = torch.sum(output * target, dim=dims) |
| cardinality = torch.sum(output + target, dim=dims) |
| else: |
| intersection = torch.sum(output * target) |
| cardinality = torch.sum(output + target) |
| dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps) |
| return dice_score |
|
|
|
|
| def soft_tversky_score(output: torch.Tensor, target: torch.Tensor, alpha: float, beta: float, |
| smooth: float = 0.0, eps: float = 1e-7, dims=None) -> torch.Tensor: |
| assert output.size() == target.size() |
| if dims is not None: |
| intersection = torch.sum(output * target, dim=dims) |
| fp = torch.sum(output * (1. - target), dim=dims) |
| fn = torch.sum((1 - output) * target, dim=dims) |
| else: |
| intersection = torch.sum(output * target) |
| fp = torch.sum(output * (1. - target)) |
| fn = torch.sum((1 - output) * target) |
|
|
| tversky_score = (intersection + smooth) / (intersection + alpha * fp + beta * fn + smooth).clamp_min(eps) |
| return tversky_score |
|
|
|
|
| def wing_loss(output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"): |
| """ |
| https://arxiv.org/pdf/1711.06753.pdf |
| :param output: |
| :param target: |
| :param width: |
| :param curvature: |
| :param reduction: |
| :return: |
| """ |
| diff_abs = (target - output).abs() |
| loss = diff_abs.clone() |
|
|
| idx_smaller = diff_abs < width |
| idx_bigger = diff_abs >= width |
|
|
| loss[idx_smaller] = width * torch.log(1 + diff_abs[idx_smaller] / curvature) |
|
|
| C = width - width * math.log(1 + width / curvature) |
| loss[idx_bigger] = loss[idx_bigger] - C |
|
|
| if reduction == "sum": |
| loss = loss.sum() |
|
|
| if reduction == "mean": |
| loss = loss.mean() |
|
|
| return loss |
|
|
|
|
| def label_smoothed_nll_loss( |
| lprobs: torch.Tensor, target: torch.Tensor, epsilon: float, ignore_index=None, reduction="mean", dim=-1 |
| ) -> torch.Tensor: |
| """ |
| Source: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py |
| :param lprobs: Log-probabilities of predictions (e.g after log_softmax) |
| :param target: |
| :param epsilon: |
| :param ignore_index: |
| :param reduction: |
| :return: |
| """ |
| if target.dim() == lprobs.dim() - 1: |
| target = target.unsqueeze(dim) |
|
|
| if ignore_index is not None: |
| pad_mask = target.eq(ignore_index) |
| target = target.masked_fill(pad_mask, 0) |
| nll_loss = -lprobs.gather(dim=dim, index=target) |
| smooth_loss = -lprobs.sum(dim=dim, keepdim=True) |
|
|
| |
| |
| nll_loss = nll_loss.masked_fill(pad_mask, 0.0) |
| smooth_loss = smooth_loss.masked_fill(pad_mask, 0.0) |
| else: |
| nll_loss = -lprobs.gather(dim=dim, index=target) |
| smooth_loss = -lprobs.sum(dim=dim, keepdim=True) |
|
|
| nll_loss = nll_loss.squeeze(dim) |
| smooth_loss = smooth_loss.squeeze(dim) |
|
|
| if reduction == "sum": |
| nll_loss = nll_loss.sum() |
| smooth_loss = smooth_loss.sum() |
| if reduction == "mean": |
| nll_loss = nll_loss.mean() |
| smooth_loss = smooth_loss.mean() |
|
|
| eps_i = epsilon / lprobs.size(dim) |
| loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss |
| return loss |
|
|