-
Notifications
You must be signed in to change notification settings - Fork 119
/
Copy pathsuper_resolve.py
102 lines (83 loc) · 4.52 KB
/
super_resolve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
from utils import *
from PIL import Image, ImageDraw, ImageFont
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model checkpoints
srgan_checkpoint = "./checkpoint_srgan.pth.tar"
srresnet_checkpoint = "./checkpoint_srresnet.pth.tar"
# Load models
srresnet = torch.load(srresnet_checkpoint)['model'].to(device)
srresnet.eval()
srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
srgan_generator.eval()
def visualize_sr(img, halve=False):
"""
Visualizes the super-resolved images from the SRResNet and SRGAN for comparison with the bicubic-upsampled image
and the original high-resolution (HR) image, as done in the paper.
:param img: filepath of the HR iamge
:param halve: halve each dimension of the HR image to make sure it's not greater than the dimensions of your screen?
For instance, for a 2160p HR image, the LR image will be of 540p (1080p/4) resolution. On a 1080p screen,
you will therefore be looking at a comparison between a 540p LR image and a 1080p SR/HR image because
your 1080p screen can only display the 2160p SR/HR image at a downsampled 1080p. This is only an
APPARENT rescaling of 2x.
If you want to reduce HR resolution by a different extent, modify accordingly.
"""
# Load image, downsample to obtain low-res version
hr_img = Image.open(img, mode="r")
hr_img = hr_img.convert('RGB')
if halve:
hr_img = hr_img.resize((int(hr_img.width / 2), int(hr_img.height / 2)),
Image.LANCZOS)
lr_img = hr_img.resize((int(hr_img.width / 4), int(hr_img.height / 4)),
Image.BICUBIC)
# Bicubic Upsampling
bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)
# Super-resolution (SR) with SRResNet
sr_img_srresnet = srresnet(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device))
sr_img_srresnet = sr_img_srresnet.squeeze(0).cpu().detach()
sr_img_srresnet = convert_image(sr_img_srresnet, source='[-1, 1]', target='pil')
# Super-resolution (SR) with SRGAN
sr_img_srgan = srgan_generator(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device))
sr_img_srgan = sr_img_srgan.squeeze(0).cpu().detach()
sr_img_srgan = convert_image(sr_img_srgan, source='[-1, 1]', target='pil')
# Create grid
margin = 40
grid_img = Image.new('RGB', (2 * hr_img.width + 3 * margin, 2 * hr_img.height + 3 * margin), (255, 255, 255))
# Font
draw = ImageDraw.Draw(grid_img)
try:
font = ImageFont.truetype("calibril.ttf", size=23)
# It will also look for this file in your OS's default fonts directory, where you may have the Calibri Light font installed if you have MS Office
# Otherwise, use any TTF font of your choice
except OSError:
print(
"Defaulting to a terrible font. To use a font of your choice, include the link to its TTF file in the function.")
font = ImageFont.load_default()
# Place bicubic-upsampled image
grid_img.paste(bicubic_img, (margin, margin))
text_size = font.getsize("Bicubic")
draw.text(xy=[margin + bicubic_img.width / 2 - text_size[0] / 2, margin - text_size[1] - 5], text="Bicubic",
font=font,
fill='black')
# Place SRResNet image
grid_img.paste(sr_img_srresnet, (2 * margin + bicubic_img.width, margin))
text_size = font.getsize("SRResNet")
draw.text(
xy=[2 * margin + bicubic_img.width + sr_img_srresnet.width / 2 - text_size[0] / 2, margin - text_size[1] - 5],
text="SRResNet", font=font, fill='black')
# Place SRGAN image
grid_img.paste(sr_img_srgan, (margin, 2 * margin + sr_img_srresnet.height))
text_size = font.getsize("SRGAN")
draw.text(
xy=[margin + bicubic_img.width / 2 - text_size[0] / 2, 2 * margin + sr_img_srresnet.height - text_size[1] - 5],
text="SRGAN", font=font, fill='black')
# Place original HR image
grid_img.paste(hr_img, (2 * margin + bicubic_img.width, 2 * margin + sr_img_srresnet.height))
text_size = font.getsize("Original HR")
draw.text(xy=[2 * margin + bicubic_img.width + sr_img_srresnet.width / 2 - text_size[0] / 2,
2 * margin + sr_img_srresnet.height - text_size[1] - 1], text="Original HR", font=font, fill='black')
# Display grid
grid_img.show()
return grid_img
if __name__ == '__main__':
grid_img = visualize_sr("/media/ssd/sr data/Set14/baboon.png")