-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_save_scripts.py
126 lines (112 loc) · 4.85 KB
/
test_save_scripts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os.path as osp
import time
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from torchvision import utils as vutils
from advent.utils.func import per_class_iu, fast_hist
from advent.utils.serialization import pickle_dump, pickle_load
import cv2
from advent.utils.tools import (
print_losses,
tesnorDict2numDict,
write_images
)
def eval_best(cfg, model,
device, test_loader, comet_exp,
verbose):
fixed_test_size = cfg.TEST.Model.fixed_test_size
if fixed_test_size:
interp = nn.Upsample(size=(cfg.TEST.Model.OUTPUT_SIZE_TARGET[1], cfg.TEST.Model.OUTPUT_SIZE_TARGET[0]), mode='bilinear',
align_corners=True)
cur_best_miou = -1
cur_best_model = ''
i_iter = cfg.TEST.Model.test_iter
restore_from = osp.join(cfg.TEST.Model.SNAPSHOT_DIR, f'model_{i_iter}.pth')
if not osp.exists(restore_from):
# continue
if cfg.TEST.WAIT_MODEL:
print('Waiting for model..!')
while not osp.exists(restore_from):
time.sleep(5)
print("Evaluating model", restore_from)
load_checkpoint_for_evaluation(model, restore_from, device)
# eval
hist = np.zeros((cfg.NUM_CLASSES, cfg.NUM_CLASSES))
test_iter = enumerate(test_loader)
for index in tqdm(range(len(test_loader))):
_, batch = test_iter.__next__()
# print(test_iter.__next__())
image, label, path = batch['data']['x'][0], batch['data']['m'][0], batch['paths']['x'][0]
if cfg.data.real_files.base == "/network/tmp1/ccai/data/mayCogSciData/Provinces-Mila-complete":
print("The dataset to eval is mayCogSciData!")
savePath = changeSubString(path)
image = image[None, :, :, :]
if not fixed_test_size:
interp = nn.Upsample(size=(label.shape[1], label.shape[2]), mode='bilinear', align_corners=True)
with torch.no_grad():
pred_main = model(image.cuda(device))[1]
output = interp(pred_main).cpu().data[0].numpy()
output = output.transpose(1, 2, 0)
output = np.argmax(output, axis=2)
label0 = label.numpy()[0]
hist += fast_hist(label0.flatten(), output.flatten(), cfg.NUM_CLASSES)
output = torch.tensor(output, dtype=torch.float32)
output = output[None, :, :]
output_RGB = output.repeat(3, 1, 1)
inters_over_union_classes = per_class_iu(hist)
computed_miou = round(np.nanmean(inters_over_union_classes) * 100, 2)
if cur_best_miou < computed_miou:
cur_best_miou = computed_miou
cur_best_model = f'model_{i_iter}.pth'
print('\tCurrent mIoU:', computed_miou)
print('\tCurrent best model:', cur_best_model)
print('\tCurrent best mIoU:', cur_best_miou)
mious = {'Current mIoU': computed_miou,
'Current best model': cur_best_model,
'Current best mIoU': cur_best_miou}
comet_exp.log_metrics(mious)
image = image[0] # change size from [1,x,y,z] to [x,y,z]
save_images = []
save_images.append(image)
# Overlay mask:
save_mask = (
image
- (image * label.repeat(3, 1, 1))
+ label.repeat(3, 1, 1)
)
save_fake_mask = (
image
- (image * output_RGB)
+ output_RGB
)
if cfg.data.real_files.base == "/network/tmp1/ccai/data/mayCogSciData/Provinces-Mila-complete":
print("Saving the overlay pictures!")
# print("Size of picture: ", np.transpose(save_fake_mask.cpu().data.numpy()).shape)
print("SavePath: ", savePath)
# cv2.imwrite(savePath, np.transpose(save_fake_mask.numpy()))
vutils.save_image(output_RGB,savePath,normalize=True)
save_images.append(save_mask)
save_images.append(save_fake_mask)
save_images.append(label.repeat(3, 1, 1))
save_images.append(output_RGB)
write_images(
save_images,
i_iter,
comet_exp=comet_exp,
store_im=cfg.TEST.store_images
)
return computed_miou, cur_best_model, cur_best_miou
def load_checkpoint_for_evaluation(model, checkpoint, device):
saved_state_dict = torch.load(checkpoint)
model.load_state_dict(saved_state_dict)
model.eval()
model.cuda(device)
def display_stats(cfg, name_classes, inters_over_union_classes):
for ind_class in range(cfg.NUM_CLASSES):
print(name_classes[ind_class]
+ '\t' + str(round(inters_over_union_classes[ind_class] * 100, 2)))
def changeSubString(Str):
tmp = Str.split("Provinces-Mila-complete")[0]+"Provinces-Mila-complete-MaskGen"+Str.split("Provinces-Mila-complete")[1]
return tmp