From e417c920cd77d8d10a8d8cef80df7ea1a6bbf576 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 13 Jan 2021 13:46:10 -0800 Subject: [PATCH] Adopt new dispatch macro so it compiles with pytorch 1.8 --- csrc/cuda/butterfly_cuda.cu | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/csrc/cuda/butterfly_cuda.cu b/csrc/cuda/butterfly_cuda.cu index 84488f0..217373b 100644 --- a/csrc/cuda/butterfly_cuda.cu +++ b/csrc/cuda/butterfly_cuda.cu @@ -5,18 +5,20 @@ #include "map.h" // Only support float (not double) for now to speed up compilation time +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h #undef AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES -#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - switch (_st) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ - } \ +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \ + switch (_st) { \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::ComplexFloat, c10::complex, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ + } \ }()