-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
86 lines (57 loc) · 2.14 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import segmentation_models_pytorch as smp
import torch
import numpy as np
import torch.nn as nn
import config
loss = config.LOSS
JaccardLoss = smp.losses.JaccardLoss(mode="binary", smooth=1.0)
DiceLoss = smp.losses.DiceLoss(mode="binary", smooth=1.0)
FocalLoss = smp.losses.FocalLoss(mode="binary")
BCELoss = smp.losses.SoftBCEWithLogitsLoss()
LovaszLoss = smp.losses.LovaszLoss(mode="binary", per_image=False)
TverskyLoss = smp.losses.TverskyLoss(mode="binary", log_loss=False)
### Declaration
if loss == "Dice":
def criterion(y_pred, y_true):
return 0.5 * BCELoss(y_pred, y_true) + 0.5 * TverskyLoss(y_pred, y_true)
if loss == "BCE":
def criterion(y_pred, y_true):
return BCELoss(y_pred, y_true)
if loss == "BCE_Tversky":
def criterion(y_pred, y_true):
return 0.5 * BCELoss(y_pred, y_true) + 0.5 * TverskyLoss(y_pred, y_true)
if loss == "Lovasz":
def criterion(y_pred, y_true):
return LovaszLoss(y_pred, y_true)
if loss == "Dice_BCE":
def criterion(y_pred, y_true):
return 0.5 * BCELoss(y_pred, y_true) + 0.5 * DiceLoss(y_pred, y_true)
if loss == "FocalLoss":
def criterion(y_pred, y_true):
return FocalLoss(y_pred, y_true)
def dice_metric(_mask1, _mask2):
batch_size = _mask1.shape[0]
dice_total = 0.0
for idx in range(batch_size):
mask1 = (
_mask1[idx]
.reshape(_mask1[idx].shape[1], _mask1[idx].shape[2])
.cpu()
.numpy()
)
mask2 = (
_mask2[idx]
.reshape(_mask2[idx].shape[1], _mask2[idx].shape[2])
.cpu()
.numpy()
)
intersect = np.sum(mask1 * mask2)
fsum = np.sum(mask1)
ssum = np.sum(mask2)
eps = 1e-7 ## for empty masks the numerator should not be divivded by zero
dice = (2 * intersect + eps) / (fsum + ssum + eps)
dice = np.mean(dice)
dice_total += dice
final_dice = dice_total / batch_size
final_dice = round(final_dice, 4) # for easy reading till 4 decimal places
return final_dice