Skip to content

Commit

Permalink
augmentation made consistent across views; dirfun added for png->jpg …
Browse files Browse the repository at this point in the history
…conversion
  • Loading branch information
suhangpro committed Feb 18, 2016
1 parent bfdc769 commit cb3ff22
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 42 deletions.
25 changes: 16 additions & 9 deletions cnn_shape.m
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@
opts.useUprightAssumption = true;
opts.aug = 'stretch';
opts.pad = 32;
opts.numEpochs = [10 20];
opts.numEpochs = [5 10 20];
opts.includeVal = false;
[opts, varargin] = vl_argparse(opts, varargin) ;

if strcmpi(opts.baseModel(end-3:end),'.mat'),
opts.baseModel = load(opts.baseModel);
[~,modelNameStr] = fileparts(opts.baseModel);
opts.baseModel = load(opts.baseModel);
else
modelNameStr = opts.baseModel;
end
Expand All @@ -71,7 +71,7 @@
opts.expDir = fullfile(opts.dataRoot, opts.expDir);
[opts, varargin] = vl_argparse(opts,varargin) ;

opts.train.learningRate = [0.001*ones(1, 10) 0.0001*ones(1, 10) 0.00001*ones(1,10)];
opts.train.learningRate = [0.005*ones(1, 5) 0.001*ones(1, 5) 0.0001*ones(1,10) 0.00001*ones(1,10)];
opts.train.momentum = 0.9;
opts.train.batchSize = 5;
opts.train.gpus = [];
Expand Down Expand Up @@ -124,10 +124,10 @@
end

trainable_layers = find(cellfun(@(l) isfield(l,'weights'),net.layers));
fc_layers = find(cellfun(@(s) numel(s.name)>=2 && strcmp(s.name(1:2),'fc'),net.layers));
fc_layers = intersect(fc_layers, trainable_layers);
lr = cellfun(@(l) l.learningRate, net.layers(trainable_layers),'UniformOutput',false);
layers_for_update = cell(1,2);
layers_for_update{1} = trainable_layers(end);
layers_for_update{2} = trainable_layers;
layers_for_update = {trainable_layers(end), fc_layers, trainable_layers};

for s=1:numel(opts.numEpochs),
if opts.numEpochs(s)<1, continue; end
Expand All @@ -143,7 +143,7 @@
'expDir', opts.expDir, ...
net.meta.trainOpts, ...
opts.train, ...
'numEpochs', opts.numEpochs(s)) ;
'numEpochs', sum(opts.numEpochs(1:s))) ;
end

% -------------------------------------------------------------------------
Expand Down Expand Up @@ -193,9 +193,16 @@
images = strcat([imdb.imageDir filesep], imdb.images.name(batch)) ;

