Skip to content

Commit

Permalink
env variable for conformer set at the top
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Feb 3, 2025
1 parent 81bc93d commit f6ca2bc
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import itertools
import json
import os

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
# disable only for deepspeech if it works fine for other workloads.
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'

import struct
import time
from types import MappingProxyType
Expand All @@ -30,12 +36,10 @@
from absl import flags
from absl import logging
import jax
import tensorflow as tf
import torch
import torch.distributed as dist

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
import tensorflow as tf

# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.set_visible_devices([], 'GPU')
Expand All @@ -52,9 +56,6 @@
from algorithmic_efficiency.pytorch_utils import sync_ddp_time
from algorithmic_efficiency.workloads import workloads

# disable only for deepspeech if it works fine for other workloads.
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'

# TODO(znado): make a nicer registry of workloads that lookup in.
BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR

Expand Down Expand Up @@ -702,12 +703,13 @@ def main(_):
]:
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80'

if base_workload != 'librispeech_conformer':
# Remove the environment variable (only for workloads other than librispeech conformer).
del os.environ['PYTORCH_CUDA_ALLOC_CONF']

if FLAGS.set_pytorch_max_split_size:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'

if FLAGS.framework == 'pytorch' and base_workload == 'librispeech_conformer':
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Extend path according to framework.
workload_metadata['workload_path'] = os.path.join(
BASE_WORKLOADS_DIR,
Expand Down

0 comments on commit f6ca2bc

Please sign in to comment.