forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_custom_op_testing.py
328 lines (263 loc) · 10.6 KB
/
test_custom_op_testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# Owner(s): ["module: custom-operators"]
from torch.testing._internal.common_utils import * # noqa: F403
from torch.testing._internal.common_device_type import * # noqa: F403
from torch.testing._internal.optests.compile_check import operator_compile_check
from torch.testing._internal.custom_op_db import custom_op_db
from torch._custom_op.impl import custom_op
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
class TestCustomOpTesting(TestCase):
def setUp(self):
self.test_ns = '_test_custom_op'
self.libraries = []
def tearDown(self):
import torch._custom_op
keys = list(torch._custom_op.impl.global_registry.keys())
for key in keys:
if not key.startswith(f'{self.test_ns}::'):
continue
torch._custom_op.impl.global_registry[key]._destroy()
if hasattr(torch.ops, self.test_ns):
del torch.ops._test_custom_op
for lib in self.libraries:
del lib.m
del self.libraries
def ns(self):
return getattr(torch.ops, self.test_ns)
def lib(self):
result = torch.library.Library(self.test_ns, 'FRAGMENT')
self.libraries.append(result)
return result
def test_incorrect_schema_mutation(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
guard = torch._C._AutoDispatchBelowAutograd()
try:
return op(x)
finally:
del guard
@staticmethod
def backward(ctx, gx):
return gx
def foo_impl(x):
x.sin_()
return x.clone()
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_impl, "CUDA")
def f(x):
x = x.clone()
v = x.view_as(x)
y = op(v)
return x
x = torch.tensor(3.14159 / 3, requires_grad=True, device=device)
with self.assertRaisesRegex(
RuntimeError,
'Argument x is not defined as mutable but was mutated'):
operator_compile_check(f, (x,), {})
def test_incorrect_schema_view(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
with torch._C._AutoDispatchBelowAutograd():
with torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)):
return op(x)
@staticmethod
def backward(ctx, gx):
return gx
def foo_impl(x):
return x.view_as(x)
def foo_meta(x):
return x.view_as(x)
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_meta, "Meta")
def f(x):
x = x.clone()
y = op(x)
x.sin_()
return y
x = torch.tensor(3.14159 / 3, requires_grad=True)
with self.assertRaisesRegex(
RuntimeError,
'Argument x is not defined to alias output but was aliasing'):
operator_compile_check(f, (x,), {})
def test_missing_abstract_impl(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
with torch._C._AutoDispatchBelowAutograd():
return op(x)
@staticmethod
def backward(ctx, gx):
return 2 * gx
def foo_impl(x):
return torch.tensor(x.cpu().numpy() ** 2, device=x.device)
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_impl, "CUDA")
def f(x):
y = op(x)
return y.sum(0)
x = torch.tensor([0, 1.], requires_grad=True)
with self.assertRaisesRegex(
torch._subclasses.fake_tensor.UnsupportedOperatorException,
'_test_custom_op.foo.default'):
operator_compile_check(f, (x,), {})
def test_incorrect_abstract_impl(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
guard = torch._C._AutoDispatchBelowAutograd()
guard2 = torch._C.ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView))
try:
return op(x)
finally:
del guard
del guard2
@staticmethod
def backward(ctx, gx):
return gx
def foo_impl(x):
return x ** 2
def foo_meta(x):
return x.unsqueeze(1) ** 2
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_impl, "CUDA")
lib.impl("foo", foo_meta, "Meta")
def f(x):
y = op(x)
return y.sum(0)
x = torch.tensor([0, 1.], requires_grad=True)
with self.assertRaisesRegex(
RuntimeError,
'Shapes .* are not equal'):
operator_compile_check(f, (x,), {})
def test_missing_functionalization(self, device):
lib = self.lib()
lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.mark_dirty(x)
with torch._C._AutoDispatchBelowAutograd():
return op(x)
@staticmethod
def backward(ctx, gx):
return gx
def foo_impl(x):
return x.sin_()
def foo_meta(x):
return x
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_impl, "CUDA")
lib.impl("foo", foo_meta, "Meta")
def f(x):
x = x.clone()
y = op(x)
return y.sum(0)
x = torch.tensor([0, 1.], requires_grad=True)
with self.assertRaisesRegex(
RuntimeError,
'Getting these operators to work with functionalization requires some extra work'):
operator_compile_check(f, (x,), {})
def test_autograd_registered_at_backend(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, gx):
return gx * 0.5
lib.impl("foo", Foo.apply, "CPU")
lib.impl("foo", Foo.apply, "CUDA")
lib.impl("foo", lambda x: x.clone(), "Meta")
def f(x):
y = op(x)
return x + y
x = torch.randn([], requires_grad=True)
with self.assertRaisesRegex(AssertionError, 'mismatched requires_grad-ness'):
operator_compile_check(f, (x,), {})
# I'm not sure why this is necessary
del lib
def test_global_state_mutation(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
invoked = 0
@staticmethod
def forward(ctx, x):
Foo.invoked += 1
return x.clone() * Foo.invoked
@staticmethod
def backward(ctx, gx):
return gx
lib.impl("foo", Foo.apply, "CompositeImplicitAutograd")
def f(x):
return op(x)
x = torch.tensor(3.14159 / 3, requires_grad=True)
with self.assertRaisesRegex(AssertionError, "not completely traceable"):
operator_compile_check(f, (x,), {})
@ops(custom_op_db, dtypes=OpDTypes.any_one)
def test_operator_compile_check_op(self, device, dtype, op):
for sample_input in op.sample_inputs(device, dtype, requires_grad=op.supports_autograd):
dynamic_only = op.name in ("NumpyNMSCustomOp", "NumpyNonzeroCustomOp")
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
operator_compile_check(
op.op, args, kwargs,
supports_autograd=op.supports_autograd,
dynamic_only=dynamic_only,
fullgraph=False, # Dynamo graph breaks on CustomOp today
)
def test_operator_compile_check_fails_basic(self, device):
@custom_op(f'{self.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sum()
x = torch.randn(3, device=device, requires_grad=True)
# Triggers the CustomOp autograd NYI error
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented for operator"):
operator_compile_check(lambda x: foo(x), (x,), {})
def test_assert_raises_regex(self, device):
from torch.testing._internal.optests.aot_autograd import assert_raises_regex
with assert_raises_regex(RuntimeError, 'c'):
raise RuntimeError("abcd")
with assert_raises_regex(RuntimeError, 'c.*'):
raise RuntimeError("abcd")
with self.assertRaisesRegex(AssertionError, 'instead got'):
with assert_raises_regex(RuntimeError, 'c.*'):
raise ValueError("abcd")
with self.assertRaisesRegex(AssertionError, 'Expected exception'):
with assert_raises_regex(RuntimeError, 'c.*'):
pass
with self.assertRaisesRegex(AssertionError, 'to match regex'):
with assert_raises_regex(RuntimeError, 'f'):
raise RuntimeError("abcd")
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
if __name__ == '__main__':
run_tests()