Skip to content

Commit

Permalink
Merge pull request #542 from mlcommons/dev
Browse files Browse the repository at this point in the history
dev -> main
  • Loading branch information
priyakasimbeg authored Oct 17, 2023
2 parents e19dacf + 45a7730 commit 8fbdc3a
Show file tree
Hide file tree
Showing 58 changed files with 46 additions and 2,408 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ pylint tests

## Unit and integration tests
We run unit tests and integration tests as part of the of github actions as well.
You can also use `python tests/reference_algorithm_tests.py` to run a single model update and two model evals for each workload using the reference algorithm in `reference_algorithms/development_algorithms/`.
You can also use `python tests/reference_algorithm_tests.py` to run a single model update and two model evals for each workload using the reference algorithm in `reference_algorithms/target_setting_algorithms/`.

## Regression tests
We also have regression tests available in [.github/workflows/regression_tests.yml](https://github.com/mlcommons/algorithmic-efficiency/tree/main/.github/workflows/regression_tests.yml) that can be run semi-automatically.
Expand Down
27 changes: 12 additions & 15 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shutil
import subprocess
import sys
from typing import Any, Optional
from typing import Any, Dict, Optional

from absl import flags
from clu import metric_writers
Expand Down Expand Up @@ -96,14 +96,14 @@ def write_hparams(hparams: spec.Hyperparameters,
return hparams


def write_json(name: str, log_dict: dict, indent: int = 2) -> None:
def write_json(name: str, log_dict: Dict, indent: int = 2) -> None:
if RANK == 0:
with open(name, 'w') as f:
f.write(json.dumps(log_dict, indent=indent))


def write_to_csv(
metrics: dict,
metrics: Dict,
csv_path: str,
) -> None:
try:
Expand All @@ -120,7 +120,7 @@ def write_to_csv(
return


def _get_utilization() -> dict:
def _get_utilization() -> Dict:
util_data = {}

# CPU
Expand Down Expand Up @@ -180,7 +180,7 @@ def _get_utilization() -> dict:
return util_data


def _get_system_hardware_info() -> dict:
def _get_system_hardware_info() -> Dict:
system_hardware_info = {}
try:
system_hardware_info['cpu_model_name'] = _get_cpu_model_name()
Expand All @@ -200,7 +200,7 @@ def _get_system_hardware_info() -> dict:
return system_hardware_info


def _get_system_software_info() -> dict:
def _get_system_software_info() -> Dict:
system_software_info = {}

system_software_info['os_platform'] = \
Expand Down Expand Up @@ -243,7 +243,7 @@ def _is_primitive_type(item: Any) -> bool:
return isinstance(item, primitive)


def _get_workload_properties(workload: spec.Workload) -> dict:
def _get_workload_properties(workload: spec.Workload) -> Dict:
workload_properties = {}
skip_list = ['param_shapes', 'model_params_types']
keys = [
Expand All @@ -262,7 +262,8 @@ def _get_workload_properties(workload: spec.Workload) -> dict:
return workload_properties


def get_meta_data(workload: spec.Workload) -> dict:
def get_meta_data(workload: spec.Workload,
rng_seed: Optional[int] = None) -> Dict:
meta_data = {}
workload_properties = _get_workload_properties(workload)
meta_data.update(workload_properties)
Expand All @@ -272,15 +273,11 @@ def get_meta_data(workload: spec.Workload) -> dict:
meta_data.update(system_software_info)
system_hardware_info = _get_system_hardware_info()
meta_data.update(system_hardware_info)
if rng_seed is not None:
meta_data.update({'rng_seed': rng_seed})
return meta_data


def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str):
meta_data = get_meta_data(workload)
meta_data.update({'rng_seed': rng_seed})
write_json(meta_file_name, meta_data)


class MetricLogger(object):
"""Used to log all measurements during training.
Expand Down Expand Up @@ -308,7 +305,7 @@ def __init__(self,
wandb.config.update(hyperparameters._asdict())

def append_scalar_metrics(self,
metrics: dict,
metrics: Dict,
global_step: int,
preemption_count: Optional[int] = None,
is_eval: bool = False) -> None:
Expand Down
3 changes: 1 addition & 2 deletions algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:
# Make sure no GPU memory is preallocated to Jax.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# Only use CPU for Jax to avoid memory issues.
# Setting the corresponding environment variable here has no effect; it has to
# be done before jax and tensorflow (!) are imported for the first time.
jax.config.update('jax_platforms', 'cpu')
jax.config.update('jax_platform_name', 'cpu')
# From the docs: "(...) causes cuDNN to benchmark multiple convolution
# algorithms and select the fastest."
torch.backends.cudnn.benchmark = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flax import jax_utils
import jax
import jax.numpy as jnp
import numpy as np

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
Expand Down Expand Up @@ -147,7 +148,8 @@ def _eval_batch(self,
batch: Dict[str, spec.Tensor]) -> spec.Tensor:
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
# shape (local_device_count,) will all be different values.
return self._eval_batch_pmapped(params, batch).sum()
return np.array(
self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64)


class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def has_reached_validation_target(self, eval_result: Dict[str,

@property
def validation_target_value(self) -> float:
return 0.123649
return 0.123735

def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool:
return eval_result['test/loss'] < self.test_target_value

@property
def test_target_value(self) -> float:
return 0.126060
return 0.126041

@property
def loss_type(self) -> spec.LossType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def has_reached_validation_target(self, eval_result: Dict[str,

@property
def validation_target_value(self) -> float:
return 0.084952
return 0.085884

def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool:
return eval_result['test/wer'] < self.test_target_value

@property
def test_target_value(self) -> float:
return 0.053000
return 0.052981

@property
def loss_type(self) -> spec.LossType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:

@property
def validation_target_value(self) -> float:
return 0.118232
return 0.119936

@property
def test_target_value(self) -> float:
return 0.073397
return 0.074143

@property
def step_hint(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:

@property
def validation_target_value(self) -> float:
return 0.118232
return 0.119936

@property
def test_target_value(self) -> float:
return 0.073397
return 0.074143

@property
def step_hint(self) -> int:
Expand Down
5 changes: 0 additions & 5 deletions reference_algorithms/development_algorithms/README.md

This file was deleted.

Empty file.

This file was deleted.

Empty file.
Loading

0 comments on commit 8fbdc3a

Please sign in to comment.