-
-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathtrain_nets.py
150 lines (131 loc) · 5.56 KB
/
train_nets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import time
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from nets.coco_dataset import COCODataset
from nets.sesf_net import SESFuseNet
from nets.nets_utility import *
import torch.nn as nn
from nets.lp_lssim_loss import LpLssimLoss
# parameter for net
experiment_name = 'lp+lssim_se_sf_net_times30'
gpu_device = "cuda:0"
# gpu_device_for_parallel = [2, 3]
learning_rate = 1e-4
epochs = 30
batch_size = 48
display_step = 100
shuffle = True
attention = 'cse'
# address
project_addrsss = os.getcwd()
train_dir = os.path.join(project_addrsss, "data", "coco2014", "train2014")
val_dir = os.path.join(project_addrsss, "data", "coco2014", "val2014")
log_address = os.path.join(project_addrsss, "nets", "train_record", experiment_name + "_log_file.txt")
is_out_log_file = True
parameter_address = os.path.join(project_addrsss, "nets", "parameters")
# datasets
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4500517361627943], [0.26465333914691797]),
])
image_datasets = {}
image_datasets['train'] = COCODataset(train_dir, transform=data_transforms, need_crop=False, need_augment=False)
image_datasets['val'] = COCODataset(val_dir, transform=data_transforms, need_crop=False, need_augment=False)
dataloders = {}
dataloders['train'] = DataLoader(
image_datasets['train'],
batch_size=batch_size,
shuffle=shuffle,
num_workers=1)
dataloders['val'] = DataLoader(
image_datasets['val'],
batch_size=batch_size,
shuffle=shuffle,
num_workers=1)
datasets_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
print_and_log("datasets size: {}".format(datasets_sizes), is_out_log_file, log_address)
# models
training_setup_seed(1) # setup seed for all parameters, numpy.random, random, pytorch
# model = UNet(in_channel=1, out_channel=1)
model = SESFuseNet(attention)
model.to(gpu_device)
# model = nn.DataParallel(model, device_ids=gpu_device_for_parallel)
criterion = LpLssimLoss().to(gpu_device)
optimizer = optim.Adam(model.parameters(), learning_rate)
# optimizer = nn.DataParallel(optimizer, device_ids=gpu_device_for_parallel)
def val():
model.eval()
running_loss = 0.0
with torch.no_grad():
for i, data in enumerate(dataloders['val']):
input = data.to(gpu_device)
optimizer.zero_grad()
output = model.forward('train', input)
loss, lp_loss, lssim_loss = criterion(image_in=input, image_out=output)
running_loss += loss.item()
epoch_loss_val = running_loss / datasets_sizes['val']
return epoch_loss_val
def train(epoch):
iterations_loss_list = []
iterations_lp_loss_list = []
iterations_lssim_loss_list = []
model.train()
adjust_learning_rate(optimizer, learning_rate, epoch)
print_and_log('Train Epoch {}/{}:'.format(epoch + 1, epochs), is_out_log_file, log_address)
running_loss = 0.0
# Iterate over data.
for i, data in enumerate(dataloders['train']):
input = data.to(gpu_device)
output = model.forward('train', input)
loss, lp_loss, lssim_loss = criterion(image_in=input, image_out=output)
running_loss += loss.item()
if i % display_step == 0:
print_and_log('\t{} {}-{}: Loss: {:.4f}'.format('train', epoch + 1, i, loss.item() / batch_size),
is_out_log_file, log_address)
iterations_loss_list.append(loss.item() / batch_size)
iterations_lp_loss_list.append(lp_loss.item() / batch_size)
iterations_lssim_loss_list.append(lssim_loss.item() / batch_size)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss_train = running_loss / datasets_sizes['train']
plot_iteration_loss(experiment_name, epoch + 1, iterations_loss_list, iterations_lp_loss_list,
iterations_lssim_loss_list)
return epoch_loss_train
def main():
min_loss = 100000000.0
loss_train = []
loss_val = []
since = time.time()
for epoch in range(epochs):
epoch_loss_train = train(epoch)
loss_train.append(epoch_loss_train)
epoch_loss_val = val()
loss_val.append(epoch_loss_val)
print_and_log('\ttrain Loss: {:.6f}'.format(epoch_loss_train), is_out_log_file, log_address)
print_and_log('\tvalidation Loss: {:.6f}'.format(epoch_loss_val), is_out_log_file, log_address)
# deep copy the models
if epoch_loss_val < min_loss:
min_loss = epoch_loss_val
best_model_wts = model.state_dict()
print_and_log("Updating", is_out_log_file, log_address)
torch.save(best_model_wts,
os.path.join(parameter_address, experiment_name + '.pkl'))
plot_loss(experiment_name, epoch, loss_train, loss_val)
# save models
model_wts = model.state_dict()
torch.save(model_wts,
os.path.join(parameter_address, experiment_name + "_" + str(epoch) + '.pkl'))
time_elapsed = time.time() - since
print_and_log('Time passed {:.0f}h {:.0f}m {:.0f}s'.
format(time_elapsed // 3600, (time_elapsed % 3600) // 60, time_elapsed % 60), is_out_log_file,
log_address)
print_and_log('-' * 20, is_out_log_file, log_address)
print_and_log("train loss: {}".format(loss_train), is_out_log_file, log_address)
print_and_log("val loss: {}".format(loss_val), is_out_log_file, log_address)
print_and_log("min val loss: {}".format(min(loss_val)), is_out_log_file, log_address)
if __name__ == "__main__":
main()