Skip to content

Commit

Permalink
MaxPool1d without indices optimization (#43745)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#43745

This is part of a larger effort to refactor and optimize the pooling code. Previously I started working on MaxPool2d here pytorch/pytorch#43267 but since it uses MaxPool1d as a subroutine, it made more sense to work on 1D first and get it tested and optimized and then move up to 2D and then 3D.

Below are some benchmarking results, the python script I used is under the results.

## Benchmarking
```
Name (time in us)                            Min                   Max                Mean             StdDev              Median                 IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_googlenet[(3, 2, 0, 1, 0)-new]      79.7659 (1.03)     1,059.6327 (5.32)      90.6280 (1.01)     19.1196 (1.41)      84.2176 (1.01)       2.4289 (1.0)     1079;2818       11.0341 (0.99)       9055           1
test_googlenet[(3, 2, 0, 1, 0)-old]     505.1531 (6.55)       830.8962 (4.17)     563.4763 (6.29)     65.3974 (4.81)     538.3361 (6.43)      80.5371 (33.16)      242;99        1.7747 (0.16)       1742           1
test_googlenet[(3, 2, 0, 1, 1)-new]      80.2949 (1.04)       233.0020 (1.17)      97.6498 (1.09)     19.1228 (1.41)      89.2282 (1.07)      18.5743 (7.65)     1858;741       10.2407 (0.92)       9587           1
test_googlenet[(3, 2, 0, 1, 1)-old]     513.5350 (6.66)       977.4677 (4.91)     594.4559 (6.63)     69.9372 (5.15)     577.9080 (6.90)      79.8218 (32.86)      503;84        1.6822 (0.15)       1675           1
test_googlenet[(3, 2, 1, 1, 0)-new]      77.1061 (1.0)        199.1168 (1.0)       89.6529 (1.0)      13.5864 (1.0)       83.7557 (1.0)        7.5139 (3.09)    1419;1556       11.1541 (1.0)        7434           1
test_googlenet[(3, 2, 1, 1, 0)-old]     543.6055 (7.05)       964.5708 (4.84)     636.9867 (7.11)     84.0732 (6.19)     616.7777 (7.36)     100.4562 (41.36)      434;65        1.5699 (0.14)       1552           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_inception[(3, 2, 0, 1, 0)-new]      84.5827 (1.00)       184.2827 (1.0)       90.5438 (1.01)      9.6324 (1.0)       89.3027 (1.05)      4.5672 (1.03)      637;759       11.0444 (0.99)       6274           1
test_inception[(3, 2, 0, 1, 0)-old]     641.2268 (7.59)     1,704.8977 (9.25)     686.9383 (7.65)     57.2499 (5.94)     682.5905 (8.01)     58.3753 (13.17)       86;21        1.4557 (0.13)        802           1
test_inception[(3, 2, 0, 1, 1)-new]      84.5008 (1.0)      1,093.6335 (5.93)      89.8233 (1.0)      14.0443 (1.46)      85.2682 (1.0)       4.4331 (1.0)      802;1106       11.1330 (1.0)        9190           1
test_inception[(3, 2, 0, 1, 1)-old]     643.7078 (7.62)       851.4188 (4.62)     687.4905 (7.65)     41.1116 (4.27)     685.1386 (8.04)     60.2733 (13.60)      286;14        1.4546 (0.13)       1300           1
test_inception[(3, 2, 1, 1, 0)-new]     106.0739 (1.26)       258.5649 (1.40)     115.3597 (1.28)     17.5436 (1.82)     106.9643 (1.25)      5.5470 (1.25)     894;1402        8.6685 (0.78)       7635           1
test_inception[(3, 2, 1, 1, 0)-old]     651.0504 (7.70)       955.2278 (5.18)     698.0295 (7.77)     45.5097 (4.72)     692.8109 (8.13)     64.6794 (14.59)      145;15        1.4326 (0.13)        909           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_large_batch_size[new]       2.9608 (1.0)        5.1127 (1.0)        3.3096 (1.0)      0.1936 (1.0)        3.3131 (1.0)      0.2093 (1.0)          71;6  302.1515 (1.0)         297           1
test_large_batch_size[old]     130.6583 (44.13)    152.9521 (29.92)    137.1385 (41.44)    7.4352 (38.40)    135.1784 (40.80)    5.1358 (24.53)         1;1    7.2919 (0.02)          7           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_large_channel_size[new]      2.9696 (1.0)       5.5595 (1.0)       3.5997 (1.0)      0.5836 (1.0)       3.3497 (1.0)      0.3445 (1.0)         58;54  277.8014 (1.0)         277           1
test_large_channel_size[old]     19.6838 (6.63)     22.6637 (4.08)     21.1775 (5.88)     0.8610 (1.48)     21.3739 (6.38)     1.4930 (4.33)         13;0   47.2199 (0.17)         36           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_large_width[new]      1.7714 (1.0)       2.4104 (1.0)       1.8988 (1.0)      0.0767 (1.0)       1.8911 (1.0)      0.0885 (1.0)         86;13  526.6454 (1.0)         373           1
test_large_width[old]     19.5708 (11.05)    22.8755 (9.49)     20.7987 (10.95)    0.7009 (9.14)     20.6623 (10.93)    0.8584 (9.70)         14;1   48.0799 (0.09)         46           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_multithreaded[new]      15.0560 (1.0)       24.2891 (1.0)       16.1627 (1.0)      1.5657 (1.0)       15.7182 (1.0)      0.7598 (1.0)           4;6  61.8709 (1.0)          65           1
test_multithreaded[old]     115.7614 (7.69)     120.9670 (4.98)     118.3004 (7.32)     1.6259 (1.04)     118.4164 (7.53)     1.9613 (2.58)          2;0   8.4531 (0.14)          8           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Legend:
  Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
  OPS: Operations Per Second, computed as 1 / Mean
```

### Benchmarking script
To run the benchmark make sure you have pytest-benchmark installed with `pip install pytest-benchmark` and use the following command: `pytest benchmark.py --benchmark-sort='name'`

```
import torch
import pytest

def _test_speedup(benchmark, batches=1, channels=32, width=32,
                  kernel_size=2, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
    torch.set_num_threads(1)
    x = torch.randn((batches, channels, width))
    model = torch.nn.MaxPool1d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
    benchmark(model, x)

pytest.mark.benchmark(group="inception")
pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
pytest.mark.parametrize("params", [(3, 2), (3, 2, 0, 1, True), (3, 2, 1)],
                         ids=["(3, 2, 0, 1, 0)",
                              "(3, 2, 0, 1, 1)",
                              "(3, 2, 1, 1, 0)"])
def test_inception(benchmark, params, return_indices):
    _test_speedup(benchmark, 10, 64, 147, *params, return_indices=return_indices)

pytest.mark.benchmark(group="googlenet")
pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
pytest.mark.parametrize("params", [(3, 2), (3, 2, 0, 1, True), (3, 2, 1)],
                         ids=["(3, 2, 0, 1, 0)",
                              "(3, 2, 0, 1, 1)",
                              "(3, 2, 1, 1, 0)"])
def test_googlenet(benchmark, params, return_indices):
    _test_speedup(benchmark, 10, 64, 112, *params, return_indices=return_indices)

pytest.mark.benchmark(group="large batch size")
pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_large_batch_size(benchmark, return_indices):
    _test_speedup(benchmark, 100000, 1, 32, return_indices=return_indices)

pytest.mark.benchmark(group="large channel size")
pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_large_channel_size(benchmark, return_indices):
    _test_speedup(benchmark, 1, 100000, 32, return_indices=return_indices)

pytest.mark.benchmark(group="large width")
pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_large_width(benchmark, return_indices):
    _test_speedup(benchmark, 1, 32, 100000, return_indices=return_indices)

pytest.mark.benchmark(group="multithreading")
pytest.mark.parametrize("return_indices", [True, False], ids=["old", "new"])
def test_multithreaded(benchmark, return_indices):
    x = torch.randn((40, 10000, 32))
    model = torch.nn.MaxPool1d(2, return_indices=return_indices)
    benchmark(model, x)
```

## Discussion

The new algorithm is on average 7x faster than the old one. But because the old algorithm had many issues with how it parallelized the code and made use of the cache, one can come up with input parameters (like large batch size) that will make the new algorithm much faster than the original one.

Test Plan: Imported from OSS

Reviewed By: glaringlee

Differential Revision: D23425348

Pulled By: heitorschueroff

fbshipit-source-id: 3fa3f9b8e71200da48424a95510124a83f50d7b2
  • Loading branch information
heitorschueroff authored and facebook-github-bot committed Sep 1, 2020
1 parent a044c03 commit 13a48ac
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 24 deletions.
110 changes: 110 additions & 0 deletions aten/src/ATen/native/MaxPooling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#include <ATen/ATen.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/MaxPooling.h>
#include <ATen/native/Pool.h>

namespace at {
namespace native {

DEFINE_DISPATCH(max_pool1d_stub);

namespace {

Tensor max_pool1d_impl(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
NoNamesGuard guard;

TORCH_CHECK(
self.dim() == 2 || self.dim() == 3,
"max_pool1d() input tensor must have 2 or 3 dimensions but got ",
self.dim());
TORCH_CHECK(
kernel_size.size() == 1,
"max_pool1d() kernel_size must be an int or int list of size 1 but got size ",
kernel_size.size());
TORCH_CHECK(
stride.size() == 0 || stride.size() == 1,
"max_pool1d() stride must be None, an int or int list of size 1 but got size ",
stride.size());
TORCH_CHECK(
padding.size() == 1,
"max_pool1d() padding must be an int or int list of size 1 but got size ",
padding.size());
TORCH_CHECK(
dilation.size() == 1,
"max_pool1d() dilation must be an int or int list of size 1 but got size ",
dilation.size());

// If stride=None then set it to kernel_size
if (stride.empty()) {
stride = kernel_size;
}

const int64_t NB = self.dim() == 3 ? self.size(-3) : 1;
const int64_t NC = self.size(-2);
const int64_t IW = self.size(-1);
const int64_t KW = kernel_size[0];
const int64_t SJ = stride[0];
const int64_t PJ = padding[0];
const int64_t DJ = dilation[0];

TORCH_CHECK(
KW > 0,
"max_pool1d() kernel_size must be greater than zero, but got ",
KW);
TORCH_CHECK(
SJ > 0, "max_pool1d() stride must be greater than zero, but got ", SJ);
TORCH_CHECK(
PJ >= 0, "max_pool1d() padding must be non-negative, but got ", PJ);
TORCH_CHECK(
PJ <= KW / 2,
"max_pool1d() padding should be at most half of kernel size, but got padding=",
PJ,
" and kernel_size=",
KW);
TORCH_CHECK(
DJ > 0, "max_pool1d() dilation must be greater than zero, but got ", DJ);

const int64_t OW = pooling_output_shape(IW, KW, PJ, SJ, DJ, ceil_mode);
TORCH_CHECK(OW >= 0, "max_pool1d() Invalid computed output size: ", OW);
Tensor output = at::empty({NB, NC, OW}, self.options());

PoolingParams1D params{NB, NC, IW, OW, KW, SJ, PJ, DJ};
max_pool1d_stub(self.device().type(), output, self, params);

if (self.dim() == 2) {
output.squeeze_(0);
}

guard.reset();
namedinference::propagate_names(output, self);

return output;
}

} // namespace

Tensor max_pool1d(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
if (self.requires_grad() || !self.device().is_cpu()) {
// Needs indices for grad and with_indices defines CUDA dispatch
return std::get<0>(at::max_pool1d_with_indices(
self, kernel_size, stride, padding, dilation, ceil_mode));
}
return max_pool1d_impl(
self, kernel_size, stride, padding, dilation, ceil_mode);
}

} // namespace native
} // namespace at
44 changes: 44 additions & 0 deletions aten/src/ATen/native/MaxPooling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/native/DispatchStub.h>

namespace at {
namespace native {

// TODO(Heitor) Template by dimension
struct PoolingParams1D {
int64_t NB; // Number of batches
int64_t NC; // Number of channels
int64_t IW; // Input width
int64_t OW; // Output width
int64_t KW; // Kernel width
int64_t SJ; // Column stride
int64_t PJ; // Column padding
int64_t DJ; // Column dilation

// Return index of input element for the given kernel and output index
inline int64_t index(int64_t kj, int64_t oj) const {
return oj * SJ + kj * DJ - PJ;
}

// Return index of first output within bounds for this kernel index
inline int64_t valid_output_start(int64_t kj) const {
int64_t ij = index(kj, 0);;
return ij < 0 ? at::divup(-ij, SJ) : 0;
}

// Return index one past last output within bounds for this kernel index
inline int64_t valid_output_end(int64_t kj) const {
int64_t ij = index(kj, OW - 1);
return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
}
};

using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);

DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);

} // namespace native
} // namespace at
12 changes: 0 additions & 12 deletions aten/src/ATen/native/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,6 @@ Tensor avg_pool1d(
return output.squeeze(2);
}

