diff --git a/csrc/cuda/butterfly_cuda.cu b/csrc/cuda/butterfly_cuda.cu index 84488f0..217373b 100644 --- a/csrc/cuda/butterfly_cuda.cu +++ b/csrc/cuda/butterfly_cuda.cu @@ -5,18 +5,20 @@ #include "map.h" // Only support float (not double) for now to speed up compilation time +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h #undef AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES -#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ }()