forked from charlesq34/pointnet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathregress.py
113 lines (91 loc) · 3.56 KB
/
regress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import tensorflow as tf
from random import shuffle
import numpy as np
import time
import os
import sys
from os import listdir
from os.path import isfile, join
import argparse
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(BASE_DIR, 'models'))
sys.path.append(os.path.join(BASE_DIR, 'utils'))
sys.path.append(os.path.join(BASE_DIR, 'log'))
import scipy.misc
import provider
import pc_util
import importlib
from plyfile import (PlyData, PlyElement, make2d, PlyParseError, PlyProperty)
import leboudyNet as MODEL
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
parser = argparse.ArgumentParser()
parser.add_argument('--ply_path', default='', help='ply file to classify')
parser.add_argument('--batch_ply_path', default='', help='folder where .ply files exist, if set, will classify the files in one go')
FLAGS = parser.parse_args()
BATCH_SIZE = 2
NUM_POINT = 2048
MODEL_PATH = 'log/model.ckpt'
testFile=FLAGS.ply_path
testDir = FLAGS.batch_ply_path
onlyPlyfiles = []
if testDir:
onlyPlyfiles = [join(testDir, f) for f in listdir(testDir) if f.endswith('.ply') and isfile(join(testDir, f))]
BATCH_SIZE = len(onlyPlyfiles)
#print(onlyPlyfiles)
reverseDict=dict({0:"bird",1:"bond",2:"can",3:"cracker",4:"house",5:"shoe",6:"teapot"})
#NUM_CLASSES = 7
def evaluate(num_votes):
is_training = False
pointclouds_pl, posesx, posesq = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
is_training_pl = tf.placeholder(tf.bool, shape=())
# simple model
predictedposesx,predictedposesq = MODEL.get_model(pointclouds_pl, is_training_pl)
#loss = MODEL.get_loss(pred, labels_pl, end_points)
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Create a session
sess = tf.Session()
# Restore variables from disk.
saver.restore(sess, MODEL_PATH)
#log_string("Model restored.")
ops = {'pointclouds_pl': pointclouds_pl,
'is_training_pl': is_training_pl,
'predictedposesx': predictedposesx,
'predictedposesq': predictedposesq,
}
eval_one_epoch(sess, ops, num_votes)
def eval_one_epoch(sess, ops, num_votes=1, topk=1):
error_cnt = 0
is_training = False
total_seen = 0
loss_sum = 0
#fout = open(os.path.join(DUMP_DIR, 'pred_label.txt'), 'w')
#for fn in range(len(TEST_FILES)):
#log_string('----'+str(fn)+'----')
current_data=[]
if len(onlyPlyfiles)>0 :
for plyfile in onlyPlyfiles:
#print('loading file')
#print(plyfile)
current_data.append(provider.load_ply_data(plyfile))
current_data = np.asarray(current_data)
else:
current_data = provider.load_ply_data(testFile)
#current_label = np.squeeze(current_label)
current_data=np.asarray([current_data,np.zeros_like(current_data)])
#print(current_data.shape)
#file_size = current_data.shape[0]
num_batches = 1
#print(file_size)
#batch_pred_sum = np.zeros((current_data.shape[0], NUM_CLASSES)) # score for classes
#batch_pred_classes = np.zeros((current_data.shape[0], NUM_CLASSES)) # 0/1 for classes
feed_dict = {ops['pointclouds_pl']: current_data,
ops['is_training_pl']: is_training}
predx,predq = sess.run( [ops['predictedposesx'],ops['predictedposesq']],feed_dict=feed_dict)
print(testFile+","+str(predx[0])+","+str(predq[0]))
if __name__=='__main__':
with tf.device('/cpu:0'):
with tf.Graph().as_default():
evaluate(num_votes=1)
#LOG_FOUT.close()