Skip to content

Commit

Permalink
Merge pull request #811 from init-22/python311
Browse files Browse the repository at this point in the history
Moving from Python3.8 to Python 3.11
  • Loading branch information
priyakasimbeg authored Dec 20, 2024
2 parents ea66793 + 53eff1d commit 9d1c957
Show file tree
Hide file tree
Showing 18 changed files with 598 additions and 125 deletions.
48 changes: 24 additions & 24 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -25,10 +25,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -42,10 +42,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -59,10 +59,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -77,10 +77,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -96,10 +96,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -113,10 +113,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -130,10 +130,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -148,10 +148,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -166,10 +166,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -184,10 +184,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install pytest
Expand All @@ -208,10 +208,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install pytest
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.11.10
- name: Install pylint
run: |
python -m pip install --upgrade pip
Expand All @@ -27,10 +27,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.11.10
- name: Install isort
run: |
python -m pip install --upgrade pip
Expand All @@ -43,10 +43,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.11.10
- name: Install yapf
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ The specs on the benchmarking machines are:

> **Prerequisites:**
>
> - Python minimum requirement >= 3.8
> - Python minimum requirement >= 3.11
> - CUDA 12.1
> - NVIDIA Driver version 535.104.05
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def save_checkpoint(framework: str,
target=checkpoint_state,
step=global_step,
overwrite=True,
keep=np.Inf if save_intermediate_checkpoints else 1)
keep=np.inf if save_intermediate_checkpoints else 1)
else:
if not save_intermediate_checkpoints:
checkpoint_files = gfile.glob(
Expand Down
16 changes: 8 additions & 8 deletions algorithmic_efficiency/halton.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import functools
import itertools
import math
from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

from absl import logging
from numpy import random

_SweepSequence = List[Dict[Text, Any]]
_GeneratorFn = Callable[[float], Tuple[Text, float]]
_SweepSequence = List[Dict[str, Any]]
_GeneratorFn = Callable[[float], Tuple[str, float]]


def generate_primes(n: int) -> List[int]:
Expand Down Expand Up @@ -195,10 +195,10 @@ def generate_sequence(num_samples: int,
return halton_sequence


def _generate_double_point(name: Text,
def _generate_double_point(name: str,
min_val: float,
max_val: float,
scaling: Text,
scaling: str,
halton_point: float) -> Tuple[str, float]:
"""Generate a float hyperparameter value from a Halton sequence point."""
if scaling not in ['linear', 'log']:
Expand Down Expand Up @@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]:
return start, end


def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
min_val, max_val = range_endpoints
return functools.partial(_generate_double_point,
name,
Expand All @@ -244,8 +244,8 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:


def uniform(
name: Text, search_points: Union[_DiscretePoints,
Tuple[int, int]]) -> _GeneratorFn:
name: str, search_points: Union[_DiscretePoints,
Tuple[int, int]]) -> _GeneratorFn:
if isinstance(search_points, _DiscretePoints):
return functools.partial(_generate_discrete_point,
name,
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _get_system_software_info() -> Dict:
system_software_info['os_platform'] = \
platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29'
system_software_info['python_version'] = platform.python_version(
) # Ex. '3.8.10'
) # Ex. '3.11.10'
system_software_info['python_compiler'] = platform.python_compiler(
) # Ex. 'GCC 9.3.0'
# Note: do not store hostname as that may be sensitive
Expand Down
16 changes: 8 additions & 8 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@

# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
# unsigned int), while RandomState.randint only accepts and returns signed ints.
MAX_INT32 = 2**31
MIN_INT32 = -MAX_INT32
MAX_UINT32 = 2**32 - 1
MIN_UINT32 = 0

SeedType = Union[int, list, np.ndarray]


def _signed_to_unsigned(seed: SeedType) -> SeedType:
if isinstance(seed, int):
return seed % 2**32
return seed % MAX_UINT32
if isinstance(seed, list):
return [s % 2**32 for s in seed]
return [s % MAX_UINT32 for s in seed]
if isinstance(seed, np.ndarray):
return np.array([s % 2**32 for s in seed.tolist()])
return np.array([s % MAX_UINT32 for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
return [new_seed, data]


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand Down Expand Up @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType:
def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.PRNGKey(seed)
return jax_rng.key(seed)
return _PRNGKey(seed)
7 changes: 4 additions & 3 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from flax import jax_utils
from flax import linen as nn
from flax.core import pop
import jax
from jax import lax
import jax.numpy as jnp
Expand Down Expand Up @@ -75,8 +76,8 @@ def sync_batch_stats(
# In this case each device has its own version of the batch statistics
# and we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
new_model_state = model_state.copy()
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state

def init_model_fn(
Expand All @@ -93,7 +94,7 @@ def init_model_fn(
input_shape = (1, 32, 32, 3)
variables = jax.jit(model.init)({'params': rng},
jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
Expand Down
Loading

0 comments on commit 9d1c957

Please sign in to comment.