-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvisualize_attention.py
90 lines (63 loc) · 2.54 KB
/
visualize_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import re
import os
import time
import numpy as np
import tensorflow as tf
import tensorflow.contrib.eager as tfe
from data import ImageNetDataset
from config import Configuration
from models.alexnet import AlexNet
import cv2
tfe.enable_eager_execution()
class Tester(object):
def __init__(self, cfg, net, testset):
self.cfg = cfg
self.net = net
self.testset = testset
# Load the model back
self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.cfg.LEARNING_RATE, momentum=self.cfg.MOMENTUM)
self.checkpoint_dir = self.cfg.CKPT_PATH
self.checkpoint_encoder = os.path.join(self.checkpoint_dir, 'Model')
self.root1 = tfe.Checkpoint(optimizer=self.optimizer, model=self.net, optimizer_step=tf.train.get_or_create_global_step())
self.root1.restore(tf.train.latest_checkpoint(self.checkpoint_dir))
def test(self, mode):
# get the image, label and attention maps for layer 4 and 5
for (ex_i, (image, label)) in enumerate(tfe.Iterator(self.testset.dataset),1):
if ex_i==5:
out ,Att1, Att2 = self.net(image)
image = image[0]
break
# Reshape for attention on layer 5
A = tf.reshape(Att2, [6,6]).numpy()
# Reshape for attention on layer 4
B = tf.reshape(Att1, [13,13]).numpy()
# Normalize the map obtained
out = np.zeros(A.shape, np.double)
A=cv2.normalize(A, out, 0.0, 1.0, cv2.NORM_MINMAX)
# Add mean back to the data image
image=image+self.cfg.DATA_MEAN
image=image.numpy()
out1 = np.zeros(image.shape, np.double)
image=cv2.normalize(image, out1, 0.0, 1.0, cv2.NORM_MINMAX)
# Resize attentin map
A=cv2.resize(A,(image.shape[1],image.shape[0]))
# Get image and map in 0 to 255 range
A = np.uint8(255 * A)
image = np.uint8(255 * image)
# Get color map of the image
A = cv2.applyColorMap(A, cv2.COLORMAP_JET)
# Weighted addition of the image and map
superimposed_img = cv2.addWeighted(image, 0.6, A, 0.4, 0)
cv2.imshow('map',A)
cv2.imshow("GradCam", superimposed_img)
cv2.imshow('image',image)
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == '__main__':
# Call tester on test data
cfg = Configuration()
net = AlexNet(cfg, training=False)
path = 'cifar-10-batches-py/test_batch'
testset = ImageNetDataset(cfg, 'test', path)
tester = Tester(cfg, net, testset)
tester.test('test')