Skip to content

Commit

Permalink
Merge pull request #1073 from keshavb96:disable_tf_gpus
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704508907
  • Loading branch information
maxtext authors committed Dec 10, 2024
2 parents 0fe43b7 + 37e3f44 commit 8e55ab1
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import profiler
import pyconfig
import pathwaysutils # pylint: disable=unused-import
import tensorflow as tf

from vertex_tensorboard import VertexTensorboardManager
# Placeholder: internal
Expand Down Expand Up @@ -756,6 +757,9 @@ def train_loop(config, state=None):

def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
# TF allocates extraneous GPU memory when using TFDS data
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
tf.config.set_visible_devices([], 'GPU')
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
Expand Down

0 comments on commit 8e55ab1

Please sign in to comment.