multiclass 대한 Dice coefficient 구현하기
배치 단위 이미지가 들어가면, 클래스별로 dice coefficient의 배치당 평균들의 클래스별 평균을 리턴하는 함수를 구현하고자 한다.
ex) batch size=2, num_classes = 3인 경우
클래스 0에 대한 dice coefficient의 평균
클래스 1에 대한 dice coefficient의 평균
클래스 2에 대한 dice coefficient의 평균
⇒ 셋의 합 / 3 ( 클래스 수로 평균)
def dice_coefficient(pred, target, num_classes, ignore_idx=None):
'''
softmax_pred: (N, C, H, W), ndarray
target : (N, H, W), ndarray
'''
assert pred.shape[0] == target.shape[0]
if num_classes == 2:
epsilon = 1e-6
# pred = np.around(pred) # 0.0 of 1.0
dice = 0
# if both a and b are 1-D arrays, it is inner product of vectors(without complex conjugation)
for batch in range(pred.shape[0]):
inter = np.dot(pred[batch].reshape((-1,)), target[batch].reshape((-1,)))
sum_sets = np.sum(pred[batch]) + np.sum(target[batch])
dice += (2*inter+epsilon) / (sum_sets + epsilon)
return dice / pred.shape[0]
else:
softmax = nn.Softmax(dim=1)
pred = softmax(torch.from_numpy(pred).type(torch.float64))
pred = np.array(pred)
dice = 0
for c in range(num_classes):
if c==ignore_idx:
continue
dice += dice_coefficient(pred[:, c, :, :], np.where(target==c, 1, 0), 2, ignore_idx)
return dice / num_classes
Python
복사
Dice Loss
보통 클래스 개수가 1개일 때, binary classification일 때 사용을 많이 하며 이 때는 sigmoid를,
multi class 대해서 사용하려면 softmax를 사용해야 한다
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
def dice_coefficient(pred:Tensor, target:Tensor, num_classes:int, ignore_idx=None):
assert pred.shape[0] == target.shape[0]
epsilon = 1e-6
if num_classes == 2:
dice = 0
# if both a and b are 1-D arrays, it is inner product of vectors(without complex conjugation)
for batch in range(pred.shape[0]):
pred_1d = pred[batch].view(-1)
target_1d = target[batch].view(-1)
inter = (pred_1d * target_1d).sum()
sum_sets = pred_1d.sum() + target_1d.sum()
dice += (2*inter+epsilon) / (sum_sets + epsilon)
return dice / pred.shape[0]
elif num_classes == 1:
dice = 0
pred = F.Sigmoid(pred)
for batch in range(pred.shape[0]):
pred_1d = pred[batch].view(-1)
target_1d = target[batch].view(-1)
inter = (pred_1d * target_1d).sum()
sum_sets = pred_1d.sum() + target_1d.sum()
dice += (2*inter+epsilon) / (sum_sets + epsilon)
else:
pred = F.softmax(pred, dim=1).float()
dice = 0
for c in range(num_classes):
if c==ignore_idx:
continue
dice += dice_coefficient(pred[:, c, :, :], torch.where(target==c, 1, 0), 2, ignore_idx)
return dice / num_classes
def dice_loss(pred, target, num_classes, ignore_idx=None):
dice = dice_coefficient(pred, target, num_classes, ignore_idx)
return 1 - dice
Python
복사
foreground에만 집중하는 loss라고 생각할 수 있는데, 응용하면 background까지도 고려하는 loss를 만들 수 있다.