Tensor max_pool1d(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
auto output_and_indices = at::max_pool1d_with_indices(
self, kernel_size, stride, padding, dilation, ceil_mode);
return std::get<0>(output_and_indices);
}

Tensor max_pool2d(
const Tensor& self,
IntArrayRef kernel_size,
Expand Down
57 changes: 57 additions & 0 deletions aten/src/ATen/native/cpu/MaxPooling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec256/vec256.h>
#include <ATen/native/MaxPooling.h>

namespace at {
namespace native {

namespace {

template <typename scalar_t>
inline void max_pool1d_kernel(
scalar_t* C10_RESTRICT op,
const scalar_t* C10_RESTRICT ip,
const PoolingParams1D& p) {
for (int64_t kj = 0; kj < p.KW; ++kj) {
int64_t oj = p.valid_output_start(kj);
int64_t oe = p.valid_output_end(kj);
int64_t ij = p.index(kj, oj);
for (; oj < oe; ++oj, ij += p.SJ) {
scalar_t val = ip[ij];
bool update_max = std::isnan(val) || op[oj] < val;
op[oj] = update_max ? val : op[oj];
}
}
}

void max_pool1d_impl(
Tensor& output,
const Tensor& input,
const PoolingParams1D& p) {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool1d_impl", [&] {
scalar_t* const OP = output.data_ptr<scalar_t>();
const scalar_t* const IP = input.contiguous().data_ptr<scalar_t>();

// Value used for padding
constexpr scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
? -std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::lowest();

at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
for (int64_t it = begin; it < end; ++it) {
scalar_t* op = OP + it * p.OW;
const scalar_t* ip = IP + it * p.IW;
std::fill_n(op, p.OW, FILL);
max_pool1d_kernel(op, ip, p);
}
});
});
}

} // namespace

