forked from ryankiros/skip-thoughts
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathpenseur_utils.py
77 lines (67 loc) · 3.03 KB
/
penseur_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# when you run this script, add a THEANO-FLAG command to the front:
# THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python
import sys, os
import cPickle as pickle
def train_encoder(name_of_data, sentences, max_epochs=5, save_frequency=1000):
if not os.path.exists('data/'):
os.makedirs('data')
sys.path.insert(0, 'training/')
import vocab
worddict, wordcount = vocab.build_dictionary(sentences)
vocab.save_dictionary(worddict, wordcount, 'data/' + name_of_data + '_dictionary.pkl')
pickle.dump(sentences, open('data/' + name_of_data + '_sen.p', 'w'))
with open('training/train.py', 'r') as f:
text = f.read()
text = text.replace('max_epochs=5', 'max_epochs=' + str(max_epochs))
text = text.replace('saveto=\'/u/rkiros/research/semhash/models/toy.npz\'',\
'saveto=\'data/' + name_of_data + '_encoder.npz\'')
text = text.replace('dictionary=\'/ais/gobi3/u/rkiros/bookgen/book_dictionary_large.pkl\'',\
'dictionary=\'data/' + name_of_data + '_dictionary.pkl\'')
text = text.replace('n_words=20000', 'n_words=' + str(len(wordcount.keys())))
text = text.replace('saveFreq=1000', 'saveFreq=' + str(save_frequency))
g = open('training/train_temp.py', 'w')
g.write(text)
g.close()
import train_temp
train_temp.trainer(sentences)
def load_encoder(model_name):
sys.path.insert(0, 'training/')
import tools
return tools.load_model('data/' + model_name + '_encoder.npz', 'data/' + model_name + '_dictionary.pkl',\
'data/GoogleNews-vectors-negative300.bin')
def encode(encoder, sentences, verbose=False):
sys.path.insert(0, 'training/')
import tools
return tools.encode(encoder, sentences)
def train_decoder(name_of_data, sentences, model, max_epochs=5, save_frequency=1000):
if not os.path.exists('data/'):
os.makedirs('data')
sys.path.insert(0, 'decoding/')
import vocab
worddict, wordcount = vocab.build_dictionary(sentences)
vocab.save_dictionary(worddict, wordcount, 'data/' + name_of_data + '_dictionary.pkl')
with open('decoding/train.py', 'r') as f:
text = f.read()
text = text.replace('max_epochs=5', 'max_epochs=' + str(max_epochs))
text = text.replace('saveto=\'/u/rkiros/research/semhash/models/toy.npz\'',\
'saveto=\'data/' + name_of_data + '_decoder.npz\'')
text = text.replace('dictionary=\'/ais/gobi3/u/rkiros/bookgen/book_dictionary_large.pkl\'',\
'dictionary=\'data/' + name_of_data + '_dictionary.pkl\'')
text = text.replace('n_words=40000', 'n_words=' + str(len(wordcount.keys())))
text = text.replace('saveFreq=1000', 'saveFreq=' + str(save_frequency))
g = open('decoding/train_temp.py', 'w')
g.write(text)
g.close()
import train_temp
return train_temp.trainer(sentences, sentences, model)
def load_decoder(decoder_name):
sys.path.insert(0, 'decoding/')
import tools
return tools.load_model('data/' + decoder_name + '_decoder.npz', 'data/' + decoder_name + '_dictionary.pkl')
def decode(decoder, vector, num_results=1):
sys.path.insert(0, 'decoding/')
import tools
sentences = tools.run_sampler(decoder, vector, beam_width=num_results)
if num_results == 1:
return sentences[0]
return sentences