Skip to content

Commit

Permalink
Updated README and source code
Browse files Browse the repository at this point in the history
  • Loading branch information
aarcosg committed Oct 18, 2018
1 parent 71ce2a7 commit 01935fe
Show file tree
Hide file tree
Showing 15 changed files with 1,932 additions and 2 deletions.
69 changes: 67 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,67 @@
# tsr-torch
Traffic Sign Recognition with Torch [WIP]
# Traffic Sign Recognition
This is the code for the paper

**[Deep neural network for traffic sign recognition systems: An analysis of spatial transformers and stochastic optimisation methods](https://doi.org/10.1016/j.neunet.2018.01.005)**
<br>
[Álvaro Arcos-García](https://scholar.google.com/citations?user=gjecl3cAAAAJ),
[Juan Antonio Álvarez-García](https://scholar.google.com/citations?user=Qk79xk8AAAAJ),
[Luis M. Soria-Morillo](https://scholar.google.com/citations?user=poBDpFkAAAAJ)
<br>

The paper addresses the problem of traffic sign classification using a Deep Neural Network which comprises Convolutional layers and Spatial Transfomer Networks. The model reports an accuracy of 99.71% in the [German Traffic Sign Recognition Benchmark](http://benchmark.ini.rub.de/?section=gtsrb&subsection=results).

We provide:
- A [pretrained model](#pretrained-model).
- Test code to [run the model on new images](#running-on-new-images).
- Instructions for [training the model](#training).

If you find this code useful in your research, please cite:

```
"Deep neural network for traffic sign recognition systems: An analysis of spatial transformers and stochastic optimisation methods."
Álvaro Arcos-García, Juan A. Álvarez-García, Luis M. Soria-Morillo. Neural Networks 99 (2018) 158-165.
```
\[[link](https://doi.org/10.1016/j.neunet.2018.01.005)\]\[[bibtex](
https://scholar.googleusercontent.com/citations?view_op=export_citations&user=gjecl3cAAAAJ&citsig=AMstHGQAAAAAW8ngijLlAnMmG2C2_Weu4sB-TWRmdT1P&hl=en)\]

## Requirements
This project is implemented in [Torch](http://torch.ch/), and depends on the following packages: [torch/torch7](https://github.com/torch/torch7), [torch/nn](https://github.com/torch/nn), [torch/image](https://github.com/torch/image), [qassemoquab/stnbhwd](https://github.com/qassemoquab/stnbhwd), [torch/cutorch](https://github.com/torch/cutorch), [torch/cunn](https://github.com/torch/cunn), [cuDNN bindings for Torch](https://github.com/soumith/cudnn.torch) and [nninit](https://github.com/Kaixhin/nninit).

After installing torch, you can install / update these dependencies by running the following:
```bash
luarocks install torch
luarocks install nn
luarocks install image
luarocks install cutorch
luarocks install cunn
luarocks install cudnn
luarocks install https://raw.githubusercontent.com/qassemoquab/stnbhwd/master/stnbhwd-scm-1.rockspec
luarocks install https://github.com/Kaixhin/nninit/blob/master/rocks/nninit-scm-1.rockspec
```

## Pretrained model
You can download a pretrained model from [Google Drive](https://drive.google.com/uc?export=download&id=1iqX1TZBxJSEmaEoReK2ZDko03JHqw1bp).
Unzip the file `gtsrb_cnn3st_pretrained.zip` and move its content to the folder `pretrained` of this project. It contains the pretrained model that obtains an accuracy of 99.71% in the German Traffic Sign Recognition Benchmark (GTSRB) and a second file with the mean and standard deviation values computed during the training process.

## Running on new images
To run the model on new images, use the script `run_model.lua`. It will run the pretrained model on the images provided in the `sample_images` folder:
```bash
th run_model.lua
```

## Training
To train and experiment with new traffic sign classification models, please follow the following steps:
1. Use the script `download_gtsrb_dataset.lua` which will create a new folder called `GTSRB` that will include two folders: `train` and `val`.
```bash
th download_gtsrb_dataset.lua
```
2. Use the script `main.lua` setting the training parameters described in the file `opts.lua`. For example, the following options will generate the files needed to train a model with just one spatial transformer network localized at the beginning of the main network:
```bash
th main.lua -data GTSRB -save GTSRB/checkpoints -dataset gtsrb -nClasses 43 -optimizer adam -LR 1e-4 -momentum 0 -weightDecay 0 -batchSize 50 -nEpochs 15 -weightInit default -netType cnn3st -globalNorm -localNorm -cNormConv -locnet2 '' -locnet3 '' -showFullOutput
```
The next example, will create a model with three spatial transformer networks:
```bash
th main.lua -data GTSRB -save GTSRB/checkpoints -dataset gtsrb -nClasses 43 -optimizer rmsprop -LR 1e-5 -momentum 0 -weightDecay 0 -batchSize 50 -nEpochs 15 -weightInit default -netType cnn3st -globalNorm -localNorm -cNormConv -showFullOutput
```
## Acknowledgements
The source code of this project is mainly based on [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and [gtsrb.torch](https://github.com/Moodstocks/gtsrb.torch).
79 changes: 79 additions & 0 deletions checkpoints.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
local checkpoint = {}

local function deepCopy(tbl)
-- creates a copy of a network with new modules and the same tensors
local copy = {}
for k, v in pairs(tbl) do
if type(v) == 'table' then
copy[k] = deepCopy(v)
else
copy[k] = v
end
end
if torch.typename(tbl) then
torch.setmetatable(copy, torch.typename(tbl))
end
return copy
end

function checkpoint.latest(opt)
if opt.resume == 'none' then
return nil
end

local latestPath = paths.concat(opt.resume, 'latest.t7')
if not paths.filep(latestPath) then
return nil
end

print('=> Loading checkpoint ' .. latestPath)
local latest = torch.load(latestPath)
local optimState = torch.load(paths.concat(opt.resume, latest.optimFile))

return latest, optimState
end

function checkpoint.save(epoch, model, optimState, isBestModel, opt)
-- don't save the DataParallelTable for easier loading on other machines
if torch.type(model) == 'nn.DataParallelTable' then
model = model:get(1)
end

-- create a clean copy on the CPU without modifying the original network
model = deepCopy(model):float():clearState()

local modelFile = 'model_' .. epoch .. '.t7'
local optimFile = 'optimState_' .. epoch .. '.t7'

if opt.checkpoint == 'true' then
torch.save(paths.concat(opt.save, modelFile), model)
torch.save(paths.concat(opt.save, optimFile), optimState)
torch.save(paths.concat(opt.save, 'latest.t7'), {
epoch = epoch,
modelFile = modelFile,
optimFile = optimFile,
})
end

if isBestModel then
local bestModelFile = 'model_best.t7'
local bestOptimFile = 'model_best_optimState.t7'
torch.save(paths.concat(opt.save, bestModelFile), model)
torch.save(paths.concat(opt.save, bestOptimFile), optimState)
torch.save(paths.concat(opt.save, 'latest.t7'), {
epoch = epoch,
modelFile = bestModelFile,
optimFile = bestOptimFile,
})
end
end

return checkpoint
127 changes: 127 additions & 0 deletions dataloader.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Multi-threaded data loader
--

local datasets = require 'datasets/init'
local Threads = require 'threads'
Threads.serialization('threads.sharedserialize')

local M = {}
local DataLoader = torch.class('DataLoader', M)

function DataLoader.create(opt)
-- The train and val loader
local loaders = {}

for i, split in ipairs{'train', 'val'} do
local dataset = datasets.create(opt, split)
loaders[i] = M.DataLoader(dataset, opt, split)
end

return table.unpack(loaders)
end

function DataLoader:__init(dataset, opt, split)
local manualSeed = opt.manualSeed
local function init()
require('datasets/' .. opt.dataset)
end
local function main(idx)
if manualSeed ~= 0 then
torch.manualSeed(manualSeed + idx)
end
torch.setnumthreads(1)
_G.dataset = dataset
_G.preprocess = dataset:preprocess()
return dataset:size()
end

local threads, sizes = Threads(opt.nThreads, init, main)
self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1
self.threads = threads
self.__size = sizes[1][1]
self.batchSize = math.floor(opt.batchSize / self.nCrops)
local function getCPUType(tensorType)
if tensorType == 'torch.CudaHalfTensor' then
return 'HalfTensor'
elseif tensorType == 'torch.CudaDoubleTensor' then
return 'DoubleTensor'
else
return 'FloatTensor'
end
end
self.cpuType = getCPUType(opt.tensorType)
end

function DataLoader:size()
return math.ceil(self.__size / self.batchSize)
end

function DataLoader:run()
local threads = self.threads
local size, batchSize = self.__size, self.batchSize
local perm = torch.randperm(size)

local idx, sample = 1, nil
local function enqueue()
while idx <= size and threads:acceptsjob() do
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
threads:addjob(
function(indices, nCrops, cpuType)
local sz = indices:size(1)
local batch, imageSize
local target = torch.IntTensor(sz)
for i, idx in ipairs(indices:totable()) do
local sample = _G.dataset:get(idx)
local input = _G.preprocess(sample.input)
if not batch then
imageSize = input:size():totable()
if nCrops > 1 then table.remove(imageSize, 1) end
batch = torch[cpuType](sz, nCrops, table.unpack(imageSize))
end
batch[i]:copy(input)
target[i] = sample.target
end
collectgarbage()
return {
input = batch:view(sz * nCrops, table.unpack(imageSize)),
target = target,
}
end,
function(_sample_)
sample = _sample_
end,
indices,
self.nCrops,
self.cpuType
)
idx = idx + batchSize
end
end

local n = 0
local function loop()
enqueue()
if not threads:hasjob() then
return nil
end
threads:dojob()
if threads:haserror() then
threads:synchronize()
end
enqueue()
n = n + 1
return n, sample
end

return loop
end

return M.DataLoader
Loading

0 comments on commit 01935fe

Please sign in to comment.