-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvisualize_trained.py
135 lines (107 loc) · 5.67 KB
/
visualize_trained.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# This file loads the trained model from disk and saves images on the disk
# This code creates a qual_results folder inside the current folder and saves some images there. If that folder already exists,
# the code throws error on purpose, to avoid overwriting previous results
import torch
import matplotlib.pyplot as plt
import os
from config import cfg
from data_factory import get_dataset
import numpy as np
import torch.nn.functional as F
# Here is a checklist before running this code:
#i. Make sure to select the correct model by setting model_best, if model_best is True the best model wrt val loss,
# will be used. Otherwise, the model saved at the end of training will be used for evaluation.
#ii. Make sure to set the correct folder variable cfg.train.out_dir in config.py
# IMP: select the correct model here:
model_best = True
eval_potsdam = True # If true, the test set of Potsdam images will be used. If false, the val set from Berlin will
# be used
def main():
out_dir = cfg.train.out_dir
if not os.path.exists(out_dir):
raise ValueError('The folder does not exist. Make sure to set the correct folder variable cfg.train.out_dir in config.py')
if os.path.exists(os.path.join(out_dir,'qual_results')):
raise ValueError('The validation folder image_results already exists. Delete the folder if those results are not needed')
else:
os.makedirs(os.path.join(out_dir, 'qual_results'))
qual_net = torch.load(os.path.join(out_dir, "trained_basemap_checkpoint.pth"))
print('Network loaded...')
print(cfg)
## Data loader
# only test/validation set is needed
if eval_potsdam == True:
cfg.train.mode = 'test_potsdam'
else:
cfg.train.mode = 'test'
ds_test = get_dataset(cfg.train.mode)
print('Data loaders have been prepared!')
qual_net.eval()
ctr = 0
with torch.no_grad():
for i, data in enumerate(ds_test, 0):
# reading clean images
images = data[0].type('torch.cuda.FloatTensor')
# occluded images
occluded_imgs = data[2]
# initializing the quality scores of all images
q_pre = torch.zeros(occluded_imgs[0].shape[0], len(occluded_imgs), occluded_imgs[0].shape[1],
occluded_imgs[0].shape[2]).type('torch.cuda.FloatTensor')
for j in range(len(occluded_imgs)): # compute all the quality masks
q_now = qual_net(occluded_imgs[j].type('torch.cuda.FloatTensor'))
q_pre[:, j, :, :] = q_now[:, 0, :, :]
# do the softmax across quality masks dimension
q_final = F.softmax(1 * q_pre, dim=1)
# make the final basemap
base_map = 0.0 * occluded_imgs[0].type('torch.cuda.FloatTensor') # initialization with zero
for j in range(len(occluded_imgs)): # compute all the quality masks
image_now = occluded_imgs[j].type('torch.cuda.FloatTensor')
base_map = base_map + q_final[:, j, :, :].view(q_now.shape).permute(0, 2, 3, 1) * image_now
# computing unweigted average as baseline
average_image = 0.0 * occluded_imgs[0].type('torch.cuda.FloatTensor') # initialize with zero
for j in range(len(occluded_imgs)):
average_image = average_image + occluded_imgs[j].type('torch.cuda.FloatTensor') # avoiding inline operation i.e. +=
average_image = average_image / np.float(len(occluded_imgs))
num_fig = np.minimum(base_map.shape[0], 18)
plt.ioff()
# save results of the last batch
for k in range(num_fig):
# target output
plt.figure()
plt.imshow(images[k,: ,:, :].detach().cpu().numpy())
plt.axis('off')
fname1 = str(str(ctr) + '_target' + '.png')
plt.savefig(os.path.join(out_dir, 'qual_results', fname1), bbox_inches='tight')
# basemap
plt.figure()
plt.imshow(base_map[k, :, :, :].detach().cpu().numpy())
plt.axis('off')
fname1 = str(str(ctr) + '_out_basemap' + '.png')
plt.savefig(os.path.join(out_dir, 'qual_results', fname1), bbox_inches='tight')
plt.figure()
plt.imshow(base_map[k, :, :, :].detach().cpu().numpy())
plt.axis('off')
fname1 = str(str(ctr) + '_out_basemap' + '.png')
plt.savefig(os.path.join(out_dir, 'qual_results', fname1), bbox_inches='tight')
# baseline
plt.figure()
plt.imshow(average_image[k, :, :, :].detach().cpu().numpy())
plt.axis('off')
fname1 = str(str(ctr) + '_out_average' + '.png')
plt.savefig(os.path.join(out_dir, 'qual_results', fname1), bbox_inches='tight')
# input images
for j in range(len(occluded_imgs)):
plt.figure()
plt.imshow(occluded_imgs[j][k,:,:,:])
plt.axis('off')
fname1 = str(str(ctr) + '_image' +str(j) + '.png')
plt.savefig(os.path.join(out_dir, 'qual_results', fname1), bbox_inches='tight')
# quality masks
for j in range(len(occluded_imgs)):
plt.figure()
plt.imshow(q_final[k, j, :, :].detach().cpu().numpy())
plt.axis('off')
fname1 = str(str(ctr) + '_mask'+ str(j) + '.png')
plt.savefig(os.path.join(out_dir, 'qual_results', fname1), bbox_inches='tight')
ctr += 1
if __name__ == '__main__':
main()