Skip to content

Commit

Permalink
Merge pull request #840 from mlcommons/python_upgrades
Browse files Browse the repository at this point in the history
Python and package upgrades
  • Loading branch information
priyakasimbeg authored Feb 11, 2025
2 parents cabcc59 + f375099 commit 4345e8b
Show file tree
Hide file tree
Showing 61 changed files with 1,134 additions and 288 deletions.
48 changes: 24 additions & 24 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -25,10 +25,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -42,10 +42,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -59,10 +59,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -77,10 +77,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -96,10 +96,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -113,10 +113,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -130,10 +130,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -148,10 +148,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -166,10 +166,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -184,10 +184,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install pytest
Expand All @@ -208,10 +208,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install pytest
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.11.10
- name: Install pylint
run: |
python -m pip install --upgrade pip
Expand All @@ -27,10 +27,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.11.10
- name: Install isort
run: |
python -m pip install --upgrade pip
Expand All @@ -43,10 +43,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.11.10
- name: Install yapf
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ The specs on the benchmarking machines are:

> **Prerequisites:**
>
> - Python minimum requirement >= 3.8
> - Python minimum requirement >= 3.11
> - CUDA 12.1
> - NVIDIA Driver version 535.104.05
Expand Down
2 changes: 1 addition & 1 deletion algoperf/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def save_checkpoint(framework: str,
target=checkpoint_state,
step=global_step,
overwrite=True,
keep=np.Inf if save_intermediate_checkpoints else 1)
keep=np.inf if save_intermediate_checkpoints else 1)
else:
if not save_intermediate_checkpoints:
checkpoint_files = gfile.glob(
Expand Down
2 changes: 1 addition & 1 deletion algoperf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _prepare(x):
# Assumes that `global_batch_size % local_device_count == 0`.
return x.reshape((local_device_count, -1, *x.shape[1:]))

return jax.tree_map(_prepare, batch)
return jax.tree.map(_prepare, batch)


def pad(tensor: np.ndarray,
Expand Down
16 changes: 8 additions & 8 deletions algoperf/halton.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import functools
import itertools
import math
from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

from absl import logging
from numpy import random

_SweepSequence = List[Dict[Text, Any]]
_GeneratorFn = Callable[[float], Tuple[Text, float]]
_SweepSequence = List[Dict[str, Any]]
_GeneratorFn = Callable[[float], Tuple[str, float]]


def generate_primes(n: int) -> List[int]:
Expand Down Expand Up @@ -195,10 +195,10 @@ def generate_sequence(num_samples: int,
return halton_sequence


def _generate_double_point(name: Text,
def _generate_double_point(name: str,
min_val: float,
max_val: float,
scaling: Text,
scaling: str,
halton_point: float) -> Tuple[str, float]:
"""Generate a float hyperparameter value from a Halton sequence point."""
if scaling not in ['linear', 'log']:
Expand Down Expand Up @@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]:
return start, end


def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
min_val, max_val = range_endpoints
return functools.partial(_generate_double_point,
name,
Expand All @@ -244,8 +244,8 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:


def uniform(
name: Text, search_points: Union[_DiscretePoints,
Tuple[int, int]]) -> _GeneratorFn:
name: str, search_points: Union[_DiscretePoints,
Tuple[int, int]]) -> _GeneratorFn:
if isinstance(search_points, _DiscretePoints):
return functools.partial(_generate_discrete_point,
name,
Expand Down
2 changes: 1 addition & 1 deletion algoperf/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _get_system_software_info() -> Dict:
system_software_info['os_platform'] = \
platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29'
system_software_info['python_version'] = platform.python_version(
) # Ex. '3.8.10'
) # Ex. '3.11.10'
system_software_info['python_compiler'] = platform.python_compiler(
) # Ex. 'GCC 9.3.0'
# Note: do not store hostname as that may be sensitive
Expand Down
2 changes: 1 addition & 1 deletion algoperf/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def pytorch_param_types(

def jax_param_shapes(
params: spec.ParameterContainer) -> spec.ParameterShapeTree:
return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params)
return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params)


def jax_param_types(param_shapes: spec.ParameterShapeTree,
Expand Down
9 changes: 5 additions & 4 deletions algoperf/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from flax import jax_utils
from flax import linen as nn
from flax.core import pop
import jax
from jax import lax
import jax.numpy as jnp
Expand Down Expand Up @@ -74,8 +75,8 @@ def sync_batch_stats(
# In this case each device has its own version of the batch statistics
# and we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
new_model_state = model_state.copy()
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state

def init_model_fn(
Expand All @@ -92,7 +93,7 @@ def init_model_fn(
input_shape = (1, 32, 32, 3)
variables = jax.jit(model.init)({'params': rng},
jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
Expand Down Expand Up @@ -205,4 +206,4 @@ def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str,
Any]) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
2 changes: 1 addition & 1 deletion algoperf/workloads/cifar/cifar_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _build_dataset(
}
if split == 'eval_train':
train_indices = indices_split['train']
random.Random(data_rng[0]).shuffle(train_indices)
random.Random(int(data_rng[0])).shuffle(train_indices)
indices_split['eval_train'] = train_indices[:self.num_eval_train_examples]
if split in indices_split:
dataset = torch.utils.data.Subset(dataset, indices_split[split])
Expand Down
Loading

0 comments on commit 4345e8b

Please sign in to comment.