diff --git a/aten/src/ATen/native/MaxPooling.cpp b/aten/src/ATen/native/MaxPooling.cpp new file mode 100644 index 00000000000..a0298ea937d --- /dev/null +++ b/aten/src/ATen/native/MaxPooling.cpp @@ -0,0 +1,110 @@ +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +DEFINE_DISPATCH(max_pool1d_stub); + +namespace { + +Tensor max_pool1d_impl( + const Tensor& self, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { + NoNamesGuard guard; + + TORCH_CHECK( + self.dim() == 2 || self.dim() == 3, + "max_pool1d() input tensor must have 2 or 3 dimensions but got ", + self.dim()); + TORCH_CHECK( + kernel_size.size() == 1, + "max_pool1d() kernel_size must be an int or int list of size 1 but got size ", + kernel_size.size()); + TORCH_CHECK( + stride.size() == 0 || stride.size() == 1, + "max_pool1d() stride must be None, an int or int list of size 1 but got size ", + stride.size()); + TORCH_CHECK( + padding.size() == 1, + "max_pool1d() padding must be an int or int list of size 1 but got size ", + padding.size()); + TORCH_CHECK( + dilation.size() == 1, + "max_pool1d() dilation must be an int or int list of size 1 but got size ", + dilation.size()); + + // If stride=None then set it to kernel_size + if (stride.empty()) { + stride = kernel_size; + } + + const int64_t NB = self.dim() == 3 ? self.size(-3) : 1; + const int64_t NC = self.size(-2); + const int64_t IW = self.size(-1); + const int64_t KW = kernel_size[0]; + const int64_t SJ = stride[0]; + const int64_t PJ = padding[0]; + const int64_t DJ = dilation[0]; + + TORCH_CHECK( + KW > 0, + "max_pool1d() kernel_size must be greater than zero, but got ", + KW); + TORCH_CHECK( + SJ > 0, "max_pool1d() stride must be greater than zero, but got ", SJ); + TORCH_CHECK( + PJ >= 0, "max_pool1d() padding must be non-negative, but got ", PJ); + TORCH_CHECK( + PJ <= KW / 2, + "max_pool1d() padding should be at most half of kernel size, but got padding=", + PJ, + " and kernel_size=", + KW); + TORCH_CHECK( + DJ > 0, "max_pool1d() dilation must be greater than zero, but got ", DJ); + + const int64_t OW = pooling_output_shape(IW, KW, PJ, SJ, DJ, ceil_mode); + TORCH_CHECK(OW >= 0, "max_pool1d() Invalid computed output size: ", OW); + Tensor output = at::empty({NB, NC, OW}, self.options()); + + PoolingParams1D params{NB, NC, IW, OW, KW, SJ, PJ, DJ}; + max_pool1d_stub(self.device().type(), output, self, params); + + if (self.dim() == 2) { + output.squeeze_(0); + } + + guard.reset(); + namedinference::propagate_names(output, self); + + return output; +} + +} // namespace + +Tensor max_pool1d( + const Tensor& self, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) { + if (self.requires_grad() || !self.device().is_cpu()) { + // Needs indices for grad and with_indices defines CUDA dispatch + return std::get<0>(at::max_pool1d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode)); + } + return max_pool1d_impl( + self, kernel_size, stride, padding, dilation, ceil_mode); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/MaxPooling.h b/aten/src/ATen/native/MaxPooling.h new file mode 100644 index 00000000000..c429c8e667b --- /dev/null +++ b/aten/src/ATen/native/MaxPooling.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace native { + +// TODO(Heitor) Template by dimension +struct PoolingParams1D { + int64_t NB; // Number of batches + int64_t NC; // Number of channels + int64_t IW; // Input width + int64_t OW; // Output width + int64_t KW; // Kernel width + int64_t SJ; // Column stride + int64_t PJ; // Column padding + int64_t DJ; // Column dilation + + // Return index of input element for the given kernel and output index + inline int64_t index(int64_t kj, int64_t oj) const { + return oj * SJ + kj * DJ - PJ; + } + + // Return index of first output within bounds for this kernel index + inline int64_t valid_output_start(int64_t kj) const { + int64_t ij = index(kj, 0);; + return ij < 0 ? at::divup(-ij, SJ) : 0; + } + + // Return index one past last output within bounds for this kernel index + inline int64_t valid_output_end(int64_t kj) const { + int64_t ij = index(kj, OW - 1); + return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW; + } +}; + +using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&); + +DECLARE_DISPATCH(pooling_fn, max_pool1d_stub); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/Pooling.cpp b/aten/src/ATen/native/Pooling.cpp index 89024a23fba..750089e8d4f 100644 --- a/aten/src/ATen/native/Pooling.cpp +++ b/aten/src/ATen/native/Pooling.cpp @@ -107,18 +107,6 @@ Tensor avg_pool1d( return output.squeeze(2); } -Tensor max_pool1d( - const Tensor& self, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode) { - auto output_and_indices = at::max_pool1d_with_indices( - self, kernel_size, stride, padding, dilation, ceil_mode); - return std::get<0>(output_and_indices); -} - Tensor max_pool2d( const Tensor& self, IntArrayRef kernel_size, diff --git a/aten/src/ATen/native/cpu/MaxPooling.cpp b/aten/src/ATen/native/cpu/MaxPooling.cpp new file mode 100644 index 00000000000..35575091dcd --- /dev/null +++ b/aten/src/ATen/native/cpu/MaxPooling.cpp @@ -0,0 +1,57 @@ +#include +#include +#include +#include + +namespace at { +namespace native { + +namespace { + +template +inline void max_pool1d_kernel( + scalar_t* C10_RESTRICT op, + const scalar_t* C10_RESTRICT ip, + const PoolingParams1D& p) { + for (int64_t kj = 0; kj < p.KW; ++kj) { + int64_t oj = p.valid_output_start(kj); + int64_t oe = p.valid_output_end(kj); + int64_t ij = p.index(kj, oj); + for (; oj < oe; ++oj, ij += p.SJ) { + scalar_t val = ip[ij]; + bool update_max = std::isnan(val) || op[oj] < val; + op[oj] = update_max ? val : op[oj]; + } + } +} + +void max_pool1d_impl( + Tensor& output, + const Tensor& input, + const PoolingParams1D& p) { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool1d_impl", [&] { + scalar_t* const OP = output.data_ptr(); + const scalar_t* const IP = input.contiguous().data_ptr(); + + // Value used for padding + constexpr scalar_t FILL = std::numeric_limits::has_infinity + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + + at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) { + for (int64_t it = begin; it < end; ++it) { + scalar_t* op = OP + it * p.OW; + const scalar_t* ip = IP + it * p.IW; + std::fill_n(op, p.OW, FILL); + max_pool1d_kernel(op, ip, p); + } + }); + }); +} + +} // namespace + +REGISTER_DISPATCH(max_pool1d_stub, &max_pool1d_impl); + +} // namespace native +} // namespace at diff --git a/test/test_nn.py b/test/test_nn.py index de36abc0ff3..48f3f62459e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -42,7 +42,7 @@ module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \ ctcloss_reference, new_module_tests from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \ - dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, \ + dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, largeCUDATensorTest, onlyOnCPUAndCUDA, \ deviceCountAtLeast, expectedAlertNondeterministic, largeTensorTest from torch.nn import MultiheadAttention @@ -9853,6 +9853,70 @@ def helper(n, c, h, w, kernel_size, stride=None, helper(10, 512, 31, 31, 3, stride=2) helper(1, 129, 8, 8, 3, stride=2) + @onlyCPU + @dtypes(torch.float) + def test_max_pool1d_errors(self, device, dtype): + def check(x, args, message): + model = torch.nn.MaxPool1d(*args) + with self.assertRaisesRegex(RuntimeError, r'max_pool1d\(\) ' + message): + model(torch.tensor(x, device=device, dtype=dtype)) + + # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode) + check(0, (1,), "input tensor must have 2 or 3 dimensions but got 0") + check([], (1,), "input tensor must have 2 or 3 dimensions but got 1") + check([[]], (1, 0), "stride must be greater than zero, but got 0") + check([[]], (1, 1, -1), "padding must be non-negative, but got -1") + check([[]], (1, 1, 2), "padding should be at most half of kernel size, but got padding=2 and kernel_size=1") + check([[]], (1, 1, 0, 0), "dilation must be greater than zero, but got 0") + check([[]], (5, 1, 0, 1), "Invalid computed output size: -4") + + @onlyCPU + @dtypes(torch.float, torch.double) + def test_max_pool1d_corner_cases(self, device, dtype): + def check(x, args, expected): + model = torch.nn.MaxPool1d(*args) + tensor = torch.tensor(x, device=device, dtype=dtype) + self.assertEqual(model(tensor), torch.tensor(expected, device=device, dtype=dtype)) + + # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode) + check([[]], (1, None, 0, 1, False, False), [[]]) + check([[[]]], (1, None, 0, 1, False, False), [[[]]]) + check([[[]]], (2, 1, 1, 2, False, True), [[[]]]) + check([[1]], (1, None, 0, 1, False, False), [[1]]) + check([[1]], (2, None, 1, 2, False, False), [[float('-inf')]]) + check([[1], [1]], (2, None, 1, 2, False, False), [[float('-inf')], [float('-inf')]]) + check([[1, 2]], (2, 1, 1, 2, False, False), [[2, 1]]) + check([[1, 2]], (2, 2, 1, 2, False, True), [[2, 2]]) + + empty_tensor = torch.empty((2, 0, 1), dtype=torch.float32) + check(empty_tensor, (1, None, 0, 1, False, False), empty_tensor) + + @onlyCPU + @dtypes(torch.float, torch.double) + def test_max_pool1d(self, device, dtype): + # FIXME For now compare against max_pool1d with indices + def check(x, *args, **kwargs): + model = torch.nn.MaxPool1d(*args, **kwargs) + ref_model = torch.nn.MaxPool1d(*args, **kwargs, return_indices=True) + tensor = torch.tensor(x, device=device, dtype=dtype) + self.assertEqual(model(tensor), ref_model(tensor)[0]) + + sizes = [random.sample(range(8, 128), 3) for _ in range(3)] + kernel_sizes = random.sample(range(1, 5), 3) + strides = random.sample(range(1, 5), 3) + dilations = random.sample(range(1, 5), 3) + ceil_modes = [True, False] + + for size, kernel_size, stride, dilation, ceil_mode in \ + itertools.product(sizes, kernel_sizes, strides, dilations, ceil_modes): + padding = random.sample(range(0, math.floor(kernel_size / 2) + 1), 1) + check(torch.randn(size), kernel_size, stride, padding, dilation, ceil_mode=ceil_mode) + + # Non-contiguous test + tensor = torch.randn(5, 151, 33)[::2, ::3, ::2] + check(tensor, 3, 2, 1, 2, ceil_mode=True) + check(tensor.transpose(1, 2), 3, 2, 1, 2, ceil_mode=True) + @onlyCUDA def test_max_pool2d(self, device): def helper(n, c, h, w, ks): @@ -11371,15 +11435,22 @@ def test_max_pool_nan_inf(self, device, dtype): for num_dim in [1, 2, 3]: fn_name = '{}max_pool{}d'.format(adaptive, num_dim) fn = getattr(F, fn_name) + x = torch.full([1, 1] + num_dim * [3], nan, device=device, dtype=dtype, requires_grad=True) res = fn(x, 1 if adaptive else 3) res.backward(torch.randn_like(res)) self.assertTrue(math.isnan(res.item())) + x.requires_grad_(False) + res = fn(x, 1 if adaptive else 3) + self.assertTrue(math.isnan(res.item())) x2 = torch.full([1, 1] + num_dim * [3], -inf, device=device, dtype=dtype, requires_grad=True) res2 = fn(x2, 1 if adaptive else 3) res2.backward(torch.randn_like(res2)) self.assertTrue(math.isinf(res2.item())) + x2.requires_grad_(False) + res2 = fn(x2, 1 if adaptive else 3) + self.assertTrue(math.isinf(res2.item())) @onlyOnCPUAndCUDA @dtypes(torch.float, torch.double) @@ -11416,12 +11487,12 @@ def test_pooling_zero_stride(self, device): fn_name = '{}_pool{}d'.format(op, num_dim) fn = getattr(F, fn_name) x = torch.ones([1, 2] + num_dim * [4], device=device, dtype=torch.float) - self.assertRaisesRegex(RuntimeError, "stride should not be zero", + self.assertRaisesRegex(RuntimeError, r"stride should not be zero|stride must be greater than zero", lambda: fn(x, kernel_size=2, stride=0)) fn_module_name = '{}Pool{}d'.format(op.title(), num_dim) fn_module = getattr(nn, fn_module_name)(kernel_size=2, stride=0) - self.assertRaisesRegex(RuntimeError, "stride should not be zero", + self.assertRaisesRegex(RuntimeError, r"stride should not be zero|stride must be greater than zero", lambda: fn_module(x)) @dtypesIfCUDA(*ALL_TENSORTYPES2) @@ -11444,6 +11515,10 @@ def test_pool_invalid_size(self, device, dtype): for op in ('max', 'avg'): for num_dim in [1, 2, 3]: fn_name = '{}_pool{}d'.format(op, num_dim) + if op == 'max': + # New implementation without indices supports empty tensors + # TODO(Heitor) change once with_indices code is updated + fn_name += '_with_indices' fn = getattr(F, fn_name) # use a configuration that gives zero outputs only # when doing a correct floor division by the stride diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 4ce555250fe..573b71aa9b9 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -46,18 +46,19 @@ class MaxPool1d(_MaxPoolNd): out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1} input(N_i, C_j, stride \times k + m) - If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides - for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. - It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides + for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the + sliding window. This `link`_ has a nice visualization of the pooling parameters. Args: - kernel_size: the size of the window to take a max over - stride: the stride of the window. Default value is :attr:`kernel_size` - padding: implicit zero padding to be added on both sides - dilation: a parameter that controls the stride of elements in the window - return_indices: if ``True``, will return the max indices along with the outputs. + kernel_size: The size of the sliding window, must be > 0. + stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`. + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + return_indices: If ``True``, will return the argmax along with the max values. Useful for :class:`torch.nn.MaxUnpool1d` later - ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. Shape: - Input: :math:`(N, C, L_{in})`