forked from krasserm/super-resolution
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_SRGAN.py
86 lines (63 loc) · 2.68 KB
/
run_SRGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import matplotlib.pyplot as plt
from data import DIV2K
from model.srgan import generator, discriminator
from train import SrganTrainer, SrganGeneratorTrainer
# Location of model weights (needed for demo)
weights_dir = 'weights/srgan'
weights_file = lambda filename: os.path.join(weights_dir, filename)
os.makedirs(weights_dir, exist_ok=True)
# Import dataset
div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')
train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=16, random_transform=True, repeat_count=1)
# change to true
attention = True
model_name = "srgan"
if attention:
model_name = "srgan_attention"
# Generator pre-training
pre_trainer = SrganGeneratorTrainer(model=generator(attention=attention), checkpoint_dir=f'.ckpt/pre_generator')
pre_trainer.train(train_ds,
valid_ds.take(10),
steps=1000000,
evaluate_every=1000,
save_best_only=False,
model_name=model_name)
pre_trainer.model.save_weights(weights_file('pre_generator.h5'))
# Generator fine-tuning
gan_generator = generator(attention=attention)
gan_generator.load_weights(weights_file('pre_generator.h5'))
gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator(attention=attention))
gan_trainer.train(train_ds, steps=200000)
gan_trainer.generator.save_weights(weights_file('gan_generator.h5'))
gan_trainer.discriminator.save_weights(weights_file('gan_discriminator.h5'))
pre_generator = generator(attention=attention)
gan_generator = generator(attention=attention)
pre_generator.load_weights(weights_file('pre_generator.h5'))
gan_generator.load_weights(weights_file('gan_generator.h5'))
print("evaluating SRGAN model...")
# Evaluate model on full validation set
psnrv = gan_trainer.evaluate(valid_ds)
ssimv = gan_trainer.evaluate2(valid_ds)
print(f'PSNR = {psnrv.numpy():3f}')
print(f'SSIM = {ssimv.numpy():3f}')
from model import resolve_single
from utils import load_image
def resolve_and_plot(lr_image_path):
lr = load_image(lr_image_path)
pre_sr = resolve_single(pre_generator, lr)
gan_sr = resolve_single(gan_generator, lr)
plt.figure(figsize=(20, 20))11-695
images = [lr, pre_sr, gan_sr]
titles = ['LR', 'SR (PRE)', 'SR (GAN)']
positions = [1, 3, 4]
for i, (img, title, pos) in enumerate(zip(images, titles, positions)):
plt.subplot(2, 2, pos)
plt.imshow(img)
plt.title(title)
plt.xticks([])
plt.yticks([])
plt.savefig("srgan_sample_" + str(i))
resolve_and_plot('demo/0869x4-crop.png')