if ~isVal, % training
im = cnn_get_batch(images, opts, 'prefetch', nargout == 0);
im = cnn_shape_get_batch(images, opts, ...
'prefetch', nargout == 0, ...
'nViews', nViews);
else
im = cnn_get_batch(images, opts, 'prefetch', nargout == 0, ...
im = cnn_shape_get_batch(images, opts, ...
'prefetch', nargout == 0, ...
'nViews', nViews, ...
'transformation', 'none');
end

nAugs = numel(im)/numel(images);
if nargout > 1, labels = repmat(labels(:)',[1 nViews]); end

65 changes: 38 additions & 27 deletions cnn_get_batch.m → cnn_shape_get_batch.m
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
function imo = cnn_get_batch(images, varargin)
function imo = cnn_shape_get_batch(images, varargin)
% Modified from CNN_IMAGENET_GET_BATCH
%
% - added `pad` option
% - deals with images of types other than jpeg
% - augmentation made consistent across views

opts.imageSize = [227, 227] ;
opts.border = [29, 29] ;
opts.pad = 0; % [TOP BOTTOM LEFT RIGHT]
opts.nViews = 1;
opts.keepAspect = true ;
opts.numAugments = 1 ;
opts.transformation = 'none' ;
Expand All @@ -27,6 +29,10 @@
% isjpg is true if all images to fetch are of jpeg format
isjpg = fetch && strcmpi(images{1}(end-3:end),'.jpg');

assert(mod(numel(images),opts.nViews)==0, '''nViews'' is incompatible with input');
nViews = opts.nViews;
nShapes = numel(images)/nViews;

if opts.prefetch
if isjpg, vl_imreadjpeg(images, 'numThreads', opts.numThreads, 'prefetch'); end
imo = [] ;
Expand Down Expand Up @@ -77,46 +83,52 @@
imo = zeros(opts.imageSize(1), opts.imageSize(2), 3, ...
numel(images)*opts.numAugments, 'single') ;

si = 1 ;
for i=1:numel(images)

% acquire image
if isempty(im{i})
imt = imread(images{i}) ;
imt = single(imt) ; % faster than im2single (and multiplies by 255)
else
imt = im{i} ;
end
if size(imt,3) == 1
imt = cat(3, imt, imt, imt) ;
for i=1:nShapes,
for j=1:nViews,
% acquire image
idx = (i-1)*nViews + j;
if isempty(im{idx})
imt = imread(images{idx}) ;
imt = single(imt) ; % faster than im2single (and multiplies by 255)
else
imt = im{idx} ;
end
if size(imt,3) == 1
imt = cat(3, imt, imt, imt) ;
end
if j==1,
imArr = zeros(size(imt,1),size(imt,2),3,nViews,'single');
end
imArr(:,:,:,j) = imt;
end

% pad
if ~isempty(opts.pad) && any(opts.pad>0),
imtt = imt;
imt = 255*ones(size(imtt,1)+sum(opts.pad(1:2)), ...
size(imtt,2)+sum(opts.pad(3:4)), 3, 'like', imtt);
imt(opts.pad(1)+(1:size(imtt,1)), opts.pad(3)+(1:size(imtt,2)),:) = imtt;
w = size(imArr,2);
h = size(imArr,1);
imArrTmp = imArr;
imArr = 255*ones(h+sum(opts.pad(1:2)), w+sum(opts.pad(3:4)), 3, nViews, 'single');
imArr(opts.pad(1)+(1:h), opts.pad(3)+(1:w),:,:) = imArrTmp;
end

% resize
w = size(imt,2) ;
h = size(imt,1) ;
w = size(imArr,2) ;
h = size(imArr,1) ;
factor = [(opts.imageSize(1)+opts.border(1))/h ...
(opts.imageSize(2)+opts.border(2))/w];

if opts.keepAspect
factor = max(factor) ;
end
if any(abs(factor - 1) > 0.0001)
imt = imresize(imt, ...
'scale', factor, ...
'method', opts.interpolation) ;
imArr = imresize(imArr, ...
'scale', factor, ...
'method', opts.interpolation) ;
end

% crop & flip
w = size(imt,2) ;
h = size(imt,1) ;
w = size(imArr,2) ;
h = size(imArr,1) ;
for ai = 1:opts.numAugments
switch opts.transformation
case 'stretch'
Expand All @@ -140,10 +152,9 @@
if ~isempty(opts.rgbVariance)
offset = bsxfun(@plus, offset, reshape(opts.rgbVariance * randn(3,1), 1,1,3)) ;
end
imo(:,:,:,si) = bsxfun(@minus, imt(sy,sx,:), offset) ;
imo(:,:,:,(ai-1)*numel(images)+(i-1)*nViews+(1:nViews)) = bsxfun(@minus, imArr(sy,sx,:,:), offset) ;
else
imo(:,:,:,si) = imt(sy,sx,:) ;
imo(:,:,:,(ai-1)*numel(images)+(i-1)*nViews+(1:nViews)) = imArr(sy,sx,:,:) ;
end
si = si + 1 ;
end
end
7 changes: 5 additions & 2 deletions cnn_shape_init.m
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@

% Initiate the last but one layer w/ random weights
widthPrev = size(net.layers{end-1}.weights{1}, 3);
net.layers{end-1}.weights{1} = init_weight(opts, 1, 1, widthPrev, nClass, dataTyp);
net.layers{end-1}.weights{2} = zeros(nClass, 1, dataTyp);
nClass0 = size(net.layers{end-1}.weights{1},4);
if nClass0 ~= nClass || opts.restart,
net.layers{end-1}.weights{1} = init_weight(opts, 1, 1, widthPrev, nClass, dataTyp);
net.layers{end-1}.weights{2} = zeros(nClass, 1, dataTyp);
end

% Initiate other layers w/ random weights if training from scratch is desired
if opts.restart,
Expand Down
8 changes: 4 additions & 4 deletions run_experiments.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
'multiview',false,...
'batchSize',60,...
'gpus',[1],...
'numEpochs',[10 20],...
'learningRate',[0.1*ones(1,10) 0.01*ones(1,10) 0.001*ones(1,10)],...
'numEpochs',[5 5 10],...
'learningRate',[0.05*ones(1,5) 0.01*ones(1,5) 0.001*ones(1,5) 0.0001*ones(1,5)],...
'expDir',fullfile('data','exp-modelnet40v1','phase1')...
);

Expand All @@ -16,7 +16,7 @@
'multiview',true,...
'batchSize',5,...
'gpus',[1],...
'numEpochs',[10 20],...
'learningRate',[0.01*ones(1,10) 0.001*ones(1,10) 0.0001*ones(1,10)],...
'numEpochs',[5 5 10],...
'learningRate',[0.005*ones(1,5) 0.001*ones(1,5) 0.0001*ones(1,5) 0.00001*ones(1,5)],...
'expDir',fullfile('data','exp-modelnet40v1','phase2')...
);
110 changes: 110 additions & 0 deletions utils/dirfun.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
function dirfun( dir_path, processFn, save_path, imreadFn, file_pattern, save_pattern, cnt_limit )
% dirfun Apply a function to each file in the directory
%
% dir_path: directory containing images, will be searched recursively
% processFn: function that will be applied to each image found
% save_path: (default:: dir_path) path to save resized images
% imreadFn: (default:: @imread_safe) function used to load images
% file_pattern: (default:: '*') images that will be processed
% save_pattern: (default:: '') e.g. '%02d.png' will save images as 01.png, 02.png, ...
% cnt_limit: (default:: inf) the maximum number of images used in each folder itself
%
% Hang Su

% if save_path is not specified, overwrite original image
if ~exist('save_path','var') || isempty(save_path),
save_path = dir_path;
end

% default imreadFn
if ~exist('imreadFn','var') || isempty(imreadFn),
imreadFn = @imread_safe;
end

% default file pattern
if ~exist('file_pattern','var') || isempty(file_pattern),
file_pattern = '*';
end

% default save pattern
if ~exist('save_pattern','var') || isempty(save_pattern),
save_pattern = '';
end

% default cnt_limit
if ~exist('cnt_limit','var') || isempty(cnt_limit),
cnt_limit = inf;
end

if ischar(file_pattern),
file_pattern = {file_pattern};
end

% run recursively
update_dir(dir_path, save_path, processFn, imreadFn, file_pattern, save_pattern, cnt_limit);

end

function update_dir(cur_dir, cur_save_dir, processFn, imreadFn, file_pattern, save_pattern, cnt_limit)

file_names = {};
for i = 1:numel(file_pattern),
files = dir(fullfile(cur_dir, file_pattern{i}));
file_names_cur = {files.name};
file_names_cur = file_names_cur(~cell2mat({files.isdir}));
file_names = [file_names file_names_cur];
end
dirs = dir(cur_dir);
dir_names = {dirs.name};
dir_names = setdiff(dir_names(cell2mat({dirs.isdir})),{'.','..'});

if~exist(cur_save_dir,'dir'), mkdir(cur_save_dir); end;

% do work
im_cnt = 0;
for i=1:numel(file_names),
im = imreadFn(fullfile(cur_dir, file_names{i}));
if isempty(im),
if strcmp(cur_dir, cur_save_dir),
delete(fullfile(cur_dir, file_names{i}));
end
continue;
end;
im = processFn(im);
im_cnt = im_cnt + 1;

% save (possibly overwriting original image)
if isempty(save_pattern),
imwrite(im,fullfile(cur_save_dir, file_names{i}));
else
if ~isempty(strfind(save_pattern,'%s')),
[~,cur_name] = fileparts(file_names{i});
cur_name = strrep(save_pattern,'%s',cur_name);
else
cur_name = save_pattern;
end
if ~isempty(strfind(cur_name,'%d')),
cur_name = sprintf(cur_name,im_cnt);
end
imwrite(im,fullfile(cur_save_dir, cur_name));
end

if im_cnt >= cnt_limit, break; end

end

for d = 1:numel(dir_names),
update_dir(fullfile(cur_dir,dir_names{d}), ...
fullfile(cur_save_dir,dir_names{d}), processFn, imreadFn, file_pattern, save_pattern, cnt_limit);
end

end

function im = imread_safe(path)
try
im = imread(path);
catch
warning('Unable to load image: %s', path);
im = [];
end
end

0 comments on commit cb3ff22

Please sign in to comment.