Skip to content

Commit

Permalink
Merge branch 'main' into sc_xla
Browse files Browse the repository at this point in the history
  • Loading branch information
Obliviour authored Jan 28, 2025
2 parents 944f1d0 + 8623e3a commit e40f4ca
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 83 deletions.
39 changes: 20 additions & 19 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import datetime
import jax
import json
import sys

from absl import app
from collections.abc import MutableMapping
from typing import Any, Dict, Optional

from jetstream.engine import token_utils

Expand All @@ -36,14 +35,15 @@
warnings.simplefilter("ignore", category=FutureWarning)

_WARMUP_ITERS = 2

_FLATTEN_MICROBENCHMARK_RESULTS = False
# pylint: disable=too-many-positional-arguments


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
rng = jax.random.PRNGKey(1234)
prefill_result = None
for _ in range(iters):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
Expand All @@ -56,6 +56,7 @@ def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
def prefill_benchmark(config, engine, params, tokens, true_length, num_model_params, iters):
"""Handles warmup, running prefill benchmark, and printing results."""
rng = jax.random.PRNGKey(1234)
prefill_result = None
for _ in range(_WARMUP_ITERS):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
Expand Down Expand Up @@ -163,7 +164,7 @@ def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_
"step_in_ms_per_seq": ar_average_ms / global_batch_size,
"global_batch_size": global_batch_size,
"total_throughput_tokens_per_second": total_throughput,
"device_bandwidth_GB_per_second": bw_per_device,
"bw_per_device_GB_per_second": bw_per_device,
}
return result_dict, decode_state

Expand Down Expand Up @@ -197,7 +198,7 @@ def write_results(results, filename, flatten_microbenchmark_results):
"""Write the results microbenchmark results to a json file."""
if flatten_microbenchmark_results:
results["flattened_results"] = flatten_dict(results)
if filename != "":
if filename:
with open(filename, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
return results
Expand Down Expand Up @@ -246,7 +247,8 @@ def summarize_prefill_result(engine, params, tokens, true_length):
}


def main(config, inference_metadata: Optional[Dict[str, Any]] = None):
def run_benchmarks(config):
"""Run microbenchmarks."""
engine = maxengine.MaxEngine(config)
rng = jax.random.PRNGKey(1234)
rng, rng_load_params = jax.random.split(rng)
Expand Down Expand Up @@ -313,21 +315,20 @@ def main(config, inference_metadata: Optional[Dict[str, Any]] = None):

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
print_results_for_analyze(results)
if inference_metadata:
flatten_microbenchmark_results = pyconfig.string_to_bool(
inference_metadata.get("flatten_microbenchmark_results", "false")
if config.inference_microbenchmark_log_file_path:
write_results(
results,
filename=config.inference_microbenchmark_log_file_path,
flatten_microbenchmark_results=_FLATTEN_MICROBENCHMARK_RESULTS,
)
else:
flatten_microbenchmark_results = "false"
results = write_results(
results,
filename=config.inference_microbenchmark_log_file_path,
flatten_microbenchmark_results=flatten_microbenchmark_results,
)
return results


if __name__ == "__main__":
def main(argv):
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
pyconfig.initialize(sys.argv)
main(pyconfig.config)
pyconfig.initialize(argv)
run_benchmarks(pyconfig.config)


if __name__ == "__main__":
app.run(main)
40 changes: 14 additions & 26 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
"""Implementation of Engine API for MaxText"""
import copy as cp
import functools
from typing import Any, Optional, Tuple, Callable
from typing import Any, List, Optional, Tuple, Callable
from collections import defaultdict

import flax
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax import struct

from layers import models, quantizations

Expand All @@ -39,26 +38,15 @@
import max_utils
import inference_utils
import pyconfig
import jaxlib

import warnings

warnings.simplefilter("ignore", category=FutureWarning)

DecodeState = Any
Prefix = Any
PackedPrefix = Any
Params = Any


@struct.dataclass
class DecodeState:
"""The inputs into a generation step."""

prefill_cache: jax.Array
generate_cache: jax.Array
generate_cache_index: int
generate_lengths: jax.Array
generated_token: jax.Array
PRNGKeyType = Any


class MaxEngineConfig:
Expand Down Expand Up @@ -110,7 +98,7 @@ def __init__(self, config: Any, devices: config_lib.Devices | None = None):
self.kv_cache_shardings = None
self.state_mesh_annotations = None

def load_params(self, *args, rng: Optional[jax.random.PRNGKey] = None, **kwargs) -> Params:
def load_params(self, *args, rng: Optional[PRNGKeyType] = None, **kwargs) -> Params:
"""Load Parameters, typically from GCS"""
# pylint: disable=unused-argument

Expand All @@ -126,7 +114,7 @@ def load_params(self, *args, rng: Optional[jax.random.PRNGKey] = None, **kwargs)
# pylint: disable=isinstance-second-argument-not-valid-type
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
if isinstance(x, jaxlib.xla_extension.ArrayImpl)
if isinstance(x, jax.Array)
else None,
state.params,
)
Expand Down Expand Up @@ -158,7 +146,7 @@ def load_params(self, *args, rng: Optional[jax.random.PRNGKey] = None, **kwargs)
max_utils.print_mem_stats("After load_params")
return params

