diff --git a/README.md b/README.md
index f24e388..6f01020 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,67 @@
-# tsr-torch
-Traffic Sign Recognition with Torch [WIP]
+# Traffic Sign Recognition
+This is the code for the paper
+
+**[Deep neural network for traffic sign recognition systems: An analysis of spatial transformers and stochastic optimisation methods](https://doi.org/10.1016/j.neunet.2018.01.005)**
+
+[Álvaro Arcos-García](https://scholar.google.com/citations?user=gjecl3cAAAAJ),
+[Juan Antonio Álvarez-García](https://scholar.google.com/citations?user=Qk79xk8AAAAJ),
+[Luis M. Soria-Morillo](https://scholar.google.com/citations?user=poBDpFkAAAAJ)
+
+
+The paper addresses the problem of traffic sign classification using a Deep Neural Network which comprises Convolutional layers and Spatial Transfomer Networks. The model reports an accuracy of 99.71% in the [German Traffic Sign Recognition Benchmark](http://benchmark.ini.rub.de/?section=gtsrb&subsection=results).
+
+We provide:
+- A [pretrained model](#pretrained-model).
+- Test code to [run the model on new images](#running-on-new-images).
+- Instructions for [training the model](#training).
+
+If you find this code useful in your research, please cite:
+
+```
+"Deep neural network for traffic sign recognition systems: An analysis of spatial transformers and stochastic optimisation methods."
+Álvaro Arcos-García, Juan A. Álvarez-García, Luis M. Soria-Morillo. Neural Networks 99 (2018) 158-165.
+```
+\[[link](https://doi.org/10.1016/j.neunet.2018.01.005)\]\[[bibtex](
+https://scholar.googleusercontent.com/citations?view_op=export_citations&user=gjecl3cAAAAJ&citsig=AMstHGQAAAAAW8ngijLlAnMmG2C2_Weu4sB-TWRmdT1P&hl=en)\]
+
+## Requirements
+This project is implemented in [Torch](http://torch.ch/), and depends on the following packages: [torch/torch7](https://github.com/torch/torch7), [torch/nn](https://github.com/torch/nn), [torch/image](https://github.com/torch/image), [qassemoquab/stnbhwd](https://github.com/qassemoquab/stnbhwd), [torch/cutorch](https://github.com/torch/cutorch), [torch/cunn](https://github.com/torch/cunn), [cuDNN bindings for Torch](https://github.com/soumith/cudnn.torch) and [nninit](https://github.com/Kaixhin/nninit).
+
+After installing torch, you can install / update these dependencies by running the following:
+```bash
+luarocks install torch
+luarocks install nn
+luarocks install image
+luarocks install cutorch
+luarocks install cunn
+luarocks install cudnn
+luarocks install https://raw.githubusercontent.com/qassemoquab/stnbhwd/master/stnbhwd-scm-1.rockspec
+luarocks install https://github.com/Kaixhin/nninit/blob/master/rocks/nninit-scm-1.rockspec
+```
+
+## Pretrained model
+You can download a pretrained model from [Google Drive](https://drive.google.com/uc?export=download&id=1iqX1TZBxJSEmaEoReK2ZDko03JHqw1bp).
+Unzip the file `gtsrb_cnn3st_pretrained.zip` and move its content to the folder `pretrained` of this project. It contains the pretrained model that obtains an accuracy of 99.71% in the German Traffic Sign Recognition Benchmark (GTSRB) and a second file with the mean and standard deviation values computed during the training process.
+
+## Running on new images
+To run the model on new images, use the script `run_model.lua`. It will run the pretrained model on the images provided in the `sample_images` folder:
+```bash
+th run_model.lua
+```
+
+## Training
+To train and experiment with new traffic sign classification models, please follow the following steps:
+1. Use the script `download_gtsrb_dataset.lua` which will create a new folder called `GTSRB` that will include two folders: `train` and `val`.
+ ```bash
+ th download_gtsrb_dataset.lua
+ ```
+2. Use the script `main.lua` setting the training parameters described in the file `opts.lua`. For example, the following options will generate the files needed to train a model with just one spatial transformer network localized at the beginning of the main network:
+ ```bash
+ th main.lua -data GTSRB -save GTSRB/checkpoints -dataset gtsrb -nClasses 43 -optimizer adam -LR 1e-4 -momentum 0 -weightDecay 0 -batchSize 50 -nEpochs 15 -weightInit default -netType cnn3st -globalNorm -localNorm -cNormConv -locnet2 '' -locnet3 '' -showFullOutput
+ ```
+ The next example, will create a model with three spatial transformer networks:
+ ```bash
+ th main.lua -data GTSRB -save GTSRB/checkpoints -dataset gtsrb -nClasses 43 -optimizer rmsprop -LR 1e-5 -momentum 0 -weightDecay 0 -batchSize 50 -nEpochs 15 -weightInit default -netType cnn3st -globalNorm -localNorm -cNormConv -showFullOutput
+ ```
+## Acknowledgements
+The source code of this project is mainly based on [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and [gtsrb.torch](https://github.com/Moodstocks/gtsrb.torch).
\ No newline at end of file
diff --git a/checkpoints.lua b/checkpoints.lua
new file mode 100644
index 0000000..c8d8dd4
--- /dev/null
+++ b/checkpoints.lua
@@ -0,0 +1,79 @@
+--
+-- Copyright (c) 2016, Facebook, Inc.
+-- All rights reserved.
+--
+-- This source code is licensed under the BSD-style license found in the
+-- LICENSE file in the root directory of this source tree. An additional grant
+-- of patent rights can be found in the PATENTS file in the same directory.
+--
+local checkpoint = {}
+
+local function deepCopy(tbl)
+ -- creates a copy of a network with new modules and the same tensors
+ local copy = {}
+ for k, v in pairs(tbl) do
+ if type(v) == 'table' then
+ copy[k] = deepCopy(v)
+ else
+ copy[k] = v
+ end
+ end
+ if torch.typename(tbl) then
+ torch.setmetatable(copy, torch.typename(tbl))
+ end
+ return copy
+end
+
+function checkpoint.latest(opt)
+ if opt.resume == 'none' then
+ return nil
+ end
+
+ local latestPath = paths.concat(opt.resume, 'latest.t7')
+ if not paths.filep(latestPath) then
+ return nil
+ end
+
+ print('=> Loading checkpoint ' .. latestPath)
+ local latest = torch.load(latestPath)
+ local optimState = torch.load(paths.concat(opt.resume, latest.optimFile))
+
+ return latest, optimState
+end
+
+function checkpoint.save(epoch, model, optimState, isBestModel, opt)
+ -- don't save the DataParallelTable for easier loading on other machines
+ if torch.type(model) == 'nn.DataParallelTable' then
+ model = model:get(1)
+ end
+
+ -- create a clean copy on the CPU without modifying the original network
+ model = deepCopy(model):float():clearState()
+
+ local modelFile = 'model_' .. epoch .. '.t7'
+ local optimFile = 'optimState_' .. epoch .. '.t7'
+
+ if opt.checkpoint == 'true' then
+ torch.save(paths.concat(opt.save, modelFile), model)
+ torch.save(paths.concat(opt.save, optimFile), optimState)
+ torch.save(paths.concat(opt.save, 'latest.t7'), {
+ epoch = epoch,
+ modelFile = modelFile,
+ optimFile = optimFile,
+ })
+ end
+
+ if isBestModel then
+ local bestModelFile = 'model_best.t7'
+ local bestOptimFile = 'model_best_optimState.t7'
+ torch.save(paths.concat(opt.save, bestModelFile), model)
+ torch.save(paths.concat(opt.save, bestOptimFile), optimState)
+ torch.save(paths.concat(opt.save, 'latest.t7'), {
+ epoch = epoch,
+ modelFile = bestModelFile,
+ optimFile = bestOptimFile,
+ })
+ end
+end
+
+return checkpoint
\ No newline at end of file
diff --git a/dataloader.lua b/dataloader.lua
new file mode 100644
index 0000000..28ad269
--- /dev/null
+++ b/dataloader.lua
@@ -0,0 +1,127 @@
+--
+-- Copyright (c) 2016, Facebook, Inc.
+-- All rights reserved.
+--
+-- This source code is licensed under the BSD-style license found in the
+-- LICENSE file in the root directory of this source tree. An additional grant
+-- of patent rights can be found in the PATENTS file in the same directory.
+--
+-- Multi-threaded data loader
+--
+
+local datasets = require 'datasets/init'
+local Threads = require 'threads'
+Threads.serialization('threads.sharedserialize')
+
+local M = {}
+local DataLoader = torch.class('DataLoader', M)
+
+function DataLoader.create(opt)
+ -- The train and val loader
+ local loaders = {}
+
+ for i, split in ipairs{'train', 'val'} do
+ local dataset = datasets.create(opt, split)
+ loaders[i] = M.DataLoader(dataset, opt, split)
+ end
+
+ return table.unpack(loaders)
+end
+
+function DataLoader:__init(dataset, opt, split)
+ local manualSeed = opt.manualSeed
+ local function init()
+ require('datasets/' .. opt.dataset)
+ end
+ local function main(idx)
+ if manualSeed ~= 0 then
+ torch.manualSeed(manualSeed + idx)
+ end
+ torch.setnumthreads(1)
+ _G.dataset = dataset
+ _G.preprocess = dataset:preprocess()
+ return dataset:size()
+ end
+
+ local threads, sizes = Threads(opt.nThreads, init, main)
+ self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1
+ self.threads = threads
+ self.__size = sizes[1][1]
+ self.batchSize = math.floor(opt.batchSize / self.nCrops)
+ local function getCPUType(tensorType)
+ if tensorType == 'torch.CudaHalfTensor' then
+ return 'HalfTensor'
+ elseif tensorType == 'torch.CudaDoubleTensor' then
+ return 'DoubleTensor'
+ else
+ return 'FloatTensor'
+ end
+ end
+ self.cpuType = getCPUType(opt.tensorType)
+end
+
+function DataLoader:size()
+ return math.ceil(self.__size / self.batchSize)
+end
+
+function DataLoader:run()
+ local threads = self.threads
+ local size, batchSize = self.__size, self.batchSize
+ local perm = torch.randperm(size)
+
+ local idx, sample = 1, nil
+ local function enqueue()
+ while idx <= size and threads:acceptsjob() do
+ local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
+ threads:addjob(
+ function(indices, nCrops, cpuType)
+ local sz = indices:size(1)
+ local batch, imageSize
+ local target = torch.IntTensor(sz)
+ for i, idx in ipairs(indices:totable()) do
+ local sample = _G.dataset:get(idx)
+ local input = _G.preprocess(sample.input)
+ if not batch then
+ imageSize = input:size():totable()
+ if nCrops > 1 then table.remove(imageSize, 1) end
+ batch = torch[cpuType](sz, nCrops, table.unpack(imageSize))
+ end
+ batch[i]:copy(input)
+ target[i] = sample.target
+ end
+ collectgarbage()
+ return {
+ input = batch:view(sz * nCrops, table.unpack(imageSize)),
+ target = target,
+ }
+ end,
+ function(_sample_)
+ sample = _sample_
+ end,
+ indices,
+ self.nCrops,
+ self.cpuType
+ )
+ idx = idx + batchSize
+ end
+ end
+
+ local n = 0
+ local function loop()
+ enqueue()
+ if not threads:hasjob() then
+ return nil
+ end
+ threads:dojob()
+ if threads:haserror() then
+ threads:synchronize()
+ end
+ enqueue()
+ n = n + 1
+ return n, sample
+ end
+
+ return loop
+end
+
+return M.DataLoader
\ No newline at end of file
diff --git a/datasets/gtsrb-gen.lua b/datasets/gtsrb-gen.lua
new file mode 100644
index 0000000..fd85f40
--- /dev/null
+++ b/datasets/gtsrb-gen.lua
@@ -0,0 +1,125 @@
+--
+-- Copyright (c) 2016, Facebook, Inc.
+-- All rights reserved.
+--
+-- This source code is licensed under the BSD-style license found in the
+-- LICENSE file in the root directory of this source tree. An additional grant
+-- of patent rights can be found in the PATENTS file in the same directory.
+--
+-- Script to compute list of ImageNet filenames and classes
+--
+-- This generates a file gen/imagenet.t7 which contains the list of all
+-- ImageNet training and validation images and their classes. This script also
+-- works for other datasets arragned with the same layout.
+--
+
+local sys = require 'sys'
+local ffi = require 'ffi'
+
+local M = {}
+
+local function findClasses(dir)
+ local dirs = paths.dir(dir)
+ table.sort(dirs)
+
+ local classList = {}
+ local classToIdx = {}
+ for _ ,class in ipairs(dirs) do
+ if not classToIdx[class] and class ~= '.' and class ~= '..' and class ~= '.DS_Store' then
+ table.insert(classList, class)
+ classToIdx[class] = #classList
+ end
+ end
+
+ assert(#classList == 43, 'expected 43 GTSRB classes')
+ return classList, classToIdx
+end
+
+local function findImages(dir, classToIdx)
+ local imagePath = torch.CharTensor()
+ local imageClass = torch.LongTensor()
+
+ ----------------------------------------------------------------------
+ -- Options for the GNU and BSD find command
+ local extensionList = {'jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
+ local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
+ for i=2,#extensionList do
+ findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
+ end
+
+ -- Find all the images using the find command
+ local f = io.popen('find -L ' .. dir .. findOptions)
+
+ local maxLength = -1
+ local imagePaths = {}
+ local imageClasses = {}
+
+ -- Generate a list of all the images and their class
+ while true do
+ local line = f:read('*line')
+ if not line then break end
+
+ local className = paths.basename(paths.dirname(line))
+ local filename = paths.basename(line)
+ local path = className .. '/' .. filename
+
+ local classId = classToIdx[className]
+ assert(classId, 'class not found: ' .. className)
+
+ table.insert(imagePaths, path)
+ table.insert(imageClasses, classId)
+
+ maxLength = math.max(maxLength, #path + 1)
+ end
+
+ f:close()
+
+ -- Convert the generated list to a tensor for faster loading
+ local nImages = #imagePaths
+ local imagePath = torch.CharTensor(nImages, maxLength):zero()
+ for i, path in ipairs(imagePaths) do
+ ffi.copy(imagePath[i]:data(), path)
+ end
+
+ local imageClass = torch.LongTensor(imageClasses)
+ return imagePath, imageClass
+end
+
+function M.exec(opt, cacheFile)
+ -- find the image path names
+ local imagePath = torch.CharTensor() -- path to each image in dataset
+ local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes)
+
+ local trainDir = paths.concat(opt.data, 'train')
+ local valDir = paths.concat(opt.data, 'val')
+ assert(paths.dirp(trainDir), 'train directory not found: ' .. trainDir)
+ assert(paths.dirp(valDir), 'val directory not found: ' .. valDir)
+
+ print("=> Generating list of images")
+ local classList, classToIdx = findClasses(trainDir)
+
+ print(" | finding all validation images")
+ local valImagePath, valImageClass = findImages(valDir, classToIdx)
+
+ print(" | finding all training images")
+ local trainImagePath, trainImageClass = findImages(trainDir, classToIdx)
+
+ local info = {
+ basedir = opt.data,
+ classList = classList,
+ train = {
+ imagePath = trainImagePath,
+ imageClass = trainImageClass,
+ },
+ val = {
+ imagePath = valImagePath,
+ imageClass = valImageClass,
+ },
+ }
+
+ print(" | saving list of images to " .. cacheFile)
+ torch.save(cacheFile, info)
+ return info
+end
+
+return M
\ No newline at end of file
diff --git a/datasets/gtsrb.lua b/datasets/gtsrb.lua
new file mode 100644
index 0000000..9b8f1c8
--- /dev/null
+++ b/datasets/gtsrb.lua
@@ -0,0 +1,81 @@
+--
+-- GTSRB dataset loader
+--
+
+local image = require 'image'
+local paths = require 'paths'
+local t = require 'datasets/transforms'
+local ffi = require 'ffi'
+
+local M = {}
+local GTSRBDataset = torch.class('GTSRBDataset', M)
+
+function GTSRBDataset:__init(imageInfo, opt, split)
+ -- imageInfo: result from gtsrb-gen.lua
+ -- opt: command-line arguments
+ -- split: "train" or "val"
+ self.imageInfo = imageInfo[split]
+ self.opt = opt
+ self.split = split
+ self.dir = paths.concat(opt.data, split)
+ assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir)
+end
+
+function GTSRBDataset:get(i)
+ local path = ffi.string(self.imageInfo.imagePath[i]:data())
+
+ local image = self:_loadImage(paths.concat(self.dir, path))
+ local class = self.imageInfo.imageClass[i]
+
+ return {
+ input = image,
+ target = class,
+ }
+end
+
+function GTSRBDataset:_loadImage(path)
+ local ok, input = pcall(function()
+ return image.load(path, 3, 'float')
+ end)
+
+ -- Sometimes image.load fails because the file extension does not match the
+ -- image format. In that case, use image.decompress on a ByteTensor.
+ if not ok then
+ local f = io.open(path, 'r')
+ assert(f, 'Error reading: ' .. tostring(path))
+ local data = f:read('*a')
+ f:close()
+
+ local b = torch.ByteTensor(string.len(data))
+ ffi.copy(b:data(), data, b:size(1))
+
+ input = image.decompress(b, 3, 'float')
+ end
+
+ return input
+end
+
+function GTSRBDataset:size()
+ return self.imageInfo.imageClass:size(1)
+end
+
+-- Computed from GTSRB training images
+local meanstd = {
+ mean = { 0.341, 0.312, 0.321 },
+ std = { 0.275, 0.264, 0.270 },
+}
+
+function GTSRBDataset:preprocess()
+ local transforms = {t.Resize(48, 48)}
+ if self.opt.globalNorm then
+ table.insert(transforms, t.ColorNormalize(meanstd))
+ end
+ if self.opt.localNorm then
+ table.insert(transforms, t.LocalContrastNorm(7))
+ end
+
+ return t.Compose(transforms)
+
+end
+
+return M.GTSRBDataset
\ No newline at end of file
diff --git a/datasets/init.lua b/datasets/init.lua
new file mode 100644
index 0000000..8bdd9f2
--- /dev/null
+++ b/datasets/init.lua
@@ -0,0 +1,34 @@
+--
+-- Copyright (c) 2016, Facebook, Inc.
+-- All rights reserved.
+--
+-- This source code is licensed under the BSD-style license found in the
+-- LICENSE file in the root directory of this source tree. An additional grant
+-- of patent rights can be found in the PATENTS file in the same directory.
+--
+
+local M = {}
+
+local function isvalid(opt, cachePath)
+ local imageInfo = torch.load(cachePath)
+ if imageInfo.basedir and imageInfo.basedir ~= opt.data then
+ return false
+ end
+ return true
+end
+
+function M.create(opt, split)
+ local cachePath = paths.concat(opt.gen, opt.dataset .. '.t7')
+ if not paths.filep(cachePath) or not isvalid(opt, cachePath) then
+ paths.mkdir('gen')
+
+ local script = paths.dofile(opt.dataset .. '-gen.lua')
+ script.exec(opt, cachePath)
+ end
+ local imageInfo = torch.load(cachePath)
+
+ local Dataset = require('datasets/' .. opt.dataset)
+ return Dataset(imageInfo, opt, split)
+end
+
+return M
\ No newline at end of file
diff --git a/datasets/transforms.lua b/datasets/transforms.lua
new file mode 100644
index 0000000..ce1fad1
--- /dev/null
+++ b/datasets/transforms.lua
@@ -0,0 +1,318 @@
+--
+-- Copyright (c) 2016, Facebook, Inc.
+-- All rights reserved.
+--
+-- This source code is licensed under the BSD-style license found in the
+-- LICENSE file in the root directory of this source tree. An additional grant
+-- of patent rights can be found in the PATENTS file in the same directory.
+--
+-- Image transforms for data augmentation and input normalization
+--
+
+require 'image'
+require 'nn'
+torch.setdefaulttensortype('torch.FloatTensor')
+
+local M = {}
+
+function M.Compose(transforms)
+ return function(input)
+ for _, transform in ipairs(transforms) do
+ input = transform(input)
+ end
+ return input
+ end
+end
+
+function M.ColorNormalize(meanstd)
+ return function(img)
+ img = img:clone()
+ for i=1,3 do
+ img[i]:add(-meanstd.mean[i])
+ img[i]:div(meanstd.std[i])
+ end
+ return img
+ end
+end
+
+-- Scales the image to have width X height size
+function M.Resize(width, height, interpolation)
+ interpolation = interpolation or 'bicubic'
+ return function(input)
+ return image.scale(input, width, height, interpolation)
+ end
+end
+
+-- Scales the smaller edge to size
+function M.ScaleEdge(size, interpolation)
+ interpolation = interpolation or 'bicubic'
+ return function(input)
+ local w, h = input:size(3), input:size(2)
+ if (w <= h and w == size) or (h <= w and h == size) then
+ return input
+ end
+ if w < h then
+ return image.scale(input, size, h/w * size, interpolation)
+ else
+ return image.scale(input, w/h * size, size, interpolation)
+ end
+ end
+end
+
+-- Crop to centered rectangle
+function M.CenterCrop(size)
+ return function(input)
+ local w1 = math.ceil((input:size(3) - size)/2)
+ local h1 = math.ceil((input:size(2) - size)/2)
+ return image.crop(input, w1, h1, w1 + size, h1 + size) -- center patch
+ end
+end
+
+-- Random crop form larger image with optional zero padding
+function M.RandomCrop(size, padding)
+ padding = padding or 0
+
+ return function(input)
+ if padding > 0 then
+ local temp = input.new(3, input:size(2) + 2*padding, input:size(3) + 2*padding)
+ temp:zero()
+ :narrow(2, padding+1, input:size(2))
+ :narrow(3, padding+1, input:size(3))
+ :copy(input)
+ input = temp
+ end
+
+ local w, h = input:size(3), input:size(2)
+ if w == size and h == size then
+ return input
+ end
+
+ local x1, y1 = torch.random(0, w - size), torch.random(0, h - size)
+ local out = image.crop(input, x1, y1, x1 + size, y1 + size)
+ assert(out:size(2) == size and out:size(3) == size, 'wrong crop size')
+ return out
+ end
+end
+
+-- Four corner patches and center crop from image and its horizontal reflection
+function M.TenCrop(size)
+ local centerCrop = M.CenterCrop(size)
+
+ return function(input)
+ local w, h = input:size(3), input:size(2)
+
+ local output = {}
+ for _, img in ipairs{input, image.hflip(input)} do
+ table.insert(output, centerCrop(img))
+ table.insert(output, image.crop(img, 0, 0, size, size))
+ table.insert(output, image.crop(img, w-size, 0, w, size))
+ table.insert(output, image.crop(img, 0, h-size, size, h))
+ table.insert(output, image.crop(img, w-size, h-size, w, h))
+ end
+
+ -- View as mini-batch
+ for i, img in ipairs(output) do
+ output[i] = img:view(1, img:size(1), img:size(2), img:size(3))
+ end
+
+ return input.cat(output, 1)
+ end
+end
+
+-- Resized with shorter side randomly sampled from [minSize, maxSize] (ResNet-style)
+function M.RandomScale(minSize, maxSize)
+ return function(input)
+ local w, h = input:size(3), input:size(2)
+
+ local targetSz = torch.random(minSize, maxSize)
+ local targetW, targetH = targetSz, targetSz
+ if w < h then
+ targetH = torch.round(h / w * targetW)
+ else
+ targetW = torch.round(w / h * targetH)
+ end
+
+ return image.scale(input, targetW, targetH, 'bicubic')
+ end
+end
+
+-- Random crop with size 8%-100% and aspect ratio 3/4 - 4/3 (Inception-style)
+function M.RandomSizedCrop(size)
+ local scale = M.Scale(size)
+ local crop = M.CenterCrop(size)
+
+ return function(input)
+ local attempt = 0
+ repeat
+ local area = input:size(2) * input:size(3)
+ local targetArea = torch.uniform(0.08, 1.0) * area
+
+ local aspectRatio = torch.uniform(3/4, 4/3)
+ local w = torch.round(math.sqrt(targetArea * aspectRatio))
+ local h = torch.round(math.sqrt(targetArea / aspectRatio))
+
+ if torch.uniform() < 0.5 then
+ w, h = h, w
+ end
+
+ if h <= input:size(2) and w <= input:size(3) then
+ local y1 = torch.random(0, input:size(2) - h)
+ local x1 = torch.random(0, input:size(3) - w)
+
+ local out = image.crop(input, x1, y1, x1 + w, y1 + h)
+ assert(out:size(2) == h and out:size(3) == w, 'wrong crop size')
+
+ return image.scale(out, size, size, 'bicubic')
+ end
+ attempt = attempt + 1
+ until attempt >= 10
+
+ -- fallback
+ return crop(scale(input))
+ end
+end
+
+function M.HorizontalFlip(prob)
+ return function(input)
+ if torch.uniform() < prob then
+ input = image.hflip(input)
+ end
+ return input
+ end
+end
+
+function M.Rotation(deg)
+ return function(input)
+ if deg ~= 0 then
+ input = image.rotate(input, (torch.uniform() - 0.5) * deg * math.pi / 180, 'bilinear')
+ end
+ return input
+ end
+end
+
+function M.Translate(tx, ty)
+ return function(input)
+ input = image.translate(input, (torch.uniform() - 0.5) * tx, (torch.uniform() - 0.5) * ty)
+ return input
+ end
+end
+
+-- Lighting noise (AlexNet-style PCA-based noise)
+function M.Lighting(alphastd, eigval, eigvec)
+ return function(input)
+ if alphastd == 0 then
+ return input
+ end
+
+ local alpha = torch.Tensor(3):normal(0, alphastd)
+ local rgb = eigvec:clone()
+ :cmul(alpha:view(1, 3):expand(3, 3))
+ :cmul(eigval:view(1, 3):expand(3, 3))
+ :sum(2)
+ :squeeze()
+
+ input = input:clone()
+ for i=1,3 do
+ input[i]:add(rgb[i])
+ end
+ return input
+ end
+end
+
+local function blend(img1, img2, alpha)
+ return img1:mul(alpha):add(1 - alpha, img2)
+end
+
+local function grayscale(dst, img)
+ dst:resizeAs(img)
+ dst[1]:zero()
+ dst[1]:add(0.299, img[1]):add(0.587, img[2]):add(0.114, img[3])
+ dst[2]:copy(dst[1])
+ dst[3]:copy(dst[1])
+ return dst
+end
+
+function M.Saturation(var)
+ local gs
+
+ return function(input)
+ gs = gs or input.new()
+ grayscale(gs, input)
+
+ local alpha = 1.0 + torch.uniform(-var, var)
+ blend(input, gs, alpha)
+ return input
+ end
+end
+
+function M.Brightness(var)
+ local gs
+
+ return function(input)
+ gs = gs or input.new()
+ gs:resizeAs(input):zero()
+
+ local alpha = 1.0 + torch.uniform(-var, var)
+ blend(input, gs, alpha)
+ return input
+ end
+end
+
+function M.Contrast(var)
+ local gs
+
+ return function(input)
+ gs = gs or input.new()
+ grayscale(gs, input)
+ gs:fill(gs[1]:mean())
+
+ local alpha = 1.0 + torch.uniform(-var, var)
+ blend(input, gs, alpha)
+ return input
+ end
+end
+
+function M.RandomOrder(ts)
+ return function(input)
+ local img = input.img or input
+ local order = torch.randperm(#ts)
+ for i=1,#ts do
+ img = ts[order[i]](img)
+ end
+ return img
+ end
+end
+
+function M.ColorJitter(opt)
+ local brightness = opt.brightness or 0
+ local contrast = opt.contrast or 0
+ local saturation = opt.saturation or 0
+
+ local ts = {}
+ if brightness ~= 0 then
+ table.insert(ts, M.Brightness(brightness))
+ end
+ if contrast ~= 0 then
+ table.insert(ts, M.Contrast(contrast))
+ end
+ if saturation ~= 0 then
+ table.insert(ts, M.Saturation(saturation))
+ end
+
+ if #ts == 0 then
+ return function(input) return input end
+ end
+
+ return M.RandomOrder(ts)
+end
+
+function M.LocalContrastNorm(gaussian_kernel_size)
+ return function(input)
+ local kernel = image.gaussian(gaussian_kernel_size)
+ local normalizer = nn.SpatialContrastiveNormalization(input:size(1), kernel)
+ input:copy(normalizer:forward(input))
+ return input
+ end
+end
+
+return M
\ No newline at end of file
diff --git a/download_gtsrb_dataset.lua b/download_gtsrb_dataset.lua
new file mode 100755
index 0000000..fbe68b9
--- /dev/null
+++ b/download_gtsrb_dataset.lua
@@ -0,0 +1,63 @@
+local pl = (require 'pl.import_into')()
+
+local dataset = {}
+
+dataset.path_remote_train = "http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Training_Images.zip"
+dataset.path_remote_test = "http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Test_Images.zip"
+dataset.path_remote_test_gt = "http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Test_GT.zip"
+
+dataset.train_folder = "train"
+dataset.val_folder = "val"
+
+function dataset.download_dataset()
+ if not pl.path.isdir('GTSRB') then
+ local zip_train = paths.basename(dataset.path_remote_train)
+ local zip_test = paths.basename(dataset.path_remote_test)
+ local zip_test_gt = paths.basename(dataset.path_remote_test_gt)
+
+ print('Downloading dataset...')
+ os.execute('wget ' .. dataset.path_remote_train .. '; ' ..
+ 'unzip ' .. zip_train .. '; '..
+ 'rm ' .. zip_train .. '; ' ..
+ 'mv GTSRB/Final_Training/Images/ GTSRB/train;' ..
+ 'rm -r GTSRB/Final_Training')
+ os.execute('wget ' .. dataset.path_remote_test .. '; ' ..
+ 'unzip ' .. zip_test .. '; '..
+ 'rm ' .. zip_test .. '; ' ..
+ 'mkdir GTSRB/val; ' ..
+ [[find GTSRB/Final_Test/Images/ -maxdepth 1 -name '*.ppm' -exec sh -c 'mv "$@" "$0"' GTSRB/val/ {} +;]] ..
+ 'rm -r GTSRB/Final_Test')
+ os.execute('wget ' .. dataset.path_remote_test_gt .. '; ' ..
+ 'unzip ' .. zip_test_gt .. '; '..
+ 'rm ' .. zip_test_gt .. '; '..
+ 'mv GT-final_test.csv GTSRB/GT-final_test.csv')
+ end
+end
+
+function dataset.move_val_images()
+ print("Moving validation images to class folders")
+ local val_dir = pl.path.join("GTSRB", dataset.val_folder)
+ local csv_file_path = pl.path.join("GTSRB", "GT-final_test.csv")
+ local csv_data = pl.data.read(csv_file_path)
+ local filename_index = csv_data.fieldnames:index("Filename")
+ local class_id_index = csv_data.fieldnames:index("ClassId")
+
+ for _, image_metadata in ipairs(csv_data) do
+ local image_name = image_metadata[filename_index]
+ local image_path = pl.path.join(val_dir, image_name)
+ local image_label = image_metadata[class_id_index]
+ local class_folder_name = string.format("%05d", image_label)
+ local class_folder_path = pl.path.join(val_dir, class_folder_name)
+ if not pl.path.exists(class_folder_path) then
+ pl.path.mkdir(class_folder_path)
+ end
+ pl.file.move(image_path, pl.path.join(class_folder_path, image_name))
+ end
+end
+
+
+dataset.download_dataset()
+dataset.move_val_images()
+
+
+
diff --git a/main.lua b/main.lua
new file mode 100644
index 0000000..0bfae1a
--- /dev/null
+++ b/main.lua
@@ -0,0 +1,110 @@
+--
+-- Copyright (c) 2016, Facebook, Inc.
+-- All rights reserved.
+--
+-- This source code is licensed under the BSD-style license found in the
+-- LICENSE file in the root directory of this source tree. An additional grant
+-- of patent rights can be found in the PATENTS file in the same directory.
+--
+
+require 'torch'
+require 'paths'
+require 'optim'
+require 'nn'
+
+local DataLoader = require 'dataloader'
+local models = require 'models/init'
+local Trainer = require 'trainer'
+local opts = require 'opts'
+local checkpoints = require 'checkpoints'
+
+-- we don't change this to the 'correct' type (e.g. HalfTensor), because math
+-- isn't supported on that type. Type conversion later will handle having
+-- the correct type.
+
+local opt = opts.parse(arg)
+torch.setdefaulttensortype('torch.FloatTensor')
+torch.setnumthreads(opt.nThreads)
+print(opt.manualSeed)
+
+torch.manualSeed(opt.manualSeed)
+cutorch.manualSeedAll(opt.manualSeed)
+math.randomseed(opt.manualSeed)
+
+-- Create unique checkpoint dir
+dir_name = 'net-' .. opt.netType .. '__cnn-' .. opt.cnn ..
+ '__locnet1-' .. opt.locnet1 .. '__locnet2-' .. opt.locnet2 .. '__locnet3-' .. opt.locnet3 ..
+ '__optimizer-'.. opt.optimizer .. '__weightinit-' .. opt.weightInit ..
+ os.date("__%Y_%m_%d_%X")
+opt.save = paths.concat(opt.save, dir_name)
+if paths.dir(opt.save) == nil then
+ paths.mkdir(opt.save)
+end
+
+-- Load previous checkpoint, if it exists
+local checkpoint, optimState = checkpoints.latest(opt)
+
+-- Create model
+local model, criterion = models.setup(opt, checkpoint)
+
+print("-- Model architecture --")
+print(model)
+
+-- Data loading
+local trainLoader, valLoader = DataLoader.create(opt)
+
+-- The trainer handles the training loop and evaluation on validation set
+local trainer = Trainer(model, criterion, opt, optimState)
+
+-- Logger
+local logger = optim.Logger(paths.concat(opt.save, 'history.log'))
+logger:setNames{'Train Loss', 'Train LossAbs','Train Acc',
+ 'Test Loss', 'Test LossAbs', 'Test Acc' }
+logger:style{'+-','+-','+-','+-','+-','+-'}
+logger:display(false)
+
+if opt.testOnly then
+ local top1Err, top5Err = trainer:test(0, valLoader)
+ print(string.format(' * Results top1: %6.3f top5: %6.3f', top1Err, top5Err))
+ return
+end
+
+local startEpoch = opt.epochNumber --checkpoint and checkpoint.epoch + 1 or opt.epochNumber
+local bestTop1 = 0
+local bestTop5 = 0
+local bestLoss = math.huge
+local bestLossAbs = math.huge
+local bestEpoch = math.huge
+for epoch = startEpoch, opt.nEpochs do
+ -- Train for a single epoch
+ local trainTop1, trainTop5, trainLoss, trainLossAbs = trainer:train(epoch, trainLoader)
+
+ -- Run model on validation set
+ local testTop1, testTop5, testLoss, testLossAbs = trainer:test(epoch, valLoader)
+
+ -- Update logger
+ logger:add{trainLoss, trainLossAbs, trainTop1,
+ testLoss, testLossAbs, testTop1 }
+-- logger:plot()
+
+ local bestModel = false
+ if testTop1 > bestTop1 then
+ bestModel = true
+ bestTop1 = testTop1
+ bestTop5 = testTop5
+ bestLoss = testLoss
+ bestLossAbs = testLossAbs
+ bestEpoch = epoch
+ if opt.showFullOutput then
+ print(string.format(' * Best Model -- epoch:%i top1: %6.3f top5: %6.3f loss: %6.3f, lossabs: %6.3f',
+ bestEpoch, bestTop1, bestTop5, bestLoss, bestLossAbs))
+ end
+ end
+
+ checkpoints.save(epoch, model, trainer.optimState, bestModel, opt)
+end
+
+--logger:plot()
+
+print(string.format(' * Finished Best Model -- epoch:%i top1: %6.3f top5: %6.3f loss: %6.3f, lossabs: %6.3f',
+ bestEpoch, bestTop1, bestTop5, bestLoss, bestLossAbs))
\ No newline at end of file
diff --git a/models/cnn3st.lua b/models/cnn3st.lua
new file mode 100644
index 0000000..a44d007
--- /dev/null
+++ b/models/cnn3st.lua
@@ -0,0 +1,276 @@
+local torch = require 'torch'
+local nn = require 'nn'
+require 'cunn'
+local cudnn = require 'cudnn'
+require 'stn'
+local image = require 'image'
+
+local layers = {}
+
+-- NVIDIA CuDNN
+layers.convolution = cudnn.SpatialConvolution
+layers.maxPooling = cudnn.SpatialMaxPooling
+layers.nonLinearity = cudnn.ReLU
+layers.batchNorm = cudnn.SpatialBatchNormalization
+
+-- CPU
+--layers.convolution = nn.SpatialConvolution
+--layers.maxPooling = nn.SpatialMaxPooling
+--layers.nonLinearity = nn.ReLU
+--layers.batchNorm = nn.SpatialBatchNormalization
+
+-- Returns the number of output elements for a table of convolution layers and the new height and width of the image
+local function convs_noutput(convs, input_size)
+ input_size = input_size or baseInputSize
+ -- Get the number of channels for conv that are multiscale or not
+ local nbr_input_channels = convs[1]:get(1).nInputPlane or
+ convs[1]:get(1):get(1).nInputPlane
+ local output = torch.CudaTensor(1, nbr_input_channels, input_size, input_size)
+ for _, conv in ipairs(convs) do
+ conv:cuda()
+ output = conv:forward(output)
+ end
+ return output:nElement(), output:size(3), output:size(4)
+end
+
+-- Creates a conv module with the specified number of channels in input and output
+-- If multiscale is true, the total number of output channels will be:
+-- nbr_input_channels + nbr_output_channels
+-- Using cnorm adds the spatial contrastive normalization module
+-- The filter size for the convolution can be specified (default 5)
+-- The stride of the convolutions is fixed at 1
+local function new_conv(nbr_input_channels,nbr_output_channels, multiscale, cnorm, filter_size)
+ multiscale = multiscale or false
+ cnorm = cnorm or false
+ filter_size = filter_size or 5
+ local padding_size = 2
+ local pooling_size = 2
+ local norm_kernel = image.gaussian(7)
+
+ local conv
+
+ local first = nn.Sequential()
+ first:add(layers.convolution(nbr_input_channels,
+ nbr_output_channels,
+ filter_size, filter_size,
+ 1,1,
+ padding_size, padding_size))
+ first:add(layers.nonLinearity(true))
+ first:add(layers.maxPooling(pooling_size, pooling_size,
+ pooling_size, pooling_size))
+
+ if cnorm then
+ first:add(nn.SpatialContrastiveNormalization(nbr_output_channels, norm_kernel))
+ end
+
+ if multiscale then
+ conv = nn.Sequential()
+ local second = layers.maxPooling(pooling_size, pooling_size,
+ pooling_size, pooling_size)
+
+ local parallel = nn.ConcatTable()
+ parallel:add(first)
+ parallel:add(second)
+ conv:add(parallel)
+ conv:add(nn.JoinTable(1,3))
+ else
+ conv = first
+ end
+
+ return conv
+end
+
+-- Creates a fully connection layer with the specified size.
+local function new_fc(nbr_input, nbr_output)
+ local fc = nn.Sequential()
+ fc:add(nn.View(nbr_input))
+ fc:add(nn.Linear(nbr_input, nbr_output))
+ fc:add(layers.nonLinearity(true))
+ return fc
+end
+
+-- Creates a classifier with the specified size.
+local function new_classifier(nbr_input, nbr_output)
+ local classifier = nn.Sequential()
+ classifier:add(nn.View(nbr_input))
+ classifier:add(nn.Linear(nbr_input, nbr_output))
+ return classifier
+end
+
+-- Creates a spatial transformer module
+-- locnet are the parameters to create the localization network
+-- rot, sca, tra can be used to force specific transformations
+-- input_size is the height (=width) of the input
+-- input_channels is the number of channels in the input
+-- no_cuda due to (1) below, we need to know if the network will run on cuda
+local function new_spatial_tranformer(locnet, rot, sca, tra, input_size, input_channels)
+ local nbr_elements = {}
+ for c in string.gmatch(locnet, "%d+") do
+ nbr_elements[#nbr_elements + 1] = tonumber(c)
+ end
+
+ -- Get number of params and initial state
+ local init_bias = {}
+ local nbr_params = 0
+ if rot then
+ nbr_params = nbr_params + 1
+ init_bias[nbr_params] = 0
+ end
+ if sca then
+ nbr_params = nbr_params + 1
+ init_bias[nbr_params] = 1
+ end
+ if tra then
+ nbr_params = nbr_params + 2
+ init_bias[nbr_params-1] = 0
+ init_bias[nbr_params] = 0
+ end
+ if nbr_params == 0 then
+ -- fully parametrized case
+ nbr_params = 6
+ init_bias = {1,0,0,0,1,0}
+ end
+
+ local st = nn.Sequential()
+
+ -- Create a localization network same as cnn but with downsampled inputs
+ local localization_network = nn.Sequential()
+ local conv1 = new_conv(input_channels, nbr_elements[1], false, false)
+ local conv2 = new_conv(nbr_elements[1], nbr_elements[2], false, false)
+ local conv_output_size = convs_noutput({conv1, conv2}, input_size/2)
+ local fc = new_fc(conv_output_size, nbr_elements[3])
+ local classifier = new_classifier(nbr_elements[3], nbr_params)
+ -- Initialize the localization network (see paper, A.3 section)
+ classifier:get(2).weight:zero()
+ classifier:get(2).bias = torch.Tensor(init_bias)
+
+ localization_network:add(layers.maxPooling(2,2,2,2))
+ localization_network:add(conv1)
+ localization_network:add(conv2)
+ localization_network:add(fc)
+ localization_network:add(classifier)
+
+ -- Create the actual module structure
+ local ct = nn.ConcatTable()
+
+ local branch1 = nn.Sequential()
+ branch1:add(nn.Transpose({3,4},{2,4}))
+ branch1:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true))
+
+ local branch2 = nn.Sequential()
+ branch2:add(localization_network)
+ branch2:add(nn.AffineTransformMatrixGenerator(rot, sca, tra))
+ branch2:add(nn.AffineGridGeneratorBHWD(input_size, input_size))
+ branch2:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true))
+
+ ct:add(branch1)
+ ct:add(branch2)
+
+ st:add(ct)
+ local sampler = nn.BilinearSamplerBHWD()
+ -- (1)
+ -- The sampler lead to non-reproducible results on GPU
+ -- We want to always keep it on CPU
+ -- This does no lead to slowdown of the training
+ sampler:type('torch.FloatTensor')
+ -- make sure it will not go back to the GPU when we call
+ -- ":cuda()" on the network later
+ sampler.type = function(type)
+ return self
+ end
+ st:add(sampler)
+ st:add(nn.Copy('torch.FloatTensor','torch.CudaTensor', true, true))
+ st:add(nn.Transpose({2,4},{3,4}))
+
+ return st
+end
+
+local function createModel(opt)
+ nInputChannels = 3
+ baseInputSize = opt.baseInputSize or 48
+ local cnorm = opt.cNormConv or false
+ local nbr_elements = {}
+ for c in string.gmatch(opt.cnn, "%d+") do
+ nbr_elements[#nbr_elements + 1] = tonumber(c)
+ end
+ assert(#nbr_elements == 4,
+ 'opt.cnn should contain 4 comma separated values, got '..#nbr_elements)
+
+ local conv1 = new_conv(nInputChannels, nbr_elements[1], false, cnorm, 7)
+ local conv2 = new_conv(nbr_elements[1], nbr_elements[2], false, cnorm, 4)
+ local conv3 = new_conv(nbr_elements[2], nbr_elements[3], false, cnorm, 4)
+
+ local convOutputSize, _, _ = convs_noutput({conv1, conv2, conv3})
+
+ local fc = new_fc(convOutputSize, nbr_elements[4])
+ local fc_class = new_classifier(nbr_elements[4], opt.nClasses)
+
+ local features = nn.Sequential()
+
+ if opt.locnet1 and opt.locnet1 ~= '' then
+ local st1 = new_spatial_tranformer(
+ opt.locnet1, -- locnet
+ false, false, false, -- rot, sca, tra
+ baseInputSize, -- input_size
+ nInputChannels -- input_channels
+ )
+ features:add(st1)
+ end
+
+ features:add(conv1)
+
+ if opt.locnet2 and opt.locnet2 ~= '' then
+ local _, currentInputSize, _ = convs_noutput({conv1})
+ local st2 = new_spatial_tranformer(
+ opt.locnet2, -- locnet
+ false, false, false, -- rot, sca, tra
+ currentInputSize, -- input_size
+ nbr_elements[1] -- input_channels
+ )
+ features:add(st2)
+ end
+
+ features:add(conv2)
+
+ if opt.locnet3 and opt.locnet3 ~= '' then
+ local _, currentInputSize, _ = convs_noutput({conv1,conv2})
+
+ local st3 = new_spatial_tranformer(
+ opt.locnet3, -- locnet
+ false, false, false, -- rot, sca, tra
+ currentInputSize, -- input_size
+ nbr_elements[2] -- input_channels
+ )
+ features:add(st3)
+ end
+
+ features:add(conv3)
+
+ local classifier = nn.Sequential()
+ classifier:add(fc)
+ classifier:add(fc_class)
+
+ local model = nn.Sequential():add(features):add(classifier)
+
+ return model
+end
+
+return createModel
+
+-- Test paramareters
+--opt = {}
+--opt.nClasses = 43
+--opt.cnn = '200,250,350,400'
+--opt.locnet1= '250,250,250'
+--opt.locnet2= '150,200,300'
+--opt.locnet3= '150,200,300'
+--opt.globalNorm = 'false'
+--opt.localNorm = 'false'
+--model = createModel(opt)
+--model:cuda()
+--print(model)
+--
+--parameters, gradParameters = model:getParameters()
+--print(parameters:size())
+--print(gradParameters:size())
+
diff --git a/models/init.lua b/models/init.lua
new file mode 100644
index 0000000..2e5eb21
--- /dev/null
+++ b/models/init.lua
@@ -0,0 +1,160 @@
+--
+-- Copyright (c) 2016, Facebook, Inc.
+-- All rights reserved.
+--
+-- This source code is licensed under the BSD-style license found in the
+-- LICENSE file in the root directory of this source tree. An additional grant
+-- of patent rights can be found in the PATENTS file in the same directory.
+--
+-- Generic model creating code. For the specific CNN3ST model see
+-- models/cnn3st.lua
+--
+
+require 'nn'
+require 'cunn'
+require 'cudnn'
+require 'stn'
+local nninit = require 'nninit'
+
+local M = {}
+
+function M.setup(opt, checkpoint)
+ local model
+ if checkpoint then
+ local modelPath = paths.concat(opt.resume, checkpoint.modelFile)
+ assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath)
+ print('=> Resuming model from ' .. modelPath)
+ model = torch.load(modelPath):type(opt.tensorType)
+ model.__memoryOptimized = nil
+ elseif opt.retrain ~= 'none' then
+ assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
+ print('Loading model from file: ' .. opt.retrain)
+ model = torch.load(opt.retrain):type(opt.tensorType)
+ model.__memoryOptimized = nil
+ else
+ print('=> Creating model from file: models/' .. opt.netType .. '.lua')
+ model = require('models/' .. opt.netType)(opt)
+ end
+
+ -- First remove any DataParallelTable
+ if torch.type(model) == 'nn.DataParallelTable' then
+ model = model:get(1)
+ end
+
+ -- optnet is an general library for reducing memory usage in neural networks
+ if opt.optnet then
+ local optnet = require 'optnet'
+ local imsize = opt.dataset == 'imagenet' and 224 or 32
+ local sampleInput = torch.zeros(4,3,imsize,imsize):type(opt.tensorType)
+ optnet.optimizeMemory(model, sampleInput, {inplace = false, mode = 'training'})
+ end
+
+ -- This is useful for fitting ResNet-50 on 4 GPUs, but requires that all
+ -- containers override backwards to call backwards recursively on submodules
+ if opt.shareGradInput then
+ M.shareGradInput(model, opt)
+ end
+
+ -- For resetting the classifier when fine-tuning on a different Dataset
+ if opt.resetClassifier and not checkpoint then
+ print(' => Replacing classifier with ' .. opt.nClasses .. '-way classifier')
+
+ local orig = model:get(#model.modules)
+ assert(torch.type(orig) == 'nn.Linear',
+ 'expected last layer to be fully connected')
+
+ local linear = nn.Linear(orig.weight:size(2), opt.nClasses)
+ linear.bias:zero()
+
+ model:remove(#model.modules)
+ model:add(linear:type(opt.tensorType))
+ end
+
+ -- Set the CUDNN flags
+ if opt.cudnn == 'fastest' then
+ cudnn.fastest = true
+ cudnn.benchmark = true
+ elseif opt.cudnn == 'deterministic' then
+ -- Use a deterministic convolution implementation
+ model:apply(function(m)
+ if m.setMode then m:setMode(1, 1, 1) end
+ end)
+ end
+
+ -- Wrap the model with DataParallelTable, if using more than one GPU
+ if opt.nGPU > 1 then
+ local gpus = torch.range(1, opt.nGPU):totable()
+ local fastest, benchmark = cudnn.fastest, cudnn.benchmark
+
+ local dpt = nn.DataParallelTable(1, true, true)
+ :add(model, gpus)
+ :threads(function()
+ local cudnn = require 'cudnn'
+ require 'stn'
+ cudnn.fastest, cudnn.benchmark = fastest, benchmark
+ end)
+ dpt.gradInput = nil
+
+ model = dpt:type(opt.tensorType)
+ end
+
+ -- Set model weights initialization
+ if opt.weightInit ~= 'default' then
+ -- Init weights of convolutional modules
+ initModelWeights(model, opt.weightInit)
+ end
+
+ local criterion = nn.CrossEntropyCriterion():type(opt.tensorType)
+ model:cuda()
+ criterion:cuda()
+ return model, criterion
+end
+
+function M.shareGradInput(model, opt)
+ local function sharingKey(m)
+ local key = torch.type(m)
+ if m.__shareGradInputKey then
+ key = key .. ':' .. m.__shareGradInputKey
+ end
+ return key
+ end
+
+ -- Share gradInput for memory efficient backprop
+ local cache = {}
+ model:apply(function(m)
+ local moduleType = torch.type(m)
+ if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' then
+ local key = sharingKey(m)
+ if cache[key] == nil then
+ cache[key] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1)
+ end
+ m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[key], 1, 0)
+ end
+ end)
+ for i, m in ipairs(model:findModules('nn.ConcatTable')) do
+ if cache[i % 2] == nil then
+ cache[i % 2] = torch[opt.tensorType:match('torch.(%a+)'):gsub('Tensor','Storage')](1)
+ end
+ m.gradInput = torch[opt.tensorType:match('torch.(%a+)')](cache[i % 2], 1, 0)
+ end
+end
+
+function initModelWeights(model, initializer)
+ local convModulesNames = {'cudnn.SpatialConvolution', 'nn.SpatialConvolution'}
+ for _,name in ipairs(convModulesNames) do
+ for _,v in pairs(model:findModules(name)) do
+ if initializer == 'kaiming' then
+ v:init('weight', nninit.kaiming, {gain = {'relu'}})
+ v:noBias()
+ elseif initializer == 'glorot' then
+ v:init('weight', nninit.xavier, {gain = {'relu'}})
+ elseif initializer == 'uniform' then
+ v:init('weight', nninit.uniform, -0.5, 0.5)
+ elseif initializer == 'conv_aware' then
+ v:init('weight', nninit.convolutionAware, {gain = {'relu'}})
+ end
+ end
+ end
+end
+
+return M
\ No newline at end of file
diff --git a/opts.lua b/opts.lua
new file mode 100644
index 0000000..81d0b86
--- /dev/null
+++ b/opts.lua
@@ -0,0 +1,100 @@
+--
+-- Copyright (c) 2016, Facebook, Inc.
+-- All rights reserved.
+--
+-- This source code is licensed under the BSD-style license found in the
+-- LICENSE file in the root directory of this source tree. An additional grant
+-- of patent rights can be found in the PATENTS file in the same directory.
+--
+local M = { }
+
+function M.parse(arg)
+ local cmd = torch.CmdLine()
+ cmd:text()
+ cmd:text('Torch-7 Training script')
+ cmd:text()
+ cmd:text('Options:')
+ ------------ General options --------------------
+ cmd:option('-data', '', 'Path to dataset')
+ cmd:option('-computeMeanStd', false, 'Compute mean and std')
+ cmd:option('-dataset', 'gtsrb', 'Options: gtsrb')
+ cmd:option('-manualSeed', 1, 'Manually set RNG seed')
+ cmd:option('-nGPU', 1, 'Number of GPUs to use by default')
+ cmd:option('-backend', 'cudnn', 'Options: cudnn | cunn')
+ cmd:option('-cudnn', 'default', 'Options: fastest | default | deterministic')
+ cmd:option('-gen', 'gen', 'Path to save generated files')
+ cmd:option('-precision', 'single', 'Options: single | double | half')
+ cmd:option('-showFullOutput', false, 'Whether show full training process (true) or just final output (false)' )
+ ------------- Data options ------------------------
+ cmd:option('-nThreads', 1, 'number of data loading threads')
+ ------------- Training options --------------------
+ cmd:option('-nEpochs', 0, 'Number of total epochs to run')
+ cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)')
+ cmd:option('-batchSize', 32, 'mini-batch size (1 = pure stochastic)')
+ cmd:option('-testOnly', false, 'Run on validation set only')
+ cmd:option('-tenCrop', false, 'Ten-crop testing')
+ ------------- Checkpointing options ---------------
+ cmd:option('-checkpoint', false, 'Save model after each epoch')
+ cmd:option('-save', 'checkpoints', 'Directory in which to save checkpoints')
+ cmd:option('-resume', 'none', 'Resume from the latest checkpoint in this directory')
+ ---------- Optimization options ----------------------
+ cmd:option('-optimizer', 'sgd', 'Options: sgd | adam | rmsprop | adagrad | lbfgs | nag')
+ cmd:option('-LR', 0.01, 'initial learning rate')
+ cmd:option('-momentum', 0.9, 'momentum')
+ cmd:option('-weightDecay', 1e-4, 'weight decay')
+ cmd:option('-nesterov', false , 'Nesterov')
+ cmd:option('-LRDecayStep', 10, 'number of steps to decay LR by 0.1')
+ ---------- Model options ----------------------------------
+ cmd:option('-netType', 'cnn3st', 'Options: cnn3st')
+ cmd:option('-retrain', 'none', 'Path to model to retrain with')
+ cmd:option('-optimState', 'none', 'Path to an optimState to reload from')
+ cmd:option('-weightInit', 'default', 'Options: default | kaiming | glorot | uniform | conv_aware')
+ ---------- Model paper_conv3_st3 options ----------------------------------
+ cmd:option('-cnn', '200,250,350,400', 'Network parameters (conv1_out, conv2_out, conv3_out, fc1_out)')
+ cmd:option('-locnet1', '250,250,250', 'Localization network 1 parameters')
+ cmd:option('-locnet2', '150,200,300', 'Localization network 2 parameters')
+ cmd:option('-locnet3', '150,200,300', 'Localization network 3 parameters')
+ cmd:option('-globalNorm', false, 'Whether perform global normalization')
+ cmd:option('-localNorm', false, 'Whether perform local normalization')
+ cmd:option('-cNormConv', false, 'Whether perform contrastive normalization in conv modules')
+ cmd:option('-dataAug', false, 'Whether perform data augmentation on training dataset')
+ ---------- Model options ----------------------------------
+ cmd:option('-shareGradInput', false, 'Share gradInput tensors to reduce memory usage')
+ cmd:option('-optnet', false, 'Use optnet to reduce memory usage')
+ cmd:option('-resetClassifier', false, 'Reset the fully connected layer for fine-tuning')
+ cmd:option('-nClasses', 0, 'Number of classes in the dataset')
+ cmd:option('-baseInputSize', 48, 'Size of input images')
+ cmd:text()
+
+ local opt = cmd:parse(arg or {})
+
+ if not paths.dirp(opt.save) and not paths.mkdir(opt.save) then
+ cmd:error('error: unable to create checkpoint directory: ' .. opt.save .. '\n')
+ end
+
+ if opt.precision == nil or opt.precision == 'single' then
+ opt.tensorType = 'torch.CudaTensor'
+ elseif opt.precision == 'double' then
+ opt.tensorType = 'torch.CudaDoubleTensor'
+ elseif opt.precision == 'half' then
+ opt.tensorType = 'torch.CudaHalfTensor'
+ else
+ cmd:error('unknown precision: ' .. opt.precision)
+ end
+
+ if opt.resetClassifier then
+ if opt.nClasses == 0 then
+ cmd:error('-nClasses required when resetClassifier is set')
+ end
+ end
+ if opt.shareGradInput and opt.optnet then
+ cmd:error('error: cannot use both -shareGradInput and -optnet')
+ end
+
+ print('--- Options ---')
+ print(opt)
+
+ return opt
+end
+
+return M
\ No newline at end of file
diff --git a/run_model.lua b/run_model.lua
new file mode 100644
index 0000000..b5f25be
--- /dev/null
+++ b/run_model.lua
@@ -0,0 +1,46 @@
+require 'torch'
+require 'image'
+require 'cunn'
+require 'cudnn'
+require 'stn'
+require 'nn'
+
+print("Loading network...")
+local model_path = "pretrained/gtsrb_cnn3st_model.t7"
+local mean_std_path = "pretrained/gtsrb_cnn3st_mean_std.t7"
+local network = torch.load(model_path)
+local mean_std = torch.load(mean_std_path)
+print("--- Network ---")
+print(network)
+print("--- Mean/Std ---")
+local mean, std = mean_std[1], mean_std[2]
+print("Mean:"..mean, "Std:"..std)
+
+
+print("Loading sample images...")
+local sample_img1 = image.load("sample_images/img1.jpg")
+sample_img1 = image.scale(sample_img1, 48, 48)
+local sample_img2 = image.load("sample_images/img2.jpg")
+sample_img2 = image.scale(sample_img2, 48, 48)
+local samples_tensor = torch.Tensor(2,sample_img1:size(1), sample_img1:size(2), sample_img1:size(3)):fill(0)
+samples_tensor[1]:copy(sample_img1)
+samples_tensor[2]:copy(sample_img2)
+
+print("Applying global normalization to sample image")
+samples_tensor:add(-mean)
+samples_tensor:div(std)
+
+print("Applying local contrast normalization to sample image")
+local norm_kernel = image.gaussian1D(7)
+local local_normalizer = nn.SpatialContrastiveNormalization(3, norm_kernel)
+samples_tensor:copy(local_normalizer:forward(samples_tensor))
+
+print("Running the network...")
+samples_tensor = samples_tensor:cuda()
+local scores = network:forward(samples_tensor)
+print("Scores...")
+print(scores)
+local _, prediction1 = scores[1]:max(1)
+local _, prediction2 = scores[2]:max(1)
+print("Prediction class sample img 1: " .. prediction1[1] - 1)
+print("Prediction class sample img 2: " .. prediction2[1] - 1)
\ No newline at end of file
diff --git a/trainer.lua b/trainer.lua
new file mode 100644
index 0000000..2c4508a
--- /dev/null
+++ b/trainer.lua
@@ -0,0 +1,245 @@
+--
+-- Copyright (c) 2016, Facebook, Inc.
+-- All rights reserved.
+--
+-- This source code is licensed under the BSD-style license found in the
+-- LICENSE file in the root directory of this source tree. An additional grant
+-- of patent rights can be found in the PATENTS file in the same directory.
+--
+-- The training loop and learning rate schedule
+--
+
+local optim = require 'optim'
+
+local M = {}
+local Trainer = torch.class('Trainer', M)
+
+function Trainer:__init(model, criterion, opt, optimState)
+ self.model = model
+ self.criterion = criterion
+ self.optimState = optimState or {
+ learningRate = opt.LR,
+ learningRateDecay = 0.0,
+ momentum = opt.momentum,
+ nesterov = opt.nesterov,
+ dampening = 0.0,
+ weightDecay = opt.weightDecay,
+ }
+ self.opt = opt
+ self.params, self.gradParams = model:getParameters()
+ print("-- Model parameters --")
+ print(self.params:size(1))
+end
+
+function Trainer:train(epoch, dataloader)
+ -- Trains the model for a single epoch
+ self.optimState.learningRate = self:learningRate(epoch)
+
+ local totalTime = 0
+ local timer = torch.Timer()
+ local dataTimer = torch.Timer()
+
+ local function feval()
+ return self.criterion.output, self.gradParams
+ end
+
+ local absCriterion = nn.AbsCriterion():cuda()
+
+ local trainSize = dataloader:size()
+ local top1Sum, top5Sum, lossSum, lossAbsSum = 0.0, 0.0, 0.0, 0.0
+ local N = 0
+
+ print('=> Training epoch # ' .. epoch)
+ -- set the batch norm to training mode
+ self.model:training()
+ for n, sample in dataloader:run() do
+ local dataTime = dataTimer:time().real
+
+ -- Copy input and target to the GPU
+ self:copyInputs(sample)
+
+ local output = self.model:forward(self.input):float()
+ local batchSize = output:size(1)
+ local loss = self.criterion:forward(self.model.output, self.target)
+
+ -- Average prediction for regression
+ local softmax = nn.SoftMax():cuda()
+ local outputSoft = softmax:forward(self.model.output)
+ local avgPred = outputSoft * torch.range(1,self.opt.nClasses):cuda()
+ local lossAbs = absCriterion:forward(avgPred, self.target)
+
+ self.model:zeroGradParameters()
+ self.criterion:backward(self.model.output, self.target)
+ self.model:backward(self.input, self.criterion.gradInput)
+
+ if self.opt.optimizer == 'sgd' then
+ optim.sgd(feval, self.params, self.optimState)
+ elseif self.opt.optimizer == 'adam' then
+ optim.adam(feval, self.params, self.optimState)
+ elseif self.opt.optimizer == 'rmsprop' then
+ optim.rmsprop(feval, self.params, self.optimState)
+ elseif self.opt.optimizer == 'adagrad' then
+ optim.adagrad(feval, self.params, self.optimState)
+ elseif self.opt.optimizer == 'lbfgs' then
+ optim.lbfgs(feval, self.params, self.optimState)
+ elseif self.opt.optimizer == 'nag' then
+ optim.nag(feval, self.params, self.optimState)
+ end
+
+ local top1, top5 = self:computeScore(output, sample.target, 1)
+ top1Sum = top1Sum + top1*batchSize
+ top5Sum = top5Sum + top5*batchSize
+ lossSum = lossSum + loss*batchSize
+ lossAbsSum = lossAbsSum + lossAbs*batchSize
+ N = N + batchSize
+
+ if self.opt.showFullOutput then
+ print((' | Epoch: [%d][%d/%d] Time %.3f Data %.3f Err %1.4f top1 %7.3f top5 %7.3f lossAbs %7.3f'):format(
+ epoch, n, trainSize, timer:time().real, dataTime, loss, top1, top5, lossAbs))
+ end
+
+ totalTime = totalTime + timer:time().real
+
+ -- check that the storage didn't get changed due to an unfortunate getParameters call
+ assert(self.params:storage() == self.model:parameters()[1]:storage())
+
+ timer:reset()
+ dataTimer:reset()
+ end
+
+ if self.opt.showFullOutput then
+ print((' * [Train] Finished epoch # %d top1: %7.3f top5: %7.3f loss: %7.3f lossAbs: %7.3f \n'):format(
+ epoch, top1Sum / N, top5Sum / N, lossSum / N, lossAbsSum / N))
+ end
+
+ print((' || Train total time: %2.5f'):format(totalTime))
+
+ return top1Sum / N, top5Sum / N, lossSum / N, lossAbsSum / N
+end
+
+function Trainer:test(epoch, dataloader)
+ -- Computes the top-1 and top-5 err on the validation set
+
+ local totalTime = 0
+ local timer = torch.Timer()
+ local dataTimer = torch.Timer()
+ local size = dataloader:size()
+
+ local nCrops = self.opt.tenCrop and 10 or 1
+ local top1Sum, top5Sum, lossSum, lossAbsSum = 0.0, 0.0, 0.0, 0.0
+ local N = 0
+
+ local absCriterion = nn.AbsCriterion():cuda()
+
+ self.model:evaluate()
+
+ for n, sample in dataloader:run() do
+ local dataTime = dataTimer:time().real
+
+ -- Copy input and target to the GPU
+ self:copyInputs(sample)
+
+ local output = self.model:forward(self.input):float()
+ local batchSize = output:size(1) / nCrops
+ local loss = self.criterion:forward(self.model.output, self.target)
+
+ -- Average prediction for regression
+ local softmax = nn.SoftMax():cuda()
+ local outputSoft = softmax:forward(self.model.output)
+ local avgPred = outputSoft * torch.range(1,self.opt.nClasses):cuda()
+ local lossAbs = absCriterion:forward(avgPred, self.target)
+
+ local top1, top5 = self:computeScore(output, sample.target, nCrops)
+ top1Sum = top1Sum + top1*batchSize
+ top5Sum = top5Sum + top5*batchSize
+ lossSum = lossSum + loss*batchSize
+ lossAbsSum = lossAbsSum + lossAbs*batchSize
+ N = N + batchSize
+
+ if self.opt.showFullOutput then
+ print((' | Test: [%d][%d/%d] Time %.3f Data %.3f top1 %7.3f (%7.3f) top5 %7.3f (%7.3f) loss %7.3f (%7.3f) lossAbs %7.3f (%7.3f)'):format(
+ epoch, n, size, timer:time().real, dataTime, top1, top1Sum / N, top5, top5Sum / N, loss, lossSum / N, lossAbs, lossAbsSum / N))
+ end
+
+ totalTime = totalTime + timer:time().real
+
+ timer:reset()
+ dataTimer:reset()
+ end
+ self.model:training()
+
+ if self.opt.showFullOutput then
+ print((' * [Test] Finished epoch # %d top1: %7.3f top5: %7.3f loss: %7.3f lossAbs: %7.3f \n'):format(
+ epoch, top1Sum / N, top5Sum / N, lossSum / N, lossAbsSum / N))
+ end
+
+ print((' || Test total time: %2.5f'):format(totalTime))
+
+ return top1Sum / N, top5Sum / N, lossSum / N, lossAbsSum / N
+end
+
+function Trainer:computeScore(output, target, nCrops)
+ if nCrops > 1 then
+ -- Sum over crops
+ output = output:view(output:size(1) / nCrops, nCrops, output:size(2))
+ --:exp()
+ :sum(2):squeeze(2)
+ end
+
+ -- Coputes the top1 and top5 error rate
+ local batchSize = output:size(1)
+
+ local _ , predictions = output:float():topk(5, 2, true, true) -- descending
+
+ -- Find which predictions match the target
+ local correct = predictions:eq(target:long():view(batchSize, 1):expandAs(predictions))
+
+ -- Top-1 score
+ local top1 = correct:narrow(2, 1, 1):sum() / batchSize
+
+ -- Top-5 score, if there are at least 5 classes
+ local len = math.min(5, correct:size(2))
+ local top5 = correct:narrow(2, 1, len):sum() / batchSize
+
+ return top1 * 100, top5 * 100
+end
+
+local function getCudaTensorType(tensorType)
+ if tensorType == 'torch.CudaHalfTensor' then
+ return cutorch.createCudaHostHalfTensor()
+ elseif tensorType == 'torch.CudaDoubleTensor' then
+ return cutorch.createCudaHostDoubleTensor()
+ else
+ return cutorch.createCudaHostTensor()
+ end
+end
+
+function Trainer:copyInputs(sample)
+ -- Copies the input to a CUDA tensor, if using 1 GPU, or to pinned memory,
+ -- if using DataParallelTable. The target is always copied to a CUDA tensor
+ self.input = self.input or (self.opt.nGPU == 1
+ and torch[self.opt.tensorType:match('torch.(%a+)')]()
+ or getCudaTensorType(self.opt.tensorType))
+ self.target = self.target or torch.CudaTensor()
+ self.input:resize(sample.input:size()):copy(sample.input)
+ self.target:resize(sample.target:size()):copy(sample.target)
+end
+
+function Trainer:learningRate(epoch)
+ -- Training schedule
+-- local decay = math.floor((epoch - 1) / self.opt.LRDecayStep)
+-- local newLR = self.opt.LR * math.pow(0.1, decay)
+-- print('New LR: ' .. newLR)
+-- return newLR
+-- local decay = 0
+-- if self.opt.optimizer == 'adam' then
+-- decay = 1.0/math.sqrt(epoch)
+-- return self.opt.LR * decay
+-- else
+-- decay = math.floor((epoch - 1) / self.opt.LRDecayStep)
+-- return self.opt.LR * math.pow(0.1, decay)
+-- end
+ return self.opt.LR
+end
+
+return M.Trainer
\ No newline at end of file
diff --git a/utils.lua b/utils.lua
new file mode 100644
index 0000000..a7719c1
--- /dev/null
+++ b/utils.lua
@@ -0,0 +1,101 @@
+require 'cunn'
+local ffi=require 'ffi'
+local util = {}
+
+-- Functions from https://github.com/soumith/imagenet-multiGPU.torch/blob/master/util.lua
+
+function util.makeDataParallel(model, nGPU)
+
+ if nGPU > 1 then
+ print('converting module to nn.DataParallelTable')
+ assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified')
+ if opt.backend == cudnn and opt.cudnnAutotune == 1 then
+ local gpu_table = torch.range(1, nGPU):totable()
+ local dpt = nn.DataParallelTable(1, true):add(model, gpu_table):threads(function()
+ require 'cudnn'
+ require 'stn'
+ cudnn.benchmark = true
+ end)
+
+ dpt.gradInput = nil
+ model = dpt:cuda()
+ else
+ local model_single = model
+ model = nn.DataParallelTable(1)
+ for i=1, nGPU do
+ cutorch.setDevice(i)
+ model:add(model_single:clone():cuda(), i)
+ end
+ cutorch.setDevice(opt.GPU)
+ end
+ else
+ if (opt.backend == cudnn and opt.cudnnAutotune == 1) then
+ require 'cudnn'
+ cudnn.benchmark = true
+ end
+ end
+
+ return model
+end
+
+local function cleanDPT(module)
+ -- This assumes this DPT was created by the function above: all the
+ -- module.modules are clones of the same network on different GPUs
+ -- hence we only need to keep one when saving the model to the disk.
+ local newDPT = nn.DataParallelTable(1)
+ cutorch.setDevice(opt.GPU)
+ newDPT:add(module:get(1), opt.GPU)
+ return newDPT
+end
+
+function util.saveDataParallel(filename, model)
+ if torch.type(model) == 'nn.DataParallelTable' then
+ torch.save(filename, cleanDPT(model))
+ elseif torch.type(model) == 'nn.Sequential' then
+ local temp_model = nn.Sequential()
+ for i, module in ipairs(model.modules) do
+ if torch.type(module) == 'nn.DataParallelTable' then
+ temp_model:add(cleanDPT(module))
+ else
+ temp_model:add(module)
+ end
+ end
+ torch.save(filename, temp_model)
+ else
+ error('This saving function only works with Sequential or DataParallelTable modules.')
+ end
+end
+
+function util.loadDataParallel(filename, nGPU)
+ if opt.backend == cudnn then
+ require 'cudnn'
+ end
+ local model = torch.load(filename)
+ if torch.type(model) == 'nn.DataParallelTable' then
+ return makeDataParallel(model:get(1):float(), nGPU)
+ elseif torch.type(model) == 'nn.Sequential' then
+ for i,module in ipairs(model.modules) do
+ if torch.type(module) == 'nn.DataParallelTable' then
+ model.modules[i] = makeDataParallel(module:get(1):float(), nGPU)
+ end
+ end
+ return model
+ else
+ error('The loaded model is not a Sequential or DataParallelTable module.')
+ end
+end
+
+function util.countParameters(model)
+ local n_parameters = 0
+ for i=1, model:size() do
+ local params = model:get(i):parameters()
+ if params then
+ local weights = params[1]
+ local biases = params[2]
+ n_parameters = n_parameters + weights:nElement() + biases:nElement()
+ end
+ end
+ return n_parameters
+end
+
+return util
\ No newline at end of file