-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtrainer.py
83 lines (74 loc) · 3.19 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import time
import torch
import os
from util import log_display, accuracy, AverageMeter
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
device = torch.device('cuda')
else:
device = torch.device('cpu')
class Trainer():
def __init__(self, data_loader, logger, config, name='Trainer', metrics='classfication'):
self.data_loader = data_loader
self.logger = logger
self.name = name
self.step = 0
self.config = config
self.log_frequency = config.log_frequency
self.loss_meters = AverageMeter()
self.acc_meters = AverageMeter()
self.acc5_meters = AverageMeter()
self.report_metrics = self.classfication_metrics if metrics == 'classfication' else self.regression_metrics
def train(self, epoch, GLOBAL_STEP, model, optimizer, criterion):
model.train()
for images, labels in self.data_loader:
images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
self.train_batch(images, labels, model, criterion, optimizer)
self.log(epoch, GLOBAL_STEP)
GLOBAL_STEP += 1
return GLOBAL_STEP
def train_batch(self, x, y, model, criterion, optimizer):
start = time.time()
model.zero_grad()
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, y)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.config.grad_bound)
optimizer.step()
self.report_metrics(pred, y, loss)
self.logger_payload['lr'] = optimizer.param_groups[0]['lr'],
self.logger_payload['|gn|'] = grad_norm
end = time.time()
self.step += 1
self.time_used = end - start
def log(self, epoch, GLOBAL_STEP):
if GLOBAL_STEP % self.log_frequency == 0:
display = log_display(epoch=epoch,
global_step=GLOBAL_STEP,
time_elapse=self.time_used,
**self.logger_payload)
self.logger.info(display)
def classfication_metrics(self, x, y, loss):
acc, acc5 = accuracy(x, y, topk=(1, 5))
self.loss_meters.update(loss.item(), y.shape[0])
self.acc_meters.update(acc.item(), y.shape[0])
self.acc5_meters.update(acc5.item(), y.shape[0])
self.logger_payload = {"acc": acc,
"acc_avg": self.acc_meters.avg,
"loss": loss,
"loss_avg": self.loss_meters.avg}
def regression_metrics(self, x, y, loss):
diff = abs((x - y).mean().detach().item())
self.loss_meters.update(loss.item(), y.shape[0])
self.acc_meters.update(diff, y.shape[0])
self.logger_payload = {"|diff|": diff,
"|diff_avg|": self.acc_meters.avg,
"loss": loss,
"loss_avg": self.loss_meters.avg}
def _reset_stats(self):
self.loss_meters.reset()
self.acc_meters.reset()
self.acc5_meters.reset()