forked from dennybritz/chatbot-retrieval
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathudc_model.py
99 lines (83 loc) · 3.32 KB
/
udc_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import tensorflow as tf
import sys
def get_id_feature(features, key, len_key, max_len):
ids = features[key]
ids_len = tf.squeeze(features[len_key], [1])
ids_len = tf.minimum(ids_len, tf.constant(max_len, dtype=tf.int64))
return ids, ids_len
def create_train_op(loss, hparams):
train_op = tf.contrib.layers.optimize_loss(
loss=loss,
global_step=tf.contrib.framework.get_global_step(),
learning_rate=hparams.learning_rate,
clip_gradients=10.0,
optimizer=hparams.optimizer)
return train_op
def create_model_fn(hparams, model_impl):
def model_fn(features, targets, mode):
context, context_len = get_id_feature(
features, "context", "context_len", hparams.max_context_len)
utterance, utterance_len = get_id_feature(
features, "utterance", "utterance_len", hparams.max_utterance_len)
#change for TF 1.2, jW 10/13/2017. See https://github.com/dennybritz/chatbot-retrieval/issues/32#issuecomment-265819109
#batch_size = targets.get_shape().as_list()[0]
if mode == tf.contrib.learn.ModeKeys.EVAL:
batch_size = targets.get_shape().as_list()[0]
if mode == tf.contrib.learn.ModeKeys.TRAIN:
probs, loss = model_impl(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
targets)
train_op = create_train_op(loss, hparams)
return probs, loss, train_op
if mode == tf.contrib.learn.ModeKeys.INFER:
probs, loss = model_impl(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
None)
return probs, 0.0, None
if mode == tf.contrib.learn.ModeKeys.EVAL:
# We have 10 exampels per record, so we accumulate them
all_contexts = [context]
all_context_lens = [context_len]
all_utterances = [utterance]
all_utterance_lens = [utterance_len]
all_targets = [tf.ones([batch_size, 1], dtype=tf.int64)]
for i in range(9):
distractor, distractor_len = get_id_feature(features,
"distractor_{}".format(i),
"distractor_{}_len".format(i),
hparams.max_utterance_len)
all_contexts.append(context)
all_context_lens.append(context_len)
all_utterances.append(distractor)
all_utterance_lens.append(distractor_len)
all_targets.append(
tf.zeros([batch_size, 1], dtype=tf.int64)
)
# change tf.concat and tf.split to 1.0 syntax. jw. 2017/10/09
probs, loss = model_impl(
hparams,
mode,
tf.concat(all_contexts, 0),
tf.concat(all_context_lens, 0),
tf.concat(all_utterances, 0),
tf.concat(all_utterance_lens, 0),
tf.concat(all_targets, 0))
split_probs = tf.split(probs, 10, 0)
shaped_probs = tf.concat(split_probs, 1)
# Add summaries. change tf.histogram_summary to tf.summary.histogram, change tf.scalar_summary to jw. 2017/10/10
tf.summary.histogram("eval_correct_probs_hist", split_probs[0])
tf.summary.scalar("eval_correct_probs_average", tf.reduce_mean(split_probs[0]))
tf.summary.histogram("eval_incorrect_probs_hist", split_probs[1])
tf.summary.scalar("eval_incorrect_probs_average", tf.reduce_mean(split_probs[1]))
return shaped_probs, loss, None
return model_fn