Skip to content

Commit

Permalink
Allow parallel gpu tests (#910)
Browse files Browse the repository at this point in the history
* Parallel gpu tests

* Fix name

* Move to top level
  • Loading branch information
hanzhi713 authored Jan 8, 2025
1 parent c40b39a commit b8219e9
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright © 2024 Apple Inc.

"""Configures pytest to distribute tests to multiple GPUs.
This is not enabled by default and requires explicit opt-in by setting the environment variable
PARALLEL_GPU_TEST. This is because not all GPU tests are single-GPU tests.
Example usage on 8 GPU machines:
PARALLEL_GPU_TEST=1 pytest -n 8 axlearn/common/flash_attention/gpu_attention_test.py
"""
import os


# pylint: disable-next=unused-argument
def pytest_configure(config):
if "PARALLEL_GPU_TEST" not in os.environ:
return
worker_id = os.getenv("PYTEST_XDIST_WORKER", "gw0")
os.environ["CUDA_VISIBLE_DEVICES"] = worker_id.lstrip("gw")

0 comments on commit b8219e9

Please sign in to comment.