diff --git a/README.md b/README.md index f24e388..6f01020 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,67 @@ -# tsr-torch -Traffic Sign Recognition with Torch [WIP] +# Traffic Sign Recognition +This is the code for the paper + +**[Deep neural network for traffic sign recognition systems: An analysis of spatial transformers and stochastic optimisation methods](https://doi.org/10.1016/j.neunet.2018.01.005)** +
+[Álvaro Arcos-García](https://scholar.google.com/citations?user=gjecl3cAAAAJ), +[Juan Antonio Álvarez-García](https://scholar.google.com/citations?user=Qk79xk8AAAAJ), +[Luis M. Soria-Morillo](https://scholar.google.com/citations?user=poBDpFkAAAAJ) +
+ +The paper addresses the problem of traffic sign classification using a Deep Neural Network which comprises Convolutional layers and Spatial Transfomer Networks. The model reports an accuracy of 99.71% in the [German Traffic Sign Recognition Benchmark](http://benchmark.ini.rub.de/?section=gtsrb&subsection=results). + +We provide: +- A [pretrained model](#pretrained-model). +- Test code to [run the model on new images](#running-on-new-images). +- Instructions for [training the model](#training). + +If you find this code useful in your research, please cite: + +``` +"Deep neural network for traffic sign recognition systems: An analysis of spatial transformers and stochastic optimisation methods." +Álvaro Arcos-García, Juan A. Álvarez-García, Luis M. Soria-Morillo. Neural Networks 99 (2018) 158-165. +``` +\[[link](https://doi.org/10.1016/j.neunet.2018.01.005)\]\[[bibtex]( +https://scholar.googleusercontent.com/citations?view_op=export_citations&user=gjecl3cAAAAJ&citsig=AMstHGQAAAAAW8ngijLlAnMmG2C2_Weu4sB-TWRmdT1P&hl=en)\] + +## Requirements +This project is implemented in [Torch](http://torch.ch/), and depends on the following packages: [torch/torch7](https://github.com/torch/torch7), [torch/nn](https://github.com/torch/nn), [torch/image](https://github.com/torch/image), [qassemoquab/stnbhwd](https://github.com/qassemoquab/stnbhwd), [torch/cutorch](https://github.com/torch/cutorch), [torch/cunn](https://github.com/torch/cunn), [cuDNN bindings for Torch](https://github.com/soumith/cudnn.torch) and [nninit](https://github.com/Kaixhin/nninit). + +After installing torch, you can install / update these dependencies by running the following: +```bash +luarocks install torch +luarocks install nn +luarocks install image +luarocks install cutorch +luarocks install cunn +luarocks install cudnn +luarocks install https://raw.githubusercontent.com/qassemoquab/stnbhwd/master/stnbhwd-scm-1.rockspec +luarocks install https://github.com/Kaixhin/nninit/blob/master/rocks/nninit-scm-1.rockspec +``` + +## Pretrained model +You can download a pretrained model from [Google Drive](https://drive.google.com/uc?export=download&id=1iqX1TZBxJSEmaEoReK2ZDko03JHqw1bp). +Unzip the file `gtsrb_cnn3st_pretrained.zip` and move its content to the folder `pretrained` of this project. It contains the pretrained model that obtains an accuracy of 99.71% in the German Traffic Sign Recognition Benchmark (GTSRB) and a second file with the mean and standard deviation values computed during the training process. + +## Running on new images +To run the model on new images, use the script `run_model.lua`. It will run the pretrained model on the images provided in the `sample_images` folder: +```bash +th run_model.lua +``` + +## Training +To train and experiment with new traffic sign classification models, please follow the following steps: +1. Use the script `download_gtsrb_dataset.lua` which will create a new folder called `GTSRB` that will include two folders: `train` and `val`. + ```bash + th download_gtsrb_dataset.lua + ``` +2. Use the script `main.lua` setting the training parameters described in the file `opts.lua`. For example, the following options will generate the files needed to train a model with just one spatial transformer network localized at the beginning of the main network: + ```bash + th main.lua -data GTSRB -save GTSRB/checkpoints -dataset gtsrb -nClasses 43 -optimizer adam -LR 1e-4 -momentum 0 -weightDecay 0 -batchSize 50 -nEpochs 15 -weightInit default -netType cnn3st -globalNorm -localNorm -cNormConv -locnet2 '' -locnet3 '' -showFullOutput + ``` + The next example, will create a model with three spatial transformer networks: + ```bash + th main.lua -data GTSRB -save GTSRB/checkpoints -dataset gtsrb -nClasses 43 -optimizer rmsprop -LR 1e-5 -momentum 0 -weightDecay 0 -batchSize 50 -nEpochs 15 -weightInit default -netType cnn3st -globalNorm -localNorm -cNormConv -showFullOutput + ``` +## Acknowledgements +The source code of this project is mainly based on [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and [gtsrb.torch](https://github.com/Moodstocks/gtsrb.torch). \ No newline at end of file diff --git a/checkpoints.lua b/checkpoints.lua new file mode 100644 index 0000000..c8d8dd4 --- /dev/null +++ b/checkpoints.lua @@ -0,0 +1,79 @@ +-- +-- Copyright (c) 2016, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- +local checkpoint = {} + +local function deepCopy(tbl) + -- creates a copy of a network with new modules and the same tensors + local copy = {} + for k, v in pairs(tbl) do + if type(v) == 'table' then + copy[k] = deepCopy(v) + else + copy[k] = v + end + end + if torch.typename(tbl) then + torch.setmetatable(copy, torch.typename(tbl)) + end + return copy +end + +function checkpoint.latest(opt) + if opt.resume == 'none' then + return nil + end + + local latestPath = paths.concat(opt.resume, 'latest.t7') + if not paths.filep(latestPath) then + return nil + end + + print('=> Loading checkpoint ' .. latestPath) + local latest = torch.load(latestPath) + local optimState = torch.load(paths.concat(opt.resume, latest.optimFile)) + + return latest, optimState +end + +function checkpoint.save(epoch, model, optimState, isBestModel, opt) + -- don't save the DataParallelTable for easier loading on other machines + if torch.type(model) == 'nn.DataParallelTable' then + model = model:get(1) + end + + -- create a clean copy on the CPU without modifying the original network + model = deepCopy(model):float():clearState() + + local modelFile = 'model_' .. epoch .. '.t7' + local optimFile = 'optimState_' .. epoch .. '.t7' + + if opt.checkpoint == 'true' then + torch.save(paths.concat(opt.save, modelFile), model) + torch.save(paths.concat(opt.save, optimFile), optimState) + torch.save(paths.concat(opt.save, 'latest.t7'), { + epoch = epoch, + modelFile = modelFile, + optimFile = optimFile, + }) + end + + if isBestModel then + local bestModelFile = 'model_best.t7' + local bestOptimFile = 'model_best_optimState.t7' + torch.save(paths.concat(opt.save, bestModelFile), model) + torch.save(paths.concat(opt.save, bestOptimFile), optimState) + torch.save(paths.concat(opt.save, 'latest.t7'), { + epoch = epoch, + modelFile = bestModelFile, + optimFile = bestOptimFile, + }) + end +end + +return checkpoint \ No newline at end of file diff --git a/dataloader.lua b/dataloader.lua new file mode 100644 index 0000000..28ad269 --- /dev/null +++ b/dataloader.lua @@ -0,0 +1,127 @@ +-- +-- Copyright (c) 2016, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- +-- Multi-threaded data loader +-- + +local datasets = require 'datasets/init' +local Threads = require 'threads' +Threads.serialization('threads.sharedserialize') + +local M = {} +local DataLoader = torch.class('DataLoader', M) + +function DataLoader.create(opt) + -- The train and val loader + local loaders = {} + + for i, split in ipairs{'train', 'val'} do + local dataset = datasets.create(opt, split) + loaders[i] = M.DataLoader(dataset, opt, split) + end + + return table.unpack(loaders) +end + +function DataLoader:__init(dataset, opt, split) + local manualSeed = opt.manualSeed + local function init() + require('datasets/' .. opt.dataset) + end + local function main(idx) + if manualSeed ~= 0 then + torch.manualSeed(manualSeed + idx) + end + torch.setnumthreads(1) + _G.dataset = dataset + _G.preprocess = dataset:preprocess() + return dataset:size() + end + + local threads, sizes = Threads(opt.nThreads, init, main) + self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1 + self.threads = threads + self.__size = sizes[1][1] + self.batchSize = math.floor(opt.batchSize / self.nCrops) + local function getCPUType(tensorType) + if tensorType == 'torch.CudaHalfTensor' then + return 'HalfTensor' + elseif tensorType == 'torch.CudaDoubleTensor' then + return 'DoubleTensor' + else + return 'FloatTensor' + end + end + self.cpuType = getCPUType(opt.tensorType) +end + +function DataLoader:size() + return math.ceil(self.__size / self.batchSize) +end + +function DataLoader:run() + local threads = self.threads + local size, batchSize = self.__size, self.batchSize + local perm = torch.randperm(size) + + local idx, sample = 1, nil + local function enqueue() + while idx <= size and threads:acceptsjob() do + local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1)) + threads:addjob( + function(indices, nCrops, cpuType) + local sz = indices:size(1) + local batch, imageSize + local target = torch.IntTensor(sz) + for i, idx in ipairs(indices:totable()) do + local sample = _G.dataset:get(idx) + local input = _G.preprocess(sample.input) + if not batch then + imageSize = input:size():totable() + if nCrops > 1 then table.remove(imageSize, 1) end + batch = torch[cpuType](sz, nCrops, table.unpack(imageSize)) + end + batch[i]:copy(input) + target[i] = sample.target + end + collectgarbage() + return { + input = batch:view(sz * nCrops, table.unpack(imageSize)), + target = target, + } + end, + function(_sample_) + sample = _sample_ + end, + indices, + self.nCrops, + self.cpuType + ) + idx = idx + batchSize + end + end + + local n = 0 + local function loop() + enqueue() + if not threads:hasjob() then + return nil + end + threads:dojob() + if threads:haserror() then + threads:synchronize() + end + enqueue() + n = n + 1 + return n, sample + end + + return loop +end + +return M.DataLoader \ No newline at end of file diff --git a/datasets/gtsrb-gen.lua b/datasets/gtsrb-gen.lua new file mode 100644 index 0000000..fd85f40 --- /dev/null +++ b/datasets/gtsrb-gen.lua @@ -0,0 +1,125 @@ +-- +-- Copyright (c) 2016, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- +-- Script to compute list of ImageNet filenames and classes +-- +-- This generates a file gen/imagenet.t7 which contains the list of all +-- ImageNet training and validation images and their classes. This script also +-- works for other datasets arragned with the same layout. +-- + +local sys = require 'sys' +local ffi = require 'ffi' + +local M = {} + +local function findClasses(dir) + local dirs = paths.dir(dir) + table.sort(dirs) + + local classList = {} + local classToIdx = {} + for _ ,class in ipairs(dirs) do + if not classToIdx[class] and class ~= '.' and class ~= '..' and class ~= '.DS_Store' then + table.insert(classList, class) + classToIdx[class] = #classList + end + end + + assert(#classList == 43, 'expected 43 GTSRB classes') + return classList, classToIdx +end + +local function findImages(dir, classToIdx) + local imagePath = torch.CharTensor() + local imageClass = torch.LongTensor() + + ---------------------------------------------------------------------- + -- Options for the GNU and BSD find command + local extensionList = {'jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} + local findOptions = ' -iname "*.' .. extensionList[1] .. '"' + for i=2,#extensionList do + findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' + end + + -- Find all the images using the find command + local f = io.popen('find -L ' .. dir .. findOptions) + + local maxLength = -1 + local imagePaths = {} + local imageClasses = {} + + -- Generate a list of all the images and their class + while true do + local line = f:read('*line') + if not line then break end + + local className = paths.basename(paths.dirname(line)) + local filename = paths.basename(line) + local path = className .. '/' .. filename + + local classId = classToIdx[className] + assert(classId, 'class not found: ' .. className) + + table.insert(imagePaths, path) + table.insert(imageClasses, classId) + + maxLength = math.max(maxLength, #path + 1) + end + + f:close() + + -- Convert the generated list to a tensor for faster loading + local nImages = #imagePaths + local imagePath = torch.CharTensor(nImages, maxLength):zero() + for i, path in ipairs(imagePaths) do + ffi.copy(imagePath[i]:data(), path) + end + + local imageClass = torch.LongTensor(imageClasses) + return imagePath, imageClass +end + +function M.exec(opt, cacheFile) + -- find the image path names + local imagePath = torch.CharTensor() -- path to each image in dataset + local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) + + local trainDir = paths.concat(opt.data, 'train') + local valDir = paths.concat(opt.data, 'val') + assert(paths.dirp(trainDir), 'train directory not found: ' .. trainDir) + assert(paths.dirp(valDir), 'val directory not found: ' .. valDir) + + print("=> Generating list of images") + local classList, classToIdx = findClasses(trainDir) + + print(" | finding all validation images") + local valImagePath, valImageClass = findImages(valDir, classToIdx) + + print(" | finding all training images") + local trainImagePath, trainImageClass = findImages(trainDir, classToIdx) + + local info = { + basedir = opt.data, + classList = classList, + train = { + imagePath = trainImagePath, + imageClass = trainImageClass, + }, + val = { + imagePath = valImagePath, + imageClass = valImageClass, + }, + } + + print(" | saving list of images to " .. cacheFile) + torch.save(cacheFile, info) + return info +end + +return M \ No newline at end of file diff --git a/datasets/gtsrb.lua b/datasets/gtsrb.lua new file mode 100644 index 0000000..9b8f1c8 --- /dev/null +++ b/datasets/gtsrb.lua @@ -0,0 +1,81 @@ +-- +-- GTSRB dataset loader +-- + +local image = require 'image' +local paths = require 'paths' +local t = require 'datasets/transforms' +local ffi = require 'ffi' + +local M = {} +local GTSRBDataset = torch.class('GTSRBDataset', M) + +function GTSRBDataset:__init(imageInfo, opt, split) + -- imageInfo: result from gtsrb-gen.lua + -- opt: command-line arguments + -- split: "train" or "val" + self.imageInfo = imageInfo[split] + self.opt = opt + self.split = split + self.dir = paths.concat(opt.data, split) + assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir) +end + +function GTSRBDataset:get(i) + local path = ffi.string(self.imageInfo.imagePath[i]:data()) + + local image = self:_loadImage(paths.concat(self.dir, path)) + local class = self.imageInfo.imageClass[i] + + return { + input = image, + target = class, + } +end + +function GTSRBDataset:_loadImage(path) + local ok, input = pcall(function() + return image.load(path, 3, 'float') + end) + + -- Sometimes image.load fails because the file extension does not match the + -- image format. In that case, use image.decompress on a ByteTensor. + if not ok then + local f = io.open(path, 'r') + assert(f, 'Error reading: ' .. tostring(path)) + local data = f:read('*a') + f:close() + + local b = torch.ByteTensor(string.len(data)) + ffi.copy(b:data(), data, b:size(1)) + + input = image.decompress(b, 3, 'float') + end + + return input +end + +function GTSRBDataset:size() + return self.imageInfo.imageClass:size(1) +end + +-- Computed from GTSRB training images +local meanstd = { + mean = { 0.341, 0.312, 0.321 }, + std = { 0.275, 0.264, 0.270 }, +} + +function GTSRBDataset:preprocess() + local transforms = {t.Resize(48, 48)} + if self.opt.globalNorm then + table.insert(transforms, t.ColorNormalize(meanstd)) + end + if self.opt.localNorm then + table.insert(transforms, t.LocalContrastNorm(7)) + end + + return t.Compose(transforms) + +end + +return M.GTSRBDataset \ No newline at end of file diff --git a/datasets/init.lua b/datasets/init.lua new file mode 100644 index 0000000..8bdd9f2 --- /dev/null +++ b/datasets/init.lua @@ -0,0 +1,34 @@ +-- +-- Copyright (c) 2016, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- + +local M = {} + +local function isvalid(opt, cachePath) + local imageInfo = torch.load(cachePath) + if imageInfo.basedir and imageInfo.basedir ~= opt.data then + return false + end + return true +end + +function M.create(opt, split) + local cachePath = paths.concat(opt.gen, opt.dataset .. '.t7') + if not paths.filep(cachePath) or not isvalid(opt, cachePath) then + paths.mkdir('gen') + + local script = paths.dofile(opt.dataset .. '-gen.lua') + script.exec(opt, cachePath) + end + local imageInfo = torch.load(cachePath) + + local Dataset = require('datasets/' .. opt.dataset) + return Dataset(imageInfo, opt, split) +end + +return M \ No newline at end of file diff --git a/datasets/transforms.lua b/datasets/transforms.lua new file mode 100644 index 0000000..ce1fad1 --- /dev/null +++ b/datasets/transforms.lua @@ -0,0 +1,318 @@ +-- +-- Copyright (c) 2016, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- +-- Image transforms for data augmentation and input normalization +-- + +require 'image' +require 'nn' +torch.setdefaulttensortype('torch.FloatTensor') + +local M = {} + +function M.Compose(transforms) + return function(input) + for _, transform in ipairs(transforms) do + input = transform(input) + end + return input + end +end + +function M.ColorNormalize(meanstd) + return function(img) + img = img:clone() + for i=1,3 do + img[i]:add(-meanstd.mean[i]) + img[i]:div(meanstd.std[i]) + end + return img + end +end + +-- Scales the image to have width X height size +function M.Resize(width, height, interpolation) + interpolation = interpolation or 'bicubic' + return function(input) + return image.scale(input, width, height, interpolation) + end +end + +-- Scales the smaller edge to size +function M.ScaleEdge(size, interpolation) + interpolation = interpolation or 'bicubic' + return function(input) + local w, h = input:size(3), input:size(2) + if (w <= h and w == size) or (h <= w and h == size) then + return input + end + if w < h then + return image.scale(input, size, h/w * size, interpolation) + else + return image.scale(input, w/h * size, size, interpolation) + end + end +end + +-- Crop to centered rectangle +function M.CenterCrop(size) + return function(input) + local w1 = math.ceil((input:size(3) - size)/2) + local h1 = math.ceil((input:size(2) - size)/2) + return image.crop(input, w1, h1, w1 + size, h1 + size) -- center patch + end +end + +-- Random crop form larger image with optional zero padding +function M.RandomCrop(size, padding) + padding = padding or 0 + + return function(input) + if padding > 0 then + local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding) + temp:zero() + :narrow(2, padding+1, input:size(2)) + :narrow(3, padding+1, input:size(3)) + :copy(input) + input = temp + end + + local w, h = input:size(3), input:size(2) + if w == size and h == size then + return input + end + + local x1, y1 = torch.random(0, w - size), torch.random(0, h - size) + local out = image.crop(input, x1, y1, x1 + size, y1 + size) + assert(out:size(2) == size and out:size(3) == size, 'wrong crop size') + return out + end +end + +-- Four corner patches and center crop from image and its horizontal reflection +function M.TenCrop(size) + local centerCrop = M.CenterCrop(size) + + return function(input) + local w, h = input:size(3), input:size(2) + + local output = {} + for _, img in ipairs{input, image.hflip(input)} do + table.insert(output, centerCrop(img)) + table.insert(output, image.crop(img, 0, 0, size, size)) + table.insert(output, image.crop(img, w-size, 0, w, size)) + table.insert(output, image.crop(img, 0, h-size, size, h)) + table.insert(output, image.crop(img, w-size, h-size, w, h)) + end + + -- View as mini-batch + for i, img in ipairs(output) do + output[i] = img:view(1, img:size(1), img:size(2), img:size(3)) + end + + return input.cat(output, 1) + end +end + +-- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style) +function M.RandomScale(minSize, maxSize) + return function(input) + local w, h = input:size(3), input:size(2) + + local targetSz = torch.random(minSize, maxSize) + local targetW, targetH = targetSz, targetSz + if w < h then + targetH = torch.round(h / w * targetW) + else + targetW = torch.round(w / h * targetH) + end + + return image.scale(input, targetW, targetH, 'bicubic') + end +end + +-- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style) +function M.RandomSizedCrop(size) + local scale = M.Scale(size) + local crop = M.CenterCrop(size) + + return function(input) + local attempt = 0 + repeat + local area = input:size(2) * input:size(3) + local targetArea = torch.uniform(0.08, 1.0) * area + + local aspectRatio = torch.uniform(3/4, 4/3) + local w = torch.round(math.sqrt(targetArea * aspectRatio)) + local h = torch.round(math.sqrt(targetArea / aspectRatio)) + + if torch.uniform() < 0.5 then + w, h = h, w + end + + if h <= input:size(2) and w <= input:size(3) then + local y1 = torch.random(0, input:size(2) - h) + local x1 = torch.random(0, input:size(3) - w) + + local out = image.crop(input, x1, y1, x1 + w, y1 + h) + assert(out:size(2) == h and out:size(3) == w, 'wrong crop size') + + return image.scale(out, size, size, 'bicubic') + end + attempt = attempt + 1 + until attempt >= 10 + + -- fallback + return crop(scale(input)) + end +end + +function M.HorizontalFlip(prob) + return function(input) + if torch.uniform() < prob then + input = image.hflip(input) + end + return input + end +end + +function M.Rotation(deg) + return function(input) + if deg ~= 0 then + input = image.rotate(input, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear') + end + return input + end +end + +function M.Translate(tx, ty) + return function(input) + input = image.translate(input, (torch.uniform() - 0.5) * tx, (torch.uniform() - 0.5) * ty) + return input + end +end + +-- Lighting noise (AlexNet-style PCA-based noise) +function M.Lighting(alphastd, eigval, eigvec) + return function(input) + if alphastd == 0 then + return input + end + + local alpha = torch.Tensor(3):normal(0, alphastd) + local rgb = eigvec:clone() + :cmul(alpha:view(1, 3):expand(3, 3)) + :cmul(eigval:view(1, 3):expand(3, 3)) + :sum(2) + :squeeze() + + input = input:clone() + for i=1,3 do + input[i]:add(rgb[i]) + end + return input + end +end + +local function blend(img1, img2, alpha) + return img1:mul(alpha):add(1 - alpha, img2) +end + +local function grayscale(dst, img) + dst:resizeAs(img) + dst[1]:zero() + dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3]) + dst[2]:copy(dst[1]) + dst[3]:copy(dst[1]) + return dst +end + +function M.Saturation(var) + local gs + + return function(input) + gs = gs or input.new() + grayscale(gs, input) + + local alpha = 1.0 + torch.uniform(-var, var) + blend(input, gs, alpha) + return input + end +end + +function M.Brightness(var) + local gs + + return function(input) + gs = gs or input.new() + gs:resizeAs(input):zero() + + local alpha = 1.0 + torch.uniform(-var, var) + blend(input, gs, alpha) + return input + end +end + +function M.Contrast(var) + local gs + + return function(input) + gs = gs or input.new() + grayscale(gs, input) + gs:fill(gs[1]:mean()) + + local alpha = 1.0 + torch.uniform(-var, var) + blend(input, gs, alpha) + return input + end +end + +function M.RandomOrder(ts) + return function(input) + local img = input.img or input + local order = torch.randperm(#ts) + for i=1,#ts do + img = ts[order[i]](img) + end + return img + end +end + +function M.ColorJitter(opt) + local brightness = opt.brightness or 0 + local contrast = opt.contrast or 0 + local saturation = opt.saturation or 0 + + local ts = {} + if brightness ~= 0 then + table.insert(ts, M.Brightness(brightness)) + end + if contrast ~= 0 then + table.insert(ts, M.Contrast(contrast)) + end + if saturation ~= 0 then + table.insert(ts, M.Saturation(saturation)) + end + + if #ts == 0 then + return function(input) return input end + end + + return M.RandomOrder(ts) +end + +function M.LocalContrastNorm(gaussian_kernel_size) + return function(input) + local kernel = image.gaussian(gaussian_kernel_size) + local normalizer = nn.SpatialContrastiveNormalization(input:size(1), kernel) + input:copy(normalizer:forward(input)) + return input + end +end + +return M \ No newline at end of file diff --git a/download_gtsrb_dataset.lua b/download_gtsrb_dataset.lua new file mode 100755 index 0000000..fbe68b9 --- /dev/null +++ b/download_gtsrb_dataset.lua @@ -0,0 +1,63 @@ +local pl = (require 'pl.import_into')() + +local dataset = {} + +dataset.path_remote_train = "http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Training_Images.zip" +dataset.path_remote_test = "http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Test_Images.zip" +dataset.path_remote_test_gt = "http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Test_GT.zip" + +dataset.train_folder = "train" +dataset.val_folder = "val" + +function dataset.download_dataset() + if not pl.path.isdir('GTSRB') then + local zip_train = paths.basename(dataset.path_remote_train) + local zip_test = paths.basename(dataset.path_remote_test) + local zip_test_gt = paths.basename(dataset.path_remote_test_gt) + + print('Downloading dataset...') + os.execute('wget ' .. dataset.path_remote_train .. '; ' .. + 'unzip ' .. zip_train .. '; '.. + 'rm ' .. zip_train .. '; ' .. + 'mv GTSRB/Final_Training/Images/ GTSRB/train;' .. + 'rm -r GTSRB/Final_Training') + os.execute('wget ' .. dataset.path_remote_test .. '; ' .. + 'unzip ' .. zip_test .. '; '.. + 'rm ' .. zip_test .. '; ' .. + 'mkdir GTSRB/val; ' .. + [[find GTSRB/Final_Test/Images/ -maxdepth 1 -name '*.ppm' -exec sh -c 'mv "$@" "$0"' GTSRB/val/ {} +;]] .. + 'rm -r GTSRB/Final_Test') + os.execute('wget ' .. dataset.path_remote_test_gt .. '; ' .. + 'unzip ' .. zip_test_gt .. '; '.. + 'rm ' .. zip_test_gt .. '; '.. + 'mv GT-final_test.csv GTSRB/GT-final_test.csv') + end +end + +function dataset.move_val_images() + print("Moving validation images to class folders") + local val_dir = pl.path.join("GTSRB", dataset.val_folder) + local csv_file_path = pl.path.join("GTSRB", "GT-final_test.csv") + local csv_data = pl.data.read(csv_file_path) + local filename_index = csv_data.fieldnames:index("Filename") + local class_id_index = csv_data.fieldnames:index("ClassId") + + for _, image_metadata in ipairs(csv_data) do + local image_name = image_metadata[filename_index] + local image_path = pl.path.join(val_dir, image_name) + local image_label = image_metadata[class_id_index] + local class_folder_name = string.format("%05d", image_label) + local class_folder_path = pl.path.join(val_dir, class_folder_name) + if not pl.path.exists(class_folder_path) then + pl.path.mkdir(class_folder_path) + end + pl.file.move(image_path, pl.path.join(class_folder_path, image_name)) + end +end + + +dataset.download_dataset() +dataset.move_val_images() + + + diff --git a/main.lua b/main.lua new file mode 100644 index 0000000..0bfae1a --- /dev/null +++ b/main.lua @@ -0,0 +1,110 @@ +-- +-- Copyright (c) 2016, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- + +require 'torch' +require 'paths' +require 'optim' +require 'nn' + +local DataLoader = require 'dataloader' +local models = require 'models/init' +local Trainer = require 'trainer' +local opts = require 'opts' +local checkpoints = require 'checkpoints' + +-- we don't change this to the 'correct' type (e.g. HalfTensor), because math +-- isn't supported on that type. Type conversion later will handle having +-- the correct type. + +local opt = opts.parse(arg) +torch.setdefaulttensortype('torch.FloatTensor') +torch.setnumthreads(opt.nThreads) +print(opt.manualSeed) + +torch.manualSeed(opt.manualSeed) +cutorch.manualSeedAll(opt.manualSeed) +math.randomseed(opt.manualSeed) + +-- Create unique checkpoint dir +dir_name = 'net-' .. opt.netType .. '__cnn-' .. opt.cnn .. + '__locnet1-' .. opt.locnet1 .. '__locnet2-' .. opt.locnet2 .. '__locnet3-' .. opt.locnet3 .. + '__optimizer-'.. opt.optimizer .. '__weightinit-' .. opt.weightInit .. + os.date("__%Y_%m_%d_%X") +opt.save = paths.concat(opt.save, dir_name) +if paths.dir(opt.save) == nil then + paths.mkdir(opt.save) +end + +-- Load previous checkpoint, if it exists +local checkpoint, optimState = checkpoints.latest(opt) + +-- Create model +local model, criterion = models.setup(opt, checkpoint) + +print("-- Model architecture --") +print(model) + +-- Data loading +local trainLoader, valLoader = DataLoader.create(opt) + +-- The trainer handles the training loop and evaluation on validation set +local trainer = Trainer(model, criterion, opt, optimState) + +-- Logger +local logger = optim.Logger(paths.concat(opt.save, 'history.log')) +logger:setNames{'Train Loss', 'Train LossAbs','Train Acc', + 'Test Loss', 'Test LossAbs', 'Test Acc' } +logger:style{'+-','+-','+-','+-','+-','+-'} +logger:display(false) + +if opt.testOnly then + local top1Err, top5Err = trainer:test(0, valLoader) + print(string.format(' * Results top1: %6.3f top5: %6.3f', top1Err, top5Err)) + return +end + +local startEpoch = opt.epochNumber --checkpoint and checkpoint.epoch + 1 or opt.epochNumber +local bestTop1 = 0 +local bestTop5 = 0 +local bestLoss = math.huge +local bestLossAbs = math.huge +local bestEpoch = math.huge +for epoch = startEpoch, opt.nEpochs do + -- Train for a single epoch + local trainTop1, trainTop5, trainLoss, trainLossAbs = trainer:train(epoch, trainLoader) + + -- Run model on validation set + local testTop1, testTop5, testLoss, testLossAbs = trainer:test(epoch, valLoader) + + -- Update logger + logger:add{trainLoss, trainLossAbs, trainTop1, + testLoss, testLossAbs, testTop1 } +-- logger:plot() + + local bestModel = false + if testTop1 > bestTop1 then + bestModel = true + bestTop1 = testTop1 + bestTop5 = testTop5 + bestLoss = testLoss + bestLossAbs = testLossAbs + bestEpoch = epoch + if opt.showFullOutput then + print(string.format(' * Best Model -- epoch:%i top1: %6.3f top5: %6.3f loss: %6.3f, lossabs: %6.3f', + bestEpoch, bestTop1, bestTop5, bestLoss, bestLossAbs)) + end + end + + checkpoints.save(epoch, model, trainer.optimState, bestModel, opt) +end + +--logger:plot() + +print(string.format(' * Finished Best Model -- epoch:%i top1: %6.3f top5: %6.3f loss: %6.3f, lossabs: %6.3f', + bestEpoch, bestTop1, bestTop5, bestLoss, bestLossAbs)) \ No newline at end of file diff --git a/models/cnn3st.lua b/models/cnn3st.lua new file mode 100644 index 0000000..a44d007 --- /dev/null +++ b/models/cnn3st.lua @@ -0,0 +1,276 @@ +local torch = require 'torch' +local nn = require 'nn' +require 'cunn' +local cudnn = require 'cudnn' +require 'stn' +local image = require 'image' + +local layers = {} + +-- NVIDIA CuDNN +layers.convolution = cudnn.SpatialConvolution +layers.maxPooling = cudnn.SpatialMaxPooling +layers.nonLinearity = cudnn.ReLU +layers.batchNorm = cudnn.SpatialBatchNormalization + +-- CPU +--layers.convolution = nn.SpatialConvolution +--layers.maxPooling = nn.SpatialMaxPooling +--layers.nonLinearity = nn.ReLU +--layers.batchNorm = nn.SpatialBatchNormalization + +-- Returns the number of output elements for a table of convolution layers and the new height and width of the image +local function convs_noutput(convs, input_size) + input_size = input_size or baseInputSize + -- Get the number of channels for conv that are multiscale or not + local nbr_input_channels = convs[1]:get(1).nInputPlane or + convs[1]:get(1):get(1).nInputPlane + local output = torch.CudaTensor(1, nbr_input_channels, input_size, input_size) + for _, conv in ipairs(convs) do + conv:cuda() + output = conv:forward(output) + end + return output:nElement(), output:size(3), output:size(4) +end + +-- Creates a conv module with the specified number of channels in input and output +-- If multiscale is true, the total number of output channels will be: +-- nbr_input_channels + nbr_output_channels +-- Using cnorm adds the spatial contrastive normalization module +-- The filter size for the convolution can be specified (default 5) +-- The stride of the convolutions is fixed at 1 +local function new_conv(nbr_input_channels,nbr_output_channels, multiscale, cnorm, filter_size) + multiscale = multiscale or false + cnorm = cnorm or false + filter_size = filter_size or 5 + local padding_size = 2 + local pooling_size = 2 + local norm_kernel = image.gaussian(7) + + local conv + + local first = nn.Sequential() + first:add(layers.convolution(nbr_input_channels, + nbr_output_channels, + filter_size, filter_size, + 1,1, + padding_size, padding_size)) + first:add(layers.nonLinearity(true)) + first:add(layers.maxPooling(pooling_size, pooling_size, + pooling_size, pooling_size)) + + if cnorm then + first:add(nn.SpatialContrastiveNormalization(nbr_output_channels, norm_kernel)) + end + + if multiscale then + conv = nn.Sequential() + local second = layers.maxPooling(pooling_size, pooling_size, + pooling_size, pooling_size) + + local parallel = nn.ConcatTable() + parallel:add(first) + parallel:add(second) + conv:add(parallel) + conv:add(nn.JoinTable(1,3)) + else + conv = first + end + + return conv +end + +-- Creates a fully connection layer with the specified size. +local function new_fc(nbr_input, nbr_output) + local fc = nn.Sequential() + fc:add(nn.View(nbr_input)) + fc:add(nn.Linear(nbr_input, nbr_output)) + fc:add(layers.nonLinearity(true)) + return fc +end + +-- Creates a classifier with the specified size. +local function new_classifier(nbr_input, nbr_output) + local classifier = nn.Sequential() + classifier:add(nn.View(nbr_input)) + classifier:add(nn.Linear(nbr_input, nbr_output)) + return classifier +end + +-- Creates a spatial transformer module +-- locnet are the parameters to create the localization network +-- rot, sca, tra can be used to force specific transformations +-- input_size is the height (=width) of the input +-- input_channels is the number of channels in the input +-- no_cuda due to (1) below, we need to know if the network will run on cuda +local function new_spatial_tranformer(locnet, rot, sca, tra, input_size, input_channels) + local nbr_elements = {} + for c in string.gmatch(locnet, "%d+") do + nbr_elements[#nbr_elements + 1] = tonumber(c) + end + + -- Get number of params and initial state + local init_bias = {} + local nbr_params = 0 + if rot then + nbr_params = nbr_params + 1 + init_bias[nbr_params] = 0 + end + if sca then + nbr_params = nbr_params + 1 + init_bias[nbr_params] = 1 + end + if tra then + nbr_params = nbr_params + 2 + init_bias[nbr_params-1] = 0 + init_bias[nbr_params] = 0 + end + if nbr_params == 0 then + -- fully parametrized case + nbr_params = 6 + init_bias = {1,0,0,0,1,0} + end + + local st = nn.Sequential() + + -- Create a localization network same as cnn but with downsampled inputs + local localization_network = nn.Sequential() + local conv1 = new_conv(input_channels, nbr_elements[1], false, false) + local conv2 = new_conv(nbr_elements[1], nbr_elements[2], false, false) + local conv_output_size = convs_noutput({conv1, conv2}, input_size/2) + local fc = new_fc(conv_output_size, nbr_elements[3]) + local classifier = new_classifier(nbr_elements[3], nbr_params) + -- Initialize the localization network (see paper, A.3 section) + classifier:get(2).weight:zero() + classifier:get(2).bias = torch.Tensor(init_bias) + + localization_network:add(layers.maxPooling(2,2,2,2)) + localization_network:add(conv1) + localization_network:add(conv2) + localization_network:add(fc) + localization_network:add(classifier) + + -- Create the actual module structure + local ct = nn.ConcatTable() + + local branch1 = nn.Sequential() + branch1:add(nn.Transpose({3,4},{2,4})) + branch1:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true)) + + local branch2 = nn.Sequential() + branch2:add(localization_network) + branch2:add(nn.AffineTransformMatrixGenerator(rot, sca, tra)) + branch2:add(nn.AffineGridGeneratorBHWD(input_size, input_size)) + branch2:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true)) + + ct:add(branch1) + ct:add(branch2) + + st:add(ct) + local sampler = nn.BilinearSamplerBHWD() + -- (1) + -- The sampler lead to non-reproducible results on GPU + -- We want to always keep it on CPU + -- This does no lead to slowdown of the training + sampler:type('torch.FloatTensor') + -- make sure it will not go back to the GPU when we call + -- ":cuda()" on the network later + sampler.type = function(type) + return self + end + st:add(sampler) + st:add(nn.Copy('torch.FloatTensor','torch.CudaTensor', true, true)) + st:add(nn.Transpose({2,4},{3,4})) + + return st +end + +local function createModel(opt) + nInputChannels = 3 + baseInputSize = opt.baseInputSize or 48 + local cnorm = opt.cNormConv or false + local nbr_elements = {} + for c in string.gmatch(opt.cnn, "%d+") do + nbr_elements[#nbr_elements + 1] = tonumber(c) + end + assert(#nbr_elements == 4, + 'opt.cnn should contain 4 comma separated values, got '..#nbr_elements) + + local conv1 = new_conv(nInputChannels, nbr_elements[1], false, cnorm, 7) + local conv2 = new_conv(nbr_elements[1], nbr_elements[2], false, cnorm, 4) + local conv3 = new_conv(nbr_elements[2], nbr_elements[3], false, cnorm, 4) + + local convOutputSize, _, _ = convs_noutput({conv1, conv2, conv3}) + + local fc = new_fc(convOutputSize, nbr_elements[4]) + local fc_class = new_classifier(nbr_elements[4], opt.nClasses) + + local features = nn.Sequential() + + if opt.locnet1 and opt.locnet1 ~= '' then + local st1 = new_spatial_tranformer( + opt.locnet1, -- locnet + false, false, false, -- rot, sca, tra + baseInputSize, -- input_size + nInputChannels -- input_channels + ) + features:add(st1) + end + + features:add(conv1) + + if opt.locnet2 and opt.locnet2 ~= '' then + local _, currentInputSize, _ = convs_noutput({conv1}) + local st2 = new_spatial_tranformer( + opt.locnet2, -- locnet + false, false, false, -- rot, sca, tra + currentInputSize, -- input_size + nbr_elements[1] -- input_channels + ) + features:add(st2) + end + + features:add(conv2) + + if opt.locnet3 and opt.locnet3 ~= '' then + local _, currentInputSize, _ = convs_noutput({conv1,conv2}) + + local st3 = new_spatial_tranformer( + opt.locnet3, -- locnet + false, false, false, -- rot, sca, tra + currentInputSize, -- input_size + nbr_elements[2] -- input_channels + ) + features:add(st3) + end + + features:add(conv3) + + local classifier = nn.Sequential() + classifier:add(fc) + classifier:add(fc_class) + + local model = nn.Sequential():add(features):add(classifier) + + return model +end + +return createModel + +-- Test paramareters +--opt = {} +--opt.nClasses = 43 +--opt.cnn = '200,250,350,400' +--opt.locnet1= '250,250,250' +--opt.locnet2= '150,200,300' +--opt.locnet3= '150,200,300' +--opt.globalNorm = 'false' +--opt.localNorm = 'false' +--model = createModel(opt) +--model:cuda() +--print(model) +-- +--parameters, gradParameters = model:getParameters() +--print(parameters:size()) +--print(gradParameters:size()) + diff --git a/models/init.lua b/models/init.lua new file mode 100644 index 0000000..2e5eb21 --- /dev/null +++ b/models/init.lua @@ -0,0 +1,160 @@ +-- +-- Copyright (c) 2016, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- +-- Generic model creating code. For the specific CNN3ST model see +-- models/cnn3st.lua +-- + +require 'nn' +require 'cunn' +require 'cudnn' +require 'stn' +local nninit = require 'nninit' + +local M = {} + +function M.setup(opt, checkpoint) + local model + if checkpoint then + local modelPath = paths.concat(opt.resume, checkpoint.modelFile) + assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath) + print('=> Resuming model from ' .. modelPath) + model = torch.load(modelPath):type(opt.tensorType) + model.__memoryOptimized = nil + elseif opt.retrain ~= 'none' then + assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain) + print('Loading model from file: ' .. opt.retrain) + model = torch.load(opt.retrain):type(opt.tensorType) + model.__memoryOptimized = nil + else + print('=> Creating model from file: models/' .. opt.netType .. '.lua') + model = require('models/' .. opt.netType)(opt) + end + + -- First remove any DataParallelTable + if torch.type(model) == 'nn.DataParallelTable' then + model = model:get(1) + end + + -- optnet is an general library for reducing memory usage in neural networks + if opt.optnet then + local optnet = require 'optnet' + local imsize = opt.dataset == 'imagenet' and 224 or 32 + local sampleInput = torch.zeros(4,3,imsize,imsize):type(opt.tensorType) + optnet.optimizeMemory(model, sampleInput, {inplace = false, mode = 'training'}) + end + + -- This is useful for fitting ResNet-50 on 4 GPUs, but requires that all + -- containers override backwards to call backwards recursively on submodules + if opt.shareGradInput then + M.shareGradInput(model, opt) + end + + -- For resetting the classifier when fine-tuning on a different Dataset + if opt.resetClassifier and not checkpoint then + print(' => Replacing classifier with ' .. opt.nClasses .. '-way classifier') + + local orig = model:get(#model.modules) + assert(torch.type(orig) == 'nn.Linear', + 'expected last layer to be fully connected') + + local linear = nn.Linear(orig.weight:size(2), opt.nClasses) + linear.bias:zero() + + model:remove(#model.modules) + model:add(linear:type(opt.tensorType)) + end + + -- Set the CUDNN flags + if opt.cudnn == 'fastest' then + cudnn.fastest = true + cudnn.benchmark = true + elseif opt.cudnn == 'deterministic' then + -- Use a deterministic convolution implementation + model:apply(function(m) + if m.setMode then m:setMode(1, 1, 1) end + end) + end + + -- Wrap the model with DataParallelTable, if using more than one GPU + if opt.nGPU > 1 then + local gpus = torch.range(1, opt.nGPU):totable() + local fastest, benchmark = cudnn.fastest, cudnn.benchmark + + local dpt = nn.DataParallelTable(1, true, true) + :add(model, gpus) + :threads(function() + local cudnn = require 'cudnn' + require 'stn' + cudnn.fastest, cudnn.benchmark = fastest, benchmark + end) + dpt.gradInput = nil + + model = dpt:type(opt.tensorType) + end + + -- Set model weights initialization + if opt.weightInit ~= 'default' then + -- Init weights of convolutional modules + initModelWeights(model, opt.weightInit) + end + + local criterion = nn.CrossEntropyCriterion():type(opt.tensorType) + model:cuda() + criterion:cuda() + return model, criterion +end + +function M.shareGradInput(model, opt) + local function sharingKey(m) + local key = torch.type(m) + if m.__shareGradInputKey then + key = key .. ':' .. m.__shareGradInputKey + end + return key + end + + -- Share gradInput for memory efficient backprop + local cache = {} + model:apply(function(m) + local moduleType = torch.type(m) + if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' then + local key = sharingKey(m) + if cache[key] == nil then + cache[key] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1) + end + m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[key], 1, 0) + end + end) + for i, m in ipairs(model:findModules('nn.ConcatTable')) do + if cache[i % 2] == nil then + cache[i % 2] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1) + end + m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[i % 2], 1, 0) + end +end + +function initModelWeights(model, initializer) + local convModulesNames = {'cudnn.SpatialConvolution', 'nn.SpatialConvolution'} + for _,name in ipairs(convModulesNames) do + for _,v in pairs(model:findModules(name)) do + if initializer == 'kaiming' then + v:init('weight', nninit.kaiming, {gain = {'relu'}}) + v:noBias() + elseif initializer == 'glorot' then + v:init('weight', nninit.xavier, {gain = {'relu'}}) + elseif initializer == 'uniform' then + v:init('weight', nninit.uniform, -0.5, 0.5) + elseif initializer == 'conv_aware' then + v:init('weight', nninit.convolutionAware, {gain = {'relu'}}) + end + end + end +end + +return M \ No newline at end of file diff --git a/opts.lua b/opts.lua new file mode 100644 index 0000000..81d0b86 --- /dev/null +++ b/opts.lua @@ -0,0 +1,100 @@ +-- +-- Copyright (c) 2016, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- +local M = { } + +function M.parse(arg) + local cmd = torch.CmdLine() + cmd:text() + cmd:text('Torch-7 Training script') + cmd:text() + cmd:text('Options:') + ------------ General options -------------------- + cmd:option('-data', '', 'Path to dataset') + cmd:option('-computeMeanStd', false, 'Compute mean and std') + cmd:option('-dataset', 'gtsrb', 'Options: gtsrb') + cmd:option('-manualSeed', 1, 'Manually set RNG seed') + cmd:option('-nGPU', 1, 'Number of GPUs to use by default') + cmd:option('-backend', 'cudnn', 'Options: cudnn | cunn') + cmd:option('-cudnn', 'default', 'Options: fastest | default | deterministic') + cmd:option('-gen', 'gen', 'Path to save generated files') + cmd:option('-precision', 'single', 'Options: single | double | half') + cmd:option('-showFullOutput', false, 'Whether show full training process (true) or just final output (false)' ) + ------------- Data options ------------------------ + cmd:option('-nThreads', 1, 'number of data loading threads') + ------------- Training options -------------------- + cmd:option('-nEpochs', 0, 'Number of total epochs to run') + cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)') + cmd:option('-batchSize', 32, 'mini-batch size (1 = pure stochastic)') + cmd:option('-testOnly', false, 'Run on validation set only') + cmd:option('-tenCrop', false, 'Ten-crop testing') + ------------- Checkpointing options --------------- + cmd:option('-checkpoint', false, 'Save model after each epoch') + cmd:option('-save', 'checkpoints', 'Directory in which to save checkpoints') + cmd:option('-resume', 'none', 'Resume from the latest checkpoint in this directory') + ---------- Optimization options ---------------------- + cmd:option('-optimizer', 'sgd', 'Options: sgd | adam | rmsprop | adagrad | lbfgs | nag') + cmd:option('-LR', 0.01, 'initial learning rate') + cmd:option('-momentum', 0.9, 'momentum') + cmd:option('-weightDecay', 1e-4, 'weight decay') + cmd:option('-nesterov', false , 'Nesterov') + cmd:option('-LRDecayStep', 10, 'number of steps to decay LR by 0.1') + ---------- Model options ---------------------------------- + cmd:option('-netType', 'cnn3st', 'Options: cnn3st') + cmd:option('-retrain', 'none', 'Path to model to retrain with') + cmd:option('-optimState', 'none', 'Path to an optimState to reload from') + cmd:option('-weightInit', 'default', 'Options: default | kaiming | glorot | uniform | conv_aware') + ---------- Model paper_conv3_st3 options ---------------------------------- + cmd:option('-cnn', '200,250,350,400', 'Network parameters (conv1_out, conv2_out, conv3_out, fc1_out)') + cmd:option('-locnet1', '250,250,250', 'Localization network 1 parameters') + cmd:option('-locnet2', '150,200,300', 'Localization network 2 parameters') + cmd:option('-locnet3', '150,200,300', 'Localization network 3 parameters') + cmd:option('-globalNorm', false, 'Whether perform global normalization') + cmd:option('-localNorm', false, 'Whether perform local normalization') + cmd:option('-cNormConv', false, 'Whether perform contrastive normalization in conv modules') + cmd:option('-dataAug', false, 'Whether perform data augmentation on training dataset') + ---------- Model options ---------------------------------- + cmd:option('-shareGradInput', false, 'Share gradInput tensors to reduce memory usage') + cmd:option('-optnet', false, 'Use optnet to reduce memory usage') + cmd:option('-resetClassifier', false, 'Reset the fully connected layer for fine-tuning') + cmd:option('-nClasses', 0, 'Number of classes in the dataset') + cmd:option('-baseInputSize', 48, 'Size of input images') + cmd:text() + + local opt = cmd:parse(arg or {}) + + if not paths.dirp(opt.save) and not paths.mkdir(opt.save) then + cmd:error('error: unable to create checkpoint directory: ' .. opt.save .. '\n') + end + + if opt.precision == nil or opt.precision == 'single' then + opt.tensorType = 'torch.CudaTensor' + elseif opt.precision == 'double' then + opt.tensorType = 'torch.CudaDoubleTensor' + elseif opt.precision == 'half' then + opt.tensorType = 'torch.CudaHalfTensor' + else + cmd:error('unknown precision: ' .. opt.precision) + end + + if opt.resetClassifier then + if opt.nClasses == 0 then + cmd:error('-nClasses required when resetClassifier is set') + end + end + if opt.shareGradInput and opt.optnet then + cmd:error('error: cannot use both -shareGradInput and -optnet') + end + + print('--- Options ---') + print(opt) + + return opt +end + +return M \ No newline at end of file diff --git a/run_model.lua b/run_model.lua new file mode 100644 index 0000000..b5f25be --- /dev/null +++ b/run_model.lua @@ -0,0 +1,46 @@ +require 'torch' +require 'image' +require 'cunn' +require 'cudnn' +require 'stn' +require 'nn' + +print("Loading network...") +local model_path = "pretrained/gtsrb_cnn3st_model.t7" +local mean_std_path = "pretrained/gtsrb_cnn3st_mean_std.t7" +local network = torch.load(model_path) +local mean_std = torch.load(mean_std_path) +print("--- Network ---") +print(network) +print("--- Mean/Std ---") +local mean, std = mean_std[1], mean_std[2] +print("Mean:"..mean, "Std:"..std) + + +print("Loading sample images...") +local sample_img1 = image.load("sample_images/img1.jpg") +sample_img1 = image.scale(sample_img1, 48, 48) +local sample_img2 = image.load("sample_images/img2.jpg") +sample_img2 = image.scale(sample_img2, 48, 48) +local samples_tensor = torch.Tensor(2,sample_img1:size(1), sample_img1:size(2), sample_img1:size(3)):fill(0) +samples_tensor[1]:copy(sample_img1) +samples_tensor[2]:copy(sample_img2) + +print("Applying global normalization to sample image") +samples_tensor:add(-mean) +samples_tensor:div(std) + +print("Applying local contrast normalization to sample image") +local norm_kernel = image.gaussian1D(7) +local local_normalizer = nn.SpatialContrastiveNormalization(3, norm_kernel) +samples_tensor:copy(local_normalizer:forward(samples_tensor)) + +print("Running the network...") +samples_tensor = samples_tensor:cuda() +local scores = network:forward(samples_tensor) +print("Scores...") +print(scores) +local _, prediction1 = scores[1]:max(1) +local _, prediction2 = scores[2]:max(1) +print("Prediction class sample img 1: " .. prediction1[1] - 1) +print("Prediction class sample img 2: " .. prediction2[1] - 1) \ No newline at end of file diff --git a/trainer.lua b/trainer.lua new file mode 100644 index 0000000..2c4508a --- /dev/null +++ b/trainer.lua @@ -0,0 +1,245 @@ +-- +-- Copyright (c) 2016, Facebook, Inc. +-- All rights reserved. +-- +-- This source code is licensed under the BSD-style license found in the +-- LICENSE file in the root directory of this source tree. An additional grant +-- of patent rights can be found in the PATENTS file in the same directory. +-- +-- The training loop and learning rate schedule +-- + +local optim = require 'optim' + +local M = {} +local Trainer = torch.class('Trainer', M) + +function Trainer:__init(model, criterion, opt, optimState) + self.model = model + self.criterion = criterion + self.optimState = optimState or { + learningRate = opt.LR, + learningRateDecay = 0.0, + momentum = opt.momentum, + nesterov = opt.nesterov, + dampening = 0.0, + weightDecay = opt.weightDecay, + } + self.opt = opt + self.params, self.gradParams = model:getParameters() + print("-- Model parameters --") + print(self.params:size(1)) +end + +function Trainer:train(epoch, dataloader) + -- Trains the model for a single epoch + self.optimState.learningRate = self:learningRate(epoch) + + local totalTime = 0 + local timer = torch.Timer() + local dataTimer = torch.Timer() + + local function feval() + return self.criterion.output, self.gradParams + end + + local absCriterion = nn.AbsCriterion():cuda() + + local trainSize = dataloader:size() + local top1Sum, top5Sum, lossSum, lossAbsSum = 0.0, 0.0, 0.0, 0.0 + local N = 0 + + print('=> Training epoch # ' .. epoch) + -- set the batch norm to training mode + self.model:training() + for n, sample in dataloader:run() do + local dataTime = dataTimer:time().real + + -- Copy input and target to the GPU + self:copyInputs(sample) + + local output = self.model:forward(self.input):float() + local batchSize = output:size(1) + local loss = self.criterion:forward(self.model.output, self.target) + + -- Average prediction for regression + local softmax = nn.SoftMax():cuda() + local outputSoft = softmax:forward(self.model.output) + local avgPred = outputSoft * torch.range(1,self.opt.nClasses):cuda() + local lossAbs = absCriterion:forward(avgPred, self.target) + + self.model:zeroGradParameters() + self.criterion:backward(self.model.output, self.target) + self.model:backward(self.input, self.criterion.gradInput) + + if self.opt.optimizer == 'sgd' then + optim.sgd(feval, self.params, self.optimState) + elseif self.opt.optimizer == 'adam' then + optim.adam(feval, self.params, self.optimState) + elseif self.opt.optimizer == 'rmsprop' then + optim.rmsprop(feval, self.params, self.optimState) + elseif self.opt.optimizer == 'adagrad' then + optim.adagrad(feval, self.params, self.optimState) + elseif self.opt.optimizer == 'lbfgs' then + optim.lbfgs(feval, self.params, self.optimState) + elseif self.opt.optimizer == 'nag' then + optim.nag(feval, self.params, self.optimState) + end + + local top1, top5 = self:computeScore(output, sample.target, 1) + top1Sum = top1Sum + top1*batchSize + top5Sum = top5Sum + top5*batchSize + lossSum = lossSum + loss*batchSize + lossAbsSum = lossAbsSum + lossAbs*batchSize + N = N + batchSize + + if self.opt.showFullOutput then + print((' | Epoch: [%d][%d/%d] Time %.3f Data %.3f Err %1.4f top1 %7.3f top5 %7.3f lossAbs %7.3f'):format( + epoch, n, trainSize, timer:time().real, dataTime, loss, top1, top5, lossAbs)) + end + + totalTime = totalTime + timer:time().real + + -- check that the storage didn't get changed due to an unfortunate getParameters call + assert(self.params:storage() == self.model:parameters()[1]:storage()) + + timer:reset() + dataTimer:reset() + end + + if self.opt.showFullOutput then + print((' * [Train] Finished epoch # %d top1: %7.3f top5: %7.3f loss: %7.3f lossAbs: %7.3f \n'):format( + epoch, top1Sum / N, top5Sum / N, lossSum / N, lossAbsSum / N)) + end + + print((' || Train total time: %2.5f'):format(totalTime)) + + return top1Sum / N, top5Sum / N, lossSum / N, lossAbsSum / N +end + +function Trainer:test(epoch, dataloader) + -- Computes the top-1 and top-5 err on the validation set + + local totalTime = 0 + local timer = torch.Timer() + local dataTimer = torch.Timer() + local size = dataloader:size() + + local nCrops = self.opt.tenCrop and 10 or 1 + local top1Sum, top5Sum, lossSum, lossAbsSum = 0.0, 0.0, 0.0, 0.0 + local N = 0 + + local absCriterion = nn.AbsCriterion():cuda() + + self.model:evaluate() + + for n, sample in dataloader:run() do + local dataTime = dataTimer:time().real + + -- Copy input and target to the GPU + self:copyInputs(sample) + + local output = self.model:forward(self.input):float() + local batchSize = output:size(1) / nCrops + local loss = self.criterion:forward(self.model.output, self.target) + + -- Average prediction for regression + local softmax = nn.SoftMax():cuda() + local outputSoft = softmax:forward(self.model.output) + local avgPred = outputSoft * torch.range(1,self.opt.nClasses):cuda() + local lossAbs = absCriterion:forward(avgPred, self.target) + + local top1, top5 = self:computeScore(output, sample.target, nCrops) + top1Sum = top1Sum + top1*batchSize + top5Sum = top5Sum + top5*batchSize + lossSum = lossSum + loss*batchSize + lossAbsSum = lossAbsSum + lossAbs*batchSize + N = N + batchSize + + if self.opt.showFullOutput then + print((' | Test: [%d][%d/%d] Time %.3f Data %.3f top1 %7.3f (%7.3f) top5 %7.3f (%7.3f) loss %7.3f (%7.3f) lossAbs %7.3f (%7.3f)'):format( + epoch, n, size, timer:time().real, dataTime, top1, top1Sum / N, top5, top5Sum / N, loss, lossSum / N, lossAbs, lossAbsSum / N)) + end + + totalTime = totalTime + timer:time().real + + timer:reset() + dataTimer:reset() + end + self.model:training() + + if self.opt.showFullOutput then + print((' * [Test] Finished epoch # %d top1: %7.3f top5: %7.3f loss: %7.3f lossAbs: %7.3f \n'):format( + epoch, top1Sum / N, top5Sum / N, lossSum / N, lossAbsSum / N)) + end + + print((' || Test total time: %2.5f'):format(totalTime)) + + return top1Sum / N, top5Sum / N, lossSum / N, lossAbsSum / N +end + +function Trainer:computeScore(output, target, nCrops) + if nCrops > 1 then + -- Sum over crops + output = output:view(output:size(1) / nCrops, nCrops, output:size(2)) + --:exp() + :sum(2):squeeze(2) + end + + -- Coputes the top1 and top5 error rate + local batchSize = output:size(1) + + local _ , predictions = output:float():topk(5, 2, true, true) -- descending + + -- Find which predictions match the target + local correct = predictions:eq(target:long():view(batchSize, 1):expandAs(predictions)) + + -- Top-1 score + local top1 = correct:narrow(2, 1, 1):sum() / batchSize + + -- Top-5 score, if there are at least 5 classes + local len = math.min(5, correct:size(2)) + local top5 = correct:narrow(2, 1, len):sum() / batchSize + + return top1 * 100, top5 * 100 +end + +local function getCudaTensorType(tensorType) + if tensorType == 'torch.CudaHalfTensor' then + return cutorch.createCudaHostHalfTensor() + elseif tensorType == 'torch.CudaDoubleTensor' then + return cutorch.createCudaHostDoubleTensor() + else + return cutorch.createCudaHostTensor() + end +end + +function Trainer:copyInputs(sample) + -- Copies the input to a CUDA tensor, if using 1 GPU, or to pinned memory, + -- if using DataParallelTable. The target is always copied to a CUDA tensor + self.input = self.input or (self.opt.nGPU == 1 + and torch[self.opt.tensorType:match('torch.(%a+)')]() + or getCudaTensorType(self.opt.tensorType)) + self.target = self.target or torch.CudaTensor() + self.input:resize(sample.input:size()):copy(sample.input) + self.target:resize(sample.target:size()):copy(sample.target) +end + +function Trainer:learningRate(epoch) + -- Training schedule +-- local decay = math.floor((epoch - 1) / self.opt.LRDecayStep) +-- local newLR = self.opt.LR * math.pow(0.1, decay) +-- print('New LR: ' .. newLR) +-- return newLR +-- local decay = 0 +-- if self.opt.optimizer == 'adam' then +-- decay = 1.0/math.sqrt(epoch) +-- return self.opt.LR * decay +-- else +-- decay = math.floor((epoch - 1) / self.opt.LRDecayStep) +-- return self.opt.LR * math.pow(0.1, decay) +-- end + return self.opt.LR +end + +return M.Trainer \ No newline at end of file diff --git a/utils.lua b/utils.lua new file mode 100644 index 0000000..a7719c1 --- /dev/null +++ b/utils.lua @@ -0,0 +1,101 @@ +require 'cunn' +local ffi=require 'ffi' +local util = {} + +-- Functions from https://github.com/soumith/imagenet-multiGPU.torch/blob/master/util.lua + +function util.makeDataParallel(model, nGPU) + + if nGPU > 1 then + print('converting module to nn.DataParallelTable') + assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified') + if opt.backend == cudnn and opt.cudnnAutotune == 1 then + local gpu_table = torch.range(1, nGPU):totable() + local dpt = nn.DataParallelTable(1, true):add(model, gpu_table):threads(function() + require 'cudnn' + require 'stn' + cudnn.benchmark = true + end) + + dpt.gradInput = nil + model = dpt:cuda() + else + local model_single = model + model = nn.DataParallelTable(1) + for i=1, nGPU do + cutorch.setDevice(i) + model:add(model_single:clone():cuda(), i) + end + cutorch.setDevice(opt.GPU) + end + else + if (opt.backend == cudnn and opt.cudnnAutotune == 1) then + require 'cudnn' + cudnn.benchmark = true + end + end + + return model +end + +local function cleanDPT(module) + -- This assumes this DPT was created by the function above: all the + -- module.modules are clones of the same network on different GPUs + -- hence we only need to keep one when saving the model to the disk. + local newDPT = nn.DataParallelTable(1) + cutorch.setDevice(opt.GPU) + newDPT:add(module:get(1), opt.GPU) + return newDPT +end + +function util.saveDataParallel(filename, model) + if torch.type(model) == 'nn.DataParallelTable' then + torch.save(filename, cleanDPT(model)) + elseif torch.type(model) == 'nn.Sequential' then + local temp_model = nn.Sequential() + for i, module in ipairs(model.modules) do + if torch.type(module) == 'nn.DataParallelTable' then + temp_model:add(cleanDPT(module)) + else + temp_model:add(module) + end + end + torch.save(filename, temp_model) + else + error('This saving function only works with Sequential or DataParallelTable modules.') + end +end + +function util.loadDataParallel(filename, nGPU) + if opt.backend == cudnn then + require 'cudnn' + end + local model = torch.load(filename) + if torch.type(model) == 'nn.DataParallelTable' then + return makeDataParallel(model:get(1):float(), nGPU) + elseif torch.type(model) == 'nn.Sequential' then + for i,module in ipairs(model.modules) do + if torch.type(module) == 'nn.DataParallelTable' then + model.modules[i] = makeDataParallel(module:get(1):float(), nGPU) + end + end + return model + else + error('The loaded model is not a Sequential or DataParallelTable module.') + end +end + +function util.countParameters(model) + local n_parameters = 0 + for i=1, model:size() do + local params = model:get(i):parameters() + if params then + local weights = params[1] + local biases = params[2] + n_parameters = n_parameters + weights:nElement() + biases:nElement() + end + end + return n_parameters +end + +return util \ No newline at end of file