-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
98 lines (84 loc) · 3.51 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import argparse
import torch
import cv2
import os
import torch.nn.parallel
import modules, net, resnet, densenet, senet
import numpy as np
import loaddata_demo as loaddata
import pdb
import argparse
from volume import get_volume
from mask import get_mask
import matplotlib.image
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser(description='KD-network')
parser.add_argument('--img', metavar='DIR',default="./input/test.jpg",
help='img to input')
parser.add_argument('--json', metavar='DIR',default="./input/test.json",
help='json file to input')
parser.add_argument('--output', metavar='DIR',default="./output",
help='dir to output')
args=parser.parse_args()
def define_model(is_resnet, is_densenet, is_senet):
if is_resnet:
original_model = resnet.resnet50(pretrained = True)
Encoder = modules.E_resnet(original_model)
model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
if is_densenet:
original_model = densenet.densenet161(pretrained=True)
Encoder = modules.E_densenet(original_model)
model = net.model(Encoder, num_features=2208, block_channel = [192, 384, 1056, 2208])
if is_senet:
original_model = senet.senet154(pretrained='imagenet')
Encoder = modules.E_senet(original_model)
model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
return model
def main():
if (not os.path.exists(args.output)):
print("Output directory doesn't exist! Creating...")
os.makedirs(args.output)
device = torch.device('cpu')
model = define_model(is_resnet=False, is_densenet=False, is_senet=True)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('./pretrained_model/model_senet', map_location=device))
model.eval()
print
img = cv2.imread(args.img)
#img = np.float32(img) # Convert image to float32 data type
nyu2_loader = loaddata.readNyu2(args.img)
return test(nyu2_loader, model, img.shape[1], img.shape[0])
def test(nyu2_loader, model, width, height):
volumes = []
with torch.no_grad():
for i, image in enumerate(nyu2_loader):
image = torch.autograd.Variable(image, volatile=True)
image = image[:, :3, :, :]
out = model(image)
out = out.view(out.size(2),out.size(3)).data.cpu().numpy()
max_pix = out.max()
min_pix = out.min()
out = (out-min_pix)/(max_pix-min_pix)*255
out = cv2.resize(out,(width,height),interpolation=cv2.INTER_CUBIC)
cv2.imwrite(os.path.join(args.output, "out_grey.png"),out)
out_grey = cv2.imread(os.path.join(args.output, "out_grey.png"),0)
out_color = cv2.applyColorMap(out_grey, cv2.COLORMAP_JET)
cv2.imwrite(os.path.join(args.output, "out_color.png"),out_color)
vol = get_volume(out_grey, args.json)
# print("Volume:")
# print(vol)
volumes.append(vol)
# print("unit: cm^3")
out_file = open(os.path.join(args.output, "out.txt"), "w")
out_file.write("Volume:\n")
out_file.write(str(vol))
out_file.write("\n")
out_file.write("unit: cm^3")
out_file.close()
get_mask(out_grey, args.json)
for v in volumes[0]:
volumes[0][v] //= 3
print(v, volumes[0][v])
return volumes
# if __name__ == '__main__':
# main()