Skip to content

Commit

Permalink
Merge pull request BVLC#2511 from flx42/fix_illegal_mode_changes
Browse files Browse the repository at this point in the history
Fix invalid mode changes during tests
  • Loading branch information
shelhamer committed May 30, 2015
2 parents aeef453 + 68133e7 commit 3cc9bac
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 90 deletions.
28 changes: 15 additions & 13 deletions include/caffe/test/test_caffe_main.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,36 @@ class MultiDeviceTest : public ::testing::Test {

typedef ::testing::Types<float, double> TestDtypes;

struct FloatCPU {
typedef float Dtype;
template <typename TypeParam>
struct CPUDevice {
typedef TypeParam Dtype;
static const Caffe::Brew device = Caffe::CPU;
};

struct DoubleCPU {
typedef double Dtype;
static const Caffe::Brew device = Caffe::CPU;
template <typename Dtype>
class CPUDeviceTest : public MultiDeviceTest<CPUDevice<Dtype> > {
};

#ifdef CPU_ONLY

typedef ::testing::Types<FloatCPU, DoubleCPU> TestDtypesAndDevices;
typedef ::testing::Types<CPUDevice<float>,
CPUDevice<double> > TestDtypesAndDevices;

#else

struct FloatGPU {
typedef float Dtype;
template <typename TypeParam>
struct GPUDevice {
typedef TypeParam Dtype;
static const Caffe::Brew device = Caffe::GPU;
};

struct DoubleGPU {
typedef double Dtype;
static const Caffe::Brew device = Caffe::GPU;
template <typename Dtype>
class GPUDeviceTest : public MultiDeviceTest<GPUDevice<Dtype> > {
};

typedef ::testing::Types<FloatCPU, DoubleCPU, FloatGPU, DoubleGPU>
TestDtypesAndDevices;
typedef ::testing::Types<CPUDevice<float>, CPUDevice<double>,
GPUDevice<float>, GPUDevice<double> >
TestDtypesAndDevices;

#endif

Expand Down
5 changes: 1 addition & 4 deletions src/caffe/test/test_accuracy_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace caffe {

template <typename Dtype>
class AccuracyLayerTest : public ::testing::Test {
class AccuracyLayerTest : public CPUDeviceTest<Dtype> {
protected:
AccuracyLayerTest()
: blob_bottom_data_(new Blob<Dtype>()),
Expand Down Expand Up @@ -92,7 +92,6 @@ TYPED_TEST(AccuracyLayerTest, TestSetupTopK) {

TYPED_TEST(AccuracyLayerTest, TestForwardCPU) {
LayerParameter layer_param;
Caffe::set_mode(Caffe::CPU);
AccuracyLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
Expand All @@ -118,7 +117,6 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPU) {
}

TYPED_TEST(AccuracyLayerTest, TestForwardWithSpatialAxes) {
Caffe::set_mode(Caffe::CPU);
this->blob_bottom_data_->Reshape(2, 10, 4, 5);
vector<int> label_shape(3);
label_shape[0] = 2; label_shape[1] = 4; label_shape[2] = 5;
Expand Down Expand Up @@ -162,7 +160,6 @@ TYPED_TEST(AccuracyLayerTest, TestForwardWithSpatialAxes) {
}

TYPED_TEST(AccuracyLayerTest, TestForwardIgnoreLabel) {
Caffe::set_mode(Caffe::CPU);
LayerParameter layer_param;
const TypeParam kIgnoreLabelValue = -1;
layer_param.mutable_accuracy_param()->set_ignore_label(kIgnoreLabelValue);
Expand Down
3 changes: 1 addition & 2 deletions src/caffe/test/test_argmax_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
namespace caffe {

template <typename Dtype>
class ArgMaxLayerTest : public ::testing::Test {
class ArgMaxLayerTest : public CPUDeviceTest<Dtype> {
protected:
ArgMaxLayerTest()
: blob_bottom_(new Blob<Dtype>(10, 20, 1, 1)),
blob_top_(new Blob<Dtype>()),
top_k_(5) {
Caffe::set_mode(Caffe::CPU);
Caffe::set_random_seed(1701);
// fill the values
FillerParameter filler_param;
Expand Down
9 changes: 2 additions & 7 deletions src/caffe/test/test_convolution_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ TYPED_TEST(ConvolutionLayerTest, TestGradientGroup) {
#ifdef USE_CUDNN

template <typename Dtype>
class CuDNNConvolutionLayerTest : public ::testing::Test {
class CuDNNConvolutionLayerTest : public GPUDeviceTest<Dtype> {
protected:
CuDNNConvolutionLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 6, 4)),
Expand Down Expand Up @@ -467,7 +467,6 @@ class CuDNNConvolutionLayerTest : public ::testing::Test {
TYPED_TEST_CASE(CuDNNConvolutionLayerTest, TestDtypes);

TYPED_TEST(CuDNNConvolutionLayerTest, TestSetupCuDNN) {
Caffe::set_mode(Caffe::GPU);
this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
this->blob_top_vec_.push_back(this->blob_top_2_);
LayerParameter layer_param;
Expand Down Expand Up @@ -505,7 +504,6 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSetupCuDNN) {
}

TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) {
Caffe::set_mode(Caffe::GPU);
this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
this->blob_top_vec_.push_back(this->blob_top_2_);
LayerParameter layer_param;
Expand Down Expand Up @@ -541,7 +539,6 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) {
}

TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
ConvolutionParameter* convolution_param =
layer_param.mutable_convolution_param();
Expand Down Expand Up @@ -572,7 +569,7 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) {
// Test separable convolution by computing the Sobel operator
// as a single filter then comparing the result
// as the convolution of two rectangular filters.
Caffe::set_mode(Caffe::GPU);

// Fill bottoms with identical Gaussian noise.
shared_ptr<GaussianFiller<TypeParam> > filler;
FillerParameter filler_param;
Expand Down Expand Up @@ -665,7 +662,6 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) {
}

TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
ConvolutionParameter* convolution_param =
layer_param.mutable_convolution_param();
Expand All @@ -683,7 +679,6 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientCuDNN) {
}

TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientGroupCuDNN) {
Caffe::set_mode(Caffe::GPU);
LayerParameter layer_param;
ConvolutionParameter* convolution_param =
layer_param.mutable_convolution_param();
Expand Down
5 changes: 1 addition & 4 deletions src/caffe/test/test_dummy_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace caffe {

template <typename Dtype>
class DummyDataLayerTest : public ::testing::Test {
class DummyDataLayerTest : public CPUDeviceTest<Dtype> {
protected:
DummyDataLayerTest()
: blob_top_a_(new Blob<Dtype>()),
Expand Down Expand Up @@ -44,7 +44,6 @@ class DummyDataLayerTest : public ::testing::Test {
TYPED_TEST_CASE(DummyDataLayerTest, TestDtypes);

TYPED_TEST(DummyDataLayerTest, TestOneTopConstant) {
Caffe::set_mode(Caffe::CPU);
LayerParameter param;
DummyDataParameter* dummy_data_param = param.mutable_dummy_data_param();
dummy_data_param->add_num(5);
Expand Down Expand Up @@ -74,7 +73,6 @@ TYPED_TEST(DummyDataLayerTest, TestOneTopConstant) {
}

TYPED_TEST(DummyDataLayerTest, TestTwoTopConstant) {
Caffe::set_mode(Caffe::CPU);
LayerParameter param;
DummyDataParameter* dummy_data_param = param.mutable_dummy_data_param();
dummy_data_param->add_num(5);
Expand Down Expand Up @@ -113,7 +111,6 @@ TYPED_TEST(DummyDataLayerTest, TestTwoTopConstant) {
}

TYPED_TEST(DummyDataLayerTest, TestThreeTopConstantGaussianConstant) {
Caffe::set_mode(Caffe::CPU);
LayerParameter param;
DummyDataParameter* dummy_data_param = param.mutable_dummy_data_param();
dummy_data_param->add_num(5);
Expand Down
4 changes: 1 addition & 3 deletions src/caffe/test/test_im2col_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;

template <typename Dtype>
class Im2colKernelTest : public ::testing::Test {
class Im2colKernelTest : public GPUDeviceTest<Dtype> {
protected:
Im2colKernelTest()
// big so launches > 1024 threads
Expand Down Expand Up @@ -68,8 +68,6 @@ class Im2colKernelTest : public ::testing::Test {
TYPED_TEST_CASE(Im2colKernelTest, TestDtypes);

TYPED_TEST(Im2colKernelTest, TestGPU) {
Caffe::set_mode(Caffe::GPU);

// Reshape the blobs to correct size for im2col output
this->blob_top_->Reshape(this->blob_bottom_->num(),
this->channels_ * this->kernel_size_ * this->kernel_size_,
Expand Down
51 changes: 31 additions & 20 deletions src/caffe/test/test_math_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

namespace caffe {

template<typename Dtype>
class MathFunctionsTest : public ::testing::Test {
template <typename TypeParam>
class MathFunctionsTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;

protected:
MathFunctionsTest()
: blob_bottom_(new Blob<Dtype>()),
Expand Down Expand Up @@ -64,22 +66,27 @@ class MathFunctionsTest : public ::testing::Test {
Blob<Dtype>* const blob_top_;
};

TYPED_TEST_CASE(MathFunctionsTest, TestDtypes);
template <typename Dtype>
class CPUMathFunctionsTest
: public MathFunctionsTest<CPUDevice<Dtype> > {
};

TYPED_TEST_CASE(CPUMathFunctionsTest, TestDtypes);

TYPED_TEST(MathFunctionsTest, TestNothing) {
TYPED_TEST(CPUMathFunctionsTest, TestNothing) {
// The first test case of a test suite takes the longest time
// due to the set up overhead.
}

TYPED_TEST(MathFunctionsTest, TestHammingDistanceCPU) {
TYPED_TEST(CPUMathFunctionsTest, TestHammingDistance) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
const TypeParam* y = this->blob_top_->cpu_data();
EXPECT_EQ(this->ReferenceHammingDistance(n, x, y),
caffe_cpu_hamming_distance<TypeParam>(n, x, y));
}

TYPED_TEST(MathFunctionsTest, TestAsumCPU) {
TYPED_TEST(CPUMathFunctionsTest, TestAsum) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
TypeParam std_asum = 0;
Expand All @@ -90,7 +97,7 @@ TYPED_TEST(MathFunctionsTest, TestAsumCPU) {
EXPECT_LT((cpu_asum - std_asum) / std_asum, 1e-2);
}

TYPED_TEST(MathFunctionsTest, TestSignCPU) {
TYPED_TEST(CPUMathFunctionsTest, TestSign) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_cpu_sign<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
Expand All @@ -100,7 +107,7 @@ TYPED_TEST(MathFunctionsTest, TestSignCPU) {
}
}

TYPED_TEST(MathFunctionsTest, TestSgnbitCPU) {
TYPED_TEST(CPUMathFunctionsTest, TestSgnbit) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_cpu_sgnbit<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
Expand All @@ -110,7 +117,7 @@ TYPED_TEST(MathFunctionsTest, TestSgnbitCPU) {
}
}

TYPED_TEST(MathFunctionsTest, TestFabsCPU) {
TYPED_TEST(CPUMathFunctionsTest, TestFabs) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_abs<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
Expand All @@ -120,7 +127,7 @@ TYPED_TEST(MathFunctionsTest, TestFabsCPU) {
}
}

TYPED_TEST(MathFunctionsTest, TestScaleCPU) {
TYPED_TEST(CPUMathFunctionsTest, TestScale) {
int n = this->blob_bottom_->count();
TypeParam alpha = this->blob_bottom_->cpu_diff()[caffe_rng_rand() %
this->blob_bottom_->count()];
Expand All @@ -133,11 +140,10 @@ TYPED_TEST(MathFunctionsTest, TestScaleCPU) {
}
}

TYPED_TEST(MathFunctionsTest, TestCopyCPU) {
TYPED_TEST(CPUMathFunctionsTest, TestCopy) {
const int n = this->blob_bottom_->count();
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
TypeParam* top_data = this->blob_top_->mutable_cpu_data();
Caffe::set_mode(Caffe::CPU);
caffe_copy(n, bottom_data, top_data);
for (int i = 0; i < n; ++i) {
EXPECT_EQ(bottom_data[i], top_data[i]);
Expand All @@ -146,8 +152,14 @@ TYPED_TEST(MathFunctionsTest, TestCopyCPU) {

#ifndef CPU_ONLY

template <typename Dtype>
class GPUMathFunctionsTest : public MathFunctionsTest<GPUDevice<Dtype> > {
};

TYPED_TEST_CASE(GPUMathFunctionsTest, TestDtypes);

// TODO: Fix caffe_gpu_hamming_distance and re-enable this test.
TYPED_TEST(MathFunctionsTest, DISABLED_TestHammingDistanceGPU) {
TYPED_TEST(GPUMathFunctionsTest, DISABLED_TestHammingDistance) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
const TypeParam* y = this->blob_top_->cpu_data();
Expand All @@ -158,7 +170,7 @@ TYPED_TEST(MathFunctionsTest, DISABLED_TestHammingDistanceGPU) {
EXPECT_EQ(reference_distance, computed_distance);
}

TYPED_TEST(MathFunctionsTest, TestAsumGPU) {
TYPED_TEST(GPUMathFunctionsTest, TestAsum) {
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
TypeParam std_asum = 0;
Expand All @@ -170,7 +182,7 @@ TYPED_TEST(MathFunctionsTest, TestAsumGPU) {
EXPECT_LT((gpu_asum - std_asum) / std_asum, 1e-2);
}

TYPED_TEST(MathFunctionsTest, TestSignGPU) {
TYPED_TEST(GPUMathFunctionsTest, TestSign) {
int n = this->blob_bottom_->count();
caffe_gpu_sign<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
Expand All @@ -181,7 +193,7 @@ TYPED_TEST(MathFunctionsTest, TestSignGPU) {
}
}

TYPED_TEST(MathFunctionsTest, TestSgnbitGPU) {
TYPED_TEST(GPUMathFunctionsTest, TestSgnbit) {
int n = this->blob_bottom_->count();
caffe_gpu_sgnbit<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
Expand All @@ -192,7 +204,7 @@ TYPED_TEST(MathFunctionsTest, TestSgnbitGPU) {
}
}

TYPED_TEST(MathFunctionsTest, TestFabsGPU) {
TYPED_TEST(GPUMathFunctionsTest, TestFabs) {
int n = this->blob_bottom_->count();
caffe_gpu_abs<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
Expand All @@ -203,7 +215,7 @@ TYPED_TEST(MathFunctionsTest, TestFabsGPU) {
}
}

TYPED_TEST(MathFunctionsTest, TestScaleGPU) {
TYPED_TEST(GPUMathFunctionsTest, TestScale) {
int n = this->blob_bottom_->count();
TypeParam alpha = this->blob_bottom_->cpu_diff()[caffe_rng_rand() %
this->blob_bottom_->count()];
Expand All @@ -216,11 +228,10 @@ TYPED_TEST(MathFunctionsTest, TestScaleGPU) {
}
}

TYPED_TEST(MathFunctionsTest, TestCopyGPU) {
TYPED_TEST(GPUMathFunctionsTest, TestCopy) {
const int n = this->blob_bottom_->count();
const TypeParam* bottom_data = this->blob_bottom_->gpu_data();
TypeParam* top_data = this->blob_top_->mutable_gpu_data();
Caffe::set_mode(Caffe::GPU);
caffe_copy(n, bottom_data, top_data);
bottom_data = this->blob_bottom_->cpu_data();
top_data = this->blob_top_->mutable_cpu_data();
Expand Down
3 changes: 1 addition & 2 deletions src/caffe/test/test_multinomial_logistic_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace caffe {

template <typename Dtype>
class MultinomialLogisticLossLayerTest : public ::testing::Test {
class MultinomialLogisticLossLayerTest : public CPUDeviceTest<Dtype> {
protected:
MultinomialLogisticLossLayerTest()
: blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
Expand Down Expand Up @@ -51,7 +51,6 @@ TYPED_TEST_CASE(MultinomialLogisticLossLayerTest, TestDtypes);

TYPED_TEST(MultinomialLogisticLossLayerTest, TestGradientCPU) {
LayerParameter layer_param;
Caffe::set_mode(Caffe::CPU);
MultinomialLogisticLossLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
GradientChecker<TypeParam> checker(1e-2, 2*1e-2, 1701, 0, 0.05);
Expand Down
Loading

0 comments on commit 3cc9bac

Please sign in to comment.