Skip to content

Commit

Permalink
Add version selector code for pytorch-1.9. (#2982)
Browse files Browse the repository at this point in the history
  • Loading branch information
zcain117 authored Jun 8, 2021
1 parent b4e0dda commit e1b2dd2
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion torch_xla/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,64 @@
import logging
import os
import re
import socket
import tempfile
import subprocess
import time

from .version import __version__

XRT_RUN_SERVER_PROCESS = 'torch_xla.core._xrt_run_server'
XRT_SERVER_REGEX = '^python3 -m {} [0-9]+$'.format(XRT_RUN_SERVER_PROCESS)


def _maybe_select_tpu_version():
# Setup correct TPU runtime version for Colab and Kaggle.

def _is_open(ip, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if s.connect_ex((ip, int(port))) == 0:
return True
return False

def _wait_for_open(version, timeout=100, interval=10, log=True):
tpu_addr = os.environ['TPU_NAME'].split('grpc://')[1]
deadline = time.time() + timeout

while not _is_open(*tpu_addr.split(':')):
if log:
logging.warning(
f'Waiting for TPU to be start up with version pytorch-{version}...')
if time.time() > deadline:
raise RuntimeError('Timed out waiting for TPU to start up')
time.sleep(interval)

if log:
logging.warning(
f'TPU has started up successfully with version pytorch-{version}')

try:
tpu_name = os.environ.get('TPU_NAME', '')
if not tpu_name.startswith('grpc://'):
# Not colab/kaggle
return

import cloud_tpu_client
client = cloud_tpu_client.Client(tpu_name)
client.configure_tpu_version(
f'pytorch-{__version__}', restart_type='ifNeeded')
# client.wait_for_healthy() API doesn't work as we dont have TPU API access
_wait_for_open(__version__)
except ImportError:
logging.warning((
'Not selecting corresponding TPU runtime since cloud_tpu_client is not '
'installed. Ignore if not running on Colab/Kaggle TPU.'))
except Exception:
# This path is hit, when we get throttled by the verison changer
# when we import torch_xla from xmp.spawn-ed processes.
_wait_for_open(__version__, log=False)


def server_is_alive():
# pgrep returns 0 when at least one running process matches the requested name.
# Otherwise, the exit code is 1. If pgrep is not availiable in the system, it
Expand Down Expand Up @@ -77,6 +129,7 @@ def _summarize_fn_tracker():


# These needs to be called before the _XLAC module is loaded.
_maybe_select_tpu_version()
_setup_default_env()
_setup_grpc()
_setup_xla_flags()
Expand All @@ -86,7 +139,6 @@ def _summarize_fn_tracker():
import atexit
import torch
from ._patched_functions import _apply_patches
from .version import __version__
import _XLAC


Expand Down

0 comments on commit e1b2dd2

Please sign in to comment.