forked from aysim/comingdowntoearth
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_synthesis_cvusa.py
134 lines (108 loc) · 6.26 KB
/
train_synthesis_cvusa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from helper import parser
from data.custom_transforms import *
from data.utils import CVUSA
from networks.c_gan import *
from os.path import exists, join, basename, dirname
from utils import rgan_wrapper, base_wrapper
from utils.setup_helper import *
import time
from argparse import Namespace
import os
if __name__ == '__main__':
parse = parser.Parser()
opt, log_file = parse.parse()
opt.is_Train = True
make_deterministic(opt.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(x) for x in opt.gpu_ids)
log = open(log_file, 'a')
log_print = lambda ms: parse.log(ms, log)
#define networks
generator = define_G(netG=opt.g_model, gpu_ids=opt.gpu_ids)
log_print('Init {} as generator model'.format(opt.g_model))
discriminator = define_D(input_c=opt.input_c, output_c=opt.realout_c, ndf=opt.feature_c, netD=opt.d_model,
condition=opt.condition, n_layers_D=opt.n_layers, gpu_ids=opt.gpu_ids)
log_print('Init {} as discriminator model'.format(opt.d_model))
retrieval = define_R(ret_method=opt.r_model, polar=opt.polar, gpu_ids=opt.gpu_ids)
log_print('Init {} as retrieval model'.format(opt.r_model))
rgan_wrapper = rgan_wrapper.RGANWrapper(opt, log_file, generator, discriminator, retrieval)
# Configure data loader
composed_transforms = transforms.Compose([RandomHorizontalFlip(),
ToTensor()])
train_dataset = CVUSA(root=opt.data_root, csv_file=opt.train_csv, use_polar=opt.polar, name=opt.name,
transform_op=composed_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0)
val_dataset = CVUSA(root=opt.data_root, csv_file=opt.val_csv, use_polar=opt.polar, name=opt.name,
transform_op=ToTensor())
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=0)
log_print('Load datasets from {}: train_set={} val_set={}'.format(opt.data_root, len(train_dataset), len(val_dataset)))
ret_best_acc = rgan_wrapper.ret_best_acc
log_print('Start training from epoch {} to {}, best acc: {}'.format(opt.start_epoch, opt.n_epochs, ret_best_acc))
for epoch in range(opt.start_epoch, opt.n_epochs):
start_time = time.time()
batches_done = 0
val_batches_done = 0
street_batches_t = []
fake_street_batches_t = []
street_batches_v = []
fake_street_batches_v = []
epoch_retrieval_loss = []
epoch_generator_loss = []
epoch_discriminator_loss = []
log_print('>>> RGAN Epoch {}'.format(epoch))
rgan_wrapper.generator.train()
rgan_wrapper.discriminator.train()
rgan_wrapper.retrieval.train()
for i, data in enumerate(train_loader): # inner loop within one epoch
rgan_wrapper.set_input(data)
rgan_wrapper.optimize_parameters(epoch)
fake_street_batches_t.append(rgan_wrapper.fake_street_out.cpu().data)
street_batches_t.append(rgan_wrapper.street_out.cpu().data)
epoch_retrieval_loss.append(rgan_wrapper.r_loss.item())
epoch_discriminator_loss.append(rgan_wrapper.d_loss.item())
epoch_generator_loss.append(rgan_wrapper.g_loss.item())
if (i + 1) % 40 == 0 or (i + 1) == len(train_loader):
fake_street_vec = torch.cat(fake_street_batches_t, dim=0)
street_vec = torch.cat(street_batches_t, dim=0)
dists = 2 - 2 * torch.matmul(fake_street_vec, street_vec.permute(1, 0))
tp1 = rgan_wrapper.mutual_topk_acc(dists, topk=1)
tp5 = rgan_wrapper.mutual_topk_acc(dists, topk=5)
tp10 = rgan_wrapper.mutual_topk_acc(dists, topk=10)
log_print('Batch:{} loss={:.3f} samples:{} tp1={tp1[0]:.2f}/{tp1[1]:.2f} ' \
'tp5={tp5[0]:.2f}/{tp5[1]:.2f}'.format(i + 1, np.mean(epoch_retrieval_loss),
len(dists), tp1=tp1, tp5=tp5))
street_batches_t.clear()
fake_street_batches_t.clear()
rgan_wrapper.save_networks(epoch, dirname(log_file), best_acc=ret_best_acc,
last_ckpt=True) # Always save last ckpt
# Save model periodically
if (epoch + 1) % opt.save_step == 0:
rgan_wrapper.save_networks(epoch, dirname(log_file), best_acc=ret_best_acc)
rgan_wrapper.generator.eval()
rgan_wrapper.retrieval.eval()
for i, data in enumerate(val_loader):
rgan_wrapper.set_input(data)
rgan_wrapper.eval_model()
fake_street_batches_v.append(rgan_wrapper.fake_street_out_val.cpu().data)
street_batches_v.append(rgan_wrapper.street_out_val.cpu().data)
fake_street_vec = torch.cat(fake_street_batches_v, dim=0)
street_vec = torch.cat(street_batches_v, dim=0)
dists = 2 - 2 * torch.matmul(fake_street_vec, street_vec.permute(1, 0))
tp1 = rgan_wrapper.mutual_topk_acc(dists, topk=1)
tp5 = rgan_wrapper.mutual_topk_acc(dists, topk=5)
tp10 = rgan_wrapper.mutual_topk_acc(dists, topk=10)
num = len(dists)
tp1p = rgan_wrapper.mutual_topk_acc(dists, topk=0.01 * num)
acc = Namespace(num=len(dists), tp1=tp1, tp5=tp5, tp10=tp10, tp1p=tp1p)
log_print('\nEvaluate Samples:{num:d}\nRecall(p2s/s2p) tp1:{tp1[0]:.2f}/{tp1[1]:.2f} ' \
'tp5:{tp5[0]:.2f}/{tp5[1]:.2f} tp10:{tp10[0]:.2f}/{tp10[1]:.2f} ' \
'tp1%:{tp1p[0]:.2f}/{tp1p[1]:.2f}'.format(epoch + 1, num=acc.num, tp1=acc.tp1,
tp5=acc.tp5, tp10=acc.tp10, tp1p=acc.tp1p))
# Save the best model
tp1_p2s_acc = acc.tp1[0]
if tp1_p2s_acc > ret_best_acc:
ret_best_acc = tp1_p2s_acc
rgan_wrapper.save_networks(epoch, dirname(log_file), best_acc=ret_best_acc, is_best=True)
log_print('>>Save best model: epoch={} best_acc(tp1_p2s):{:.2f}'.format(epoch + 1, tp1_p2s_acc))
# Progam stastics
rss, vms = get_sys_mem()
log_print('Memory usage: rss={:.2f}GB vms={:.2f}GB Time:{:.2f}s'.format(rss, vms, time.time() - start_time))