forked from b06b01073/go_thesis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_nature2016.py
82 lines (58 loc) · 2.32 KB
/
train_nature2016.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from argparse import ArgumentParser
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from NNFactory import Nature2016Factory
from OptimFactory import SGDFactory
from Trainer import Trainer
from config.nature2016_config import *
from config.training_config import *
import GoDataset
from seed import set_seed
from TrainingChock import settings_checker, path_checker
import os
if __name__ == '__main__':
set_seed()
parser = ArgumentParser()
# parser.add_argument('--config_path', type=str, default='./hyperparams.json', help='path to the hyperparams file')
parser.add_argument('--save_path', type=str, default='./trained_model')
parser.add_argument('--train', type=str, default='./dataset/train.txt', help='path to the train set')
parser.add_argument('--test', type=str, default='./dataset/test.txt', help='path to the test set')
# we want the users to explicitly type out the path
parser.add_argument('--file_name', type=str, required=True, help='name of the saved model (best performing model)')
parser.add_argument('--log_path', type=str, required=True, help='file name of the log (accuracy during training)')
parser.add_argument('--latest_path', type=str, required=True, help='file name of the latest iteration of the model')
args = parser.parse_args()
# crreate the folder to save the model if the path does not exist
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
path_checker(
args.log_path,
args.latest_path,
os.path.join(args.save_path, args.file_name)
)
settings_checker(
args,
optim_config,
net_config,
scheduler_config,
training_config
)
net = Nature2016Factory().createModel()
optim = SGDFactory().create_optim(net, optim_config)
scheduler = StepLR(optim, step_size=scheduler_config['step_size'], gamma=scheduler_config['gamma'])
loss_func = nn.CrossEntropyLoss()
trainer = Trainer(
net,
optim,
loss_func,
os.path.join(args.save_path, args.file_name),
scheduler,
)
train_set = GoDataset.get_loader(args.train, 'train')
test_set = GoDataset.get_loader(args.test, 'test')
trainer.fit(
train_set,
test_set,
args.log_path,
args.latest_path
)