Skip to content

Commit

Permalink
Add TPUVM python disable test list (#2984)
Browse files Browse the repository at this point in the history
* Add TPUVM python disable test list

* Handle the XRT_TPU_CONFIG not set case
  • Loading branch information
JackCaoG authored Jun 8, 2021
1 parent e1b2dd2 commit bcc59d6
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,17 @@
}
}

DISABLED_TORCH_TESTS_TPUVM_ONLY = {
# test_nn.py
'TestNNDeviceTypeXLA': {
'test_AdaptiveMaxPool1d_indices_xla', # TODO: segfualt on TPUVM
'test_AdaptiveMaxPool2d_indices_xla', # TODO: segfualt on TPUVM
'test_AdaptiveMaxPool3d_indices_xla', # TODO: segfualt on TPUVM
'test_MaxPool3d_indices_xla', # TODO: segfualt on TPUVM
'test_multi_margin_loss_errors_xla', # TODO: segfualt on TPUVM
},
}

DISABLED_TORCH_TESTS_GPU_ONLY = {
# test_torch.py
'TestTorchDeviceTypeXLA': {
Expand Down Expand Up @@ -406,10 +417,18 @@ def union_of_disabled_tests(sets):
return union


def on_tpuvm():
config = os.getenv('XRT_TPU_CONFIG')
return config and re.match('^localservice;[0-9]+;localhost:[0-9]+', config)


DISABLED_TORCH_TESTS_CPU = DISABLED_TORCH_TESTS_ANY
DISABLED_TORCH_TESTS_GPU = union_of_disabled_tests(
[DISABLED_TORCH_TESTS_ANY, DISABLED_TORCH_TESTS_GPU_ONLY])
DISABLED_TORCH_TESTS_TPU = union_of_disabled_tests(
DISABLED_TORCH_TESTS_TPU = union_of_disabled_tests([
DISABLED_TORCH_TESTS_ANY, DISABLED_TORCH_TESTS_TPU_ONLY,
DISABLED_TORCH_TESTS_TPUVM_ONLY
]) if on_tpuvm() else union_of_disabled_tests(
[DISABLED_TORCH_TESTS_ANY, DISABLED_TORCH_TESTS_TPU_ONLY])

DISABLED_TORCH_TESTS = {
Expand Down

0 comments on commit bcc59d6

Please sign in to comment.