REGISTER_DISPATCH(max_pool1d_stub, &max_pool1d_impl);

} // namespace native
} // namespace at
81 changes: 78 additions & 3 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
ctcloss_reference, new_module_tests
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \
dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, \
dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, largeCUDATensorTest, onlyOnCPUAndCUDA, \
deviceCountAtLeast, expectedAlertNondeterministic, largeTensorTest
from torch.nn import MultiheadAttention
Expand Down Expand Up @@ -9853,6 +9853,70 @@ def helper(n, c, h, w, kernel_size, stride=None,
helper(10, 512, 31, 31, 3, stride=2)
helper(1, 129, 8, 8, 3, stride=2)

@onlyCPU
@dtypes(torch.float)
def test_max_pool1d_errors(self, device, dtype):
def check(x, args, message):
model = torch.nn.MaxPool1d(*args)
with self.assertRaisesRegex(RuntimeError, r'max_pool1d\(\) ' + message):
model(torch.tensor(x, device=device, dtype=dtype))

# Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
check(0, (1,), "input tensor must have 2 or 3 dimensions but got 0")
check([], (1,), "input tensor must have 2 or 3 dimensions but got 1")
check([[]], (1, 0), "stride must be greater than zero, but got 0")
check([[]], (1, 1, -1), "padding must be non-negative, but got -1")
check([[]], (1, 1, 2), "padding should be at most half of kernel size, but got padding=2 and kernel_size=1")
check([[]], (1, 1, 0, 0), "dilation must be greater than zero, but got 0")
check([[]], (5, 1, 0, 1), "Invalid computed output size: -4")

@onlyCPU
@dtypes(torch.float, torch.double)
def test_max_pool1d_corner_cases(self, device, dtype):
def check(x, args, expected):
model = torch.nn.MaxPool1d(*args)
tensor = torch.tensor(x, device=device, dtype=dtype)
self.assertEqual(model(tensor), torch.tensor(expected, device=device, dtype=dtype))

# Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
check([[]], (1, None, 0, 1, False, False), [[]])
check([[[]]], (1, None, 0, 1, False, False), [[[]]])
check([[[]]], (2, 1, 1, 2, False, True), [[[]]])
check([[1]], (1, None, 0, 1, False, False), [[1]])
check([[1]], (2, None, 1, 2, False, False), [[float('-inf')]])
check([[1], [1]], (2, None, 1, 2, False, False), [[float('-inf')], [float('-inf')]])
check([[1, 2]], (2, 1, 1, 2, False, False), [[2, 1]])
check([[1, 2]], (2, 2, 1, 2, False, True), [[2, 2]])

empty_tensor = torch.empty((2, 0, 1), dtype=torch.float32)
check(empty_tensor, (1, None, 0, 1, False, False), empty_tensor)

@onlyCPU
@dtypes(torch.float, torch.double)
def test_max_pool1d(self, device, dtype):
# FIXME For now compare against max_pool1d with indices
def check(x, *args, **kwargs):
model = torch.nn.MaxPool1d(*args, **kwargs)
ref_model = torch.nn.MaxPool1d(*args, **kwargs, return_indices=True)
tensor = torch.tensor(x, device=device, dtype=dtype)
self.assertEqual(model(tensor), ref_model(tensor)[0])

sizes = [random.sample(range(8, 128), 3) for _ in range(3)]
kernel_sizes = random.sample(range(1, 5), 3)
strides = random.sample(range(1, 5), 3)
dilations = random.sample(range(1, 5), 3)
ceil_modes = [True, False]

for size, kernel_size, stride, dilation, ceil_mode in \
itertools.product(sizes, kernel_sizes, strides, dilations, ceil_modes):
padding = random.sample(range(0, math.floor(kernel_size / 2) + 1), 1)
check(torch.randn(size), kernel_size, stride, padding, dilation, ceil_mode=ceil_mode)

# Non-contiguous test
tensor = torch.randn(5, 151, 33)[::2, ::3, ::2]
check(tensor, 3, 2, 1, 2, ceil_mode=True)
check(tensor.transpose(1, 2), 3, 2, 1, 2, ceil_mode=True)

@onlyCUDA
def test_max_pool2d(self, device):
def helper(n, c, h, w, ks):
Expand Down Expand Up @@ -11371,15 +11435,22 @@ def test_max_pool_nan_inf(self, device, dtype):
for num_dim in [1, 2, 3]:
fn_name = '{}max_pool{}d'.format(adaptive, num_dim)
fn = getattr(F, fn_name)

x = torch.full([1, 1] + num_dim * [3], nan, device=device, dtype=dtype, requires_grad=True)
res = fn(x, 1 if adaptive else 3)
res.backward(torch.randn_like(res))
self.assertTrue(math.isnan(res.item()))
x.requires_grad_(False)
res = fn(x, 1 if adaptive else 3)
self.assertTrue(math.isnan(res.item()))

x2 = torch.full([1, 1] + num_dim * [3], -inf, device=device, dtype=dtype, requires_grad=True)
res2 = fn(x2, 1 if adaptive else 3)
res2.backward(torch.randn_like(res2))
self.assertTrue(math.isinf(res2.item()))
x2.requires_grad_(False)
res2 = fn(x2, 1 if adaptive else 3)
self.assertTrue(math.isinf(res2.item()))

@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.double)
Expand Down Expand Up @@ -11416,12 +11487,12 @@ def test_pooling_zero_stride(self, device):
fn_name = '{}_pool{}d'.format(op, num_dim)
fn = getattr(F, fn_name)
x = torch.ones([1, 2] + num_dim * [4], device=device, dtype=torch.float)
self.assertRaisesRegex(RuntimeError, "stride should not be zero",
self.assertRaisesRegex(RuntimeError, r"stride should not be zero|stride must be greater than zero",
lambda: fn(x, kernel_size=2, stride=0))

fn_module_name = '{}Pool{}d'.format(op.title(), num_dim)
fn_module = getattr(nn, fn_module_name)(kernel_size=2, stride=0)
self.assertRaisesRegex(RuntimeError, "stride should not be zero",
self.assertRaisesRegex(RuntimeError, r"stride should not be zero|stride must be greater than zero",
lambda: fn_module(x))

@dtypesIfCUDA(*ALL_TENSORTYPES2)
Expand All @@ -11444,6 +11515,10 @@ def test_pool_invalid_size(self, device, dtype):
for op in ('max', 'avg'):
for num_dim in [1, 2, 3]:
fn_name = '{}_pool{}d'.format(op, num_dim)
if op == 'max':
# New implementation without indices supports empty tensors
# TODO(Heitor) change once with_indices code is updated
fn_name += '_with_indices'
fn = getattr(F, fn_name)
# use a configuration that gives zero outputs only
# when doing a correct floor division by the stride
Expand Down
19 changes: 10 additions & 9 deletions torch/nn/modules/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,19 @@ class MaxPool1d(_MaxPoolNd):
out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
input(N_i, C_j, stride \times k + m)
If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
sliding window. This `link`_ has a nice visualization of the pooling parameters.
Args:
kernel_size: the size of the window to take a max over
stride: the stride of the window. Default value is :attr:`kernel_size`
padding: implicit zero padding to be added on both sides
dilation: a parameter that controls the stride of elements in the window
return_indices: if ``True``, will return the max indices along with the outputs.
kernel_size: The size of the sliding window, must be > 0.
stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`.
padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
dilation: The stride between elements within a sliding window, must be > 0.
return_indices: If ``True``, will return the argmax along with the max values.
Useful for :class:`torch.nn.MaxUnpool1d` later
ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
ensures that every element in the input tensor is covered by a sliding window.
Shape:
- Input: :math:`(N, C, L_{in})`
Expand Down

0 comments on commit 13a48ac

Please sign in to comment.