-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
1,932 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,67 @@ | ||
# tsr-torch | ||
Traffic Sign Recognition with Torch [WIP] | ||
# Traffic Sign Recognition | ||
This is the code for the paper | ||
|
||
**[Deep neural network for traffic sign recognition systems: An analysis of spatial transformers and stochastic optimisation methods](https://doi.org/10.1016/j.neunet.2018.01.005)** | ||
<br> | ||
[Álvaro Arcos-García](https://scholar.google.com/citations?user=gjecl3cAAAAJ), | ||
[Juan Antonio Álvarez-García](https://scholar.google.com/citations?user=Qk79xk8AAAAJ), | ||
[Luis M. Soria-Morillo](https://scholar.google.com/citations?user=poBDpFkAAAAJ) | ||
<br> | ||
|
||
The paper addresses the problem of traffic sign classification using a Deep Neural Network which comprises Convolutional layers and Spatial Transfomer Networks. The model reports an accuracy of 99.71% in the [German Traffic Sign Recognition Benchmark](http://benchmark.ini.rub.de/?section=gtsrb&subsection=results). | ||
|
||
We provide: | ||
- A [pretrained model](#pretrained-model). | ||
- Test code to [run the model on new images](#running-on-new-images). | ||
- Instructions for [training the model](#training). | ||
|
||
If you find this code useful in your research, please cite: | ||
|
||
``` | ||
"Deep neural network for traffic sign recognition systems: An analysis of spatial transformers and stochastic optimisation methods." | ||
Álvaro Arcos-García, Juan A. Álvarez-García, Luis M. Soria-Morillo. Neural Networks 99 (2018) 158-165. | ||
``` | ||
\[[link](https://doi.org/10.1016/j.neunet.2018.01.005)\]\[[bibtex]( | ||
https://scholar.googleusercontent.com/citations?view_op=export_citations&user=gjecl3cAAAAJ&citsig=AMstHGQAAAAAW8ngijLlAnMmG2C2_Weu4sB-TWRmdT1P&hl=en)\] | ||
|
||
## Requirements | ||
This project is implemented in [Torch](http://torch.ch/), and depends on the following packages: [torch/torch7](https://github.com/torch/torch7), [torch/nn](https://github.com/torch/nn), [torch/image](https://github.com/torch/image), [qassemoquab/stnbhwd](https://github.com/qassemoquab/stnbhwd), [torch/cutorch](https://github.com/torch/cutorch), [torch/cunn](https://github.com/torch/cunn), [cuDNN bindings for Torch](https://github.com/soumith/cudnn.torch) and [nninit](https://github.com/Kaixhin/nninit). | ||
|
||
After installing torch, you can install / update these dependencies by running the following: | ||
```bash | ||
luarocks install torch | ||
luarocks install nn | ||
luarocks install image | ||
luarocks install cutorch | ||
luarocks install cunn | ||
luarocks install cudnn | ||
luarocks install https://raw.githubusercontent.com/qassemoquab/stnbhwd/master/stnbhwd-scm-1.rockspec | ||
luarocks install https://github.com/Kaixhin/nninit/blob/master/rocks/nninit-scm-1.rockspec | ||
``` | ||
|
||
## Pretrained model | ||
You can download a pretrained model from [Google Drive](https://drive.google.com/uc?export=download&id=1iqX1TZBxJSEmaEoReK2ZDko03JHqw1bp). | ||
Unzip the file `gtsrb_cnn3st_pretrained.zip` and move its content to the folder `pretrained` of this project. It contains the pretrained model that obtains an accuracy of 99.71% in the German Traffic Sign Recognition Benchmark (GTSRB) and a second file with the mean and standard deviation values computed during the training process. | ||
|
||
## Running on new images | ||
To run the model on new images, use the script `run_model.lua`. It will run the pretrained model on the images provided in the `sample_images` folder: | ||
```bash | ||
th run_model.lua | ||
``` | ||
|
||
## Training | ||
To train and experiment with new traffic sign classification models, please follow the following steps: | ||
1. Use the script `download_gtsrb_dataset.lua` which will create a new folder called `GTSRB` that will include two folders: `train` and `val`. | ||
```bash | ||
th download_gtsrb_dataset.lua | ||
``` | ||
2. Use the script `main.lua` setting the training parameters described in the file `opts.lua`. For example, the following options will generate the files needed to train a model with just one spatial transformer network localized at the beginning of the main network: | ||
```bash | ||
th main.lua -data GTSRB -save GTSRB/checkpoints -dataset gtsrb -nClasses 43 -optimizer adam -LR 1e-4 -momentum 0 -weightDecay 0 -batchSize 50 -nEpochs 15 -weightInit default -netType cnn3st -globalNorm -localNorm -cNormConv -locnet2 '' -locnet3 '' -showFullOutput | ||
``` | ||
The next example, will create a model with three spatial transformer networks: | ||
```bash | ||
th main.lua -data GTSRB -save GTSRB/checkpoints -dataset gtsrb -nClasses 43 -optimizer rmsprop -LR 1e-5 -momentum 0 -weightDecay 0 -batchSize 50 -nEpochs 15 -weightInit default -netType cnn3st -globalNorm -localNorm -cNormConv -showFullOutput | ||
``` | ||
## Acknowledgements | ||
The source code of this project is mainly based on [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and [gtsrb.torch](https://github.com/Moodstocks/gtsrb.torch). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
-- | ||
-- Copyright (c) 2016, Facebook, Inc. | ||
-- All rights reserved. | ||
-- | ||
-- This source code is licensed under the BSD-style license found in the | ||
-- LICENSE file in the root directory of this source tree. An additional grant | ||
-- of patent rights can be found in the PATENTS file in the same directory. | ||
-- | ||
local checkpoint = {} | ||
|
||
local function deepCopy(tbl) | ||
-- creates a copy of a network with new modules and the same tensors | ||
local copy = {} | ||
for k, v in pairs(tbl) do | ||
if type(v) == 'table' then | ||
copy[k] = deepCopy(v) | ||
else | ||
copy[k] = v | ||
end | ||
end | ||
if torch.typename(tbl) then | ||
torch.setmetatable(copy, torch.typename(tbl)) | ||
end | ||
return copy | ||
end | ||
|
||
function checkpoint.latest(opt) | ||
if opt.resume == 'none' then | ||
return nil | ||
end | ||
|
||
local latestPath = paths.concat(opt.resume, 'latest.t7') | ||
if not paths.filep(latestPath) then | ||
return nil | ||
end | ||
|
||
print('=> Loading checkpoint ' .. latestPath) | ||
local latest = torch.load(latestPath) | ||
local optimState = torch.load(paths.concat(opt.resume, latest.optimFile)) | ||
|
||
return latest, optimState | ||
end | ||
|
||
function checkpoint.save(epoch, model, optimState, isBestModel, opt) | ||
-- don't save the DataParallelTable for easier loading on other machines | ||
if torch.type(model) == 'nn.DataParallelTable' then | ||
model = model:get(1) | ||
end | ||
|
||
-- create a clean copy on the CPU without modifying the original network | ||
model = deepCopy(model):float():clearState() | ||
|
||
local modelFile = 'model_' .. epoch .. '.t7' | ||
local optimFile = 'optimState_' .. epoch .. '.t7' | ||
|
||
if opt.checkpoint == 'true' then | ||
torch.save(paths.concat(opt.save, modelFile), model) | ||
torch.save(paths.concat(opt.save, optimFile), optimState) | ||
torch.save(paths.concat(opt.save, 'latest.t7'), { | ||
epoch = epoch, | ||
modelFile = modelFile, | ||
optimFile = optimFile, | ||
}) | ||
end | ||
|
||
if isBestModel then | ||
local bestModelFile = 'model_best.t7' | ||
local bestOptimFile = 'model_best_optimState.t7' | ||
torch.save(paths.concat(opt.save, bestModelFile), model) | ||
torch.save(paths.concat(opt.save, bestOptimFile), optimState) | ||
torch.save(paths.concat(opt.save, 'latest.t7'), { | ||
epoch = epoch, | ||
modelFile = bestModelFile, | ||
optimFile = bestOptimFile, | ||
}) | ||
end | ||
end | ||
|
||
return checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
-- | ||
-- Copyright (c) 2016, Facebook, Inc. | ||
-- All rights reserved. | ||
-- | ||
-- This source code is licensed under the BSD-style license found in the | ||
-- LICENSE file in the root directory of this source tree. An additional grant | ||
-- of patent rights can be found in the PATENTS file in the same directory. | ||
-- | ||
-- Multi-threaded data loader | ||
-- | ||
|
||
local datasets = require 'datasets/init' | ||
local Threads = require 'threads' | ||
Threads.serialization('threads.sharedserialize') | ||
|
||
local M = {} | ||
local DataLoader = torch.class('DataLoader', M) | ||
|
||
function DataLoader.create(opt) | ||
-- The train and val loader | ||
local loaders = {} | ||
|
||
for i, split in ipairs{'train', 'val'} do | ||
local dataset = datasets.create(opt, split) | ||
loaders[i] = M.DataLoader(dataset, opt, split) | ||
end | ||
|
||
return table.unpack(loaders) | ||
end | ||
|
||
function DataLoader:__init(dataset, opt, split) | ||
local manualSeed = opt.manualSeed | ||
local function init() | ||
require('datasets/' .. opt.dataset) | ||
end | ||
local function main(idx) | ||
if manualSeed ~= 0 then | ||
torch.manualSeed(manualSeed + idx) | ||
end | ||
torch.setnumthreads(1) | ||
_G.dataset = dataset | ||
_G.preprocess = dataset:preprocess() | ||
return dataset:size() | ||
end | ||
|
||
local threads, sizes = Threads(opt.nThreads, init, main) | ||
self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1 | ||
self.threads = threads | ||
self.__size = sizes[1][1] | ||
self.batchSize = math.floor(opt.batchSize / self.nCrops) | ||
local function getCPUType(tensorType) | ||
if tensorType == 'torch.CudaHalfTensor' then | ||
return 'HalfTensor' | ||
elseif tensorType == 'torch.CudaDoubleTensor' then | ||
return 'DoubleTensor' | ||
else | ||
return 'FloatTensor' | ||
end | ||
end | ||
self.cpuType = getCPUType(opt.tensorType) | ||
end | ||
|
||
function DataLoader:size() | ||
return math.ceil(self.__size / self.batchSize) | ||
end | ||
|
||
function DataLoader:run() | ||
local threads = self.threads | ||
local size, batchSize = self.__size, self.batchSize | ||
local perm = torch.randperm(size) | ||
|
||
local idx, sample = 1, nil | ||
local function enqueue() | ||
while idx <= size and threads:acceptsjob() do | ||
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1)) | ||
threads:addjob( | ||
function(indices, nCrops, cpuType) | ||
local sz = indices:size(1) | ||
local batch, imageSize | ||
local target = torch.IntTensor(sz) | ||
for i, idx in ipairs(indices:totable()) do | ||
local sample = _G.dataset:get(idx) | ||
local input = _G.preprocess(sample.input) | ||
if not batch then | ||
imageSize = input:size():totable() | ||
if nCrops > 1 then table.remove(imageSize, 1) end | ||
batch = torch[cpuType](sz, nCrops, table.unpack(imageSize)) | ||
end | ||
batch[i]:copy(input) | ||
target[i] = sample.target | ||
end | ||
collectgarbage() | ||
return { | ||
input = batch:view(sz * nCrops, table.unpack(imageSize)), | ||
target = target, | ||
} | ||
end, | ||
function(_sample_) | ||
sample = _sample_ | ||
end, | ||
indices, | ||
self.nCrops, | ||
self.cpuType | ||
) | ||
idx = idx + batchSize | ||
end | ||
end | ||
|
||
local n = 0 | ||
local function loop() | ||
enqueue() | ||
if not threads:hasjob() then | ||
return nil | ||
end | ||
threads:dojob() | ||
if threads:haserror() then | ||
threads:synchronize() | ||
end | ||
enqueue() | ||
n = n + 1 | ||
return n, sample | ||
end | ||
|
||
return loop | ||
end | ||
|
||
return M.DataLoader |
Oops, something went wrong.