def quantize_params(self, state, rng: Optional[jax.random.PRNGKey] = None):
def quantize_params(self, state, rng: Optional[PRNGKeyType] = None):
"""Forward pass to quantize decode params."""
if rng is None:
rng = jax.random.PRNGKey(0)
Expand Down Expand Up @@ -227,7 +215,7 @@ def prefill(
padded_tokens: jax.Array,
true_length: int,
sampler: Optional[Callable[[Any], Any]] = None, # pylint: disable=unused-argument
rng: Optional[jax.random.PRNGKey] = None,
rng: Optional[PRNGKeyType] = None,
) -> Tuple[Prefix, engine_api.ResultTokens]:
"""Computes a kv-cache for a new generate request.
Expand Down Expand Up @@ -325,8 +313,8 @@ def prefill_concat(
true_lengths: jax.Array,
num_prompts: int,
sampler: Optional[Callable[[Any], Any]] = None, # pylint: disable=unused-argument
rng: Optional[jax.random.PRNGKey] = None,
) -> Tuple[Any, PackedPrefix, engine_api.ResultTokens]:
rng: Optional[PRNGKeyType] = None,
) -> Tuple[Any, PackedPrefix, List[engine_api.ResultTokens]]:
"""Computes a kv-cache for a new packed generate request, which is a
concatenation of several shorter prompts. Experimentation shows that
longer prefill sequences gives approximately 15% boost in time per prefilled
Expand Down Expand Up @@ -424,7 +412,7 @@ def generate(
params: Params,
decode_state: DecodeState,
sampler: Optional[Callable[[Any], Any]] = None, # pylint: disable=unused-argument
rng: Optional[jax.random.PRNGKey] = None,
rng: Optional[PRNGKeyType] = None,
) -> Tuple[DecodeState, engine_api.ResultTokens]:
"""Run one generate step"""
if rng is None:
Expand Down Expand Up @@ -718,7 +706,7 @@ def build_tokenizer(self, metadata: tokenizer_pb2.TokenizerParameters) -> tokeni
def init_decode_state(
self,
*args, # pylint: disable=unused-argument
rng: Optional[jax.random.PRNGKey] = None,
rng: Optional[PRNGKeyType] = None,
**kwargs, # pylint: disable=unused-argument
) -> DecodeState:
"""Initialises any state which a generation step transforms."""
Expand Down Expand Up @@ -820,9 +808,9 @@ def colocated_cpus(self) -> None:


def set_engine_vars_from_base_engine(
engine: engine_api.Engine,
base_engine: engine_api.Engine,
rng: jax.random.PRNGKey,
engine: MaxEngine,
base_engine: MaxEngine,
rng: PRNGKeyType,
):
"""Set internal vars from base_engine, which has already loaded the checkpoint and has sharding,
mesh, and kv cache related vars set.
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/inference_microbenchmark_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import unittest
from absl.testing import absltest
from inference_microbenchmark import main as inference_microbenchmark_main
from inference_microbenchmark import run_benchmarks


class Inference_Microbenchmark(unittest.TestCase):
Expand All @@ -38,7 +38,7 @@ def test(self):
"weight_dtype=bfloat16",
]
)
inference_microbenchmark_main(pyconfig.config)
run_benchmarks(pyconfig.config)


if __name__ == "__main__":
Expand Down
14 changes: 10 additions & 4 deletions benchmarks/Getting_Started_Benchmarking.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,26 @@ Two approaches are here:
- **benchmark_runner.py**: A cli interface to running a specific model recipe, on pathways or mcjax directly or with orchestration like xpk with one command.

```shell
# McJax
# McJax with XPK
CLUSTER=my-cluster
ZONE=my-zone
PROJECT=my-project
python3 benchmarks/benchmark_runner.py --project $PROJECT --zone $ZONE --cluster_name $CLUSTER --device_type v6e-256 --base_output_directory gs://maxtext-experiments-tpem/ --num_steps=5
python3 benchmarks/benchmark_runner.py xpk --project $PROJECT --zone $ZONE --cluster_name $CLUSTER --device_type v6e-256 --base_output_directory gs://maxtext-experiments-tpem/ --num_steps=5
```

```shell
# Pathways
# Pathways with XPK
export RUNNER=us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable
export PROXY_IMAGE=us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest
export SERVER_IMAGE=us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest

python3 benchmarks/benchmark_runner.py --project $PROJECT --zone $ZONE --cluster_name $CLUSTER --device_type v6e-256 --base_output_directory gs://maxtext-experiments-tpem/ --num_steps=5 --pathways_server_image="${SERVER_IMAGE}" --pathways_proxy_image="${PROXY_IMAGE}" --pathways_runner_image="${RUNNER}"
python3 benchmarks/benchmark_runner.py xpk --project $PROJECT --zone $ZONE --cluster_name $CLUSTER --device_type v6e-256 --base_output_directory gs://maxtext-experiments-tpem/ --num_steps=5 --pathways_server_image="${SERVER_IMAGE}" --pathways_proxy_image="${PROXY_IMAGE}" --pathways_runner_image="${RUNNER}"
```

```shell
# On-device
# Run model benchmark on current device (must run same command on all workers).
python3 benchmarks/benchmark_runner.py on-device --base_output_directory gs://maxtext-experiments-tpem/ --run_name="test-run" --num_steps=5
```

- **maxtext_xpk_runner.py**: A pythonic way to run xpk workloads! With the magic of for looping and python code, run several xpk workloads across a sweep of parameters including libtpu version, gke clusters, and maxtext parameters with one python script.
Expand Down
Loading

0 comments on commit e40f4ca

Please sign in to comment.