From 3ec956b5b1023cbd8602bb3a6891d65083360df2 Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Sun, 15 May 2016 13:13:31 +0200 Subject: [PATCH] add dropout --- Dropout.lua | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++ convert.lua | 1 + init.lua | 1 + test/test.lua | 17 ++++++++++ 4 files changed, 113 insertions(+) create mode 100644 Dropout.lua diff --git a/Dropout.lua b/Dropout.lua new file mode 100644 index 0000000..4066a5c --- /dev/null +++ b/Dropout.lua @@ -0,0 +1,94 @@ +local Dropout, parent = torch.class('cudnn.Dropout','nn.Dropout') + +local errcheck = cudnn.errcheck +local ffi = require 'ffi' + +local function getSize(f, desc) + local size = ffi.new'size_t[1]' + errcheck(f, desc, size) + return tonumber(size[0]) +end + +function Dropout:createIODescriptors(input) + assert(input:isContiguous(), 'Non-contiguous inputs not supported yet'); + if not self.inplace then + self.output:resizeAs(input) + end + + local nElem = input:nElement() + self.nElem = self.nElem or nElem -- this goes to the second branch only once + if self.iDesc and nElem == self.nElem then return end + self.nElem = nElem + self.iDesc = cudnn.toDescriptor(input:view(1,1,1,nElem)) + + -- initialize RNG for dropouts lazily (per device) + cudnn.dropout_rng_states = cudnn.dropout_rng_states or {} + local dev = cutorch.getDevice() + if not cudnn.dropout_rng_states[dev] then + local states_size = getSize('cudnnDropoutGetStatesSize', cudnn.getHandle()) + cudnn.dropout_rng_states[dev] = torch.CudaByteTensor(states_size) + end + + if not self.dropDesc then + self.dropDesc = ffi.new('struct cudnnDropoutStruct*[1]') + errcheck('cudnnCreateDropoutDescriptor', self.dropDesc) + local reserves_size = getSize('cudnnDropoutGetReserveSpaceSize', self.iDesc[0]) + self.reserves = self.reserves or torch.CudaByteTensor() + self.reserves = self.reserves:cudaByte():resize(reserves_size) + local state = cudnn.dropout_rng_states[dev] + errcheck('cudnnSetDropoutDescriptor', self.dropDesc[0], + cudnn.getHandle(), self.p, + state:data(), state:nElement(), torch.seed()) + + local function destroyADesc(a) + if (a[0]) then + errcheck('cudnnDestroyDropoutDescriptor', a[0]); + a[0] = nil + end + end + ffi.gc(self.dropDesc, destroyADesc) + end +end + +function Dropout:updateOutput(input) + assert(self.v2) + if self.inplace then + self.output:set(input) + else + self.output:resizeAs(input) + end + self:createIODescriptors(input) + local train = self.p > 0 or self.train + if train then + errcheck('cudnnDropoutForward', cudnn.getHandle(), + self.dropDesc[0], + self.iDesc[0], input:data(), + self.iDesc[0], self.output:data(), + self.reserves:data(), + self.reserves:nElement()) + elseif not self.inplace then + self.output:copy(input) + end + return self.output +end + +function Dropout:updateGradInput(input, gradOutput) + assert(self.train) + if self.inplace then + self.gradInput:set(gradOutput) + else + self.gradInput:resizeAs(gradOutput) + end + if self.p > 0 then + errcheck('cudnnDropoutBackward', cudnn.getHandle(), + self.dropDesc[0], + self.iDesc[0], gradOutput:data(), + self.iDesc[0], self.gradInput:data(), + self.reserves:data(), + self.reserves:nElement()) + elseif not self.inplace then + self.gradInput:copy(self.gradOutput) + end + return self.gradInput +end + diff --git a/convert.lua b/convert.lua index 638928b..c08a9b8 100644 --- a/convert.lua +++ b/convert.lua @@ -1,6 +1,7 @@ -- modules that can be converted to nn seamlessly local layer_list = { 'BatchNormalization', + 'Dropout', 'SpatialBatchNormalization', 'SpatialConvolution', 'SpatialCrossMapLRN', diff --git a/init.lua b/init.lua index 318570b..420b8fd 100644 --- a/init.lua +++ b/init.lua @@ -128,6 +128,7 @@ require('cudnn.RNNReLU') require('cudnn.BLSTM') require('cudnn.LSTM') require('cudnn.GRU') +require('cudnn.Dropout') require('cudnn.functional') require('cudnn.convert') diff --git a/test/test.lua b/test/test.lua index ba1b5a1..3dbe8cd 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1234,6 +1234,23 @@ function cudnntest.SpatialLogSoftMax() mytester:assertlt(err, precision_backward, 'error in difference between central difference and :backward') end +function cudnntest.Dropout() + local p = 0.2 --prob of droping out a neuron + local input = torch.Tensor(1000):fill((1-p)):cuda() + local module = cudnn.Dropout(p):cuda() + -- version 2 + local output = module:forward(input) + mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output') + local gradInput = module:backward(input, input) + mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput') + -- test inplace version + local module = cudnn.Dropout(p,nil,true):cuda() + local output = module:forward(input:clone()) + mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output') + local gradInput = module:backward(input:clone(), input:clone()) + mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput') +end + local function testBatchNormalization(moduleName, inputSize) local input = torch.randn(table.unpack(inputSize)):cuda() local gradOutput = torch.randn(table.unpack(inputSize)):cuda()