diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 1b1b9d4e7df..2408e690c7c 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -1,12 +1,64 @@ +import logging import os import re +import socket import tempfile import subprocess +import time + +from .version import __version__ XRT_RUN_SERVER_PROCESS = 'torch_xla.core._xrt_run_server' XRT_SERVER_REGEX = '^python3 -m {} [0-9]+$'.format(XRT_RUN_SERVER_PROCESS) +def _maybe_select_tpu_version(): + # Setup correct TPU runtime version for Colab and Kaggle. + + def _is_open(ip, port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if s.connect_ex((ip, int(port))) == 0: + return True + return False + + def _wait_for_open(version, timeout=100, interval=10, log=True): + tpu_addr = os.environ['TPU_NAME'].split('grpc://')[1] + deadline = time.time() + timeout + + while not _is_open(*tpu_addr.split(':')): + if log: + logging.warning( + f'Waiting for TPU to be start up with version pytorch-{version}...') + if time.time() > deadline: + raise RuntimeError('Timed out waiting for TPU to start up') + time.sleep(interval) + + if log: + logging.warning( + f'TPU has started up successfully with version pytorch-{version}') + + try: + tpu_name = os.environ.get('TPU_NAME', '') + if not tpu_name.startswith('grpc://'): + # Not colab/kaggle + return + + import cloud_tpu_client + client = cloud_tpu_client.Client(tpu_name) + client.configure_tpu_version( + f'pytorch-{__version__}', restart_type='ifNeeded') + # client.wait_for_healthy() API doesn't work as we dont have TPU API access + _wait_for_open(__version__) + except ImportError: + logging.warning(( + 'Not selecting corresponding TPU runtime since cloud_tpu_client is not ' + 'installed. Ignore if not running on Colab/Kaggle TPU.')) + except Exception: + # This path is hit, when we get throttled by the verison changer + # when we import torch_xla from xmp.spawn-ed processes. + _wait_for_open(__version__, log=False) + + def server_is_alive(): # pgrep returns 0 when at least one running process matches the requested name. # Otherwise, the exit code is 1. If pgrep is not availiable in the system, it @@ -77,6 +129,7 @@ def _summarize_fn_tracker(): # These needs to be called before the _XLAC module is loaded. +_maybe_select_tpu_version() _setup_default_env() _setup_grpc() _setup_xla_flags() @@ -86,7 +139,6 @@ def _summarize_fn_tracker(): import atexit import torch from ._patched_functions import _apply_patches -from .version import __version__ import _XLAC