Skip to content

Commit

Permalink
Adopt new dispatch macro so it compiles with pytorch 1.8
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jan 13, 2021
1 parent a5d7ca6 commit e417c92
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions csrc/cuda/butterfly_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
#include "map.h"

// Only support float (not double) for now to speed up compilation time
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#undef AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::ComplexFloat, c10::complex<float>, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::ComplexFloat, c10::complex<float>, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()


Expand Down

0 comments on commit e417c92

Please sign in to comment.