-
Notifications
You must be signed in to change notification settings - Fork 203
/
Copy pathtest.lua
130 lines (115 loc) · 5.98 KB
/
test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
require 'image'
require 'nn'
util = paths.dofile('util.lua')
torch.setdefaulttensortype('torch.FloatTensor')
opt = {
batchSize = 30, -- number of samples to produce
net = '', -- path to the generator network
name = 'test1', -- name of the experiment and prefix of file saved
gpu = 1, -- gpu mode. 0 = CPU, 1 = 1st GPU etc.
nc = 3, -- # of channels in input
display = 1, -- Display image: 0 = false, 1 = true
loadSize = 0, -- resize the loaded image to loadsize maintaining aspect ratio. 0 means don't resize. -1 means scale randomly between [0.5,2] -- see donkey_folder.lua
fineSize = 128, -- size of random crops
nThreads = 1, -- # of data loading threads to use
manualSeed = 0, -- 0 means random seed
overlapPred = 0, -- overlapping edges of center with context
-- Extra Options:
noiseGen = 0, -- 0 means false else true; only works if network was trained with noise too.
noisetype = 'normal', -- type of noise distribution (uniform / normal)
nz = 100, -- length of noise vector if used
}
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
print(opt)
if opt.display == 0 then opt.display = false end
if opt.noiseGen == 0 then opt.noiseGen = false end
-- set seed
if opt.manualSeed == 0 then
opt.manualSeed = torch.random(1, 10000)
end
print("Seed: " .. opt.manualSeed)
torch.manualSeed(opt.manualSeed)
-- load Context-Encoder
assert(opt.net ~= '', 'provide a generator model')
net = util.load(opt.net, opt.gpu)
net:evaluate()
-- initialize variables
input_image_ctx = torch.Tensor(opt.batchSize, opt.nc, opt.fineSize, opt.fineSize)
local noise
if opt.noiseGen then
noise = torch.Tensor(opt.batchSize, opt.nz, 1, 1)
if opt.noisetype == 'uniform' then
noise:uniform(-1, 1)
elseif opt.noisetype == 'normal' then
noise:normal(0, 1)
end
end
-- port to GPU
if opt.gpu > 0 then
require 'cunn'
if pcall(require, 'cudnn') then
print('Using CUDNN !')
require 'cudnn'
net = util.cudnn(net)
end
net:cuda()
input_image_ctx = input_image_ctx:cuda()
if opt.noiseGen then
noise = noise:cuda()
end
else
net:float()
end
print(net)
-- load data
local DataLoader = paths.dofile('data/data.lua')
local data = DataLoader.new(opt.nThreads, opt)
print("Dataset Size: ", data:size())
local image_ctx = data:getBatch()
print('Loaded Image Block: ', image_ctx:size(1)..' x '..image_ctx:size(2) ..' x '..image_ctx:size(3)..' x '..image_ctx:size(4))
-- remove center region from input image
real_center = image_ctx[{{},{},{1 + opt.fineSize/4, opt.fineSize/2 + opt.fineSize/4},{1 + opt.fineSize/4, opt.fineSize/2 + opt.fineSize/4}}]:clone() -- copy by value
-- fill center region with mean value
image_ctx[{{},{1},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred}}] = 2*117.0/255.0 - 1.0
image_ctx[{{},{2},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred}}] = 2*104.0/255.0 - 1.0
image_ctx[{{},{3},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred}}] = 2*123.0/255.0 - 1.0
input_image_ctx:copy(image_ctx)
-- run Context-Encoder to inpaint center
local pred_center
if opt.noiseGen then
pred_center = net:forward({input_image_ctx,noise})
else
pred_center = net:forward(input_image_ctx)
end
print('Prediction: size: ', pred_center:size(1)..' x '..pred_center:size(2) ..' x '..pred_center:size(3)..' x '..pred_center:size(4))
print('Prediction: Min, Max, Mean, Stdv: ', pred_center:min(), pred_center:max(), pred_center:mean(), pred_center:std())
-- paste predicted center in the context
image_ctx[{{},{},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred}}]:copy(pred_center[{{},{},{1 + opt.overlapPred, opt.fineSize/2 - opt.overlapPred},{1 + opt.overlapPred, opt.fineSize/2 - opt.overlapPred}}])
-- re-transform scale back to normal
input_image_ctx:add(1):mul(0.5)
image_ctx:add(1):mul(0.5)
pred_center:add(1):mul(0.5)
real_center:add(1):mul(0.5)
-- save outputs
-- image.save(opt.name .. '_predWithContext.png', image.toDisplayTensor(image_ctx))
-- image.save(opt.name .. '_realCenter.png', image.toDisplayTensor(real_center))
-- image.save(opt.name .. '_predCenter.png', image.toDisplayTensor(pred_center))
if opt.display then
disp = require 'display'
disp.image(pred_center, {win=1000, title=opt.name})
disp.image(real_center, {win=1001, title=opt.name})
disp.image(image_ctx, {win=1002, title=opt.name})
print('Displayed image in browser !')
end
-- save outputs in a pretty manner
real_center=nil; pred_center=nil;
pretty_output = torch.Tensor(2*opt.batchSize, opt.nc, opt.fineSize, opt.fineSize)
input_image_ctx[{{},{1},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred}}] = 1
input_image_ctx[{{},{2},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred}}] = 1
input_image_ctx[{{},{3},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred},{1 + opt.fineSize/4 + opt.overlapPred, opt.fineSize/2 + opt.fineSize/4 - opt.overlapPred}}] = 1
for i=1,opt.batchSize do
pretty_output[2*i-1]:copy(input_image_ctx[i])
pretty_output[2*i]:copy(image_ctx[i])
end
image.save(opt.name .. '.png', image.toDisplayTensor(pretty_output))
print('Saved predictions to: ./', opt.name .. '.png')