-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfcnInitializeModel.m
113 lines (92 loc) · 3.89 KB
/
fcnInitializeModel.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
function net = fcnInitializeModel(varargin)
%FCNINITIALIZEMODEL Initialize the FCN-32 model from VGG-VD-16
opts.sourceModelUrl = 'http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-16.mat' ;
opts.sourceModelPath = 'data/models/imagenet-vgg-verydeep-16.mat' ;
opts = vl_argparse(opts, varargin) ;
% -------------------------------------------------------------------------
% Load & download the source model if needed (VGG VD 16)
% -------------------------------------------------------------------------
if ~exist(opts.sourceModelPath)
fprintf('%s: downloading %s\n', opts.sourceModelUrl) ;
mkdir(fileparts(opts.sourceModelPath)) ;
urlwrite('http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-16.mat', opts.sourceModelPath) ;
end
net = load(opts.sourceModelPath) ;
% -------------------------------------------------------------------------
% Edit the model to create the FCN version
% -------------------------------------------------------------------------
% Add dropout to the fully-connected layers in the source model
drop1 = struct('name', 'dropout1', 'type', 'dropout', 'rate' , 0.5) ;
drop2 = struct('name', 'dropout2', 'type', 'dropout', 'rate' , 0.5) ;
net.layers = [net.layers(1:33) drop1 net.layers(34:35) drop2 net.layers(36:end)] ;
% Convert the model from SimpleNN to DagNN
net = dagnn.DagNN.fromSimpleNN(net, 'canonicalNames', true) ;
% Add more padding to the input layer
%net.layers(1).block.pad = 100 ;
net.layers(5).block.pad = [0 1 0 1] ;
net.layers(10).block.pad = [0 1 0 1] ;
net.layers(17).block.pad = [0 1 0 1] ;
net.layers(24).block.pad = [0 1 0 1] ;
net.layers(31).block.pad = [0 1 0 1] ;
net.layers(32).block.pad = [3 3 3 3] ;
% ^-- we could do [2 3 2 3] but that would not use CuDNN
% Modify the bias learning rate for all layers
for i = 1:numel(net.layers)-1
if (isa(net.layers(i).block, 'dagnn.Conv') && net.layers(i).block.hasBias)
filt = net.getParamIndex(net.layers(i).params{1}) ;
bias = net.getParamIndex(net.layers(i).params{2}) ;
net.params(bias).learningRate = 2 * net.params(filt).learningRate ;
end
end
% Modify the last fully-connected layer to have 21 output classes
% Initialize the new filters to zero
for i = net.getParamIndex(net.layers(end-1).params) ;
sz = size(net.params(i).value) ;
sz(end) = 21 ;
net.params(i).value = zeros(sz, 'single') ;
end
net.layers(end-1).block.size = size(...
net.params(net.getParamIndex(net.layers(end-1).params{1})).value) ;
% Remove the last loss layer
net.removeLayer('prob') ;
net.setLayerOutputs('fc8', {'x38'}) ;
% -------------------------------------------------------------------------
% Upsampling and prediction layer
% -------------------------------------------------------------------------
filters = single(bilinear_u(64, 21, 21)) ;
net.addLayer('deconv32', ...
dagnn.ConvTranspose(...
'size', size(filters), ...
'upsample', 32, ...
'crop', [16 16 16 16], ...
'numGroups', 21, ...
'hasBias', false), ...
'x38', 'prediction', 'deconvf') ;
f = net.getParamIndex('deconvf') ;
net.params(f).value = filters ;
net.params(f).learningRate = 0 ;
net.params(f).weightDecay = 1 ;
% Make the output of the bilinear interpolator is not discared for
% visualization purposes
net.vars(net.getVarIndex('prediction')).precious = 1 ;
% -------------------------------------------------------------------------
% Losses and statistics
% -------------------------------------------------------------------------
% Add loss layer
net.addLayer('objective', ...
SegmentationLoss('loss', 'softmaxlog'), ...
{'prediction', 'label'}, 'objective') ;
% Add accuracy layer
net.addLayer('accuracy', ...
SegmentationAccuracy(), ...
{'prediction', 'label'}, 'accuracy') ;
if 0
figure(100) ; clf ;
n = numel(net.vars) ;
for i=1:n
vl_tightsubplot(n,i) ;
showRF(net, 'input', net.vars(i).name) ;
title(sprintf('%s', net.vars(i).name)) ;
drawnow ;
end
end