Search

[Segmentation metric] Dice coefficient, Dice Loss

multiclass 대한 Dice coefficient 구현하기

배치 단위 이미지가 들어가면, 클래스별로 dice coefficient의 배치당 평균들의 클래스별 평균을 리턴하는 함수를 구현하고자 한다.
ex) batch size=2, num_classes = 3인 경우
클래스 0에 대한 dice coefficient의 평균
클래스 1에 대한 dice coefficient의 평균
클래스 2에 대한 dice coefficient의 평균 ⇒ 셋의 합 / 3 ( 클래스 수로 평균)
def dice_coefficient(pred, target, num_classes, ignore_idx=None): ''' softmax_pred: (N, C, H, W), ndarray target : (N, H, W), ndarray ''' assert pred.shape[0] == target.shape[0] if num_classes == 2: epsilon = 1e-6 # pred = np.around(pred) # 0.0 of 1.0 dice = 0 # if both a and b are 1-D arrays, it is inner product of vectors(without complex conjugation) for batch in range(pred.shape[0]): inter = np.dot(pred[batch].reshape((-1,)), target[batch].reshape((-1,))) sum_sets = np.sum(pred[batch]) + np.sum(target[batch]) dice += (2*inter+epsilon) / (sum_sets + epsilon) return dice / pred.shape[0] else: softmax = nn.Softmax(dim=1) pred = softmax(torch.from_numpy(pred).type(torch.float64)) pred = np.array(pred) dice = 0 for c in range(num_classes): if c==ignore_idx: continue dice += dice_coefficient(pred[:, c, :, :], np.where(target==c, 1, 0), 2, ignore_idx) return dice / num_classes
Python
복사

Dice Loss

DiceLoss=1DiceCoefficientDiceLoss = 1-DiceCoefficient
0DiceCoefficient1 0\leq DiceCoefficient \leq 1
보통 클래스 개수가 1개일 때, binary classification일 때 사용을 많이 하며 이 때는 sigmoid를, multi class 대해서 사용하려면 softmax를 사용해야 한다
import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F def dice_coefficient(pred:Tensor, target:Tensor, num_classes:int, ignore_idx=None): assert pred.shape[0] == target.shape[0] epsilon = 1e-6 if num_classes == 2: dice = 0 # if both a and b are 1-D arrays, it is inner product of vectors(without complex conjugation) for batch in range(pred.shape[0]): pred_1d = pred[batch].view(-1) target_1d = target[batch].view(-1) inter = (pred_1d * target_1d).sum() sum_sets = pred_1d.sum() + target_1d.sum() dice += (2*inter+epsilon) / (sum_sets + epsilon) return dice / pred.shape[0] elif num_classes == 1: dice = 0 pred = F.Sigmoid(pred) for batch in range(pred.shape[0]): pred_1d = pred[batch].view(-1) target_1d = target[batch].view(-1) inter = (pred_1d * target_1d).sum() sum_sets = pred_1d.sum() + target_1d.sum() dice += (2*inter+epsilon) / (sum_sets + epsilon) else: pred = F.softmax(pred, dim=1).float() dice = 0 for c in range(num_classes): if c==ignore_idx: continue dice += dice_coefficient(pred[:, c, :, :], torch.where(target==c, 1, 0), 2, ignore_idx) return dice / num_classes def dice_loss(pred, target, num_classes, ignore_idx=None): dice = dice_coefficient(pred, target, num_classes, ignore_idx) return 1 - dice
Python
복사
foreground에만 집중하는 loss라고 생각할 수 있는데, 응용하면 background까지도 고려하는 loss를 만들 수 있다.
2(AB+(1A)(1B))A+B+(2AB)\frac{2(|A\cap B|+|(1-A)\cap (1-B)|)}{|A|+|B|+(2-|A|-|B|)}

Compound Loss = CrossEntropy Loss + Dice Loss