From f73ed4dfcb2c5271da1268c0fa14db9acfdabf4f Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Wed, 4 Mar 2015 15:30:06 -0500 Subject: [PATCH] Update cuDNN layers to use v2 API --- include/caffe/common_layers.hpp | 4 +- include/caffe/neuron_layers.hpp | 12 ++-- include/caffe/util/cudnn.hpp | 18 +++--- include/caffe/vision_layers.hpp | 9 ++- src/caffe/layers/cudnn_conv_layer.cpp | 15 +++-- src/caffe/layers/cudnn_conv_layer.cu | 75 +++++++++++++++++++----- src/caffe/layers/cudnn_pooling_layer.cpp | 4 +- src/caffe/layers/cudnn_pooling_layer.cu | 16 +++-- src/caffe/layers/cudnn_relu_layer.cpp | 4 +- src/caffe/layers/cudnn_relu_layer.cu | 16 +++-- src/caffe/layers/cudnn_sigmoid_layer.cpp | 4 +- src/caffe/layers/cudnn_sigmoid_layer.cu | 16 +++-- src/caffe/layers/cudnn_softmax_layer.cpp | 4 +- src/caffe/layers/cudnn_softmax_layer.cu | 16 +++-- src/caffe/layers/cudnn_tanh_layer.cpp | 4 +- src/caffe/layers/cudnn_tanh_layer.cu | 16 +++-- 16 files changed, 162 insertions(+), 71 deletions(-) diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index c67822c3738..d8867519ae5 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -392,8 +392,8 @@ class CuDNNSoftmaxLayer : public SoftmaxLayer { bool handles_setup_; cudnnHandle_t handle_; - cudnnTensor4dDescriptor_t bottom_desc_; - cudnnTensor4dDescriptor_t top_desc_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; }; #endif diff --git a/include/caffe/neuron_layers.hpp b/include/caffe/neuron_layers.hpp index 0c306fb41bf..6cefc5d9396 100644 --- a/include/caffe/neuron_layers.hpp +++ b/include/caffe/neuron_layers.hpp @@ -433,8 +433,8 @@ class CuDNNReLULayer : public ReLULayer { bool handles_setup_; cudnnHandle_t handle_; - cudnnTensor4dDescriptor_t bottom_desc_; - cudnnTensor4dDescriptor_t top_desc_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; }; #endif @@ -516,8 +516,8 @@ class CuDNNSigmoidLayer : public SigmoidLayer { bool handles_setup_; cudnnHandle_t handle_; - cudnnTensor4dDescriptor_t bottom_desc_; - cudnnTensor4dDescriptor_t top_desc_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; }; #endif @@ -601,8 +601,8 @@ class CuDNNTanHLayer : public TanHLayer { bool handles_setup_; cudnnHandle_t handle_; - cudnnTensor4dDescriptor_t bottom_desc_; - cudnnTensor4dDescriptor_t top_desc_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; }; #endif diff --git a/include/caffe/util/cudnn.hpp b/include/caffe/util/cudnn.hpp index eaed7333df8..e3575ece050 100644 --- a/include/caffe/util/cudnn.hpp +++ b/include/caffe/util/cudnn.hpp @@ -57,12 +57,12 @@ template<> class dataType { }; template -inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc) { - CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc)); +inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) { + CUDNN_CHECK(cudnnCreateTensorDescriptor(desc)); } template -inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc, +inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc, int n, int c, int h, int w, int stride_n, int stride_c, int stride_h, int stride_w) { CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType::type, @@ -70,7 +70,7 @@ inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc, } template -inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc, +inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc, int n, int c, int h, int w) { const int stride_w = 1; const int stride_h = w * stride_w; @@ -84,7 +84,7 @@ template inline void createFilterDesc(cudnnFilterDescriptor_t* desc, int n, int c, int h, int w) { CUDNN_CHECK(cudnnCreateFilterDescriptor(desc)); - CUDNN_CHECK(cudnnSetFilterDescriptor(*desc, dataType::type, + CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType::type, n, c, h, w)); } @@ -95,9 +95,9 @@ inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) { template inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv, - cudnnTensor4dDescriptor_t bottom, cudnnFilterDescriptor_t filter, + cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter, int pad_h, int pad_w, int stride_h, int stride_w) { - CUDNN_CHECK(cudnnSetConvolutionDescriptor(*conv, bottom, filter, + CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv, pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION)); } @@ -110,13 +110,13 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* conv, *mode = CUDNN_POOLING_MAX; break; case PoolingParameter_PoolMethod_AVE: - *mode = CUDNN_POOLING_AVERAGE; + *mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; break; default: LOG(FATAL) << "Unknown pooling method."; } CUDNN_CHECK(cudnnCreatePoolingDescriptor(conv)); - CUDNN_CHECK(cudnnSetPoolingDescriptor(*conv, *mode, h, w, + CUDNN_CHECK(cudnnSetPooling2dDescriptor(*conv, *mode, h, w, 0, 0, stride_h, stride_w)); } diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 6cb507a5780..8700430efce 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -246,11 +246,14 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { bool handles_setup_; cudnnHandle_t* handle_; cudaStream_t* stream_; - vector bottom_descs_, top_descs_; - cudnnTensor4dDescriptor_t bias_desc_; + vector bottom_descs_, top_descs_; + cudnnTensorDescriptor_t bias_desc_; cudnnFilterDescriptor_t filter_desc_; vector conv_descs_; int bottom_offset_, top_offset_, weight_offset_, bias_offset_; + + size_t workspaceSizeInBytes; + void *workspace; }; #endif @@ -445,7 +448,7 @@ class CuDNNPoolingLayer : public PoolingLayer { bool handles_setup_; cudnnHandle_t handle_; - cudnnTensor4dDescriptor_t bottom_desc_, top_desc_; + cudnnTensorDescriptor_t bottom_desc_, top_desc_; cudnnPoolingDescriptor_t pooling_desc_; cudnnPoolingMode_t mode_; }; diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp index 4a69ca20d0a..7c8e8977c7e 100644 --- a/src/caffe/layers/cudnn_conv_layer.cpp +++ b/src/caffe/layers/cudnn_conv_layer.cpp @@ -25,6 +25,9 @@ void CuDNNConvolutionLayer::LayerSetUp( stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP]; handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP]; + workspace = NULL; + workspaceSizeInBytes = (size_t)0; + for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) { CUDA_CHECK(cudaStreamCreate(&stream_[g])); CUDNN_CHECK(cudnnCreate(&handle_[g])); @@ -43,10 +46,10 @@ void CuDNNConvolutionLayer::LayerSetUp( // Create tensor descriptor(s) for data and corresponding convolution(s). for (int i = 0; i < bottom.size(); i++) { - cudnnTensor4dDescriptor_t bottom_desc; + cudnnTensorDescriptor_t bottom_desc; cudnn::createTensor4dDesc(&bottom_desc); bottom_descs_.push_back(bottom_desc); - cudnnTensor4dDescriptor_t top_desc; + cudnnTensorDescriptor_t top_desc; cudnn::createTensor4dDesc(&top_desc); top_descs_.push_back(top_desc); cudnnConvolutionDescriptor_t conv_desc; @@ -104,12 +107,12 @@ CuDNNConvolutionLayer::~CuDNNConvolutionLayer() { if (!handles_setup_) { return; } for (int i = 0; i < bottom_descs_.size(); i++) { - cudnnDestroyTensor4dDescriptor(bottom_descs_[i]); - cudnnDestroyTensor4dDescriptor(top_descs_[i]); + cudnnDestroyTensorDescriptor(bottom_descs_[i]); + cudnnDestroyTensorDescriptor(top_descs_[i]); cudnnDestroyConvolutionDescriptor(conv_descs_[i]); } if (this->bias_term_) { - cudnnDestroyTensor4dDescriptor(bias_desc_); + cudnnDestroyTensorDescriptor(bias_desc_); } cudnnDestroyFilterDescriptor(filter_desc_); @@ -118,6 +121,8 @@ CuDNNConvolutionLayer::~CuDNNConvolutionLayer() { cudnnDestroy(handle_[g]); } + if (workspace) cudaFree(workspace); + delete [] stream_; delete [] handle_; } diff --git a/src/caffe/layers/cudnn_conv_layer.cu b/src/caffe/layers/cudnn_conv_layer.cu index 071014e1b48..ae9033fc68b 100644 --- a/src/caffe/layers/cudnn_conv_layer.cu +++ b/src/caffe/layers/cudnn_conv_layer.cu @@ -21,20 +21,57 @@ void CuDNNConvolutionLayer::Forward_gpu( // Forward through cuDNN in parallel over groups. for (int g = 0; g < this->group_; g++) { + const Dtype alpha = 1.0; + const Dtype beta = 0.0; + + cudnnConvolutionFwdAlgo_t algo; + + CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[g], + bottom_descs_[i], + filter_desc_, + conv_descs_[i], + top_descs_[i], + CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, + 0, // memoryLimitInBytes, + &algo)); + + size_t workspaceSizeInBytes_temp = 0; + + CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[g], + bottom_descs_[i], + filter_desc_, + conv_descs_[i], + top_descs_[i], + algo, + &workspaceSizeInBytes_temp)); + + if (workspaceSizeInBytes_temp > workspaceSizeInBytes) { + workspaceSizeInBytes = workspaceSizeInBytes_temp; + // free the existing workspace and allocate a new (larger) one + if (this->workspace != NULL) { + cudaFree(this->workspace); + } + cudaMalloc(&(this->workspace), workspaceSizeInBytes); + CUDA_POST_KERNEL_CHECK; + } + // Filters. - CUDNN_CHECK(cudnnConvolutionForward(handle_[g], + CUDNN_CHECK(cudnnConvolutionForward(handle_[g], (void *)(&alpha), bottom_descs_[i], bottom_data + bottom_offset_ * g, filter_desc_, weight + weight_offset_ * g, conv_descs_[i], - top_descs_[i], top_data + top_offset_ * g, - CUDNN_RESULT_NO_ACCUMULATE)); + algo, workspace, workspaceSizeInBytes, + (void *)(&beta), + top_descs_[i], top_data + top_offset_ * g) ); // Bias. if (this->bias_term_) { const Dtype* bias_data = this->blobs_[1]->gpu_data(); - Dtype alpha = 1.; - CUDNN_CHECK(cudnnAddTensor4d(handle_[g], CUDNN_ADD_SAME_C, &alpha, + Dtype alpha = 1.0; + Dtype beta = 1.0; + CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C, (void *)(&alpha), bias_desc_, bias_data + bias_offset_ * g, + (void *)(&beta), top_descs_[i], top_data + top_offset_ * g)); } } @@ -66,36 +103,42 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, // Backward through cuDNN in parallel over groups and gradients. for (int g = 0; g < this->group_; g++) { // Gradient w.r.t. bias. + if (this->bias_term_ && this->param_propagate_down_[1]) { + Dtype alpha = 1.0; + Dtype beta = 1.0; CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0*this->group_ + g], + (void *)(&alpha), top_descs_[i], top_diff + top_offset_ * g, - bias_desc_, bias_diff + bias_offset_ * g, - CUDNN_RESULT_ACCUMULATE)); + (void *)(&beta), + bias_desc_, bias_diff + bias_offset_ * g) ); } // Gradient w.r.t. weights. if (this->param_propagate_down_[0]) { + Dtype alpha = 1.0; + Dtype beta = 1.0; const Dtype* bottom_data = bottom[i]->gpu_data(); CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g], + (void *)(&alpha), bottom_descs_[i], bottom_data + bottom_offset_ * g, top_descs_[i], top_diff + top_offset_ * g, - conv_descs_[i], - filter_desc_, weight_diff + weight_offset_ * g, - CUDNN_RESULT_ACCUMULATE)); + conv_descs_[i], (void *)(&beta), + filter_desc_, weight_diff + weight_offset_ * g) ); } // Gradient w.r.t. bottom data. if (propagate_down[i]) { - if (weight == NULL) { - weight = this->blobs_[0]->gpu_data(); - } + Dtype alpha = 1.0; + Dtype beta = 0.0; + Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g], + (void *)(&alpha), filter_desc_, weight + weight_offset_ * g, top_descs_[i], top_diff + top_offset_ * g, - conv_descs_[i], - bottom_descs_[i], bottom_diff + bottom_offset_ * g, - CUDNN_RESULT_NO_ACCUMULATE)); + conv_descs_[i], (void *)(&beta), + bottom_descs_[i], bottom_diff + bottom_offset_ * g) ); } } diff --git a/src/caffe/layers/cudnn_pooling_layer.cpp b/src/caffe/layers/cudnn_pooling_layer.cpp index dd90195637b..b447f19b426 100644 --- a/src/caffe/layers/cudnn_pooling_layer.cpp +++ b/src/caffe/layers/cudnn_pooling_layer.cpp @@ -40,8 +40,8 @@ CuDNNPoolingLayer::~CuDNNPoolingLayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } - cudnnDestroyTensor4dDescriptor(bottom_desc_); - cudnnDestroyTensor4dDescriptor(top_desc_); + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); cudnnDestroyPoolingDescriptor(pooling_desc_); cudnnDestroy(handle_); } diff --git a/src/caffe/layers/cudnn_pooling_layer.cu b/src/caffe/layers/cudnn_pooling_layer.cu index 1c113aad75f..8fd0458c023 100644 --- a/src/caffe/layers/cudnn_pooling_layer.cu +++ b/src/caffe/layers/cudnn_pooling_layer.cu @@ -14,8 +14,12 @@ void CuDNNPoolingLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); - CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_, - bottom_desc_, bottom_data, top_desc_, top_data)); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + + CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_, (void *)(&alpha), + bottom_desc_, bottom_data, (void *)(&beta), top_desc_, top_data)); } template @@ -28,9 +32,13 @@ void CuDNNPoolingLayer::Backward_gpu(const vector*>& top, const Dtype* top_data = top[0]->gpu_data(); const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); - CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_, + + Dtype alpha = 1.0; + Dtype beta = 0.0; + + CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_, (void *)(&alpha), top_desc_, top_data, top_desc_, top_diff, - bottom_desc_, bottom_data, bottom_desc_, bottom_diff)); + bottom_desc_, bottom_data, (void *)(&beta), bottom_desc_, bottom_diff)); } INSTANTIATE_LAYER_GPU_FUNCS(CuDNNPoolingLayer); diff --git a/src/caffe/layers/cudnn_relu_layer.cpp b/src/caffe/layers/cudnn_relu_layer.cpp index 0b8a6bc3248..759d83984ef 100644 --- a/src/caffe/layers/cudnn_relu_layer.cpp +++ b/src/caffe/layers/cudnn_relu_layer.cpp @@ -35,8 +35,8 @@ CuDNNReLULayer::~CuDNNReLULayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } - cudnnDestroyTensor4dDescriptor(this->bottom_desc_); - cudnnDestroyTensor4dDescriptor(this->top_desc_); + cudnnDestroyTensorDescriptor(this->bottom_desc_); + cudnnDestroyTensorDescriptor(this->top_desc_); cudnnDestroy(this->handle_); } diff --git a/src/caffe/layers/cudnn_relu_layer.cu b/src/caffe/layers/cudnn_relu_layer.cu index 862508707a0..047c8d6c6bf 100644 --- a/src/caffe/layers/cudnn_relu_layer.cu +++ b/src/caffe/layers/cudnn_relu_layer.cu @@ -17,9 +17,13 @@ void CuDNNReLULayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + CUDNN_CHECK(cudnnActivationForward(this->handle_, - CUDNN_ACTIVATION_RELU, - this->bottom_desc_, bottom_data, this->top_desc_, top_data)); + CUDNN_ACTIVATION_RELU, (void *)(&alpha), + this->bottom_desc_, bottom_data, (void *)(&beta), this->top_desc_, top_data)); } template @@ -39,10 +43,14 @@ void CuDNNReLULayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + CUDNN_CHECK(cudnnActivationBackward(this->handle_, - CUDNN_ACTIVATION_RELU, + CUDNN_ACTIVATION_RELU, (void *)(&alpha), this->top_desc_, top_data, this->top_desc_, top_diff, - this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff)); + this->bottom_desc_, bottom_data, (void *)(&beta), this->bottom_desc_, bottom_diff)); } INSTANTIATE_LAYER_GPU_FUNCS(CuDNNReLULayer); diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cpp b/src/caffe/layers/cudnn_sigmoid_layer.cpp index 67bd9c373b0..32637873d46 100644 --- a/src/caffe/layers/cudnn_sigmoid_layer.cpp +++ b/src/caffe/layers/cudnn_sigmoid_layer.cpp @@ -35,8 +35,8 @@ CuDNNSigmoidLayer::~CuDNNSigmoidLayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } - cudnnDestroyTensor4dDescriptor(this->bottom_desc_); - cudnnDestroyTensor4dDescriptor(this->top_desc_); + cudnnDestroyTensorDescriptor(this->bottom_desc_); + cudnnDestroyTensorDescriptor(this->top_desc_); cudnnDestroy(this->handle_); } diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cu b/src/caffe/layers/cudnn_sigmoid_layer.cu index 31b094e25d4..f3bbd067584 100644 --- a/src/caffe/layers/cudnn_sigmoid_layer.cu +++ b/src/caffe/layers/cudnn_sigmoid_layer.cu @@ -12,9 +12,13 @@ void CuDNNSigmoidLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + CUDNN_CHECK(cudnnActivationForward(this->handle_, - CUDNN_ACTIVATION_SIGMOID, - this->bottom_desc_, bottom_data, this->top_desc_, top_data)); + CUDNN_ACTIVATION_SIGMOID, (void *)(&alpha), + this->bottom_desc_, bottom_data, (void *)(&beta), this->top_desc_, top_data)); } template @@ -29,10 +33,14 @@ void CuDNNSigmoidLayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + CUDNN_CHECK(cudnnActivationBackward(this->handle_, - CUDNN_ACTIVATION_SIGMOID, + CUDNN_ACTIVATION_SIGMOID, (void *)(&alpha), this->top_desc_, top_data, this->top_desc_, top_diff, - this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff)); + this->bottom_desc_, bottom_data, (void *)(&beta), this->bottom_desc_, bottom_diff)); } INSTANTIATE_LAYER_GPU_FUNCS(CuDNNSigmoidLayer); diff --git a/src/caffe/layers/cudnn_softmax_layer.cpp b/src/caffe/layers/cudnn_softmax_layer.cpp index 83a5b69a626..45346bc5687 100644 --- a/src/caffe/layers/cudnn_softmax_layer.cpp +++ b/src/caffe/layers/cudnn_softmax_layer.cpp @@ -39,8 +39,8 @@ CuDNNSoftmaxLayer::~CuDNNSoftmaxLayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } - cudnnDestroyTensor4dDescriptor(bottom_desc_); - cudnnDestroyTensor4dDescriptor(top_desc_); + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); cudnnDestroy(handle_); } diff --git a/src/caffe/layers/cudnn_softmax_layer.cu b/src/caffe/layers/cudnn_softmax_layer.cu index f328afdd831..b08f4058552 100644 --- a/src/caffe/layers/cudnn_softmax_layer.cu +++ b/src/caffe/layers/cudnn_softmax_layer.cu @@ -16,9 +16,13 @@ void CuDNNSoftmaxLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + CUDNN_CHECK(cudnnSoftmaxForward(handle_, CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - bottom_desc_, bottom_data, top_desc_, top_data)); + CUDNN_SOFTMAX_MODE_CHANNEL, (void *)(&alpha), + bottom_desc_, bottom_data, (void *)(&beta), top_desc_, top_data)); } template @@ -29,9 +33,13 @@ void CuDNNSoftmaxLayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + CUDNN_CHECK(cudnnSoftmaxBackward(handle_, CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - top_desc_, top_data, top_desc_, top_diff, bottom_desc_, bottom_diff)); + CUDNN_SOFTMAX_MODE_CHANNEL, (void *)(&alpha), + top_desc_, top_data, top_desc_, top_diff, (void *)(&beta), bottom_desc_, bottom_diff)); } } diff --git a/src/caffe/layers/cudnn_tanh_layer.cpp b/src/caffe/layers/cudnn_tanh_layer.cpp index b1d2b86384e..376faad324d 100644 --- a/src/caffe/layers/cudnn_tanh_layer.cpp +++ b/src/caffe/layers/cudnn_tanh_layer.cpp @@ -35,8 +35,8 @@ CuDNNTanHLayer::~CuDNNTanHLayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } - cudnnDestroyTensor4dDescriptor(this->bottom_desc_); - cudnnDestroyTensor4dDescriptor(this->top_desc_); + cudnnDestroyTensorDescriptor(this->bottom_desc_); + cudnnDestroyTensorDescriptor(this->top_desc_); cudnnDestroy(this->handle_); } diff --git a/src/caffe/layers/cudnn_tanh_layer.cu b/src/caffe/layers/cudnn_tanh_layer.cu index bf9ec7cfac4..c70c064e81a 100644 --- a/src/caffe/layers/cudnn_tanh_layer.cu +++ b/src/caffe/layers/cudnn_tanh_layer.cu @@ -12,9 +12,13 @@ void CuDNNTanHLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + CUDNN_CHECK(cudnnActivationForward(this->handle_, - CUDNN_ACTIVATION_TANH, - this->bottom_desc_, bottom_data, this->top_desc_, top_data)); + CUDNN_ACTIVATION_TANH, (void *)(&alpha), + this->bottom_desc_, bottom_data, (void *)(&beta), this->top_desc_, top_data)); } template @@ -29,10 +33,14 @@ void CuDNNTanHLayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + CUDNN_CHECK(cudnnActivationBackward(this->handle_, - CUDNN_ACTIVATION_TANH, + CUDNN_ACTIVATION_TANH, (void *)(&alpha), this->top_desc_, top_data, this->top_desc_, top_diff, - this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff)); + this->bottom_desc_, bottom_data, (void *)(&beta), this->bottom_desc_, bottom_diff)); } INSTANTIATE_LAYER_GPU_FUNCS(CuDNNTanHLayer);