diff --git a/google3/third_party/py/jetstream/.github/CODEOWNERS b/.github/CODEOWNERS similarity index 100% rename from google3/third_party/py/jetstream/.github/CODEOWNERS rename to .github/CODEOWNERS diff --git a/google3/third_party/py/jetstream/.github/workflows/e2e_tests.yaml b/.github/workflows/e2e_tests.yaml similarity index 100% rename from google3/third_party/py/jetstream/.github/workflows/e2e_tests.yaml rename to .github/workflows/e2e_tests.yaml diff --git a/google3/third_party/py/jetstream/.github/workflows/release.yaml b/.github/workflows/release.yaml similarity index 100% rename from google3/third_party/py/jetstream/.github/workflows/release.yaml rename to .github/workflows/release.yaml diff --git a/google3/third_party/py/jetstream/.github/workflows/scripts/create_release.js b/.github/workflows/scripts/create_release.js similarity index 100% rename from google3/third_party/py/jetstream/.github/workflows/scripts/create_release.js rename to .github/workflows/scripts/create_release.js diff --git a/google3/third_party/py/jetstream/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml similarity index 100% rename from google3/third_party/py/jetstream/.github/workflows/unit_tests.yaml rename to .github/workflows/unit_tests.yaml diff --git a/google3/third_party/py/jetstream/.gitignore b/.gitignore similarity index 100% rename from google3/third_party/py/jetstream/.gitignore rename to .gitignore diff --git a/google3/third_party/py/jetstream/AUTHORS b/AUTHORS similarity index 100% rename from google3/third_party/py/jetstream/AUTHORS rename to AUTHORS diff --git a/google3/third_party/py/jetstream/CONTRIBUTING.md b/CONTRIBUTING.md similarity index 100% rename from google3/third_party/py/jetstream/CONTRIBUTING.md rename to CONTRIBUTING.md diff --git a/google3/third_party/py/jetstream/LICENSE b/LICENSE similarity index 100% rename from google3/third_party/py/jetstream/LICENSE rename to LICENSE diff --git a/google3/third_party/py/jetstream/MANIFEST.in b/MANIFEST.in similarity index 100% rename from google3/third_party/py/jetstream/MANIFEST.in rename to MANIFEST.in diff --git a/google3/third_party/py/jetstream/Makefile b/Makefile similarity index 100% rename from google3/third_party/py/jetstream/Makefile rename to Makefile diff --git a/google3/third_party/py/jetstream/README.md b/README.md similarity index 100% rename from google3/third_party/py/jetstream/README.md rename to README.md diff --git a/google3/third_party/py/jetstream/benchmarks/README.md b/benchmarks/README.md similarity index 100% rename from google3/third_party/py/jetstream/benchmarks/README.md rename to benchmarks/README.md diff --git a/google3/third_party/py/jetstream/benchmarks/__init__.py b/benchmarks/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/benchmarks/__init__.py rename to benchmarks/__init__.py diff --git a/google3/third_party/py/jetstream/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py similarity index 71% rename from google3/third_party/py/jetstream/benchmarks/benchmark_serving.py rename to benchmarks/benchmark_serving.py index 97628372..68f0d296 100644 --- a/google3/third_party/py/jetstream/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -62,6 +62,7 @@ import asyncio from dataclasses import dataclass, field from datetime import datetime +import gc import json import random import time @@ -70,6 +71,7 @@ import grpc +from benchmarks.metrics import EventMetric, CounterMetric from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine.token_utils import load_vocab @@ -107,6 +109,40 @@ def str2bool(v: str) -> bool: raise ValueError(f"Invalid value '{v}'!") +class AsyncCounter: + """An counter class for counting and quota management with asycio, + not thread safe. It's safe with asyncio as value changes are done + outside of await statements. + """ + + def __init__(self, init_value: int, block_on_zero_seconds=0.002): + """ + Args: + init_value: Initial value for the counter. + block_on_zero_seconds: if greater than 0, the counter will spin when + value hits 0, hence can be used for quota management. + """ + self._init_value = init_value + self._value = init_value + self._block_on_zero_seconds = block_on_zero_seconds + + async def inc(self): + self._value += 1 + + async def dec(self): + while True: + if self._value > 0 or self._block_on_zero_seconds <= 0.0: + self._value -= 1 + return + await asyncio.sleep(self._block_on_zero_seconds) + + def value(self): + return self._value + + def delta(self): + return self._init_value - self._value + + @dataclass class BenchmarkMetrics: """Data class to store benchmark metrics.""" @@ -117,12 +153,10 @@ class BenchmarkMetrics: request_throughput: float input_throughput: float output_throughput: float - mean_ttft_ms: float - median_ttft_ms: float - p99_ttft_ms: float - mean_tpot_ms: float - median_tpot_ms: float - p99_tpot_ms: float + + ttft: EventMetric # Time-to-first-token + ttst: EventMetric # Time-to-second-token + tpot: EventMetric # Time-per-output-token @dataclass @@ -136,24 +170,37 @@ class InputRequest: @dataclass class RequestFuncOutput: + """Data class to store the response of a request.""" + input_request: Optional[InputRequest] = None - generated_token_list: list[str] = field(default_factory=list) + generated_token_list: list[int] = field(default_factory=list) generated_text: str = "" success: bool = False - latency: float = 0 - ttft: float = 0 + latency_sec: float = 0 + ttft_sec: float = 0 + ttst_sec: float = 0 prompt_len: int = 0 # Flatten the structure and return only the necessary results def to_dict(self): + if self.input_request: + prompt = self.input_request.prompt + original_output = self.input_request.output + sample_idx = self.input_request.sample_idx + else: + prompt = None + original_output = None + sample_idx = None return { - "prompt": self.input_request.prompt, - "original_output": self.input_request.output, + "prompt": prompt, + "original_output": original_output, "generated_text": self.generated_text, "success": self.success, - "latency": self.latency, + "latency_sec": self.latency_sec, + "ttft_sec": self.ttft_sec, + "ttst_sec": self.ttst_sec, "prompt_len": self.prompt_len, - "sample_idx": self.input_request.sample_idx, + "sample_idx": sample_idx, } @@ -210,12 +257,16 @@ def load_sharegpt_dataset( return dataset -def load_openorca_dataset_pkl(): +def load_openorca_dataset_pkl( + dataset_path: str, +) -> list[tuple[Any, Any]]: + if not dataset_path: + dataset_path = "open_orca_gpt4_tokenized_llama.calibration_1000.pkl" # read pickle file samples = pandas.read_pickle( os.path.join( os.path.dirname(os.path.relpath(__file__)), - "open_orca_gpt4_tokenized_llama.calibration_1000.pkl", + dataset_path, ) ) @@ -376,15 +427,19 @@ def calculate_metrics( total_output = 0 total_input = 0 completed = 0 - per_token_latencies = [] - ttfts = [] + ttft = EventMetric("ttft", "Time-to-first-token", "ms") + ttst = EventMetric("ttst", "Time-to-second-token", "ms") + per_out_token_lat = EventMetric("TPOT", "Time-per-output-token", "ms") + output_sizes = [] for i in range(len(outputs)): if outputs[i].success: + completed += 1 output_len = len( outputs[i].generated_token_list if tokenizer != "test" else ["Ċ", "Ō", "Ɵ"] ) + output_sizes.append(output_len) total_output += output_len total_input += input_requests[i].prompt_len if output_len == 0: @@ -393,9 +448,13 @@ def calculate_metrics( output: {outputs[i]}""" ) continue - per_token_latencies.append(outputs[i].latency / output_len) - ttfts.append(outputs[i].ttft) - completed += 1 + ttft.record(outputs[i].ttft_sec * 1000) + ttst.record(outputs[i].ttst_sec * 1000) + per_out_token_lat.record(outputs[i].latency_sec / output_len * 1000) + + print("Mean output size:", float(np.mean(output_sizes))) + print("Median output size:", float(np.median(output_sizes))) + print("P99 output size:", float(np.percentile(output_sizes, 99))) metrics = BenchmarkMetrics( completed=completed, @@ -404,65 +463,99 @@ def calculate_metrics( request_throughput=completed / dur_s, input_throughput=total_input / dur_s, output_throughput=total_output / dur_s, - mean_ttft_ms=float(np.mean(ttfts) * 1000), - median_ttft_ms=float(np.median(ttfts) * 1000), - p99_ttft_ms=float(np.percentile(ttfts, 99) * 1000), - mean_tpot_ms=float(np.mean(per_token_latencies) * 1000), - median_tpot_ms=float(np.median(per_token_latencies) * 1000), - p99_tpot_ms=float(np.percentile(per_token_latencies, 99) * 1000), + ttft=ttft, + ttst=ttst, + tpot=per_out_token_lat, ) return metrics async def grpc_async_request( - api_url: str, request: Any -) -> tuple[list[str], float, float]: + api_url: str, + request: Any, + prefill_quota: AsyncCounter, + active_req_quota: AsyncCounter, + out_token_cnt: CounterMetric, +) -> tuple[list[int], float, float, float]: """Send grpc synchronous request since the current grpc server is sync.""" options = [("grpc.keepalive_timeout_ms", 10000)] async with grpc.aio.insecure_channel(api_url, options=options) as channel: stub = jetstream_pb2_grpc.OrchestratorStub(channel) - print("Making request") - ttft = 0 - token_list = [] request_start_time = time.perf_counter() response = stub.Decode(request) + token_list = [] + ttft = 0 + ttst = 0 + stream_resp_cnt = 0 async for resp in response: - if ttft == 0: + stream_resp_cnt += 1 + if stream_resp_cnt == 1: + await prefill_quota.inc() ttft = time.perf_counter() - request_start_time - token_list.extend(resp.stream_content.samples[0].token_ids) - latency = time.perf_counter() - request_start_time - return token_list, ttft, latency + if ttft > 2.0: + print(datetime.now(), f"slow TTFT {ttft:.2f}", prefill_quota.value()) + elif stream_resp_cnt == 2: + ttst = time.perf_counter() - request_start_time + resp_tokens = resp.stream_content.samples[0].token_ids + token_list.extend(resp_tokens) + out_token_cnt.increment(len(resp_tokens)) + await active_req_quota.inc() + req_latency = time.perf_counter() - request_start_time + return token_list, ttft, ttst, req_latency async def send_request( api_url: str, tokenizer: Any, input_request: InputRequest, + prefill_quota: AsyncCounter, + active_req_quota: AsyncCounter, + req_complete_cnt: CounterMetric, + out_token_cnt: CounterMetric, pbar: tqdm, ) -> RequestFuncOutput: """Send the request to JetStream server.""" - # Tokenization on client side following MLPerf standard. + # Tokenize on client side following MLPerf standard. token_ids = tokenizer.encode(input_request.prompt) + + # Send the request request = jetstream_pb2.DecodeRequest( token_content=jetstream_pb2.DecodeRequest.TokenContent( token_ids=token_ids ), max_tokens=input_request.output_len, + metadata=jetstream_pb2.DecodeRequest.Metadata( + start_time=time.perf_counter() + ), + ) + out_tokens, ttft_sec, ttst_sec, latency_sec = await grpc_async_request( + api_url, + request, + prefill_quota, + active_req_quota, + out_token_cnt, ) + req_complete_cnt.increment() + + # Collect per-request output and metrics. output = RequestFuncOutput() output.input_request = input_request output.prompt_len = input_request.prompt_len - generated_token_list, ttft, latency = await grpc_async_request( - api_url, request - ) - output.ttft = ttft - output.latency = latency - output.generated_token_list = generated_token_list + output.ttft_sec = ttft_sec + output.ttst_sec = ttst_sec + output.latency_sec = latency_sec + output.generated_token_list = out_tokens # generated_token_list is a list of token ids, decode it to generated_text. - output.generated_text = tokenizer.decode(generated_token_list) + output.generated_text = tokenizer.decode(out_tokens) output.success = True if pbar: + pbar.postfix = ( + f"#reqs: {active_req_quota.delta()}/" + f"{active_req_quota.value()}; " + f"#prefill: {prefill_quota.delta()}/" + f"{prefill_quota.value()}" + ) pbar.update(1) return output @@ -473,69 +566,112 @@ async def benchmark( input_requests: list[InputRequest], request_rate: float, disable_tqdm: bool, -): - """Benchmark the online serving performance.""" - pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + prefill_quota: AsyncCounter, + active_req_quota: AsyncCounter, + is_warmup: bool = False, +) -> tuple[dict[str, float | int], list[RequestFuncOutput]]: + """Benchmark the online serving performance. - print(f"Traffic request rate: {request_rate}") + Args: + api_url: URL (e.g. host:port) of the JetStream server to send requests to. + tokenizer: The tokenizer used to convert texts into tokens that will be set + in requests. + input_requests: A list of requests to send. + request_rate: The number of requests to send per second. + disable_tqdm: Whether progress bar should be disabled or not. + prefill_quota: Quota for limiting pending prefill operations. + active_req_quota: Quota for limiting inflight requests. + is_warmup: Whether this run is to warm up the server. + + Return: + A tuple containing the performance statistics for all requests and a list + of responses from the executed requests. + """ + print(f"Benchmarking with a total number of {len(input_requests)} requests") + print(f"Benchmarking with request rate of {request_rate}") + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + req_complete_cnt = CounterMetric( + "ReqCompleteCount", "Request Completion Counter" + ) + out_token_cnt = CounterMetric("OutTokenCount", "OutToken Counter") - benchmark_start_time = time.perf_counter() + # Run benchmarking tasks = [] + benchmark_start_time = time.perf_counter() async for request in get_request(input_requests, request_rate): + await prefill_quota.dec() + await active_req_quota.dec() tasks.append( asyncio.create_task( send_request( api_url=api_url, tokenizer=tokenizer, input_request=request, + prefill_quota=prefill_quota, + active_req_quota=active_req_quota, + req_complete_cnt=req_complete_cnt, + out_token_cnt=out_token_cnt, pbar=pbar, ) ) ) outputs = await asyncio.gather(*tasks) - - if not disable_tqdm and pbar: + if pbar is not None: pbar.close() - benchmark_duration = time.perf_counter() - benchmark_start_time - - metrics = calculate_metrics( - input_requests=input_requests, - outputs=outputs, - dur_s=benchmark_duration, - tokenizer=tokenizer, - ) - - print(f"Successful requests: {metrics.completed}") - print(f"Benchmark duration: {benchmark_duration:2f} s") - print(f"Total input tokens: {metrics.total_input}") - print(f"Total generated tokens: {metrics.total_output}") - print(f"Request throughput: {metrics.request_throughput:.2f} requests/s") - print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s") - print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s") - print(f"Mean TTFT: {metrics.mean_ttft_ms:.2f} ms") - print(f"Median TTFT: {metrics.median_ttft_ms:.2f} ms") - print(f"P99 TTFT: {metrics.p99_ttft_ms:.2f} ms") - print(f"Mean TPOT: {metrics.mean_tpot_ms:.2f} ms") - print(f"Median TPOT: {metrics.median_tpot_ms:.2f} ms") - print(f"P99 TPOT: {metrics.p99_tpot_ms:.2f} ms") - - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "request_throughput": metrics.request_throughput, - "input_throughput": metrics.input_throughput, - "output_throughput": metrics.output_throughput, - "mean_ttft_ms": metrics.mean_ttft_ms, - "median_ttft_ms": metrics.median_ttft_ms, - "p99_ttft_ms": metrics.p99_ttft_ms, - "mean_tpot_ms": metrics.mean_tpot_ms, - "median_tpot_ms": metrics.median_tpot_ms, - "p99_tpot_ms": metrics.p99_tpot_ms, - } - return result, outputs + # Compute metrics + output_metrics = {} + if not is_warmup: + # No need to calculate metrics when executing warmup requests + benchmark_duration = time.perf_counter() - benchmark_start_time + metrics = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + ) + print(f"Successful requests: {metrics.completed}") + print(f"Benchmark duration: {benchmark_duration:2f} s") + print(f"Total input tokens: {metrics.total_input}") + print(f"Total generated tokens: {metrics.total_output}") + print(f"Request throughput: {metrics.request_throughput:.2f} requests/s") + print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s") + print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s") + + print(f"{metrics.ttft.distribution_summary_str()}") + print(f"{metrics.ttst.distribution_summary_str()}") + print(f"{metrics.tpot.distribution_summary_str()}") + + # Calculate one rate for each 10 sec window. Adjusts the window size if + # needed to use csv output below for plotting the rate over time. + window_size_sec = 10 + print( + f"----- Request complete rate time series " + f"(window_size = {window_size_sec} sec) -----" + ) + print(f"{req_complete_cnt.rate_over_window_to_csv(window_size_sec)}") + print( + f"----- Output token rate time series " + f"(window_size = {window_size_sec} sec) -----" + ) + print(f"{out_token_cnt.rate_over_window_to_csv(window_size_sec)}") + + output_metrics = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + } + output_metrics = { + **output_metrics, + **metrics.ttft.distribution_summary_dict(), + **metrics.ttst.distribution_summary_dict(), + **metrics.tpot.distribution_summary_dict(), + } + return output_metrics, outputs def mock_requests(total_mock_requests: int): @@ -579,6 +715,9 @@ def main(args: argparse.Namespace): tokenizer_id = args.tokenizer use_hf_tokenizer = args.use_hf_tokenizer + prefill_quota = AsyncCounter(init_value=3) + active_req_quota = AsyncCounter(init_value=450) + api_url = f"{args.server}:{args.port}" tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer) @@ -589,7 +728,7 @@ def main(args: argparse.Namespace): else: dataset = [] if args.dataset == "openorca": - dataset = load_openorca_dataset_pkl() + dataset = load_openorca_dataset_pkl(args.dataset_path) elif args.dataset == "sharegpt": dataset = load_sharegpt_dataset( args.dataset_path, @@ -613,17 +752,20 @@ def main(args: argparse.Namespace): warmup_requests = list(sample_warmup_requests(input_requests)) * 2 if warmup_requests: - print(f"Starting {args.warmup_mode} warmup:") - benchmark_result, request_outputs = asyncio.run( + print(f"Warmup (mode: {args.warmup_mode}) is starting.") + _, _ = asyncio.run( benchmark( api_url=api_url, tokenizer=tokenizer, input_requests=warmup_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, + prefill_quota=prefill_quota, + active_req_quota=active_req_quota, + is_warmup=True, ) ) - print(f"{args.warmup_mode} warmup completed.") + print(f"Warmup (mode: {args.warmup_mode}) has completed.") # TODO: Replace this with warmup complete signal once supported. # Wait for server completely warmup before running the benchmark. @@ -636,6 +778,8 @@ def main(args: argparse.Namespace): input_requests=input_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, + prefill_quota=prefill_quota, + active_req_quota=active_req_quota, ) ) @@ -692,7 +836,6 @@ def main(args: argparse.Namespace): if __name__ == "__main__": - parser = argparse.ArgumentParser( description="Benchmark the online serving throughput." ) @@ -836,4 +979,5 @@ def main(args: argparse.Namespace): ) parsed_args = parser.parse_args() + gc.disable() main(parsed_args) diff --git a/google3/third_party/py/jetstream/benchmarks/eval_accuracy.py b/benchmarks/eval_accuracy.py similarity index 99% rename from google3/third_party/py/jetstream/benchmarks/eval_accuracy.py rename to benchmarks/eval_accuracy.py index 559cd2a8..f84562be 100644 --- a/google3/third_party/py/jetstream/benchmarks/eval_accuracy.py +++ b/benchmarks/eval_accuracy.py @@ -64,6 +64,7 @@ def main(args): eval_accuracy(request_outputs_dict) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/benchmarks/metrics.py b/benchmarks/metrics.py new file mode 100644 index 00000000..1cd122c4 --- /dev/null +++ b/benchmarks/metrics.py @@ -0,0 +1,244 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Metrics util classes for collecting and managing metrics.""" + +import datetime + +import numpy as np +from typing import Tuple, List, Dict + + +def _floor_datetime_to_sec(timestamp: datetime.datetime) -> datetime.datetime: + """ "Floor the timestamp to the nearest most recent second""" + return timestamp.replace(microsecond=0) + + +def _now_floored_to_second() -> datetime.datetime: + """Return the current timestamp floored to the nearest most recent second a""" + now = datetime.datetime.now() + return _floor_datetime_to_sec(now) + + +class EventMetric: + """An event metric for distribution stats reporting. Not thread-safe.""" + + def __init__(self, name: str, description: str, unit: str = ""): + self._name = name + self._description = description + self._unit = unit + self._data = [] + + def data(self) -> List[float]: + """Returns all stored data points. + + Returns: + A list of data points in the order that was stored + """ + return self._data + + def record(self, value: float): + """Record a data point + + Args: + value: The data point to be stored. + """ + self._data.append(value) + + def percentile(self, percentile: int) -> float: + """Computes and returns the specified percentile of the collected data. + + Args: + percentile: The percentile to compute. + + Returns: + The computed percentile. + """ + if not 0 <= percentile <= 100: + raise ValueError(f"Percentile {percentile} is not in [0, 100]") + if not self._data: + raise ValueError( + f"No data points in metric {self._name} to compute percentile" + ) + return np.percentile(self._data, percentile) + + def mean(self) -> float: + """Calculates and returns the mean value of the collected data. + + Returns: + The mean value of the collected data + """ + if not self._data: + raise ValueError(f"No data points in metric {self._name} to compute mean") + return np.mean(self._data) + + def distribution_summary_str(self) -> str: + """Generates a string representation of the distribution summary + + Returns: + The string representation of the distribution summary including + mean, p50, p90 and p99. + """ + s = "" + s += f"Mean {self._name}: {self.mean():.2f} {self._unit}\n" + s += f"Median {self._name}: {self.percentile(50):.2f} {self._unit}\n" + s += f"P99 {self._name}: {self.percentile(99):.2f} {self._unit}" + return s + + def distribution_summary_dict(self) -> dict[str, float]: + """Generates a dictionary representation of the distribution summary + + Returns: + A dictionary containing of the distribution summary including mean, + p50, p90 and p99. + """ + return { + f"mean_{self._name}_{self._unit}": self.mean(), + f"median_{self._name}_{self._unit}": self.percentile(50), + f"p99_{self._name}_{self._unit}": self.percentile(99), + } + + +class CounterMetric: + """A count metric for computing rates over time. Not thread-safe.""" + + def __init__(self, name: str, description: str): + self._name = name + self._description = description + self._data: dict[datetime.datetime, int] = {} + + def data(self) -> Dict[datetime.datetime, int]: + """Returns all stored data points. + + Returns: + A dictionary of data points where the key is the timestamp and the value + is the aggregated counts within the second of the timestamp. + """ + return self._data + + def total_count(self) -> int: + """Returns aggregated counts + + Returns: + The aggregated counts. + """ + return sum(self._data.values()) + + def total_duration_sec(self) -> int: + """Returns the duration between the first and last count increment + + Returns: + The duration (in seconds) between the first and last increment + (inclusive of both ends). + """ + start_time = min(self._data.keys()) + end_time = max(self._data.keys()) + return int((end_time - start_time).total_seconds() + 1) + + def increment( + self, count: int = 1, timestamp: datetime.datetime | None = None + ): + """Increment the counter by count + + Args: + count: The amount to increment + timestamp: The timestamp for the increment. Default to now if none is + provided. + """ + if timestamp is None: + cur_time = _now_floored_to_second() + else: + cur_time = _floor_datetime_to_sec(timestamp) + # Add timestamp with default value 0 if doesn't exist + cur_count = self._data.setdefault(cur_time, 0) + self._data[cur_time] = cur_count + count + return + + def rate(self) -> float: + """Calculates the rate of change between the first and last increments. + + Returns: + The rate of change between the first and last increments. + """ + if len(self._data.keys()) < 2: + raise ValueError( + "At least 2 data points are required to compute the rate" + ) + start_time = min(self._data.keys()) + end_time = max(self._data.keys()) + delta_time_sec = (end_time - start_time).total_seconds() + sorted_counts = [count for timestamp, count in sorted(self._data.items())] + delta_count = sum(sorted_counts[1:]) + return delta_count / delta_time_sec + + def rate_over_window( + self, window_size_sec: int + ) -> List[Tuple[datetime.datetime, float]]: + """Calculate the rates over time." + + Args: + window_size_sec: The size of the window in seconds for computing each + individual rate + + Returns: + A list of rates over time, where each element represents the rate of + change for the specified window size. + """ + if len(self._data.keys()) < 2: + raise ValueError( + f"At least 2 different timestamp values are required to calculate " + f"the rate, but have only {len(self._data.keys())}" + ) + rates: List[Tuple[datetime.datetime, float]] = [] + sorted_data = sorted(self._data.items()) + + start_time, _ = sorted_data[0] + end_time, _ = sorted_data[-1] + cur_start_time = start_time + cur_end_time = cur_start_time + datetime.timedelta(seconds=window_size_sec) + cur_total_count = 0 + for data_point in sorted_data: + timestamp, count = data_point + if timestamp >= cur_end_time: + while timestamp >= cur_end_time: + rates.append((cur_start_time, cur_total_count / window_size_sec)) + cur_start_time = cur_end_time + cur_end_time = cur_start_time + datetime.timedelta( + seconds=window_size_sec + ) + cur_total_count = 0 + cur_total_count += count + if cur_start_time <= end_time: + delta_time_sec = (end_time - cur_start_time).total_seconds() + 1 + rates.append((cur_start_time, cur_total_count / delta_time_sec)) + return rates + + def rate_over_window_to_csv(self, window_size_sec: int) -> str: + """Compute and return the rates over time and return them in csv string + + Args: + window_size_sec: The size of the window in seconds for computing each + individual rate + + Returns: + A CSV string representation of the rates over time, with two rows: + the first row contains timestamps, and the second row contains rate + values. + """ + rates = self.rate_over_window(window_size_sec) + # Generate CSV string with two rows + timestamps = "TimeStamp," + ",".join([str(e[0]) for e in rates]) + values = "Value," + ",".join([f"{e[1]:.2f}" for e in rates]) + csv_output = timestamps + "\n" + values + return csv_output diff --git a/benchmarks/mlperf/README.md b/benchmarks/mlperf/README.md new file mode 100644 index 00000000..7a139b12 --- /dev/null +++ b/benchmarks/mlperf/README.md @@ -0,0 +1,155 @@ + +## Create TPU VM. +Follow these [instructions](https://cloud.google.com/tpu/docs/v5e-inference#tpu-vm) to create TPU v5e-8 VM and ssh into the VM + + +## Clone repo +``` +git clone https://github.com/mlcommons/inference.git +``` + +## Install loadgen +``` +apt-get install python3-dev +apt-get install build-essential -y +cd loadgen/ && pip install . +``` + +## Install eval dependencies +``` +pip install \ +transformers==4.31.0 \ +nltk==3.8.1 \ +evaluate==0.4.0 \ +absl-py==1.4.0 \ +rouge-score==0.1.2 \ +sentencepiece==0.1.99 \ +accelerate==0.21.0 +``` + +## Download data file +``` +cd / +export DATA_DISK_DIR=/loadgen_run_data +mkdir -p ${DATA_DISK_DIR} +cd ${DATA_DISK_DIR} +gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl . +mv open_orca_gpt4_tokenized_llama.calibration_1000.pkl processed-calibration-data.pkl + +gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl . +mv open_orca_gpt4_tokenized_llama.sampled_24576.pkl processed-data.pkl +cd /inference_mlperf4.1 +``` + +## Install Maxtext +``` +cd / +git clone git@github.com:google/maxtext.git +cd maxtext +git checkout offline_inf +cd maxtext/MaxText +``` + +## Checkpoint generation + +Steps to get a quantized llama2-70B checkpoint for v5e-8 + +Note llama2-70B model takes about 140G of memory and will not fit into a v5e-8. It must be downloaded onto a large machine (such as v5p-8) and quantized to a smaller quantized checkpoint to be loaded onto a v5e-8 machine. + +* Obtain a llama2-70b checkpoint and convert it to a maxtext inference checkpoint. Please follow maxtext instructions specified here: https://github.com/google/maxtext/blob/main/getting_started/Run_Llama2.md + +* Convert the checkpoint into a quantized checkpoint + +To create an int8 DRQ checkpoint run the following step: + +1. Define paths to load maxtext checkpoint from and save quantized checkpoint to. + +``` +export LOAD_PARAMS_PATH=gs://${USER}-bkt/llama2-70b-chat/param-only-decode-ckpt-maxtext/checkpoints/0/items + +export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-70b-chat +``` + +2. Run the following maxtext script to generate and save an in8 quantized checkpoint + +``` +export TOKENIZER_PATH=maxtext/assets/tokenizer.llama2 +cd maxtext && \ +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +``` + +Your checkpoint is generated at `$SAVE_QUANT_PARAMS_PATH`. This is used to set `load_parameters_path` param below in `MAXENGINE_ARGS` env variable. + +## HF login +``` +huggingface-cli login +``` + +## Loadgen settings +``` +cd Google/code/llama2-70b/tpu_v5e_8_jetstream_maxtext/scripts/ +export API_URL=0.0.0.0:9000 +export DATA_DISK_DIR=/loadgen_run_data +export DATASET_TYPE=full # for calibration run, DATASET_TYPE=calibration + +export MODEL_NAME=llama70b +export TOTAL_SAMPLE_COUNT=24576 # for calibration run, TOTAL_SAMPLE_COUNT=1000 +export LOG_INTERVAL=1000 +export BATCH_SIZE_EXP=8 +export USER_CONFIG=user.conf +``` + +## Offline Setup +``` +cd / +git clone git@github.com:google/maxtext.git +cd maxtext +git checkout offline_inf +cd maxtext/MaxText + +# For v5e use +export BATCH_AND_PREFILL_LEN=“256,80|512,40|1024,20” + +# For v6e use +export BATCH_AND_PREFILL_LEN=“256,216|512,108|1024,54” +export TOKENIZER_PATH=maxtext/assets/tokenizer.llama2 + +export MAXENGINE_ARGS="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True compute_axis_order=0,1,2,3 ar_cache_axis_order=0,1,2,3" +``` + +## Run offline performance + +``` +cd /maxtext/MaxText +bash ./llama_offline_performance_run.sh +``` + +## Run offline accuracy +``` +cd /maxtext/MaxText +bash ./llama_offline_accuracy_run.sh +``` + +## Run offline audit +``` +cd /maxtext/MaxText +bash ./llama_offline_audit_run.sh +``` + +## Run server performance +``` +cd Google/code/llama2-70b/tpu_v5e_8_jetstream_maxtext/scripts/ +bash ./generate_server_performance_run.sh +``` + +## Run server accuracy +``` +cd Google/code/llama2-70b/tpu_v5e_8_jetstream_maxtext/scripts/ +bash ./generate_server_accuracy_run.sh +``` + +## Run server audit +``` +cd Google/code/llama2-70b/tpu_v5e_8_jetstream_maxtext/scripts/ +bash ./generate_server_audit_run.sh +``` diff --git a/benchmarks/mlperf/backend.py b/benchmarks/mlperf/backend.py new file mode 100644 index 00000000..f6a3859c --- /dev/null +++ b/benchmarks/mlperf/backend.py @@ -0,0 +1,323 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""mlperf loadgen interface for LLama2.""" +import array +import concurrent.futures +import dataclasses +import json +import logging +from operator import itemgetter # pylint: disable=g-importing-member +import time +from typing import List, Optional, Any + +import numpy as np + +import dataset + +import mlperf_loadgen as lg + +import grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc + +from transformers import AutoTokenizer + + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("backend.py") + + +@dataclasses.dataclass +class WarmupSample: + id: int + index: int + + +@dataclasses.dataclass +class StreamResponse: + result: str = "" + + +class ThreadedLMClient: + """Holds a thread pool and a loadgen client for LM inference.""" + + _thread_pool: concurrent.futures.ThreadPoolExecutor + _dataset: dataset.Dataset + _futures = List[concurrent.futures.Future] + + def __init__( + self, + is_stream: bool, + num_threads: int, + api_url: str, + dataset_object: dataset.Dataset, + input_mode: str, + output_mode: str, + tokenizer: Optional[AutoTokenizer] = None, + max_output_len: int = 1024, + log_interval: int = 1000, + ): + log.info(f"Initiating {self.__class__.__name__} ...") + self._is_stream = is_stream + self._input_mode = dataset.validate_sample_mode(input_mode) + self._output_mode = dataset.validate_sample_mode(output_mode) + if self._input_mode == "text" or self._output_mode == "text": + assert tokenizer is not None + self._tokenizer = tokenizer + self._max_output_len = max_output_len + + self._log_interval = log_interval + + self._thread_pool = concurrent.futures.ThreadPoolExecutor(num_threads) + self._api_url = api_url + self._dataset = dataset_object + self._futures = [] + self.pred_outputs = {} + self._resp_cnt = 0 + + log.info("Creating grpc channel with api_url {}".format(api_url)) + options = [("grpc.keepalive_timeout_ms", 10000)] + self._grpc_channel = grpc.insecure_channel(api_url, options=options) + + @property + def tokenizer(self): + return self._tokenizer + + def _log_resp_cnt(self): + self._resp_cnt += 1 + if self._resp_cnt % self._log_interval == 0: + log.info("Completed %d queries", self._resp_cnt) + + def process_single_sample_async(self, query_sample, warmup): + """Executes a single query and marks responses complete asynchronously. + + Args: + query_sample: Single prompt + warmup: Indicates that this is a warmup request. + """ + future = self._thread_pool.submit( + self._process_sample, query_sample, warmup + ) + self._futures.append(future) + + def flush(self): + concurrent.futures.wait(self._futures) + self._futures = [] + + def _grpc_request(self, request, sample, warmup): + """Send grpc synchronous request since the current grpc server is sync.""" + stub = jetstream_pb2_grpc.OrchestratorStub(self._grpc_channel) + token_list = [] + ttft = 0 + start_time = time.perf_counter() + response = stub.Decode(request) + for resp in response: + if not warmup and self._is_stream and ttft == 0: + # TTFT for online mode + ttft = time.perf_counter() - start_time + log.info("TTFT {}ms".format(ttft * 1000)) + response_token_ids = resp.stream_content.samples[0].token_ids + assert len(response_token_ids) == 1 + response_token_ids = np.array(response_token_ids, dtype=np.int64) + response_array = array.array("B", response_token_ids.tobytes()) + response_info = response_array.buffer_info() + first_token_response = lg.QuerySampleResponse( + sample.id, response_info[0], response_info[1] + ) + lg.FirstTokenComplete([first_token_response]) + log.info("mark first token complete") + token_list.extend(resp.stream_content.samples[0].token_ids) + return token_list + + def _process_sample(self, sample, warmup): + """Processes a single sample.""" + sample_data = self._dataset.inputs[sample.index] + if self._input_mode == "text": + token_ids = self._tokenizer.encode(sample_data) + else: + assert self._input_mode == "tokenized" + token_ids = [int(token_id_str) for token_id_str in sample_data.split(",")] + + request = jetstream_pb2.DecodeRequest( + token_content=jetstream_pb2.DecodeRequest.TokenContent( + token_ids=token_ids + ), + max_tokens=self._max_output_len, + ) + generated_token_list = self._grpc_request(request, sample, warmup) + if not warmup: + response_token_ids = generated_token_list + n_tokens = len(response_token_ids) + response_token_ids = np.array(response_token_ids, dtype=np.int64) + response_array = array.array("B", response_token_ids.tobytes()) + response_info = response_array.buffer_info() + response_data = response_info[0] + response_size = response_info[1] * response_array.itemsize + query_sample_response = lg.QuerySampleResponse( + sample.id, response_data, response_size, n_tokens + ) + lg.QuerySamplesComplete([query_sample_response]) + log.info("mark query complete") + + pred_output = self._tokenizer.decode(response_token_ids) + self.pred_outputs[sample.index] = pred_output + self._log_resp_cnt() + + +class SUT: + """SUT.""" + + def __init__( + self, + scenario, + api_url, + is_stream, + input_mode, + output_mode, + max_output_len, + dataset_path, + total_sample_count, + tokenizer_path=None, + perf_count_override=None, + num_client_threads=200, + log_interval=1000, + batch_size_exp=5, + pred_outputs_log_path=None, + dataset_rename_cols="", + ): + log.info(f"Starting {scenario} SUT with {api_url}.") + self._is_stream = is_stream + self._input_mode = dataset.validate_sample_mode(input_mode) + self._output_mode = dataset.validate_sample_mode(output_mode) + assert tokenizer_path is not None + self._tokenizer = self.load_tokenizer(tokenizer_path) + self._max_output_len = max_output_len + self._api_url = api_url + self._dataset_path = dataset_path + self._total_sample_count = total_sample_count + self._perf_count_override = perf_count_override + self._num_client_threads = num_client_threads + self._log_interval = log_interval + self._batch_size_exp = batch_size_exp + self._pred_outputs_log_path = pred_outputs_log_path + + log.info("Loading Dataset ... ") + self.dataset = dataset.Dataset( + dataset_path=self._dataset_path, + input_mode=self._input_mode, + total_sample_count=self._total_sample_count, + perf_count_override=self._perf_count_override, + dataset_rename_cols=dataset_rename_cols, + ) + + client_cls = ThreadedLMClient + self._client = client_cls( + is_stream=self._is_stream, + num_threads=self._num_client_threads, + api_url=self._api_url, + dataset_object=self.dataset, + input_mode=self._input_mode, + output_mode=self._output_mode, + tokenizer=self._tokenizer, + max_output_len=self._max_output_len, + log_interval=self._log_interval, + ) + + self.qsl = lg.ConstructQSL( + self.dataset.total_sample_count, + self.dataset.perf_count, + self.dataset.LoadSamplesToRam, + self.dataset.UnloadSamplesFromRam, + ) + + # We need to add some warmup to improve throughput estimation + log.info("Starting warmup....") + # Warm up with exponentially increasing batch sizes up to 32. + for batch_size_exp in range(self._batch_size_exp): + batch_size = 2**batch_size_exp + for warmup_id, warmup_idx in enumerate(range(batch_size)): + warmup_sample = WarmupSample(id=warmup_id, index=warmup_idx) + self._client.process_single_sample_async(warmup_sample, True) + self._client.flush() + + log.info("Warmup done....") + time.sleep(30) + self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries) + + def load_tokenizer( + self, tokenizer_path: Optional[str] = None + ) -> Optional[AutoTokenizer]: + """Returns tokenizer""" + if tokenizer_path is not None: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + model_max_length=1024, + padding_side="left", + use_fast=True, + ) + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def _sort_issue_queries(self, query_samples): + """Issue queries.""" + query_samples_with_length = [] + for query_sample in query_samples: + query_sample_token_length = self.dataset.inputs_with_token_lengths[ + query_sample.index + ][1] + query_samples_with_length.append( + (query_sample_token_length, query_sample) + ) + sorted_query_samples_with_length = sorted( + query_samples_with_length, key=itemgetter(0) + ) + sorted_query_samples = [x[1] for x in sorted_query_samples_with_length] + return sorted_query_samples + + def issue_queries(self, query_samples): + """Issue queries.""" + num_query_samples = len(query_samples) + if num_query_samples > 1: + log.info(f"Issuing {num_query_samples} queries. ") + query_samples = self._sort_issue_queries(query_samples) + for query_sample in query_samples: + self._client.process_single_sample_async(query_sample, False) + + def flush_queries(self): + """Flush queries.""" + log.info("Loadgen has completed issuing queries... ") + self._client.flush() + + if self._pred_outputs_log_path is not None: + + pred_outputs = [] + for idx, x in self._client.pred_outputs.items(): + pred_output = { + "qsl_idx": idx, + "intput": self._client._dataset.inputs[idx], + "data": x, + } + pred_outputs.append(pred_output) + log.info(f"Generated {len(pred_outputs)} prediction outputs") + + if pred_outputs: + self.accuracy_log = open(self._pred_outputs_log_path, "w") + self.accuracy_log.write(json.dumps(pred_outputs)) + self.accuracy_log.flush() + self.accuracy_log.close() + log.info("Dumpped prediction outputs to accuracy log... ") + + def __del__(self): + print("Finished destroying SUT.") diff --git a/benchmarks/mlperf/dataset.py b/benchmarks/mlperf/dataset.py new file mode 100644 index 00000000..318ab9a1 --- /dev/null +++ b/benchmarks/mlperf/dataset.py @@ -0,0 +1,122 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os + +import pandas as pd + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("dataset.py") + + +class Dataset: + + def __init__( + self, + dataset_path: str, + input_mode: str, + total_sample_count: int = 24576, + perf_count_override: int = 0, + dataset_rename_cols: str = "", + ): + if not os.path.isfile(dataset_path): + log.warn( + "Processed pickle file {} not found. Please check that the path is correct".format( + dataset_path + ) + ) + self.dataset_path = dataset_path + + self._input_mode = validate_sample_mode(input_mode) + self.dataset_rename_cols = dataset_rename_cols + self.load_processed_dataset() + + self.total_sample_count = min(len(self.input_ids_strs), total_sample_count) + self.perf_count = perf_count_override or self.total_sample_count + + @property + def input_ids_strs(self): + return self._input_ids_strs + + @property + def input_texts(self): + return self._input_texts + + @property + def input_token_lengths(self): + return self._input_token_lengths + + @property + def inputs(self): + return self._inputs + + @property + def inputs_with_token_lengths(self): + return self._inputs_with_token_lengths + + def load_processed_dataset(self): + processed_data = pd.read_pickle(self.dataset_path) + if self.dataset_rename_cols: + rename_dict = json.loads(self.dataset_rename_cols) + processed_data.rename(columns=rename_dict, inplace=True) + log.info(f"Renaming columns of dataset with mapping: {rename_dict}") + + self._input_ids_strs = [] + for input_ids in processed_data["tok_input"]: + input_ids_str = ",".join([str(input_id) for input_id in input_ids]) + self._input_ids_strs.append(input_ids_str) + + self._input_texts = [] + for input_text in processed_data["input"]: + self._input_texts.append(input_text) + + self._input_token_lengths = [] + for token_length in processed_data["tok_input_length"]: + self._input_token_lengths.append(token_length) + + log.info(f"input_mode is {self._input_mode}") + self._inputs = ( + self._input_ids_strs + if self._input_mode == "tokenized" + else self._input_texts + ) + log.info(f"example sample input is {self._inputs[0]}") + self._inputs_with_token_lengths = [ + (input_ids_str_or_input_text, token_length) + for input_ids_str_or_input_text, token_length in zip( + self._inputs, self._input_token_lengths + ) + ] + + def LoadSamplesToRam(self, sample_list): + pass + + def UnloadSamplesFromRam(self, sample_list): + pass + + def __del__(self): + pass + + +SAMPLE_MODE_CHOICES = ["tokenized", "text"] + + +def validate_sample_mode(sample_mode: str) -> str: + if sample_mode not in SAMPLE_MODE_CHOICES: + raise ValueError( + "The sample_mode should be set to either `tokenized` or `text`." + ) + return sample_mode diff --git a/benchmarks/mlperf/evaluate-accuracy.py b/benchmarks/mlperf/evaluate-accuracy.py new file mode 100644 index 00000000..2fe79030 --- /dev/null +++ b/benchmarks/mlperf/evaluate-accuracy.py @@ -0,0 +1,129 @@ +import argparse +from transformers import AutoTokenizer +import nltk +import evaluate +import numpy as np +import json + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint-path", + required=True, + help="Path to Llama2-70b-hf-chat checkpoint", + ) + parser.add_argument( + "--mlperf-accuracy-file", + required=True, + help="path to mlperf_log_accuracy.json", + ) + parser.add_argument( + "--dataset-file", + required=True, + help="path to processed openorca validation set", + ) + parser.add_argument("--verbose", action="store_true", help="verbose messages") + parser.add_argument( + "--dtype", + default="int64", + help="dtype of the accuracy log", + choices=["int32", "int64", "float"], + ) + args = parser.parse_args() + return args + + +def get_groundtruth(processed_dataset_file): + import pandas as pd + + data = pd.read_pickle(processed_dataset_file) + ground_truths = data["output"] + return ground_truths + + +def postprocess_text(preds, targets): + preds = [pred.strip() for pred in preds] + targets = [target.strip() for target in targets] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] + + return preds, targets + + +def main(): + + args = get_args() + dataset_path = args.dataset_file + checkpoint_path = args.checkpoint_path + metric = evaluate.load("rouge") + nltk.download("punkt") + + tokenizer = AutoTokenizer.from_pretrained( + checkpoint_path, + model_max_length=2048, + padding_side="left", + use_fast=False, + ) + + targets = get_groundtruth(args.dataset_file) + + target_required = [] + preds_token_ids = [] + + eval_dtype = np.int64 + if args.dtype == "int32": + eval_dtype = np.int32 + elif args.dtype == "float": + eval_dtype = np.float32 + + with open(args.mlperf_accuracy_file, "r") as f: + results = json.load(f) + + seen = set() + gen_tok_len = 0 + for pred in results: + qsl_idx = pred["qsl_idx"] + if qsl_idx in seen: + continue + + seen.add(qsl_idx) + target = targets[qsl_idx] + target_required.append(target) + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + + gen_tok_len += len(pred) + preds_token_ids.append(pred) + + preds_decoded_text = tokenizer.batch_decode( + preds_token_ids, skip_special_tokens=True + ) + + preds, targets = postprocess_text(preds_decoded_text, target_required) + + result = metric.compute( + predictions=preds, + references=targets, + use_stemmer=True, + use_aggregator=False, + ) + result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} + prediction_lens = [len(pred) for pred in preds] + gen_num = len(preds) + + result = { + **result, + "gen_len": np.sum(prediction_lens), + "gen_num": gen_num, + "gen_tok_len": gen_tok_len, + "tokens_per_sample": round(gen_tok_len / gen_num, 1), + } + + print("\nResults\n") + print(result) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/mlperf/main.py b/benchmarks/mlperf/main.py new file mode 100644 index 00000000..4c3fd7c4 --- /dev/null +++ b/benchmarks/mlperf/main.py @@ -0,0 +1,223 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import gc +import logging +import os +import sys + +import backend + +import mlperf_loadgen as lg + +_MLPERF_ID = "mixtral-8x7b" + +sys.path.insert(0, os.getcwd()) + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("main.py") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--scenario", + type=str, + choices=["Offline", "Server"], + default="Offline", + help="Scenario", + ) + parser.add_argument( + "--api-url", type=str, default=None, help="SAX published model path." + ) + parser.add_argument("--dataset-path", type=str, default=None, help="") + parser.add_argument("--tokenizer-path", type=str, default=None, help="") + parser.add_argument( + "--accuracy", action="store_true", help="Run accuracy mode" + ) + parser.add_argument("--is-stream", action="store_true", help="") + parser.add_argument( + "--input-mode", + type=str, + choices=["text", "tokenized"], + default="tokenized", + ) + parser.add_argument( + "--output-mode", + type=str, + choices=["text", "tokenized"], + default="tokenized", + ) + parser.add_argument( + "--max-output-len", type=int, default=1024, help="Maximum output len" + ) + parser.add_argument( + "--audit-conf", + type=str, + default="audit.conf", + help="audit config for LoadGen settings during compliance runs", + ) + parser.add_argument( + "--mlperf-conf", + type=str, + default="mlperf.conf", + help="mlperf rules config", + ) + parser.add_argument( + "--user-conf", + type=str, + default="user.conf", + help="user config for user LoadGen settings such as target QPS", + ) + parser.add_argument( + "--total-sample-count", + type=int, + default=24576, + help="Number of samples to use in benchmark.", + ) + parser.add_argument( + "--perf-count-override", + type=int, + default=None, + help="Overwrite number of samples to use in benchmark.", + ) + parser.add_argument( + "--output-log-dir", + type=str, + default="output-logs", + help="Where logs are saved.", + ) + parser.add_argument( + "--enable-log-trace", + action="store_true", + help="Enable log tracing. This file can become quite large", + ) + parser.add_argument( + "--num-client-threads", + type=int, + default=200, + help="Number of client threads to use", + ) + parser.add_argument("--batch-size-exp", type=int, default=6, help="") + parser.add_argument("--log-pred-outputs", action="store_true", help="") + parser.add_argument( + "--log-interval", + type=int, + default=1000, + help="Logging interval in seconds", + ) + parser.add_argument( + "--user-conf-override-path", + type=str, + default="", + help="When given overrides the default user.conf path", + ) + parser.add_argument( + "--rename-dataset-cols", + type=str, + default="", + help=( + "Rename some of the dataset columns to whats expected by code. For example, " + "mixtral dataset uses ref_token_length instead of ref_token_len. Format is a string dict " + 'eg. {"tok_input_len": "tok_input_length"}' + ), + ) + + args = parser.parse_args() + return args + + +scenario_map = { + "offline": lg.TestScenario.Offline, + "server": lg.TestScenario.Server, +} + + +def main(): + args = get_args() + + settings = lg.TestSettings() + settings.scenario = scenario_map[args.scenario.lower()] + if args.user_conf_override_path: + user_conf = args.user_conf_override_path + else: + user_conf = args.user_conf + + settings.FromConfig(args.mlperf_conf, _MLPERF_ID, args.scenario) + settings.FromConfig(user_conf, _MLPERF_ID, args.scenario) + log.info("Mlperf config: %s", args.mlperf_conf) + log.info("User config: %s", user_conf) + + if args.accuracy: + settings.mode = lg.TestMode.AccuracyOnly + log.warning( + "Accuracy run will generate the accuracy logs, but the evaluation of the log is not completed yet" + ) + else: + settings.mode = lg.TestMode.PerformanceOnly + settings.print_timestamps = True + + settings.use_token_latencies = True + + os.makedirs(args.output_log_dir, exist_ok=True) + log_output_settings = lg.LogOutputSettings() + log_output_settings.outdir = args.output_log_dir + log_output_settings.copy_summary_to_stdout = True + log_settings = lg.LogSettings() + log_settings.log_output = log_output_settings + log_settings.enable_trace = args.enable_log_trace + + sut = backend.SUT( + scenario=args.scenario.lower(), + api_url=args.api_url, + is_stream=args.is_stream, + input_mode=args.input_mode, + output_mode=args.output_mode, + max_output_len=args.max_output_len, + dataset_path=args.dataset_path, + total_sample_count=args.total_sample_count, + tokenizer_path=args.tokenizer_path, + perf_count_override=args.perf_count_override, + num_client_threads=args.num_client_threads, + log_interval=args.log_interval, + batch_size_exp=args.batch_size_exp, + pred_outputs_log_path=os.path.join( + args.output_log_dir, "pred_outputs_logger.json" + ) + if args.log_pred_outputs + else None, + dataset_rename_cols=args.rename_dataset_cols, + ) + + lgSUT = lg.ConstructSUT(sut.issue_queries, sut.flush_queries) + log.info("Starting Benchmark run") + lg.StartTestWithLogSettings( + lgSUT, sut.qsl, settings, log_settings, args.audit_conf + ) + + log.info("Run Completed!") + + log.info("Destroying SUT...") + lg.DestroySUT(lgSUT) + + log.info("Destroying QSL...") + lg.DestroyQSL(sut.qsl) + + +if __name__ == "__main__": + # Disable garbage collection to avoid stalls when running tests. + gc.disable() + main() diff --git a/benchmarks/mlperf/mlperf.conf b/benchmarks/mlperf/mlperf.conf new file mode 100644 index 00000000..1d036f4b --- /dev/null +++ b/benchmarks/mlperf/mlperf.conf @@ -0,0 +1,111 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds + +# Set performance_sample_count for each model. +# User can optionally set this to higher values in user.conf. +resnet50.*.performance_sample_count_override = 1024 +ssd-mobilenet.*.performance_sample_count_override = 256 +retinanet.*.performance_sample_count_override = 64 +bert.*.performance_sample_count_override = 10833 +dlrm.*.performance_sample_count_override = 204800 +dlrm-v2.*.performance_sample_count_override = 204800 +rnnt.*.performance_sample_count_override = 2513 +gptj.*.performance_sample_count_override = 13368 +llama2-70b.*.performance_sample_count_override = 24576 +llama3_1-405b.*.performance_sample_count_override = 8313 +stable-diffusion-xl.*.performance_sample_count_override = 5000 +rgat.*.performance_sample_count_override = 788379 +# set to 0 to let entire sample set to be performance sample +3d-unet.*.performance_sample_count_override = 0 + +# Set seeds. The seeds will be distributed two weeks before the submission. +*.*.qsl_rng_seed = 3066443479025735752 +*.*.sample_index_rng_seed = 10688027786191513374 +*.*.schedule_rng_seed = 14962580496156340209 +# Set seeds for TEST_05. The seeds will be distributed two weeks before the submission. +*.*.test05_qsl_rng_seed = 16799458546791641818 +*.*.test05_sample_index_rng_seed = 5453809927556429288 +*.*.test05_schedule_rng_seed = 5435552105434836064 + + +*.SingleStream.target_latency_percentile = 90 +*.SingleStream.min_duration = 600000 + +*.MultiStream.target_latency_percentile = 99 +*.MultiStream.samples_per_query = 8 +*.MultiStream.min_duration = 600000 +*.MultiStream.min_query_count = 662 +retinanet.MultiStream.target_latency = 528 + +# 3D-UNet uses equal issue mode because it has non-uniform inputs +3d-unet.*.sample_concatenate_permutation = 1 + +# R-GAT uses equal issue mode because it may have non-uniform inputs +rgat.*.sample_concatenate_permutation = 1 + +# LLM benchmarks have non-uniform inputs and outputs, and use equal issue mode for all latency scenario +gptj.*.sample_concatenate_permutation = 1 +llama2-70b.*.sample_concatenate_permutation = 1 +mixtral-8x7b.*.sample_concatenate_permutation = 1 +llama3_1-405b.*.sample_concatenate_permutation = 1 + +*.Server.target_latency = 10 +*.Server.target_latency_percentile = 99 +*.Server.target_duration = 0 +*.Server.min_duration = 600000 +resnet50.Server.target_latency = 15 +retinanet.Server.target_latency = 100 +bert.Server.target_latency = 130 +dlrm.Server.target_latency = 60 +dlrm-v2.Server.target_latency = 60 +rnnt.Server.target_latency = 1000 +gptj.Server.target_latency = 20000 +stable-diffusion-xl.Server.target_latency = 20000 +# Benchmarks that measure token latencies +llama2-70b.*.use_token_latencies = 1 +mixtral-8x7b.*.use_token_latencies = 1 +llama3_1-405b.*.use_token_latencies = 1 +# gptj benchmark infers token latencies +gptj.*.infer_token_latencies = 1 +gptj.*.token_latency_scaling_factor = 69 +# Only ttft and tpot are tracked for the llama2-70b, mixtral-8x7B & llama3_1-405b benchmark therefore target_latency = 0 +llama2-70b.Server.target_latency = 0 +llama2-70b.Server.ttft_latency = 2000 +llama2-70b.Server.tpot_latency = 200 + +mixtral-8x7b.Server.target_latency = 0 +mixtral-8x7b.Server.ttft_latency = 2000 +mixtral-8x7b.Server.tpot_latency = 200 + +llama3_1-405b.Server.target_latency = 0 +llama3_1-405b.Server.ttft_latency = 6000 +llama3_1-405b.Server.tpot_latency = 175 + +*.Offline.target_latency_percentile = 90 +*.Offline.min_duration = 600000 + +# In Offline scenario, we always have one query. But LoadGen maps this to +# min_sample_count internally in Offline scenario. If the dataset size is larger +# than 24576 we limit the min_query_count to 24576 and otherwise we use +# the dataset size as the limit + +resnet50.Offline.min_query_count = 24576 +retinanet.Offline.min_query_count = 24576 +dlrm-v2.Offline.min_query_count = 24576 +bert.Offline.min_query_count = 10833 +gptj.Offline.min_query_count = 13368 +rnnt.Offline.min_query_count = 2513 +3d-unet.Offline.min_query_count = 43 +stable-diffusion-xl.Offline.min_query_count = 5000 +llama2-70b.Offline.min_query_count = 24576 +llama3_1-405b.Offline.min_query_count = 8313 +mixtral-8x7b.Offline.min_query_count = 15000 +rgat.Offline.min_query_count = 788379 + +# These fields should be defined and overridden by user.conf. +*.SingleStream.target_latency = 10 +*.MultiStream.target_latency = 80 +*.Server.target_qps = 1.0 +*.Offline.target_qps = 1.0 diff --git a/benchmarks/mlperf/scripts/config_utils.sh b/benchmarks/mlperf/scripts/config_utils.sh new file mode 100755 index 00000000..866cd585 --- /dev/null +++ b/benchmarks/mlperf/scripts/config_utils.sh @@ -0,0 +1,16 @@ +export base_output_dir=gs://${USER}-tpu/mlperf-4.1 +export experiment_time=$(date +%Y-%m-%d-%H-%M) + +export tpu=v5e-16 +export model_name=llama2-70b +export attention=dot_product +export reshape_q=${reshape_q:=False} +export compute_axis_order=${compute_axis_order:=0,2,1,3} +export prefill_cache_axis_order=${prefill_cache_axis_order:=0,2,1,3} +export ar_cache_axis_order=${ar_cache_axis_order:=0,2,1,3} + +export config_file_path=MaxText/configs/v5e/inference/llama2_70b_v5e-16.yml +export ici_fsdp_parallelism=1 +export ici_autoregressive_parallelism=${ici_autoregressive_parallelism:=2} +export ici_tensor_parallelism=${ici_tensor_parallelism:=8} +export allow_split_physical_axes=True diff --git a/benchmarks/mlperf/scripts/config_w-b16-kv-b16.sh b/benchmarks/mlperf/scripts/config_w-b16-kv-b16.sh new file mode 100755 index 00000000..405f7f01 --- /dev/null +++ b/benchmarks/mlperf/scripts/config_w-b16-kv-b16.sh @@ -0,0 +1,12 @@ + +source config_utils.sh + +export checkpoint_path=gs://runner-maxtext-logs/2024-05-07-23-34/unscanned_chkpt/checkpoints/0/items + +export quant_mode=w-b16-kv-b16 +export checkpoint_is_quantized=False +export quantization= +export quantize_kvcache=False +export kv_quant_axis= +export kv_quant_dtype= +export per_device_batch_size=${per_device_batch_size:=6} diff --git a/benchmarks/mlperf/scripts/config_w-i8-kv-b16.sh b/benchmarks/mlperf/scripts/config_w-i8-kv-b16.sh new file mode 100755 index 00000000..baaa5e04 --- /dev/null +++ b/benchmarks/mlperf/scripts/config_w-i8-kv-b16.sh @@ -0,0 +1,12 @@ + +source config_utils.sh + +export checkpoint_path=gs://morgandu-tpu/checkpoints/quantized/aqt/llama2-70b-chat + +export quant_mode=w-i8-kv-b16 +export checkpoint_is_quantized=True +export quantization=int8 +export quantize_kvcache=False +export kv_quant_axis= +export kv_quant_dtype= +export per_device_batch_size=${per_device_batch_size:=14} diff --git a/benchmarks/mlperf/scripts/config_w-i8-kv-i8.sh b/benchmarks/mlperf/scripts/config_w-i8-kv-i8.sh new file mode 100755 index 00000000..883901e0 --- /dev/null +++ b/benchmarks/mlperf/scripts/config_w-i8-kv-i8.sh @@ -0,0 +1,12 @@ + +source config_utils.sh + +export checkpoint_path=gs://morgandu-tpu/checkpoints/quantized/aqt/llama2-70b-chat + +export quant_mode=w-i8-kv-i8 +export checkpoint_is_quantized=True +export quantization=int8 +export quantize_kvcache=True +export kv_quant_axis=heads_and_dkv +export kv_quant_dtype=int8 +export per_device_batch_size=${per_device_batch_size:=28} diff --git a/benchmarks/mlperf/scripts/config_w-i8w-kv-b16.sh b/benchmarks/mlperf/scripts/config_w-i8w-kv-b16.sh new file mode 100755 index 00000000..daf04912 --- /dev/null +++ b/benchmarks/mlperf/scripts/config_w-i8w-kv-b16.sh @@ -0,0 +1,12 @@ + +source config_utils.sh + +export checkpoint_path=gs://msingh-bkt/checkpoints/quant_llama2-70b-chat/mlperf_070924/int8w_ + +export quant_mode=w-i8w-kv-b16 +export checkpoint_is_quantized=True +export quantization=int8w +export quantize_kvcache=False +export kv_quant_axis= +export kv_quant_dtype= +export per_device_batch_size=${per_device_batch_size:=14} diff --git a/benchmarks/mlperf/scripts/config_w-i8w-kv-i8.sh b/benchmarks/mlperf/scripts/config_w-i8w-kv-i8.sh new file mode 100755 index 00000000..1f8a830d --- /dev/null +++ b/benchmarks/mlperf/scripts/config_w-i8w-kv-i8.sh @@ -0,0 +1,12 @@ + +source config_utils.sh + +export checkpoint_path=gs://msingh-bkt/checkpoints/quant_llama2-70b-chat/mlperf_070924/int8w_ + +export quant_mode=w-i8w-kv-i8 +export checkpoint_is_quantized=True +export quantization=int8w +export quantize_kvcache=True +export kv_quant_axis=heads_and_dkv +export kv_quant_dtype=int8 +export per_device_batch_size=${per_device_batch_size:=28} diff --git a/benchmarks/mlperf/scripts/download_loadgen_data.sh b/benchmarks/mlperf/scripts/download_loadgen_data.sh new file mode 100755 index 00000000..46c7eb23 --- /dev/null +++ b/benchmarks/mlperf/scripts/download_loadgen_data.sh @@ -0,0 +1,9 @@ +export DATA_DISK_DIR=/loadgen_run_data + +mkdir -p ${DATA_DISK_DIR} +cd ${DATA_DISK_DIR} +gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl . +mv open_orca_gpt4_tokenized_llama.calibration_1000.pkl processed-calibration-data.pkl + +gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl . +mv open_orca_gpt4_tokenized_llama.sampled_24576.pkl processed-data.pkl diff --git a/benchmarks/mlperf/scripts/generate_server_accuracy_run.sh b/benchmarks/mlperf/scripts/generate_server_accuracy_run.sh new file mode 100755 index 00000000..23d495b3 --- /dev/null +++ b/benchmarks/mlperf/scripts/generate_server_accuracy_run.sh @@ -0,0 +1,64 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source run_utils.sh + +DATASET_NAME=$(get_dataset_name ${DATASET_TYPE}) +export DATASET_PATH=${DATA_DISK_DIR}/${DATASET_NAME}.pkl +export API_URL=${API_URL} +export LOADGEN_RUN_TYPE=server-accuracy +export OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} +export OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} +export OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json + +echo "LOADGEN_RUN_TYPE: ${LOADGEN_RUN_TYPE}" +echo "LOADGEN_RUN_TIMESTAMP: ${LOADGEN_RUN_TIMESTAMP}" +echo "DATASET_PATH: ${DATASET_PATH}" +echo "TOTAL_SAMPLE_COUNT: ${TOTAL_SAMPLE_COUNT}" +echo "API_URL: ${API_URL}" +echo "BATCH_SIZE_EXP: ${BATCH_SIZE_EXP}" +echo "OUTPUT_LOG_DIR: ${OUTPUT_LOG_DIR}" +echo "OUTPUT_ACCURACY_JSON_PATH: ${OUTPUT_ACCURACY_JSON_PATH}" +echo "USER_CONFIG: ${USER_CONFIG}" + +mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} + +# Accuracy Run +cd ../ && python3 main.py \ + --api-url ${API_URL} \ + --is-stream \ + --accuracy \ + --log-pred-outputs \ + --scenario Server \ + --input-mode tokenized \ + --output-mode tokenized \ + --max-output-len 1024 \ + --mlperf-conf ../mlperf.conf \ + --user-conf ${USER_CONFIG} \ + --audit-conf no-audit \ + --total-sample-count ${TOTAL_SAMPLE_COUNT} \ + --batch-size-exp ${BATCH_SIZE_EXP} \ + --dataset-path ${DATASET_PATH} \ + --tokenizer-path ${TOKENIZER_PATH} \ + --log-interval ${LOG_INTERVAL} \ + --num-client-threads ${NUM_CLIENT_THREADS} \ + --output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log + +# Eval Run +if [ -e ${OUTPUT_ACCURACY_JSON_PATH} ]; then + python3 evaluate-accuracy.py \ + --checkpoint-path meta-llama/Llama-2-70b-chat-hf \ + --mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \ + --dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_server_accuracy_log.log +fi diff --git a/benchmarks/mlperf/scripts/generate_server_audit_run.sh b/benchmarks/mlperf/scripts/generate_server_audit_run.sh new file mode 100755 index 00000000..1a01dff5 --- /dev/null +++ b/benchmarks/mlperf/scripts/generate_server_audit_run.sh @@ -0,0 +1,53 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source run_utils.sh + +DATASET_NAME=$(get_dataset_name ${DATASET_TYPE}) +export DATASET_PATH=${DATA_DISK_DIR}/${DATASET_NAME}.pkl +export API_URL=${API_URL} +export LOADGEN_RUN_TYPE=server-audit +export OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} +export OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} + +echo "LOADGEN_RUN_TYPE: ${LOADGEN_RUN_TYPE}" +echo "LOADGEN_RUN_TIMESTAMP: ${LOADGEN_RUN_TIMESTAMP}" +echo "DATASET_PATH: ${DATASET_PATH}" +echo "TOTAL_SAMPLE_COUNT: ${TOTAL_SAMPLE_COUNT}" +echo "API_URL: ${API_URL}" +echo "BATCH_SIZE_EXP: ${BATCH_SIZE_EXP}" +echo "OUTPUT_LOG_DIR: ${OUTPUT_LOG_DIR}" +echo "USER_CONFIG: ${USER_CONFIG}" + +mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} + +# Audit Run +cd ../ && python3 main.py \ + --api-url ${API_URL} \ + --is-stream \ + --log-pred-outputs \ + --scenario Server \ + --input-mode tokenized \ + --output-mode tokenized \ + --max-output-len 1024 \ + --mlperf-conf ../mlperf.conf \ + --user-conf ${USER_CONFIG} \ + --audit-conf ../../../compliance/nvidia/TEST06/audit.config \ + --total-sample-count ${TOTAL_SAMPLE_COUNT} \ + --batch-size-exp ${BATCH_SIZE_EXP} \ + --dataset-path ${DATASET_PATH} \ + --tokenizer-path ${TOKENIZER_PATH} \ + --log-interval ${LOG_INTERVAL} \ + --num-client-threads ${NUM_CLIENT_THREADS} \ + --output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_audit_log.log diff --git a/benchmarks/mlperf/scripts/generate_server_performance_run.sh b/benchmarks/mlperf/scripts/generate_server_performance_run.sh new file mode 100755 index 00000000..e8720180 --- /dev/null +++ b/benchmarks/mlperf/scripts/generate_server_performance_run.sh @@ -0,0 +1,55 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source run_utils.sh + +DATASET_NAME=$(get_dataset_name ${DATASET_TYPE}) +export DATASET_PATH=${DATA_DISK_DIR}/${DATASET_NAME}.pkl +export API_URL=${API_URL} +export LOADGEN_RUN_TYPE=server-performance +export OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} +export OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} + +echo "LOADGEN_RUN_TYPE: ${LOADGEN_RUN_TYPE}" +echo "LOADGEN_RUN_TIMESTAMP: ${LOADGEN_RUN_TIMESTAMP}" +echo "DATASET_PATH: ${DATASET_PATH}" +echo "TOTAL_SAMPLE_COUNT: ${TOTAL_SAMPLE_COUNT}" +echo "API_URL: ${API_URL}" +echo "BATCH_SIZE_EXP: ${BATCH_SIZE_EXP}" +echo "OUTPUT_LOG_DIR: ${OUTPUT_LOG_DIR}" +echo "USER_CONFIG: ${USER_CONFIG}" + +mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} +MIXTRAL_COLS_RENAME="{\"tok_input_len\": \"tok_input_length\", \"tok_ref_output_len\": \"tok_output_length\"}" + +# Perf Run +cd ../ && python3 main.py \ + --api-url ${API_URL} \ + --is-stream \ + --log-pred-outputs \ + --scenario Server \ + --input-mode tokenized \ + --output-mode tokenized \ + --max-output-len 1024 \ + --mlperf-conf ../mlperf.conf \ + --user-conf ${USER_CONFIG} \ + --audit-conf no-audit \ + --total-sample-count ${TOTAL_SAMPLE_COUNT} \ + --batch-size-exp ${BATCH_SIZE_EXP} \ + --dataset-path ${DATASET_PATH} \ + --tokenizer-path ${TOKENIZER_PATH} \ + --log-interval ${LOG_INTERVAL} \ + --num-client-threads ${NUM_CLIENT_THREADS} \ + --rename-dataset-cols "${MIXTRAL_COLS_RENAME}" \ + --output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_performance_log.log diff --git a/benchmarks/mlperf/scripts/init.sh b/benchmarks/mlperf/scripts/init.sh new file mode 100755 index 00000000..4781d690 --- /dev/null +++ b/benchmarks/mlperf/scripts/init.sh @@ -0,0 +1,7 @@ +git clone https://github.com/google/jax.git +cd jax +git reset 44359cb30ab5cdbe97e6b78c2c64fe9f8add29ca --hard +pip install -e . +gsutil cp gs://zhihaoshan-maxtext-profiling/jax_proxy_stream_buffer/jaxlib-0.4.31.dev20240719-cp310-cp310-manylinux2014_x86_64-mlperf_version_3.whl . +mv jaxlib-0.4.31.dev20240719-cp310-cp310-manylinux2014_x86_64-mlperf_version_3.whl jaxlib-0.4.31.dev20240719-cp310-cp310-manylinux2014_x86_64.whl +pip install jaxlib-0.4.31.dev20240719-cp310-cp310-manylinux2014_x86_64.whl diff --git a/benchmarks/mlperf/scripts/init_loadgen.sh b/benchmarks/mlperf/scripts/init_loadgen.sh new file mode 100755 index 00000000..b62f1e87 --- /dev/null +++ b/benchmarks/mlperf/scripts/init_loadgen.sh @@ -0,0 +1,13 @@ +cd /inference_mlperf4.1/language/llama2-70b/tpu/scripts/ +export API_URL=0.0.0.0:9000 + +export DATA_DISK_DIR=/loadgen_run_data +export MODEL_NAME=llama70b +export LOG_INTERVAL=1000 +export BATCH_SIZE_EXP=10 +export USER_CONFIG=user.conf + +export DATASET_TYPE=full +export TOTAL_SAMPLE_COUNT=24576 +export NUM_CLIENT_THREADS=600 + diff --git a/benchmarks/mlperf/scripts/init_xprof.sh b/benchmarks/mlperf/scripts/init_xprof.sh new file mode 100755 index 00000000..347d4636 --- /dev/null +++ b/benchmarks/mlperf/scripts/init_xprof.sh @@ -0,0 +1,2 @@ +pip install keyring keyrings.google-artifactregistry-auth # install keyring +pip install --extra-index-url https://us-central1-python.pkg.dev/cloud-tpu-multipod-dev/multipod-python-repo/simple/ previewutilities \ No newline at end of file diff --git a/benchmarks/mlperf/scripts/launch_microbenchmark.sh b/benchmarks/mlperf/scripts/launch_microbenchmark.sh new file mode 100755 index 00000000..3764ff71 --- /dev/null +++ b/benchmarks/mlperf/scripts/launch_microbenchmark.sh @@ -0,0 +1,46 @@ +echo "config: ${config}" +source ./${config}.sh +export run_name=${model_name}_${tpu}_${attention}_ici_${ici_tensor_parallelism}-${ici_autoregressive_parallelism}_${reshape_q}_${quant_mode}_pbs${per_device_batch_size}_${compute_axis_order//,/}-${prefill_cache_axis_order//,/}-${ar_cache_axis_order//,/} + +export inference_microbenchmark_stages=${inference_microbenchmark_stages:="prefill,generate"} +export inference_microbenchmark_prefill_lengths=${inference_microbenchmark_prefill_lengths:="64,128,256,512,1024"} +echo "inference_microbenchmark_stages: ${inference_microbenchmark_stages}" +echo "inference_microbenchmark_prefill_lengths: ${inference_microbenchmark_prefill_lengths}" + +cd /maxtext +export run_dir=${base_output_dir}/microbenchmark/${run_name}/${experiment_time}/ +echo "run_dir: ${run_dir}" +gsutil cp ${config_file_path} ${run_dir}/ + +python3 MaxText/inference_microbenchmark.py \ + ${config_file_path} \ + model_name=${model_name} \ + tokenizer_path=assets/tokenizer.llama2 \ + load_parameters_path=${checkpoint_path} \ + async_checkpointing=false \ + weight_dtype=bfloat16 \ + attention=dot_product \ + reshape_q=${reshape_q} \ + scan_layers=false \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + base_output_directory=${base_output_dir}/microbenchmark \ + run_name=${run_name}/${experiment_time} \ + save_config_to_gcs=true \ + profiler=xplane \ + enable_single_controller=true \ + ici_tensor_parallelism=${ici_tensor_parallelism} \ + ici_autoregressive_parallelism=${ici_autoregressive_parallelism} \ + allow_split_physical_axes=${allow_split_physical_axes} \ + inference_microbenchmark_prefill_lengths=${inference_microbenchmark_prefill_lengths} \ + inference_microbenchmark_stages=${inference_microbenchmark_stages} \ + inference_microbenchmark_loop_iters=10 \ + per_device_batch_size=${per_device_batch_size} \ + quantization=${quantization} \ + quantize_kvcache=${quantize_kvcache} \ + kv_quant_axis=${kv_quant_axis} \ + kv_quant_dtype=${kv_quant_dtype} \ + checkpoint_is_quantized=${checkpoint_is_quantized} \ + compute_axis_order=${compute_axis_order} \ + prefill_cache_axis_order=${prefill_cache_axis_order} \ + ar_cache_axis_order=${ar_cache_axis_order} 2>&1 | tee results.log && gsutil mv results.log ${run_dir}/ diff --git a/benchmarks/mlperf/scripts/launch_server.sh b/benchmarks/mlperf/scripts/launch_server.sh new file mode 100755 index 00000000..e433b605 --- /dev/null +++ b/benchmarks/mlperf/scripts/launch_server.sh @@ -0,0 +1,33 @@ +echo "config: ${config}" +source ./${config}.sh +export run_name=${model_name}_${tpu}_${attention}_ici_${ici_tensor_parallelism}-${ici_autoregressive_parallelism}_${reshape_q}_${quant_mode}_pbs${per_device_batch_size}_${compute_axis_order//,/}-${prefill_cache_axis_order//,/}-${ar_cache_axis_order//,/} + +cd /maxtext +python3 MaxText/maxengine_server.py \ + ${config_file_path} \ + model_name=${model_name} \ + tokenizer_path=assets/tokenizer.llama2 \ + load_parameters_path=${checkpoint_path} \ + async_checkpointing=false \ + weight_dtype=bfloat16 \ + attention=dot_product \ + reshape_q=${reshape_q} \ + scan_layers=false \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + base_output_directory=${base_output_dir}/server \ + run_name=${run_name}/${experiment_time} \ + save_config_to_gcs=true \ + enable_single_controller=true \ + ici_tensor_parallelism=${ici_tensor_parallelism} \ + ici_autoregressive_parallelism=${ici_autoregressive_parallelism} \ + allow_split_physical_axes=${allow_split_physical_axes} \ + per_device_batch_size=${per_device_batch_size} \ + quantization=${quantization} \ + quantize_kvcache=${quantize_kvcache} \ + kv_quant_axis=${kv_quant_axis} \ + kv_quant_dtype=${kv_quant_dtype} \ + checkpoint_is_quantized=${checkpoint_is_quantized} \ + compute_axis_order=${compute_axis_order} \ + prefill_cache_axis_order=${prefill_cache_axis_order} \ + ar_cache_axis_order=${ar_cache_axis_order} \ No newline at end of file diff --git a/benchmarks/mlperf/scripts/run_utils.sh b/benchmarks/mlperf/scripts/run_utils.sh new file mode 100644 index 00000000..8271f10b --- /dev/null +++ b/benchmarks/mlperf/scripts/run_utils.sh @@ -0,0 +1,32 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Tokenizer +# export TOKENIZER_PATH=meta-llama/Llama-2-70b-chat-hf +export DATASET_PREFIX=mixtral +export TOKENIZER_PATH=mistralai/Mixtral-8x7B-Instruct-v0.1 +export NUM_CLIENT_THREADS=${NUM_CLIENT_THREADS:=600} + +# Loadgen +export LOADGEN_RUN_TIMESTAMP=$(TZ=America/Los_Angeles date +%Y%m%d%H%M%S%Z) + +get_dataset_name() { + dataset_type=$1 + if [ ${dataset_type} = "full" ] + then echo "${DATASET_PREFIX}-processed-data" + elif [ ${dataset_type} = "calibration" ] + then echo "${DATASET_PREFIX}-processed-calibration-data" + fi +} diff --git a/benchmarks/mlperf/scripts/tpu_script.sh b/benchmarks/mlperf/scripts/tpu_script.sh new file mode 100644 index 00000000..47a88a7c --- /dev/null +++ b/benchmarks/mlperf/scripts/tpu_script.sh @@ -0,0 +1,352 @@ +#!/bin/bash + +# Multi-Host vlp (TODO: replace these params for your own config) +NAME="jwyang-tpu-sh2" +# NAME="jwyang-v5p8-vm" +# ACCELERATOR_TYPE="v5litepod-4" +ACCELERATOR_TYPE="v5litepod-8" +# ACCELERATOR_TYPE="v5p-8" +RUNTIME_VERSION="v2-alpha-tpuv5-lite" +# PROJECT="tpu-prod-env-automated" +PROJECT="cloud-tpu-inference-test" +# PROJECT="tpu-prod-env-small" +# PROJECT="tpu-prod-env-large-cont" +# ZONE="us-east1-c" +ZONE="us-west1-c" +# ZONE="us-east5-a" + +USER=jwyang + +# (TODO: replace these params to your own config) +NUM_WORKERS=1 +TPU_NAME="t1v-n-63d3a09c" + +create_tpu() { + # A temporary solution to clean up the failed and suspended queued resources. + # Otherwise, there will be a quota error. + existing_qr=$(gcloud alpha compute tpus queued-resources list \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --quiet) + while read -r line; do + name=$(echo $line | awk '{print $1}') + status=$(echo $line | awk '{print $5}') + echo ${name} + echo ${status} + if [[ ${status} == "SUSPENDED" || ${status} == "FAILED" ]]; then + gcloud alpha compute tpus queued-resources delete ${name} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --quiet + fi + done <<< ${existing_qr} + + gcloud alpha compute tpus queued-resources create ${NAME} \ + --description noteardown \ + --node-id ${NAME} \ + --project=${PROJECT} \ + --zone=${ZONE} \ + --accelerator-type=${ACCELERATOR_TYPE} \ + --runtime-version=${RUNTIME_VERSION} \ + --reserved; +} + +list_tpu() { + gcloud compute tpus tpu-vm list --project=${PROJECT} --zone=${ZONE}; +} + +list_queue_resource() { + gcloud alpha compute tpus queued-resources list --project=${PROJECT} --zone=${ZONE}; +} + +delete_tpu() { + gcloud alpha compute tpus tpu-vm delete ${NAME} --project=${PROJECT} --zone=${ZONE}; + gcloud alpha compute tpus queued-resources delete ${NAME} --project=${PROJECT} --zone=${ZONE}; +} + +ssh_to_tpu() { + gcloud compute tpus tpu-vm ssh ${NAME} --zone ${ZONE} --worker ${1} --project ${PROJECT} -- -o ProxyCommand='corp-ssh-helper %h %p' +} + +create_disk() { + for ((i = 0; i < ${NUM_WORKERS}; i++)); do + TPU_WORKER_NAME=${TPU_NAME}-w-${i} + DISK_NAME=${NAME}-w${i}-ssd + + SIZE=35 + if [[ ${i} == 0 ]] + then + SIZE=512 + fi + + gcloud compute disks create ${DISK_NAME} \ + --size ${SIZE} \ + --zone ${ZONE} \ + --type pd-ssd \ + --project=${PROJECT} + + # attach disk to tpu + gcloud alpha compute instances attach-disk ${TPU_WORKER_NAME} \ + --zone=${ZONE} \ + --disk=${DISK_NAME} \ + --mode=rw \ + --project=${PROJECT} + + gcloud compute instances set-disk-auto-delete ${TPU_WORKER_NAME} \ + --zone=${ZONE} \ + --auto-delete \ + --disk=${DISK_NAME} \ + --project=${PROJECT} + + gcloud compute tpus tpu-vm ssh ${NAME} --zone ${ZONE} --worker ${i} --project=${PROJECT} \ + --command="sudo mkfs.ext4 -m 0 -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb && + sudo mkdir -p /mnt/disks/persist && + sudo mount -o discard,defaults /dev/sdb /mnt/disks/persist" \ + -- -o ProxyCommand='corp-ssh-helper %h %p' + done +} + +detach_disks() { + for ((i = 0; i < ${NUM_WORKERS}; i++)); do + TPU_WORKER_NAME=${TPU_NAME}-w-${i} + DISK_NAME=${NAME}-w${i}-ssd + + # attach disk to tpu + gcloud alpha compute instances detach-disk ${TPU_WORKER_NAME} \ + --zone=${ZONE} \ + --disk=${DISK_NAME} \ + --project=${PROJECT} + done +} + +check_disks() { + set -o xtrace + dir_checks="" + for ((i = 0; i < ${NUM_WORKERS}; i++)); do + dir_checks="$dir_checks $( + gcloud compute tpus tpu-vm ssh ${NAME} --zone ${ZONE} --worker ${i} --project=${PROJECT} \ + --command="if [ -d /mnt/disks/persist ]; then echo "exists"; fi" \ + -- -o ProxyCommand='corp-ssh-helper %h %p' + )" + done + num_dir_exists=$(echo "$dir_checks" | wc -w) + echo "Number of workers with disks: $num_dir_exists" + set +o xtrace +} + +copy_relevant_files() { + gcloud compute tpus tpu-vm \ + scp --zone=${ZONE} --project=${PROJECT} --worker=all \ + $PWD/Maxtext/checkpointing.py \ + ${NAME}:~/maxtext/MaxText/checkpointing.py \ + --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" + + gcloud compute tpus tpu-vm \ + scp --zone=${ZONE} --project=${PROJECT} --worker=all \ + $PWD/Maxtext/maxengine.py \ + ${NAME}:~/maxtext/MaxText/maxengine.py \ + --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" + + gcloud compute tpus tpu-vm \ + scp --zone=${ZONE} --project=${PROJECT} --worker=all \ + $PWD/Maxtext/layers/quantizations.py \ + ${NAME}:~/maxtext/MaxText/layers/quantizations.py \ + --scp-flag "-o ProxyCommand=corp-ssh-helper %h %p" + +} + + +# LLaMA2-7B JetStream/Maxtext commands + +# # source .env/bin/activate +# your_run_name=jwyang_bs1_llama7b +# python MaxText/inference_microbenchmark.py \ +# MaxText/configs/base.yml \ +# base_output_directory=gs://jwyang-data/maxtext-llama2-7b/microbenchmark \ +# run_name=${your_run_name} \ +# per_device_batch_size=12 \ +# save_config_to_gcs=true \ +# model_name=llama2-7b \ +# tokenizer_path=assets/tokenizer.llama2 \ +# inference_microbenchmark_prefill_lengths=1024 \ +# max_prefill_predict_length=1024 \ +# max_target_length=2048 \ +# ici_fsdp_parallelism=1 \ +# ici_tensor_parallelism=-1 \ +# ici_autoregressive_parallelism=1 \ +# weight_dtype=bfloat16 \ +# enable_profiler=true \ +# scan_layers=false \ +# quantization=int8 \ +# quantize_kvcache=true +# inference_mode=true + + +export model_name=llama2-7b +export tokenizer_path=assets/tokenizer.llama2 +export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" +export ici_tensor_parallelism=-1 +export ici_autoregressive_parallelism=1 +export per_device_batch_size=1 +export load_parameters_path_chat=gs://jwyang-runner-maxtext-logs/llama2-7b_unscanned_chkpt_2024-04-26-18-28/checkpoints/0/items +export load_parameters_path=gs://jwyang-runner-maxtext-logs/llama2-7b_unscanned_chkpt_2024-04-26-19-40/checkpoints/0/items +export load_parameters_path_chat_quantized=gs://jwyang-data/llama7b-chat-quantized-fixed/0/items + +python MaxText/maxengine_server.py \ + MaxText/configs/base.yml \ + base_output_directory=gs://jwyang-data/maxtext-llama2-7b/microbenchmark \ + load_parameters_path=${load_parameters_path_chat} \ + run_name=$(date +%Y-%m-%d-%H-%M) \ + save_config_to_gcs=true \ + model_name=${model_name} \ + tokenizer_path=${tokenizer_path} \ + inference_microbenchmark_log_file_path=microbenchmark.json \ + inference_microbenchmark_prefill_lengths=1024 \ + inference_microbenchmark_stages=prefill,generate \ + inference_microbenchmark_loop_iters=1000 \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + per_device_batch_size=${per_device_batch_size} \ + ici_fsdp_parallelism=1 \ + ici_tensor_parallelism=${ici_tensor_parallelism} \ + ici_autoregressive_parallelism=${ici_autoregressive_parallelism} \ + enable_profiler=false \ + scan_layers=false \ + weight_dtype=bfloat16 + # quantization=int8 + # quantize_kvcache=True + + +export model_name=llama2-7b +export dataset_path=/home/jwyang/llama7b_chat_openorca_input.json +python JetStream/benchmarks/benchmark_serving.py \ + --tokenizer ~/maxtext/assets/tokenizer.llama2 \ + --warmup-first true \ + --save-result \ + --save-request-outputs \ + --request-outputs-file-path /home/jwyang/outputs.json \ + --num-prompts 1000 \ + --max-output-length 1024 \ + --dataset openorca \ + --dataset-path ${dataset_path} + + + +# # 13b model +export model_name=llama2-13b +export tokenizer_path=assets/tokenizer.llama2 +export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" +export ici_tensor_parallelism=-1 +export ici_autoregressive_parallelism=1 +export per_device_batch_size=1 +export load_parameters_path=gs://runner-maxtext-logs/2024-05-16-23-59/unscanned_chkpt/checkpoints/0/items + + +export experiment_time=$(date +%Y-%m-%d-%H-%M) +echo "export experiment_time=${experiment_time}" +python MaxText/maxengine_server.py \ + MaxText/configs/base.yml \ + base_output_directory=gs://morgandu-tpu/maxtext-logs/microbenchmark/${experiment_time} \ + model_name=llama2-13b \ + async_checkpointing=false \ + load_parameters_path=gs://runner-maxtext-logs/2024-05-16-23-59/unscanned_chkpt/checkpoints/0/items \ + run_name=${experiment_time} \ + inference_microbenchmark_log_file_path=${run_name}.json \ + tokenizer_path=assets/tokenizer.llama2 \ + weight_dtype=bfloat16 \ + inference_microbenchmark_prefill_lengths=1024 \ + inference_microbenchmark_stages=prefill,generate \ + inference_microbenchmark_loop_iters=10 \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + ici_fsdp_parallelism=1 \ + ici_tensor_parallelism=-1 \ + ici_autoregressive_parallelism=1 \ + enable_profiler=false \ + scan_layers=false \ + attention=dot_product \ + save_config_to_gcs=true \ + per_device_batch_size=1 + + +python MaxText/inference_microbenchmark.py \ + MaxText/configs/base.yml \ + base_output_directory=gs://morgandu-tpu/maxtext-logs/microbenchmark/${experiment_time} \ + model_name=llama2-13b \ + async_checkpointing=false \ + load_parameters_path=gs://runner-maxtext-logs/2024-05-16-23-59/unscanned_chkpt/checkpoints/0/items \ + run_name=${experiment_time} \ + inference_microbenchmark_log_file_path=${run_name}.json \ + tokenizer_path=assets/tokenizer.llama2 \ + weight_dtype=bfloat16 \ + inference_microbenchmark_prefill_lengths=1024 \ + inference_microbenchmark_stages=prefill,generate \ + inference_microbenchmark_loop_iters=10 \ + max_prefill_predict_length=1024 \ + max_target_length=2048 \ + ici_fsdp_parallelism=1 \ + ici_tensor_parallelism=-1 \ + ici_autoregressive_parallelism=1 \ + enable_profiler=false \ + scan_layers=false \ + attention=dot_product \ + save_config_to_gcs=true \ + per_device_batch_size=1 + + + +# # LLaMA2-70B commands +# # source .env/bin/activate +# your_run_name=jwyang_bs1_llama70b +# python MaxText/inference_microbenchmark.py \ +# MaxText/configs/base.yml \ +# base_output_directory=gs://jwyang-data/maxtext-llama2-70b/microbenchmark \ +# run_name=${your_run_name} \ +# per_device_batch_size=1 \ +# save_config_to_gcs=true \ +# model_name=llama2-70b \ +# tokenizer_path=assets/tokenizer.llama2 \ +# inference_microbenchmark_prefill_lengths=32 \ +# max_prefill_predict_length=32 \ +# max_target_length=64 \ +# ici_fsdp_parallelism=1 \ +# ici_tensor_parallelism=-1 \ +# ici_autoregressive_parallelism=1 \ +# weight_dtype=bfloat16 \ +# enable_profiler=true \ +# scan_layers=false \ +# quantization=int8 \ +# quantize_kvcache=true + + +export model_name=llama2-70b +export tokenizer_path=assets/tokenizer.llama2 +export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" +export ici_tensor_parallelism=-1 +export ici_autoregressive_parallelism=1 +export per_device_batch_size=1 +export prefill_length=16 +export target_length=32 + +python MaxText/maxengine_server.py \ + MaxText/configs/base.yml \ + base_output_directory=gs://jwyang-data/maxtext-llama2-70b/microbenchmark \ + run_name=$(date +%Y-%m-%d-%H-%M) \ + save_config_to_gcs=true \ + model_name=${model_name} \ + tokenizer_path=${tokenizer_path} \ + inference_microbenchmark_log_file_path=microbenchmark.json \ + inference_microbenchmark_prefill_lengths=${prefill_length} \ + inference_microbenchmark_stages=prefill,generate \ + inference_microbenchmark_loop_iters=1000 \ + max_prefill_predict_length=${prefill_length} \ + max_target_length=${target_length} \ + per_device_batch_size=${per_device_batch_size} \ + ici_fsdp_parallelism=1 \ + ici_tensor_parallelism=${ici_tensor_parallelism} \ + ici_autoregressive_parallelism=${ici_autoregressive_parallelism} \ + enable_profiler=false \ + scan_layers=false \ + weight_dtype=bfloat16 \ + quantization=int8 \ + quantize_kvcache=True \ No newline at end of file diff --git a/benchmarks/mlperf/user.conf b/benchmarks/mlperf/user.conf new file mode 100644 index 00000000..4a53a70b --- /dev/null +++ b/benchmarks/mlperf/user.conf @@ -0,0 +1,29 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds + +# Set performance_sample_count for each model. +llama2-70b.*.performance_sample_count_override = 24576 +mixtral-8x7b.*.performance_sample_count_override = 15000 + +llama2-70b.*.min_duration = 600000 +mixtral-8x7b.*.min_duration = 30000 + + +# In Offline scenario, we always have one query. But LoadGen maps this to +# min_sample_count internally in Offline scenario. If the dataset size is larger +# than 24576 we limit the min_query_count to 24576 and otherwise we use +# the dataset size as the limit +llama2-70b.Offline.min_query_count = 24576 +llama2-70b.Server.min_query_count = 24576 +mixtral-8x7b.Offline.min_query_count = 15000 +mixtral-8x7b.Server.min_query_count = 15000 + +# These fields should be defined and overridden by user.conf. +*.Offline.target_qps = 5.0 +llama2-70b.Server.target_qps = 1.0 +mixtral-8x7b.Server.target_qps = 11.0 + + +*.sample_concatenate_permutation = 1 diff --git a/benchmarks/mlperf/user100.conf b/benchmarks/mlperf/user100.conf new file mode 100644 index 00000000..995de1e2 --- /dev/null +++ b/benchmarks/mlperf/user100.conf @@ -0,0 +1,28 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds + +# Set performance_sample_count for each model. +llama2-70b.*.performance_sample_count_override = 100 +mixtral-8x7b.*.performance_sample_count_override = 100 + +llama2-70b.*.min_duration = 60 +mixtral-8x7b.*.min_duration = 30 + +# In Offline scenario, we always have one query. But LoadGen maps this to +# min_sample_count internally in Offline scenario. If the dataset size is larger +# than 24576 we limit the min_query_count to 24576 and otherwise we use +# the dataset size as the limit +llama2-70b.Offline.min_query_count = 100 +llama2-70b.Server.min_query_count = 100 +mixtral-8x7b.Offline.min_query_count = 100 +mixtral-8x7b.Server.min_query_count = 100 + + +# These fields should be defined and overridden by user.conf. +*.Offline.target_qps = 5.0 +llama2-70b.Server.target_qps = 1.0 +mixtral-8x7b.Server.target_qps = 9.0 + +*.sample_concatenate_permutation = 1 diff --git a/benchmarks/mlperf/user2000.conf b/benchmarks/mlperf/user2000.conf new file mode 100644 index 00000000..51855597 --- /dev/null +++ b/benchmarks/mlperf/user2000.conf @@ -0,0 +1,28 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds + +# Set performance_sample_count for each model. +llama2-70b.*.performance_sample_count_override = 2000 +mixtral-8x7b.*.performance_sample_count_override = 2000 + +llama2-70b.*.min_duration = 60000 +mixtral-8x7b.*.min_duration = 30000 + +# In Offline scenario, we always have one query. But LoadGen maps this to +# min_sample_count internally in Offline scenario. If the dataset size is larger +# than 24576 we limit the min_query_count to 24576 and otherwise we use +# the dataset size as the limit +llama2-70b.Offline.min_query_count = 2000 +llama2-70b.Server.min_query_count = 2000 +mixtral-8x7b.Offline.min_query_count = 2000 +mixtral-8x7b.Server.min_query_count = 2000 + + +# These fields should be defined and overridden by user.conf. +*.Offline.target_qps = 5.0 +llama2-70b.Server.target_qps = 1.0 +mixtral-8x7b.Server.target_qps = 9.0 + +*.sample_concatenate_permutation = 1 diff --git a/google3/third_party/py/jetstream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl b/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl similarity index 100% rename from google3/third_party/py/jetstream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl rename to benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl diff --git a/google3/third_party/py/jetstream/benchmarks/requirements.in b/benchmarks/requirements.in similarity index 100% rename from google3/third_party/py/jetstream/benchmarks/requirements.in rename to benchmarks/requirements.in diff --git a/google3/third_party/py/jetstream/external_tokenizers/llama3/__init__.py b/benchmarks/tests/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/external_tokenizers/llama3/__init__.py rename to benchmarks/tests/__init__.py diff --git a/benchmarks/tests/test_metrics.py b/benchmarks/tests/test_metrics.py new file mode 100644 index 00000000..08288f93 --- /dev/null +++ b/benchmarks/tests/test_metrics.py @@ -0,0 +1,177 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for metrics.""" + +import unittest + +import benchmarks.metrics as metrics +import datetime +import re + + +class TestEventMetric(unittest.TestCase): + """ "Tests for event metric (i.e. distribution).""" + + def setUp(self): + self._metric = metrics.EventMetric( + "requestLatency", "Latency of requests", "ms" + ) + + def test_record_adds_a_data_event(self): + m = self._metric + m.record(1.0) + data_points = m.data() + self.assertEqual(1, len(data_points)) + self.assertEqual(1.0, data_points[0]) + + def test_percentile_returns_correct_percentile(self): + m = self._metric + n = 11 + for i in range(0, n): + m.record(i) + self.assertEqual(m.percentile(50), 5) + self.assertEqual(m.percentile(90), 9) + self.assertEqual(m.percentile(100), 10) + + def test_mean_returns_correct_mean_value(self): + m = self._metric + n = 3 + for i in range(0, n): + m.record(i) + self.assertEqual(sum(range(0, n)) / n, m.mean()) + + def test_distribution_summary_str_returns_expected_str(self): + m = self._metric + n = 100 + for i in range(0, n): + m.record(i) + summary = m.distribution_summary_str() + self.assertTrue(re.search(r"Mean requestLatency", summary)) + self.assertTrue(re.search(r"Median requestLatency", summary)) + self.assertTrue(re.search(r"P99 requestLatency", summary)) + + def test_distribution_summary_dict_returns_expected_dict(self): + m = self._metric + n = 100 + for i in range(0, n): + m.record(i) + summary = m.distribution_summary_dict() + self.assertIn("mean_requestLatency_ms", summary) + self.assertIn("median_requestLatency_ms", summary) + self.assertIn("p99_requestLatency_ms", summary) + + +class TestCounterMetric(unittest.TestCase): + """Tests for counter metric (i.e. monotonically increasing counter).""" + + def setUp(self): + self._counter = metrics.CounterMetric( + "RequestCompleteCount", "Number of completed requests" + ) + + def test_increment_increases_total_count(self): + m = self._counter + old_total_cnt = m.total_count() + m.increment() + new_total_cnt = m.total_count() + self.assertEqual(1, new_total_cnt - old_total_cnt) + + def test_increment_within_same_second_update_counts_for_the_second(self): + """Test to ensure one entry used to cumulate counts within a second""" + m = self._counter + timestamp = datetime.datetime.strptime( + "2025-01-01 00:00:00", "%Y-%m-%d %H:%M:%S" + ) + m.increment(1, timestamp) + m.increment(2, timestamp) + data = m.data() + self.assertEqual(1, len(data.keys())) + self.assertEqual(3, data[timestamp]) + + def test_increment_at_different_seconds_creates_separate_entries(self): + """Test to ensure separate entries used for different seconds""" + m = self._counter + timestamp_first = datetime.datetime.strptime( + "2025-01-01 00:00:00", "%Y-%m-%d %H:%M:%S" + ) + m.increment(1, timestamp_first) + timestamp_second = datetime.datetime.strptime( + "2025-01-01 00:00:01", "%Y-%m-%d %H:%M:%S" + ) + m.increment(2, timestamp_second) + data = m.data() + self.assertEqual(2, len(data.keys())) + self.assertEqual(1, data[timestamp_first]) + self.assertEqual(2, data[timestamp_second]) + + def test_rate_returns_expected(self): + m = self._counter + n = 10 + start_time = datetime.datetime.strptime( + "2025-01-01 00:00:00", "%Y-%m-%d %H:%M:%S" + ) + delta_time_sec = 1 + for i in range(0, n): + m.increment( + 1, start_time + datetime.timedelta(seconds=delta_time_sec * i) + ) + # n counts across n seconds, thus rate = 1 + self.assertEqual(1, m.rate()) + + def test_rate_over_window_returns_expected(self): + m = self._counter + n = 10 + start_time = datetime.datetime.strptime( + "2025-01-01 00:00:00", "%Y-%m-%d %H:%M:%S" + ) + delta_time_sec = 1 + for i in range(0, n): + m.increment( + 1, start_time + datetime.timedelta(seconds=delta_time_sec * i) + ) + + rates_with_timestamp = m.rate_over_window(window_size_sec=5) + + rates = [rate for timestamp, rate in rates_with_timestamp] + # 10 seconds with 1 count in each second. One rate per window_size_sec=5 sec + # so rate = 1 for [0, 5) and rate = 1 [5, 10) + self.assertEqual([1, 1], rates) + + def test_rate_over_window_to_csv_returns_correct(self): + m = self._counter + n = 10 + start_time = datetime.datetime.strptime( + "2025-01-01 00:00:00", "%Y-%m-%d %H:%M:%S" + ) + delta_time_sec = 1 + for i in range(0, n): + m.increment( + 1, start_time + datetime.timedelta(seconds=delta_time_sec * i) + ) + + csv_output = m.rate_over_window_to_csv(window_size_sec=5) + + rows = csv_output.split("\n") + self.assertEqual(2, len(rows)) + expected_timestamps = "TimeStamp,2025-01-01 00:00:00,2025-01-01 00:00:05" + got_timestamps = rows[0] + self.assertEqual(expected_timestamps, got_timestamps) + expected_values = "Value,1.00,1.00" + got_values = rows[1] + self.assertEqual(expected_values, got_values) + + +if __name__ == "__main__": + unittest.main() diff --git a/google3/third_party/py/jetstream/docs/observability-prometheus-metrics-in-jetstream-server.md b/docs/observability-prometheus-metrics-in-jetstream-server.md similarity index 94% rename from google3/third_party/py/jetstream/docs/observability-prometheus-metrics-in-jetstream-server.md rename to docs/observability-prometheus-metrics-in-jetstream-server.md index 079b132a..04d7be4c 100644 --- a/google3/third_party/py/jetstream/docs/observability-prometheus-metrics-in-jetstream-server.md +++ b/docs/observability-prometheus-metrics-in-jetstream-server.md @@ -80,6 +80,6 @@ echo '{ }' | kubectl apply -f - ``` -The metrics can now be queried in the Google Cloud Metrics Explorer. When adding a metrics query with the `+Add Query` button the new metrics should be found under the `Prometheus Target > Jetstream` submenu. +The metrics can now be queried in the [Google Cloud Metrics Explorer](https://pantheon.corp.google.com/monitoring/metrics-explorer). When adding a metrics query with the `+Add Query` button the new metrics should be found under the `Prometheus Target > Jetstream` submenu. Additional guides on the metrics explorer can be found [here](https://cloud.google.com/monitoring/charts/metrics-selector). \ No newline at end of file diff --git a/google3/third_party/py/jetstream/docs/online-inference-with-maxtext-engine.md b/docs/online-inference-with-maxtext-engine.md similarity index 100% rename from google3/third_party/py/jetstream/docs/online-inference-with-maxtext-engine.md rename to docs/online-inference-with-maxtext-engine.md diff --git a/google3/third_party/py/jetstream/docs/profiling-with-jax-profiler-and-tensorboard.md b/docs/profiling-with-jax-profiler-and-tensorboard.md similarity index 100% rename from google3/third_party/py/jetstream/docs/profiling-with-jax-profiler-and-tensorboard.md rename to docs/profiling-with-jax-profiler-and-tensorboard.md diff --git a/experimental/jax/README.md b/experimental/jax/README.md new file mode 100644 index 00000000..7b6d7698 --- /dev/null +++ b/experimental/jax/README.md @@ -0,0 +1,104 @@ +# An experimental JAX inference framework for prototyping new ideas. + +## About + + It has the following features (some of them are limited version): + +``` + Performance: + 1. Paged Attention + 2. Chunked Prefill and Piggybacking Decode + 3. Collective Matmul + + Framework: + 1. Pythonic model builder + 2. JAX manual sharding + 3. Interface for different hardware supports + 4. On-the-flying HF model conversion and deployment +``` + +## Quick Start + +So far, the experimental code only works for llama2 7b and TPU v5e-8. The whole process only takes less than 10 mins if you have a Cloud TPU v5e-8 ready. + +### 1. Create Cloud TPU v5e-8 on Google Cloud: + +``` +gcloud alpha compute tpus queued-resources create ${QR_NAME} \ + --node-id ${NODE_NAME} \ + --project ${PROJECT_ID} \ + --zone ${ZONE} \ + --accelerator-type v5litepod-8 \ + --runtime-version v2-alpha-tpuv5-lite +``` + +For more [information](https://cloud.google.com/tpu/docs/queued-resources) + + +### 2. Set up the LLM Server and serve request: +SSH into your Cloud TPU VM first and run the following command: + +Set up a new Python env. +``` +virtualenv jax-inference +source jax-inference/bin/activate +``` + +Clone the repo and install the dependencies. +``` +git clone https://github.com/AI-Hypercomputer/JetStream.git + +cd JetStream/experimental/jax + +pip install -r requirements.txt +``` + +Log in to the Hugging Face (make sure your account has the permission to access `meta-llama/Llama-2-7b-chat-hf`) + +``` +huggingface-cli login +``` + + +### 3. Offline Benchmarking: + +Note: the current setup is using 8-ways TP which is just for experiment and compare with current JetStream + MaxText number. + +``` +export PYTHONPATH=$(pwd) +export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache" +python inference/entrypoint/mini_offline_benchmarking.py +``` + +Offline Benchmarking result: + +This number is around `45%` better than the current MaxText and JetStream (as of 2024/08/16) number in the same situation. + + +``` +Benchmarking result: + Total requests: 1000 + Total input tokens: 218743 + Total output tokens: 291740 + Input token throughput: 2980.654636529649 tokens/sec + Output token throughput: 3975.332621666338 tokens/sec +``` + +Note: The online number should be even more better than the current MaxText and JetStream as the experimental framework runs the prefill and decode together in one model forward pass. + +### 4. Online Serving Example: + +Start server: + +``` +python inference/entrypoint/run_simple_server.py & +``` + +Send request: + +``` +curl --no-buffer -H 'Content-Type: application/json' \ + -d '{ "prompt": "Today is a good day" }' \ + -X POST \ + localhost:8000/generate +``` \ No newline at end of file diff --git a/experimental/jax/inference/entrypoint/__init__.py b/experimental/jax/inference/entrypoint/__init__.py new file mode 100644 index 00000000..e7c0b714 --- /dev/null +++ b/experimental/jax/inference/entrypoint/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/experimental/jax/inference/entrypoint/mini_offline_benchmarking.py b/experimental/jax/inference/entrypoint/mini_offline_benchmarking.py new file mode 100644 index 00000000..53a85663 --- /dev/null +++ b/experimental/jax/inference/entrypoint/mini_offline_benchmarking.py @@ -0,0 +1,75 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import time +import pandas +from inference.runtime.request_type import * +from inference.runtime import offline_inference + + +def load_openorca_dataset_pkl(): + # Read pickle file + current_dir = os.path.dirname(__file__) + samples = pandas.read_pickle( + f"{current_dir}/open_orca_gpt4_tokenized_llama.calibration_1000.pkl" + ) + + prompts = [] + outputs = [] + for _, row in samples.iterrows(): + prompts.append(row["input"]) + outputs.append(row["output"]) + + return [(prompt, output) for prompt, output in zip(prompts, outputs)] + + +def benchmarking(): + dataset = load_openorca_dataset_pkl() + + ds = dataset[:1000] + ds = [d[0] for d in ds] + + inference_instance = offline_inference.OfflineInference() + + start_time = time.perf_counter() + res_list: list[Response] = inference_instance(ds) + end_time = time.perf_counter() + duration = end_time - start_time + + input_tokens = [] + for res in res_list: + input_tokens = input_tokens + res.input_tokens + + output_tokens = [] + for res in res_list: + output_tokens = output_tokens + res.generated_tokens + + num_input_tokens = len(input_tokens) + num_output_tokens = len(output_tokens) + + print("Benchmarking result: ") + # Hardcode the number of requests as 1000 based on the test + # dataset. + print(" Total requests: 1000") + print(" Total input tokens:", num_input_tokens) + print(" Total output tokens:", num_output_tokens) + print(f" Input token throughput: {num_input_tokens/duration} tokens/sec") + print(f" Output token throughput: {num_output_tokens/duration} tokens/sec") + + +if __name__ == "__main__": + benchmarking() diff --git a/experimental/jax/inference/entrypoint/open_orca_gpt4_tokenized_llama.calibration_1000.pkl b/experimental/jax/inference/entrypoint/open_orca_gpt4_tokenized_llama.calibration_1000.pkl new file mode 100644 index 00000000..cde4330e Binary files /dev/null and b/experimental/jax/inference/entrypoint/open_orca_gpt4_tokenized_llama.calibration_1000.pkl differ diff --git a/experimental/jax/inference/entrypoint/run_simple_server.py b/experimental/jax/inference/entrypoint/run_simple_server.py new file mode 100644 index 00000000..5533ac10 --- /dev/null +++ b/experimental/jax/inference/entrypoint/run_simple_server.py @@ -0,0 +1,31 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import uvicorn +import os + + +if __name__ == "__main__": + print("start") + current_dir = os.path.dirname(__file__) + parent_dir = os.path.dirname(current_dir) + + uvicorn.run( + app_dir=f"{parent_dir}/server", + app="simple_server:app", + host="0.0.0.0", + port=8000, + ) diff --git a/experimental/jax/inference/kernel/__init__.py b/experimental/jax/inference/kernel/__init__.py new file mode 100644 index 00000000..43588f5f --- /dev/null +++ b/experimental/jax/inference/kernel/__init__.py @@ -0,0 +1,23 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from .attention_ops import * +from .attention.tpu.quantization_utils import * +from .collective_matmul_ops import * +from .linear.tpu.collective_matmul import ( + prepare_rhs_for_all_gather_collective_matmul, + prepare_rhs_for_collective_matmul_reduce_scatter, +) diff --git a/experimental/jax/inference/kernel/attention/tpu/__init__.py b/experimental/jax/inference/kernel/attention/tpu/__init__.py new file mode 100644 index 00000000..320d9bc8 --- /dev/null +++ b/experimental/jax/inference/kernel/attention/tpu/__init__.py @@ -0,0 +1,18 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from .chunked_prefill_attention import * +from .paged_attention import * diff --git a/experimental/jax/inference/kernel/attention/tpu/chunked_prefill_attention.py b/experimental/jax/inference/kernel/attention/tpu/chunked_prefill_attention.py new file mode 100644 index 00000000..6649e994 --- /dev/null +++ b/experimental/jax/inference/kernel/attention/tpu/chunked_prefill_attention.py @@ -0,0 +1,275 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Chunked prefill TPU kernel with paged kv cache.""" + +import jax +from jax import numpy as jnp +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from inference.kernel.attention.tpu.paged_attention import * + + +DEFAULT_MASK_VALUE = -2.3819763e38 # Set to a large negative number. + + +def chunked_prefill_attention_impl( + length_ref, # shape: (1,), smem, + page_indices_ref, # shape: (max_seq_len // page_size), smem, + buffer_index_ref, # shape: (1,), smem, + q_ref, # shape: (group_size, chunk, head_dim), vmem, + k_pages_hbm_ref, # shape: (num_kv_heads, num_pages, page_size, head_dim), hbm + v_pages_hbm_ref, # shape: (num_kv_heads, num_pages, page_size, head_dim), hbm + out_ref, # shape: (group_size, chunk, head_dim), vmem, + l_ref, # shape: (group_size, chunk, 1), vmem, + m_ref, # shape: (group_size, chunk, 1), vmem, + k_vmem_buffer, # shape: (2, page_per_chunk, page_size, head_dim), vmem, + v_vmem_buffer, # shape: (2, page_per_chunk, page_size, head_dim), vmem, + sem, +): + h = pl.program_id(0) + page_size = k_pages_hbm_ref.shape[2] + head_dim = k_pages_hbm_ref.shape[3] + group_size = q_ref.shape[0] + num_kv_heads = k_pages_hbm_ref.shape[0] + chunk_size = q_ref.shape[1] + length = length_ref[0] + q_chunk_idx = jax.lax.div(length, chunk_size) + reminder = jax.lax.rem(length, chunk_size) + q_chunk_idx -= jnp.where(reminder > 0, 0, 1) + out_ref[...] = jnp.zeros_like(out_ref) + + def create_kv_async_copy_descriptors(h, i, buffer_index): + pages_to_load = chunk_size // page_size + page_offset = i * pages_to_load + async_copy_k = MultiPageAsyncCopyDescriptor( + k_pages_hbm_ref, + None, + k_vmem_buffer.at[buffer_index], + None, + sem, + page_indices_ref, + page_offset, + pages_to_load, + head_index=h, + ) + async_copy_v = MultiPageAsyncCopyDescriptor( + v_pages_hbm_ref, + None, + v_vmem_buffer.at[buffer_index], + None, + sem, + page_indices_ref, + page_offset, + pages_to_load, + head_index=h, + ) + return async_copy_k, async_copy_v + + def next_block_indice(h, i): + return jax.lax.cond( + (i + 1) * chunk_size < length, lambda: (h, i + 1), lambda: (h + 1, 0) + ) + + def per_kv_chunk_body(i, _): + @pl.when((i * chunk_size) < length) + def body(): + buffer_index = buffer_index_ref[0] + + @pl.when(i == 0) + def init(): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + + @pl.when(h == 0) + def prefetch_first_kv(): + # prefetch the first kv chunk. + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + h, i, buffer_index + ) + async_copy_k.start() + async_copy_v.start() + + next_h, next_i = next_block_indice(h, i) + + @pl.when((next_h < num_kv_heads) & (next_i <= q_chunk_idx)) + def prefetch_next_block(): + # prefetch the kv chunk for next iteration. + next_buffer_index = jnp.where(buffer_index == 0, 1, 0) + async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors( + next_h, next_i, next_buffer_index + ) + + async_copy_next_k.start() + async_copy_next_v.start() + buffer_index_ref[0] = next_buffer_index + + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + h, i, buffer_index + ) + + k = async_copy_k.wait_and_get_loaded() + v = async_copy_v.wait_and_get_loaded() + + mask_shape = (chunk_size, chunk_size) + row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + row_ids += q_chunk_idx * chunk_size + col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) + col_ids += i * chunk_size + causal_mask = col_ids <= row_ids + causal_mask_value = jnp.where(causal_mask, 0.0, DEFAULT_MASK_VALUE) + + def per_group_body(group_idx, _): + q = q_ref[group_idx] + s = ( + jnp.einsum("td,sd->ts", q, k, preferred_element_type=jnp.float32) + + causal_mask_value + ) + # mask. + s_max = jnp.max(s, axis=1, keepdims=True) + + prev_m = m_ref[group_idx] + prev_l = l_ref[group_idx] + + cur_m = jnp.maximum(prev_m, s_max) + cur_m_to_attn_size = jax.lax.broadcast_in_dim( + cur_m, (chunk_size, chunk_size), (0, 1) + ) + + p = jnp.exp(s - cur_m_to_attn_size) + + cur_l = jnp.exp(prev_m - cur_m) * prev_l + jnp.sum( + p, axis=1, keepdims=True + ) + + out = out_ref[group_idx] + + out_ref[group_idx, :, :] = ( + out + * jax.lax.broadcast_in_dim( + jnp.exp(prev_m - cur_m), (chunk_size, head_dim), (0, 1) + ) + + p @ v + ).astype( + out_ref.dtype + ) # p @ v "ts,sd->td" + + m_ref[group_idx, :, :] = cur_m + l_ref[group_idx, :, :] = cur_l + return () + + jax.lax.fori_loop(0, group_size, per_group_body, ()) + + @pl.when(((i + 1) * chunk_size) >= length) + def rescale(): + l = jax.lax.broadcast_in_dim( + l_ref[...], (group_size, chunk_size, head_dim), (0, 1, 2) + ) + out_ref[...] = (out_ref[...] / l).astype(out_ref.dtype) + + return () + + # loop over k, v cache chunk. + jax.lax.fori_loop( + 0, lax.div(length + chunk_size - 1, chunk_size), per_kv_chunk_body, () + ) + + +# TODO: Change to firstly attend to the current chunk kv +# and then write to the KV Cache storage to avoid redundant +# KV Cache reading. +def chunked_prefill_attention( + q: jax.Array, + k_pages: jax.Array, + v_pages: jax.Array, + length: jax.Array, + page_indices: jax.Array, +): + """TPU chunked prefill attention.""" + chunk_size, num_attn_heads, head_dim = q.shape + num_kv_heads, _, page_size, _ = k_pages.shape + attn_group_size = num_attn_heads // num_kv_heads + page_per_chunk = chunk_size // page_size + + # q shape as (num_attn_heads, chunk_size, head_dim) + q = q.transpose((1, 0, 2)) + q = q / jnp.sqrt(head_dim) + + q_block_spec = pl.BlockSpec( + (attn_group_size, chunk_size, head_dim), lambda i, *_: (i, 0, 0) + ) + lm_block_spec = pl.BlockSpec( + (attn_group_size, chunk_size, 1), lambda *_: (0, 0, 0) + ) + lm_shape = jax.ShapeDtypeStruct( + shape=(attn_group_size, chunk_size, 1), dtype=jnp.float32 + ) + # loop over Q chunk and num kv heads dimension. + out, _, _ = pl.pallas_call( + chunked_prefill_attention_impl, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=3, + in_specs=[ + q_block_spec, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=[ + q_block_spec, + lm_block_spec, + lm_block_spec, + ], + scratch_shapes=[ + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + page_per_chunk, + page_size, + head_dim, + ), + k_pages.dtype, + ), # k_pages buffer + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + page_per_chunk, + page_size, + head_dim, + ), + v_pages.dtype, + ), # v_pages buffer + pltpu.SemaphoreType.DMA, + ], + grid=(num_kv_heads,), + ), + out_shape=[ + jax.ShapeDtypeStruct(q.shape, q.dtype), + lm_shape, + lm_shape, + ], + # interpret=True, + # debug=True + )( + jnp.reshape(length, (1,)), + page_indices, + jnp.asarray([0], jnp.int32), + q, + k_pages, + v_pages, + ) + out = out.transpose((1, 0, 2)).reshape(chunk_size, -1).astype(q.dtype) + + return out diff --git a/experimental/jax/inference/kernel/attention/tpu/paged_attention.py b/experimental/jax/inference/kernel/attention/tpu/paged_attention.py new file mode 100644 index 00000000..51bec6d9 --- /dev/null +++ b/experimental/jax/inference/kernel/attention/tpu/paged_attention.py @@ -0,0 +1,666 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PagedAttention TPU kernel for decode phase (ported from jax 0.4.33 for quick experiment)""" + +from collections.abc import Sequence +import functools +from typing import Literal + +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np +from inference.kernel.attention.tpu import quantization_utils + + +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + + +class MultiPageAsyncCopyDescriptor: + """Descriptor for async copy of multiple K/V pages from HBM.""" + + def __init__( + self, + pages_hbm_ref, + scales_pages_hbm_ref, + vmem_buffer, + scales_vmem_buffer, + sem, + page_indices, + page_indices_start_offset, + num_pages_to_load, + head_index, + ): + self._vmem_buffer = vmem_buffer + self._scales_vmem_buffer = scales_vmem_buffer + self._num_pages_to_load = num_pages_to_load + if head_index is not None: + self._pages_hbm_ref = pages_hbm_ref.at[head_index] + if scales_pages_hbm_ref is not None: + self._scales_pages_hbm_ref = scales_pages_hbm_ref.at[head_index] + else: + self._scales_pages_hbm_ref = None + else: + self._pages_hbm_ref = pages_hbm_ref + self._scales_pages_hbm_ref = scales_pages_hbm_ref + self._sem = sem + self._page_indices = page_indices + self._page_indices_start_offset = page_indices_start_offset + self._async_copies = [ + self._make_async_copy(i) for i in range(self._num_pages_to_load) + ] + if ( + self._scales_pages_hbm_ref is not None + and self._scales_vmem_buffer is not None + ): + self._async_copies += [ + self._make_scales_async_copy(i) + for i in range(self._num_pages_to_load) + ] + + def _make_async_copy(self, i): + page_index = self._page_indices[self._page_indices_start_offset + i] + return pltpu.make_async_copy( + self._pages_hbm_ref.at[page_index], self._vmem_buffer.at[i], self._sem + ) + + def _make_scales_async_copy(self, i): + page_index = self._page_indices[self._page_indices_start_offset + i] + return pltpu.make_async_copy( + self._scales_pages_hbm_ref.at[ + page_index + ], # pytype: disable=attribute-error + self._scales_vmem_buffer.at[i], # pytype: disable=attribute-error + self._sem, + ) + + def start(self): + """Starts the async copies.""" + for async_copy in self._async_copies: + async_copy.start() + + def _maybe_dequantize(self, x, x_scale, dtype=jnp.bfloat16): + if x_scale is None: + return x.astype(dtype) + return quantization_utils.from_int8(x, x_scale, dtype=dtype) + + def wait_and_get_loaded(self) -> jax.Array: + """Wait async copies and gets the loaded buffer as a jax.Array.""" + for async_copy in self._async_copies: + async_copy.wait() + head_dim = self._vmem_buffer.shape[-1] + jax_array = self._vmem_buffer[...].astype(jnp.float32) + if self._scales_vmem_buffer is not None: + scales_jax_array = self._scales_vmem_buffer[...].astype(jnp.float32) + else: + scales_jax_array = None + jax_array = self._maybe_dequantize(jax_array, scales_jax_array) + return jax_array.reshape(-1, head_dim) + + +def paged_flash_attention_kernel( + lengths_ref, + page_indices_ref, + buffer_index_ref, + step_ref, + q_ref, + k_pages_hbm_ref, + k_scales_pages_hbm_ref, + v_pages_hbm_ref, + v_scales_pages_hbm_ref, + o_ref, + m_ref, + l_ref, + k_vmem_buffer, + k_scales_vmem_buffer, + v_vmem_buffer, + v_scales_vmem_buffer, + sem, + *, + batch_size: int, + pages_per_compute_block: int, + pages_per_sequence: int, + mask_value: float, + attn_logits_soft_cap: float | None, + megacore_mode: str | None, + program_ids=(), +): + """Pallas kernel for paged attention.""" + if program_ids: + core_index, b, h, i = program_ids + else: + core_index, b, h, i = ( + pl.program_id(0), + pl.program_id(1), + pl.program_id(2), + pl.program_id(3), + ) + num_kv_heads, _, page_size, _ = k_pages_hbm_ref.shape + bk = page_size * pages_per_compute_block + num_cores = pl.num_programs(0) + + b_step = num_cores if megacore_mode == "batch" else 1 + b_start = core_index if megacore_mode == "batch" else 0 + h_step = num_cores if megacore_mode == "kv_head" else 1 + h_start = core_index if megacore_mode == "kv_head" else 0 + + h = h * h_step + h_start + b = b * b_step + b_start + length = lengths_ref[b] + + def compute_block_indices(b, h, i): + + def advance_b(): + next_b = b + b_step + + def advance_to_next_non_zero_length(): + next_next_b = next_b + b_step + return lax.fori_loop( + lax.div(next_next_b, b_step), + lax.div(batch_size, b_step), + lambda _, b: jnp.where(lengths_ref[b] == 0, b + b_step, b), + next_next_b, + ) + + return ( + lax.cond( + jnp.logical_and(next_b < batch_size, lengths_ref[next_b] == 0), + advance_to_next_non_zero_length, + lambda: next_b, + ), + h_start, + 0, + ) + + def advance_h(): + next_h = h + h_step + return lax.cond(next_h < num_kv_heads, lambda: (b, next_h, 0), advance_b) + + return lax.cond(i * bk < lengths_ref[b], lambda: (b, h, i), advance_h) + + def create_kv_async_copy_descriptors(b, h, i, buffer_index): + page_offset = b * pages_per_sequence + i * pages_per_compute_block + pages_to_load = pages_per_compute_block + async_copy_k = MultiPageAsyncCopyDescriptor( + k_pages_hbm_ref, + k_scales_pages_hbm_ref, + k_vmem_buffer.at[buffer_index], + k_scales_vmem_buffer.at[buffer_index] + if k_scales_vmem_buffer is not None + else None, + sem, + page_indices_ref, + page_offset, + pages_to_load, + h, + ) + async_copy_v = MultiPageAsyncCopyDescriptor( + v_pages_hbm_ref, + v_scales_pages_hbm_ref, + v_vmem_buffer.at[buffer_index], + v_scales_vmem_buffer.at[buffer_index] + if v_scales_vmem_buffer is not None + else None, + sem, + page_indices_ref, + page_offset, + pages_to_load, + h, + ) + return async_copy_k, async_copy_v + + @pl.when(i * bk < length) + def flash_attention(): # pylint: disable=unused-variable + step = step_ref[0] + buffer_index = buffer_index_ref[0] + + @pl.when(i == 0) + def init(): # pylint: disable=unused-variable + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + @pl.when(step == 0) + def prefetch_first_block(): # pylint: disable=unused-variable + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + b, h, i, buffer_index + ) + async_copy_k.start() + async_copy_v.start() + + next_b, next_h, next_i = compute_block_indices(b, h, i + 1) + + @pl.when(next_b < batch_size) + def prefetch_next_block(): # pylint: disable=unused-variable + next_buffer_index = jnp.where(buffer_index == 0, 1, 0) + async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors( + next_b, next_h, next_i, next_buffer_index + ) + async_copy_next_k.start() + async_copy_next_v.start() + buffer_index_ref[0] = next_buffer_index + + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + b, h, i, buffer_index + ) + q = q_ref[...].astype(jnp.float32) + k = async_copy_k.wait_and_get_loaded() + qk = jnp.einsum("hd,td->ht", q, k, preferred_element_type=jnp.float32) + if attn_logits_soft_cap is not None: + capped_qk = jnp.tanh(qk / attn_logits_soft_cap) + qk = capped_qk * attn_logits_soft_cap + + mask = i * bk + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length + qk = qk + jnp.where(mask, 0.0, mask_value) + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + m_prev, l_prev = m_ref[...], l_ref[...] + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + v = async_copy_v.wait_and_get_loaded() + o_curr_times_l_curr = jnp.dot(s_curr, v) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + ).astype(o_ref.dtype) + + step_ref[0] = step + 1 + + +def paged_flash_attention_kernel_inline_seq_dim( + lengths_ref, + page_indices_ref, + buffer_index_ref, + step_ref, + q_ref, + k_pages_hbm_ref, + k_scales_pages_hbm_ref, + v_pages_hbm_ref, + v_scales_pages_hbm_ref, + o_ref, + m_ref, + l_ref, + k_vmem_buffer, + k_scales_vmem_buffer, + v_vmem_buffer, + v_scales_vmem_buffer, + sem, + *, + batch_size: int, + pages_per_compute_block: int, + pages_per_sequence: int, + mask_value: float, + attn_logits_soft_cap: float | None, + megacore_mode: str | None, +): + core_index, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2) + + # Initialize the output HBM buffers to avoid accessing garbage memory inside + # the kernel body below. + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + def body(i, _): + paged_flash_attention_kernel( + lengths_ref, + page_indices_ref, + buffer_index_ref, + step_ref, + q_ref, + k_pages_hbm_ref, + k_scales_pages_hbm_ref, + v_pages_hbm_ref, + v_scales_pages_hbm_ref, + o_ref, + m_ref, + l_ref, + k_vmem_buffer, + k_scales_vmem_buffer, + v_vmem_buffer, + v_scales_vmem_buffer, + sem, + batch_size=batch_size, + pages_per_compute_block=pages_per_compute_block, + pages_per_sequence=pages_per_sequence, + mask_value=mask_value, + attn_logits_soft_cap=attn_logits_soft_cap, + megacore_mode=megacore_mode, + program_ids=(core_index, b, h, i), + ) + return () + + bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2] + + if megacore_mode == "batch": + num_cores = pl.num_programs(0) + length = lengths_ref[b * num_cores + core_index] + else: + length = lengths_ref[b] + + lax.fori_loop(0, lax.div(length + bk - 1, bk), body, ()) + + +@functools.partial( + jax.jit, + static_argnames=[ + "pages_per_compute_block", + "attn_logits_soft_cap", + "mask_value", + "megacore_mode", + "inline_seq_dim", + ], +) +def paged_attention( + q: jax.Array, + k_pages: jax.Array | quantization_utils.QuantizedTensor, + v_pages: jax.Array | quantization_utils.QuantizedTensor, + lengths: jax.Array, + page_indices: jax.Array, + *, + mask_value: float = DEFAULT_MASK_VALUE, + attn_logits_soft_cap: float | None = None, + pages_per_compute_block: int, + megacore_mode: str | None = None, + inline_seq_dim: bool = True, +) -> jax.Array: + """Paged grouped query attention. + + Args: + q: A [batch_size, num_heads, head_dim] jax.Array. + k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. + v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. + lengths: A i32[batch_size] jax.Array the length of each example. + page_indices: A i32[batch_size, pages_per_sequence] jax.Array. Each entry + should be in the range of [0, total_num_pages), indicating where to locate + the page in `k_pages` or `v_pages`. + mask_value: The value used for padding in attention. By default it is a very + negative floating point number. + attn_logits_soft_cap: The value used for soft capping the attention logits. + pages_per_compute_block: how many pages to be processed in one flash + attention block in the pallas kernel. + megacore_mode: if set, enable megacore to parallelize the computation. Must + be one of ['kv_head', 'batch', None]. Caveat: set this only if megacore is + enabled, otherwise the kernel may hang. If you are not sure, leave it to + None. + * None: disable megacore parallelism. + * kv_head: megacore parallelism on KV heads; requires number of KV heads + divisible by 2. + * batch: megacore parallelism on batch dimension; requires batch divisible + by 2. + inline_seq_dim: whether to fuse kernel instances along the sequence dim into + one kernel. + + Returns: + The output of attention([batch_size, num_heads, head_dim]). + """ + if isinstance(k_pages, quantization_utils.QuantizedTensor): + k_pages, k_scales_pages = k_pages.weight, k_pages.scales + assert isinstance(k_scales_pages, jax.Array) # For typing. + k_scales_pages = jnp.broadcast_to( + k_scales_pages, (*k_scales_pages.shape[:-1], k_pages.shape[-1]) + ) + else: + k_scales_pages = None + if isinstance(v_pages, quantization_utils.QuantizedTensor): + v_pages, v_scales_pages = v_pages.weight, v_pages.scales + assert isinstance(v_scales_pages, jax.Array) # For typing. + v_scales_pages = jnp.broadcast_to( + v_scales_pages, (*v_scales_pages.shape[:-1], v_pages.shape[-1]) + ) + else: + v_scales_pages = None + + batch_size, num_heads, head_dim = q.shape + num_kv_heads, _, page_size, head_dim_k = k_pages.shape + batch_size_paged_indices, pages_per_sequence = page_indices.shape + + if k_pages.shape != v_pages.shape: + raise ValueError( + f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" + f" {v_pages.shape}" # pytype: disable=attribute-error + ) + if num_heads % num_kv_heads != 0: + raise ValueError( + "Number of Q heads must be divisible by number of KV heads. Got" + f" {num_heads} and {num_kv_heads}." + ) + if head_dim_k != head_dim: + raise ValueError( + "head_dim of Q must be the same as that of K/V. Got" + f" {head_dim} and {head_dim_k}." + ) + if pages_per_sequence % pages_per_compute_block != 0: + raise ValueError( + "pages_per_compute_block must be divisible by pages per sequence. Got" + f" {pages_per_compute_block} and {pages_per_sequence}." + ) + if lengths.shape != (batch_size,): + raise ValueError("`lengths` and `q` must have the same batch size") + if batch_size_paged_indices != batch_size: + raise ValueError("`page_indices` and `q` must have the same batch size") + if lengths.dtype != jnp.int32: + raise ValueError( + "The dtype of `lengths` must be int32. Got {lengths.dtype}" + ) + + # TODO(dinghua): get the actual cores per chip once there's an official API. + if megacore_mode == "kv_head": + if num_kv_heads % 2 != 0: + raise ValueError( + "number of KV heads must be even when megacore_mode is 'kv_head'" + ) + num_cores = 2 + elif megacore_mode == "batch": + if batch_size % 2 != 0: + raise ValueError("batch size must be even when megacore_mode is 'batch'") + num_cores = 2 + elif megacore_mode is None: + num_cores = 1 + else: + raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]") + + if (num_heads // num_kv_heads) % 8 != 0: + # Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a + # <8x128> layout for a <1x128> memref inside the kernel and error out. + q = q.reshape(batch_size, num_heads, 1, head_dim) + if megacore_mode == "kv_head": + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, None, head_dim), + lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0, 0), + ) + elif megacore_mode == "batch": + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, None, head_dim), + lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0, 0), + ) + else: + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, None, head_dim), + lambda core_index, b, h, *_: (b, h, 0, 0), + ) + q_dtype_for_kernel_launch = jnp.float32 + else: + if megacore_mode == "kv_head": + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, head_dim), + lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0), + ) + elif megacore_mode == "batch": + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, head_dim), + lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0), + ) + else: + q_block_spec = pl.BlockSpec( + (None, num_heads // num_kv_heads, head_dim), + lambda core_index, b, h, *_: (b, h, 0), + ) + q_dtype_for_kernel_launch = q.dtype + + dimension_semantics: Sequence[Literal["parallel", "arbitrary"]] + if inline_seq_dim: + kernel = paged_flash_attention_kernel_inline_seq_dim + grid = ( + num_cores, + batch_size // num_cores if megacore_mode == "batch" else batch_size, + num_kv_heads // num_cores + if megacore_mode == "kv_head" + else num_kv_heads, + ) + dimension_semantics = ("parallel", "arbitrary", "arbitrary") + else: + kernel = paged_flash_attention_kernel + grid = ( + num_cores, + batch_size // num_cores if megacore_mode == "batch" else batch_size, + num_kv_heads // num_cores + if megacore_mode == "kv_head" + else num_kv_heads, + pages_per_sequence // pages_per_compute_block, + ) # type: ignore + dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary") + + if k_scales_pages is not None and v_scales_pages is not None: + in_specs = [ + q_block_spec, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + scratch_shapes = ( + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + k_pages.dtype, + ), # k_pages buffer + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + k_scales_pages.dtype, # pytype: disable=attribute-error + ), # k_scales_pages buffer + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + v_pages.dtype, + ), # v_pages buffer + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + v_scales_pages.dtype, # pytype: disable=attribute-error + ), # v_scales_pages buffer + pltpu.SemaphoreType.DMA, + ) + else: + in_specs = [ + q_block_spec, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + None, # type: ignore[list-item] + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + None, # type: ignore[list-item] + ] + scratch_shapes = ( + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + k_pages.dtype, + ), # k_pages buffer + None, + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + pages_per_compute_block, + page_size, + head_dim, + ), + v_pages.dtype, + ), # v_pages buffer + None, + pltpu.SemaphoreType.DMA, + ) + + out, _, _ = pl.pallas_call( + functools.partial( + kernel, + pages_per_sequence=pages_per_sequence, + batch_size=batch_size, + pages_per_compute_block=pages_per_compute_block, + mask_value=mask_value, + attn_logits_soft_cap=attn_logits_soft_cap, + megacore_mode=megacore_mode, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + # There are 4 scalars prefetched per kernel call: `lengths_ref`, + # `page_indices_ref`, `buffer_index_ref`, `step_ref` + num_scalar_prefetch=4, + in_specs=in_specs, + out_specs=[ + q_block_spec, + q_block_spec, + q_block_spec, + ], + grid=grid, + scratch_shapes=scratch_shapes, + ), + # compiler_params=pltpu.TPUCompilerParams( + # dimension_semantics=dimension_semantics), + out_shape=[ + jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), + jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), + jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), + ], + )( + lengths, + page_indices.reshape(-1), + jnp.zeros((1,), jnp.int32), # buffer index + jnp.zeros((1,), jnp.int32), # step + q.astype(q_dtype_for_kernel_launch), + k_pages, + k_scales_pages, + v_pages, + v_scales_pages, + ) + return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype) diff --git a/experimental/jax/inference/kernel/attention/tpu/quantization_utils.py b/experimental/jax/inference/kernel/attention/tpu/quantization_utils.py new file mode 100644 index 00000000..d6c6ed7c --- /dev/null +++ b/experimental/jax/inference/kernel/attention/tpu/quantization_utils.py @@ -0,0 +1,112 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import NamedTuple +import jax +from jax import numpy as jnp + +P = jax.sharding.PartitionSpec +MAX_INT8 = 127 +MIN_INT8 = -128 + + +class QuantizedTensor(NamedTuple): + """A tensor which has been quantized to int8 and its scales. + + Attributes: + weight: Weight + scales: Scales + """ + + weight: jnp.ndarray + scales: jnp.ndarray + + +def to_int8(x: jnp.ndarray, h: jnp.ndarray) -> jnp.ndarray: + """Converts a float array to an int8 array with a scale. + + Args: + x: Float array. + h: Quantization scale. + + Returns: + Int8 array. + """ + x = x * h + return jnp.clip(jnp.round(x), MIN_INT8, MAX_INT8).astype(jnp.int8) + + +def from_int8( + x: jnp.ndarray, h: jnp.ndarray, dtype: jnp.dtype = jnp.bfloat16 +) -> jnp.ndarray: + """Converts an int8 array to a float array with a scale. + + Args: + x: Int8 array. + h: Quantization scale. + dtype: Float dtype to convert to. + + Returns: + Float array. + """ + x = x.astype(dtype=dtype) / h + return x.astype(dtype=dtype) + + +def get_quantization_scales(x: jnp.ndarray, axis=-1) -> jnp.ndarray: + """Computes the quantization scales for a float array. + + These are the maximum values of the trailing dimension. + + Args: + x: Float array to quantize. + + Returns: + Array of the same shape as input but with the trailing dimension reduced to + a size 1 absolute max value. + """ + scale_reciprocal = MAX_INT8 / jnp.max(jnp.abs(x), axis=axis, keepdims=True) + return scale_reciprocal.astype(jnp.float32) + + +def quantize_to_int8( + x: jnp.ndarray, + axis=-1, +) -> QuantizedTensor: + """Quantizes a float array to an int8 QuantizedTensor. + + Args: + x: Float array to quantize. + + Returns: + Int8 QuantizedTensor. + """ + x_scales = get_quantization_scales(x, axis=axis) + return QuantizedTensor(weight=to_int8(x, x_scales), scales=x_scales) + + +def unquantize_from_int8( + x: QuantizedTensor, + dtype: jnp.dtype = jnp.bfloat16, +) -> jnp.ndarray: + """Unquantizes an int8 QuantizedTensor to a float array. + + Args: + x: Int8 QuantizedTensor to unquantize. + dtype: Float dtype to unquantize to. + + Returns: + Float array. + """ + return from_int8(x.weight, x.scales, dtype) diff --git a/experimental/jax/inference/kernel/attention_ops.py b/experimental/jax/inference/kernel/attention_ops.py new file mode 100644 index 00000000..24940e94 --- /dev/null +++ b/experimental/jax/inference/kernel/attention_ops.py @@ -0,0 +1,139 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Attention kernel builder. + +TODO: change to builder pattern instead of direct function call. +""" + +import jax +from jax import numpy as jnp +from inference.kernel.attention import tpu + +K_MASK = -2.3819763e38 # Set to a large negative number. + + +def vanilla_prefill_mha(q, k, v): + """multi-head attention. + + q, k, v shape is (total_len, num_kv_heads, head_dim). + output shape is (total_len, num_kv_heads * head_dim). + """ + total_len = q.shape[0] + num_kv_heads = q.shape[1] + head_dim = q.shape[2] + causal_mask = jnp.tril(jnp.ones(shape=(total_len, total_len))).astype( + jnp.bool + ) + mask = causal_mask[:, None, :] + + wei = jnp.einsum("tkh,skh->tks", q, k) / jnp.sqrt(head_dim) + wei = jnp.where(mask, wei, K_MASK).astype(jnp.float32) + wei = jax.nn.softmax(wei, axis=-1) + out = jnp.einsum("tks,skh->tkh", wei, v) + out = out.astype(q.dtype).reshape(total_len, num_kv_heads * head_dim) + return out + + +def vanilla_prefill_mqa(q, k, v): + """multi-query attention. + + q is (total_len, num_attn, head_dim). + k/v shape is (total_len, 1, head_dim). + output shape is (total_len, num_kv_heads * head_dim). + """ + total_len = q.shape[0] + head_dim = q.shape[2] + k = jnp.squeeze(k, axis=1) + v = jnp.squeeze(v, axis=1) + causal_mask = jnp.tril(jnp.ones(shape=(total_len, total_len))).astype( + jnp.bool + ) + mask = causal_mask[:, None, :] + + wei = jnp.einsum("tah,sh->tas", q, k) / jnp.sqrt(head_dim) + wei = jnp.where(mask, wei, K_MASK) + wei = jax.nn.softmax(wei, axis=-1) + out = jnp.einsum("tas,sh->tah", wei, v) + out = out.astype(q.dtype).reshape((total_len, -1)) + return out + + +def vanilla_prefill_gqa(q, k, v): + """group-query attention. + + q shape is (total_len, num_attn_heads, head_dim). + k/v shape is (total_len, num_kv_heads, head_dim). + output shape is (total_len, num_attn_heads * head_dim). + """ + total_len, num_attn_heads, head_dim = q.shape + num_kv_heads = k.shape[1] + q = q.reshape( + (total_len, num_kv_heads, num_attn_heads // num_kv_heads, head_dim) + ) + causal_mask = jnp.tril(jnp.ones(shape=(total_len, total_len))).astype( + jnp.bool + ) + # padding_mask_row = jnp.where(jnp.arange(0, total_len) < len, 1, 0).astype(jnp.bool).reshape((-1, 1)) + # padding_mask_col = jnp.reshape(padding_mask_row, (1, -1)) + # padding_mask = jnp.logical_and(padding_mask_row, padding_mask_col) + # mask = jnp.logical_and(padding_mask, causal_mask)[:, None, None, :] + mask = causal_mask[:, None, None, :] + + wei = jnp.einsum("tkgh,skh->tkgs", q, k) / jnp.sqrt(head_dim) + wei = jnp.where(mask, wei, K_MASK) + wei = jax.nn.softmax(wei, axis=-1) + out = jnp.einsum("tkgs,skh->tkgh", wei, v) + out = out.astype(q.dtype).reshape((total_len, -1)) + return out + + +def chunked_prefill_attention( + q, cache_k, cache_v, length, page_indices, accelerator="tpu" +): + if accelerator == "tpu": + return tpu.chunked_prefill_attention( + q, + cache_k, + cache_v, + length, + page_indices, + ) + else: + raise NotImplementedError(f"not supported accelerate {accelerator}") + + +def decode_attention(q, cache_k, cache_v, pos, page_table, accelerator="tpu"): + if accelerator == "tpu": + # Heuristically set the pages per compute block. + # TODO: tune the setup. + pages_per_compute_block = 8 + _, _, _, head_dim = cache_k.shape + q = q / jnp.sqrt(head_dim) + seq_len = pos + 1 + + output = tpu.paged_attention( + q, + cache_k, + cache_v, + seq_len, + page_table, + pages_per_compute_block=pages_per_compute_block, + ) + return output + + else: + raise NotImplementedError(f"not supported accelerate {accelerator}") diff --git a/experimental/jax/inference/kernel/collective_matmul_ops.py b/experimental/jax/inference/kernel/collective_matmul_ops.py new file mode 100644 index 00000000..ca7bb261 --- /dev/null +++ b/experimental/jax/inference/kernel/collective_matmul_ops.py @@ -0,0 +1,40 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections.abc import Callable +from typing import Sequence +import functools +import jax +from inference import parallel +from .linear.tpu.collective_matmul import * + + +def build_collective_matmul( + type: parallel.CollectiveMatmulType, + axis_names: str | Sequence[str], +) -> Callable[[jax.Array, jax.Array], jax.Array]: + if type == parallel.CollectiveMatmulType.ALL_GATHER: + return functools.partial( + all_gather_collective_matmul, + axis_names=axis_names, + ) + elif type == parallel.CollectiveMatmulType.REDUCE_SCATTER: + return functools.partial( + collective_matmul_reduce_scatter, + axis_names=axis_names, + ) + else: + raise ValueError(f"Unsupported collective matmul type {type}") diff --git a/experimental/jax/inference/kernel/linear/tpu/collective_matmul.py b/experimental/jax/inference/kernel/linear/tpu/collective_matmul.py new file mode 100644 index 00000000..4bbb9ebd --- /dev/null +++ b/experimental/jax/inference/kernel/linear/tpu/collective_matmul.py @@ -0,0 +1,223 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" collective matmul for linear layer (1d ring topology). + +Reference: +https://dl.acm.org/doi/pdf/10.1145/3567955.3567959 + +The implementation is throughput-bound. Latency-bound implementation +should also be added. + +The implementation is via the JAX/XLA collective-permute API, and it doesn't +work for GPU well: https://github.com/openxla/xla/issues/10640#issuecomment-2246448416 +""" + +import jax +from jax import numpy as jnp +from jax.sharding import Mesh +from jax.experimental.shard_map import shard_map +from inference.parallel import get_num_partitions, get_partition_index, tp_axis_names + + +def prepare_rhs_for_all_gather_collective_matmul(rhs: jax.Array, mesh: Mesh): + """Prepare rhs for all gather collective matmul. + + For bidirectional collective matmul, the two lhs chunks + received from the neighbor are not contiguous in the + all-gathered lhs. Reshuffle the rhs ahead to process + the non-contiguous lhs chunks together in the collective + matmul. + """ + axis_names = tp_axis_names() + + def reshuffle(rhs: jax.Array): + idx = get_partition_index(axis_names=axis_names) + num_partitions = get_num_partitions(axis_names=axis_names) + rhs_chunk_row_size = rhs.shape[0] // num_partitions + half_rhs_chunk_row_size = rhs_chunk_row_size // 2 + + def swap(i, carry): + rhs = carry + idx_1 = ((idx + i) % num_partitions) * rhs_chunk_row_size + idx_2 = ((idx - i) % num_partitions) * rhs_chunk_row_size + operand_1 = jax.lax.dynamic_slice_in_dim( + rhs, idx_1, half_rhs_chunk_row_size, axis=0 + ) + operand_2 = jax.lax.dynamic_slice_in_dim( + rhs, idx_2, half_rhs_chunk_row_size, axis=0 + ) + rhs = jax.lax.dynamic_update_slice_in_dim(rhs, operand_1, idx_2, axis=0) + rhs = jax.lax.dynamic_update_slice_in_dim(rhs, operand_2, idx_1, axis=0) + return rhs + + rhs = jax.lax.fori_loop(1, num_partitions // 2, swap, rhs) + return rhs + + return shard_map( + f=reshuffle, + mesh=mesh, + in_specs=rhs.sharding.spec, + out_specs=rhs.sharding.spec, + )(rhs) + + +def all_gather_collective_matmul(lhs, rhs, axis_names): + """All gather collective matmul. + + The function works for matmul where the lhs is partitioned at + the contracting dimension. + """ + idx = get_partition_index(axis_names=axis_names) + num_partitions = get_num_partitions(axis_names=axis_names) + rhs_chunk_row_size = rhs.shape[0] // num_partitions + + def step(i, carry): + accum, fwd_lhs, bwd_lhs = carry + rhs_row_idx = ((idx + i) % num_partitions) * rhs_chunk_row_size + cur_lhs = jnp.concatenate((fwd_lhs, bwd_lhs), axis=1) + rhs_chunk = jax.lax.dynamic_slice_in_dim( + rhs, rhs_row_idx, rhs_chunk_row_size + ) + output = cur_lhs @ rhs_chunk + accum += output + fwd_lhs = jax.lax.ppermute( + fwd_lhs, + axis_names, + [(j, (j + 1) % num_partitions) for j in range(num_partitions)], + ) + bwd_lhs = jax.lax.ppermute( + bwd_lhs, + axis_names, + [(j, (j - 1) % num_partitions) for j in range(num_partitions)], + ) + return accum, fwd_lhs, bwd_lhs + + res_shape = (lhs.shape[0], rhs.shape[1]) + accum = jnp.zeros(shape=res_shape, dtype=lhs.dtype) + fwd_lhs, bwd_lhs = jnp.split(lhs, 2, 1) + accum, fwd_lhs, bwd_lhs = jax.lax.fori_loop( + 0, num_partitions - 1, step, (accum, fwd_lhs, bwd_lhs) + ) + + # Last round which doesn't need collective permute. + rhs_row_idx = ( + (idx + num_partitions - 1) % num_partitions + ) * rhs_chunk_row_size + cur_lhs = jnp.concatenate((fwd_lhs, bwd_lhs), axis=1) + rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, rhs_row_idx, rhs_chunk_row_size) + output = cur_lhs @ rhs_chunk + accum += output + + return accum + + +def prepare_rhs_for_collective_matmul_reduce_scatter( + rhs: jax.Array, mesh: Mesh +): + """Prepare rhs for collective matmul with reduce scatter. + + For bidirectional collective matmul, the two accum chunks + received from the neighbor are not contiguous in the + final accum. Reshuffle the rhs ahead to process + the non-contiguous accum chunks together in the collective + matmul. + """ + axis_names = tp_axis_names() + + def reshuffle(rhs): + idx = get_partition_index(axis_names=axis_names) + num_partitions = get_num_partitions(axis_names=axis_names) + rhs_chunk_col_size = rhs.shape[1] // num_partitions + half_rhs_chunk_col_size = rhs_chunk_col_size // 2 + + def swap(i, carry): + rhs = carry + idx_1 = ((idx + i) % num_partitions) * rhs_chunk_col_size + idx_2 = ((idx - i) % num_partitions) * rhs_chunk_col_size + operand_1 = jax.lax.dynamic_slice_in_dim( + rhs, idx_1, half_rhs_chunk_col_size, axis=1 + ) + operand_2 = jax.lax.dynamic_slice_in_dim( + rhs, idx_2, half_rhs_chunk_col_size, axis=1 + ) + rhs = jax.lax.dynamic_update_slice_in_dim(rhs, operand_1, idx_2, axis=1) + rhs = jax.lax.dynamic_update_slice_in_dim(rhs, operand_2, idx_1, axis=1) + return rhs + + rhs = jax.lax.fori_loop(1, num_partitions // 2, swap, rhs) + return rhs + + pspec = rhs.sharding.spec + + return shard_map( + f=reshuffle, + mesh=mesh, + in_specs=pspec, + out_specs=pspec, + )(rhs) + + +def collective_matmul_reduce_scatter(lhs, rhs, axis_names): + """Collective matmul with reduce scatter at the output column axis.""" + idx = get_partition_index(axis_names=axis_names) + num_partitions = get_num_partitions(axis_names=axis_names) + rhs_chunk_col_size = rhs.shape[1] // num_partitions + # Compute the partial result for the chip at the last step. + rhs_col_idx = ((idx + 1) % num_partitions) * rhs_chunk_col_size + rhs_chunk = jax.lax.dynamic_slice_in_dim( + rhs, + start_index=rhs_col_idx, + slice_size=rhs_chunk_col_size, + axis=1, + ) + partial_res = lhs @ rhs_chunk + accum_to_send_shape = (lhs.shape[0], rhs_chunk_col_size // 2) + fwd_accum = jnp.zeros(shape=accum_to_send_shape) + bwd_accum = jnp.zeros(shape=accum_to_send_shape) + + def step(i, carry): + fwd_accum, bwd_accum, partial_res = carry + accum = jnp.concatenate((fwd_accum, bwd_accum), axis=1) + accum += partial_res + fwd_accum, bwd_accum = jnp.split(accum, 2, axis=1) + + rhs_col_idx = ((idx + 1 + i) % num_partitions) * rhs_chunk_col_size + rhs_chunk = jax.lax.dynamic_slice_in_dim( + rhs, + start_index=rhs_col_idx, + slice_size=rhs_chunk_col_size, + axis=1, + ) + partial_res = lhs @ rhs_chunk + fwd_accum = jax.lax.ppermute( + fwd_accum, + axis_names, + [(j, (j + 1) % num_partitions) for j in range(num_partitions)], + ) + bwd_accum = jax.lax.ppermute( + bwd_accum, + axis_names, + [(j, (j - 1) % num_partitions) for j in range(num_partitions)], + ) + return fwd_accum, bwd_accum, partial_res + + fwd_accum, bwd_accum, partial_res = jax.lax.fori_loop( + 1, num_partitions, step, (fwd_accum, bwd_accum, partial_res) + ) + accum = jnp.concatenate((fwd_accum, bwd_accum), axis=1) + accum += partial_res + return accum diff --git a/experimental/jax/inference/model/__init__.py b/experimental/jax/inference/model/__init__.py new file mode 100644 index 00000000..6f7c5884 --- /dev/null +++ b/experimental/jax/inference/model/__init__.py @@ -0,0 +1,20 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from .llama import LlamaForCausalLM, LlamaModel +from .management.registry import ModelSource, ModelRegistry +from .sampling.sampler import SamplingParams +from .postprocess import ModelOutput diff --git a/experimental/jax/inference/model/llama.py b/experimental/jax/inference/model/llama.py new file mode 100644 index 00000000..84f239bc --- /dev/null +++ b/experimental/jax/inference/model/llama.py @@ -0,0 +1,382 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Llama model family.""" + +import jax +from transformers import LlamaConfig +from inference import nn +from inference import parallel +from .sampling.sampler import Sampler, SamplingParams +from inference.model.postprocess import * + + +class LlamaFeedForward(nn.Module): + + def __init__( + self, + config: LlamaConfig, + parallel_config: parallel.FeedForwardParallelConfig, + ): + super().__init__() + self.config = config + self.parallel_config = parallel_config + + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + gate_up_proj_parallel = parallel.LinearParallelConfig( + mesh=parallel_config.mesh, + parallel_type=parallel.LinearParallelType.COLUMN, + ) + + if parallel_config.enable_collective_matmul: + gate_up_proj_parallel.collective_matmul_type = ( + parallel.CollectiveMatmulType.ALL_GATHER + ) + + self.gate_up_proj = nn.Linear( + in_features=hidden_size, + out_features=[intermediate_size] * 2, + parallel_config=gate_up_proj_parallel, + ) + + down_proj_parallel = parallel.LinearParallelConfig( + mesh=parallel_config.mesh, + parallel_type=parallel.LinearParallelType.ROW, + ) + + if parallel_config.enable_collective_matmul: + down_proj_parallel.collective_matmul_type = ( + parallel.CollectiveMatmulType.REDUCE_SCATTER + ) + else: + down_proj_parallel.reduce_output = True + + self.down_proj = nn.Linear( + in_features=intermediate_size, + out_features=hidden_size, + parallel_config=down_proj_parallel, + ) + + self.silu = jax.nn.silu + + def __call__(self, x: jax.Array) -> jax.Array: + gate, up = self.gate_up_proj(x) + x = self.silu(gate) * up + x = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__( + self, + config: LlamaConfig, + parallel_config: parallel.AttentionParallelConfig, + ): + super().__init__() + self.config = config + self.parallel_config = parallel_config + self.hidden_size = config.hidden_size + self.num_attn_heads = config.num_attention_heads + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.num_attn_heads + ) + self.num_kv_heads = config.num_key_value_heads + self.num_kv_groups = self.num_attn_heads // self.num_kv_heads + self.rope_theta = config.rope_theta + self.parallel_config = parallel_config + + column_parallel_config = parallel.config.LinearParallelConfig( + mesh=parallel_config.mesh, + parallel_type=parallel.LinearParallelType.COLUMN, + ) + + self.q_proj = nn.Linear( + self.hidden_size, + self.num_attn_heads * self.head_dim, + parallel_config=column_parallel_config, + ) + + self.k_proj = nn.Linear( + self.hidden_size, + self.num_kv_heads * self.head_dim, + parallel_config=column_parallel_config, + ) + + self.v_proj = nn.Linear( + self.hidden_size, + self.num_kv_heads * self.head_dim, + parallel_config=column_parallel_config, + ) + + out_proj_parallel = parallel.config.LinearParallelConfig( + mesh=parallel_config.mesh, + parallel_type=parallel.LinearParallelType.ROW, + ) + + if parallel_config.reduce_output: + out_proj_parallel.reduce_output = True + else: + out_proj_parallel.reduce_scatter_output = True + + self.o_proj = nn.Linear( + self.num_attn_heads * self.head_dim, + self.hidden_size, + out_proj_parallel, + ) + + self.rotary_emb = nn.apply_rope_embedding + self.attn = nn.AttentionOps( + self.num_attn_heads, + self.num_kv_heads, + self.head_dim, + ) + + def __call__( + self, + hidden_states, + positions, + kv_cache: nn.KVCache, + attn_metadata: nn.AttentionMetadata, + ) -> tuple[jax.Array, nn.KVCache]: + if self.parallel_config.gather_input: + hidden_states = parallel.ops.all_gather( + hidden_states, + len(hidden_states.shape) - 1, + parallel.tp_major_axis_names(), + ) + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # reshape as (num_tokens, num_heads, head_dim) + q = q.reshape((q.shape[0], -1, self.head_dim)) + k = k.reshape((k.shape[0], -1, self.head_dim)) + v = v.reshape((v.shape[0], -1, self.head_dim)) + + q = self.rotary_emb(q, positions, self.rope_theta) + k = self.rotary_emb(k, positions, self.rope_theta) + + output, kv_cache = self.attn( + q, + k, + v, + kv_cache, + attn_metadata, + ) + + output = self.o_proj(output) + return output, kv_cache + + +class LlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + parallel_config: parallel.DecoderLayerParallelConfig, + ): + super().__init__() + self.config = config + self.parallel_config = parallel_config + mesh = parallel_config.mesh + + if parallel.platform() == "tpu": + enable_collective_matmul = True + else: + enable_collective_matmul = False + + self.self_attn = LlamaAttention( + config, + parallel.AttentionParallelConfig( + mesh=mesh, + gather_input=enable_collective_matmul, + reduce_output=(not enable_collective_matmul), + ), + ) + + self.ffw = LlamaFeedForward( + config, + parallel_config=parallel.FeedForwardParallelConfig( + mesh, enable_collective_matmul=enable_collective_matmul + ), + ) + + self.input_layernorm = nn.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + parallel_config=parallel.RMSNormParallelConfig( + mesh=mesh, + activation_shared=enable_collective_matmul, + ), + ) + + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + parallel_config=parallel.RMSNormParallelConfig( + mesh=mesh, + activation_shared=enable_collective_matmul, + ), + ) + + def __call__( + self, + hidden_states, + positions, + kv_cache, + attn_metadata: nn.AttentionMetadata, + ) -> tuple[jax.Array, nn.KVCache]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + hidden_states, kv_cache = self.self_attn( + hidden_states, positions, kv_cache, attn_metadata + ) + hidden_states += residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.ffw(hidden_states) + hidden_states += residual + + return hidden_states, kv_cache + + +class LlamaModel(nn.Model): + + def __init__( + self, + config: LlamaConfig, + parallel_config: parallel.ModelParallelConfig, + ): + super().__init__() + self.config = config + self.parallel_config = parallel_config + self.vocab_size = config.vocab_size + mesh = parallel_config.mesh + + self.embed_tokens = nn.Embedding( + vocab_size=config.vocab_size, + embedding_dim=config.hidden_size, + parallel_config=parallel.EmbeddingParallelConfig( + mesh=mesh, + parallel_type=parallel.EmbeddingParallelType.COLUMN, + ), + ) + + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + config=config, + parallel_config=parallel.DecoderLayerParallelConfig( + mesh=parallel_config.mesh, + ), + ) + for _ in range(config.num_hidden_layers) + ] + ) + + if parallel.platform() == "tpu": + enable_collective_matmul = True + else: + enable_collective_matmul = False + + self.norm = nn.RMSNorm( + config.hidden_size, + config.rms_norm_eps, + parallel_config=parallel.RMSNormParallelConfig( + mesh=mesh, + activation_shared=enable_collective_matmul, + ), + ) + lm_head_parallel = parallel.LinearParallelConfig( + mesh=mesh, + parallel_type=parallel.LinearParallelType.COLUMN, + ) + + if enable_collective_matmul: + lm_head_parallel.collective_matmul_type = ( + parallel.CollectiveMatmulType.ALL_GATHER + ) + + self.lm_head = nn.Linear( + config.hidden_size, + config.vocab_size, + parallel_config=lm_head_parallel, + ) + + def __call__( + self, + input_ids, + positions, + kv_caches, + attn_metadata, + ) -> tuple[jax.Array, list[nn.KVCache]]: + hidden_states = self.embed_tokens(input_ids) + for i in range(self.config.num_hidden_layers): + hidden_states, kv_cache = self.layers[i]( + hidden_states, + positions, + kv_caches[i], + attn_metadata, + ) + kv_caches[i] = kv_cache + + hidden_states = self.norm(hidden_states) + logits = self.lm_head(hidden_states) + + logits = parallel.ops.all_gather( + logits, axis=len(logits.shape) - 1, axis_names=parallel.tp_axis_names() + ) + + return logits, kv_caches + + +class LlamaForCausalLM(nn.CausalLM): + + def __init__( + self, + config: LlamaConfig, + parallel_config: parallel.ModelParallelConfig, + eos: int | None = None, + max_length: int | None = None, + ): + super().__init__() + self.config = config + self.model = LlamaModel(config=config, parallel_config=parallel_config) + self.sampler = Sampler(eos, max_length) + + def __call__( + self, + input_ids: jax.Array, + positions: jax.Array, + kv_caches: list[nn.KVCache], + attn_metadata: nn.AttentionMetadata, + sampling_params: SamplingParams, + ) -> tuple[ModelOutput, list[nn.KVCache]]: + logits, kv_caches = self.model( + input_ids, positions, kv_caches, attn_metadata + ) + tokens, done = self.sampler.sample( + logits, positions, attn_metadata, sampling_params + ) + + return postprocess(tokens, done, attn_metadata), kv_caches diff --git a/experimental/jax/inference/model/management/hf_llama_ckpt_conversion.py b/experimental/jax/inference/model/management/hf_llama_ckpt_conversion.py new file mode 100644 index 00000000..03bf597e --- /dev/null +++ b/experimental/jax/inference/model/management/hf_llama_ckpt_conversion.py @@ -0,0 +1,145 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Hugging Face Llama2 ckpt conversion utility function.""" + +from jax import numpy as jnp +from transformers import LlamaConfig +from .util import convert_to_jax_array + + +def merge_gate_up_proj_weights(gate, up, num_devices): + col_chunk_size = gate.shape[1] // num_devices + weight = jnp.concatenate( + [gate[:, :col_chunk_size], up[:, :col_chunk_size]], axis=-1 + ) + for i in range(1, num_devices): + index = col_chunk_size * i + weight = jnp.concatenate( + [ + weight, + gate[:, index : index + col_chunk_size], + up[:, index : index + col_chunk_size], + ], + axis=-1, + ) + + return weight + + +def convert_hf_llama(torch_weight_state, num_devices, config: LlamaConfig): + jax_weight_state = { + "embed_tokens": { + "weight": convert_to_jax_array( + torch_weight_state["model.embed_tokens.weight"] + ), + }, + "layers": {}, + "norm": { + "weight": convert_to_jax_array( + torch_weight_state["model.norm.weight"] + ), + }, + "lm_head": { + "weight": convert_to_jax_array( + torch_weight_state["lm_head.weight"].T + ), + }, + } + del torch_weight_state["model.embed_tokens.weight"] + del torch_weight_state["model.norm.weight"] + del torch_weight_state["lm_head.weight"] + for i in range(config.num_hidden_layers): + gate_up_proj_weight = merge_gate_up_proj_weights( + gate=convert_to_jax_array( + torch_weight_state[f"model.layers.{i}.mlp.gate_proj.weight"].T + ), + up=convert_to_jax_array( + torch_weight_state[f"model.layers.{i}.mlp.up_proj.weight"].T + ), + num_devices=num_devices, + ) + del torch_weight_state[f"model.layers.{i}.mlp.gate_proj.weight"] + del torch_weight_state[f"model.layers.{i}.mlp.up_proj.weight"] + jax_weight_state["layers"][i] = { + "self_attn": { + "q_proj": { + "weight": convert_to_jax_array( + torch_weight_state[ + f"model.layers.{i}.self_attn.q_proj.weight" + ].T + ), + }, + "k_proj": { + "weight": convert_to_jax_array( + torch_weight_state[ + f"model.layers.{i}.self_attn.k_proj.weight" + ].T + ), + }, + "v_proj": { + "weight": convert_to_jax_array( + torch_weight_state[ + f"model.layers.{i}.self_attn.v_proj.weight" + ].T + ), + }, + "o_proj": { + "weight": convert_to_jax_array( + torch_weight_state[ + f"model.layers.{i}.self_attn.o_proj.weight" + ].T + ), + }, + }, + "ffw": { + "gate_up_proj": { + "weight": gate_up_proj_weight, + }, + "down_proj": { + "weight": convert_to_jax_array( + torch_weight_state[ + f"model.layers.{i}.mlp.down_proj.weight" + ].T + ), + }, + }, + "input_layernorm": { + "weight": convert_to_jax_array( + torch_weight_state[f"model.layers.{i}.input_layernorm.weight"] + ), + }, + "post_attention_layernorm": { + "weight": convert_to_jax_array( + torch_weight_state[ + f"model.layers.{i}.post_attention_layernorm.weight" + ] + ), + }, + } + del torch_weight_state[f"model.layers.{i}.self_attn.q_proj.weight"] + del torch_weight_state[f"model.layers.{i}.self_attn.k_proj.weight"] + del torch_weight_state[f"model.layers.{i}.self_attn.v_proj.weight"] + del torch_weight_state[f"model.layers.{i}.self_attn.o_proj.weight"] + del torch_weight_state[f"model.layers.{i}.mlp.down_proj.weight"] + del torch_weight_state[f"model.layers.{i}.input_layernorm.weight"] + del torch_weight_state[f"model.layers.{i}.post_attention_layernorm.weight"] + + del torch_weight_state + + jax_causal_lm_weight_state = {} + jax_causal_lm_weight_state["model"] = jax_weight_state + return jax_causal_lm_weight_state diff --git a/experimental/jax/inference/model/management/registry.py b/experimental/jax/inference/model/management/registry.py new file mode 100644 index 00000000..67e62968 --- /dev/null +++ b/experimental/jax/inference/model/management/registry.py @@ -0,0 +1,181 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Model registry""" + +import enum +from typing import Any +from jax.sharding import Mesh +from jax import numpy as jnp +from transformers import ( + logging as hf_logging, + AutoModelForCausalLM, + AutoTokenizer, + AutoConfig, + PretrainedConfig, + PreTrainedTokenizer, +) +from inference import nn +from inference.model.llama import LlamaForCausalLM +from .hf_llama_ckpt_conversion import convert_hf_llama +from .util import torch_jax_dtype_map + +hf_logging.set_verbosity_error() + +supported_model = [ + "meta-llama/Llama-2-7b-chat-hf", + "meta-llama/Llama-2-7b-hf", + "maxtext/Llama2-7b", +] + +maxtext_config_map = { + "maxtext/Llama2-7b": "meta-llama/Llama-2-7b-chat-hf", +} + +hf_ckpt_conversion = { + "meta-llama/Llama-2-7b-chat-hf": convert_hf_llama, + "meta-llama/Llama-2-7b-hf": convert_hf_llama, +} + +hf_model_class_map: dict[str, nn.CausalLM] = { + "meta-llama/Llama-2-7b-chat-hf": LlamaForCausalLM, + "meta-llama/Llama-2-7b-hf": LlamaForCausalLM, +} + + +@enum.unique +class ModelSource(enum.Enum): + NATIVE = enum.auto() + HUGGINGFACE = enum.auto() + MAXTEXT = enum.auto() + + +class ModelRegistry: + + def load_model_config( + self, + model_id: str, + source: ModelSource = ModelSource.HUGGINGFACE, + ) -> PretrainedConfig: + if model_id not in supported_model: + raise ValueError(f"{model_id} is not supported") + + if source == ModelSource.MAXTEXT: + model_id = maxtext_config_map[model_id] + + config = AutoConfig.from_pretrained(model_id) + return config + + def load_tokenizer( + self, + model_id: str, + source: ModelSource = ModelSource.HUGGINGFACE, + path: str | None = None, + ) -> PreTrainedTokenizer: + if model_id not in supported_model: + raise ValueError(f"{model_id} is not supported") + + if path: + raise ValueError( + f"Load tokenizer from given path {path} is not supported" + ) + + if source == ModelSource.MAXTEXT: + model_id = maxtext_config_map[model_id] + + tokenizer = AutoTokenizer.from_pretrained(model_id) + return tokenizer + + def model_cls(self, model_id: str): + model_cls = hf_model_class_map[model_id] + return model_cls + + def load_model( + self, + mesh: Mesh, + model_id: str, + model_config: PretrainedConfig | None = None, + path: str | None = None, + source: ModelSource = ModelSource.HUGGINGFACE, + dtype: jnp.dtype = jnp.bfloat16, + ) -> tuple[nn.Module, dict]: + if model_config: + config = model_config + else: + config = self.load_model_config(model_id) + + weights_on_host = self.load_weights_to_host( + model_id, + mesh.devices.size, + path, + source, + dtype, + config, + ) + print("loaded to host") + if model not in hf_model_class_map: + raise ValueError(f"cannot find class for model {model}") + model_cls = hf_model_class_map[model] + model: nn.Module = model_cls(config, mesh) + print("loading to device") + weight_dict = model.load_weights_dict(weights_on_host) + print("loaded to device") + return model, weight_dict + + def load_weights_to_host( + self, + model_id: str, + num_devices: int, + model_config: PretrainedConfig | None = None, + path: str | None = None, + source: ModelSource = ModelSource.HUGGINGFACE, + dtype: jnp.dtype = jnp.bfloat16, + ) -> Any: + """Load the ckpt to the host DRAM.""" + + if model_id not in supported_model: + raise ValueError(f"{model_id} is not supported") + if dtype not in torch_jax_dtype_map: + raise ValueError(f"Unknown jax dtype for weight to load {dtype}") + + if source == ModelSource.HUGGINGFACE: + if model_id not in hf_ckpt_conversion: + raise ValueError( + f"No weight conversion function for HF model {model_id}" + ) + + if model_config: + config = model_config + else: + config = AutoConfig.from_pretrained(model_id) + + if not path: + hg_model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch_jax_dtype_map[dtype], + config=config, + ) + ckpt_conversion_func = hf_ckpt_conversion[model_id] + state_dict = hg_model.state_dict() + del hg_model + weights = ckpt_conversion_func(state_dict, num_devices, config) + return weights + else: + raise NotImplemented( + "Loading from path for HF model is not supported yet" + ) + + raise NotImplemented(f"Loading from {source} is not supported yet") diff --git a/experimental/jax/inference/model/management/util.py b/experimental/jax/inference/model/management/util.py new file mode 100644 index 00000000..29d80972 --- /dev/null +++ b/experimental/jax/inference/model/management/util.py @@ -0,0 +1,34 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Model management utility module """ +import jax +from jax import numpy as jnp +import torch + +torch_jax_dtype_map = { + torch.bfloat16: jnp.bfloat16, + torch.float32: jnp.float32, + jnp.bfloat16: torch.bfloat16, + jnp.float32: torch.float32, +} + + +def convert_to_jax_array(x: torch.Tensor): + return jax.device_put( + jnp.asarray(x.float().numpy(), dtype=torch_jax_dtype_map[x.dtype]), + device=jax.devices("cpu")[0], + ) diff --git a/experimental/jax/inference/model/postprocess.py b/experimental/jax/inference/model/postprocess.py new file mode 100644 index 00000000..0c1b32d3 --- /dev/null +++ b/experimental/jax/inference/model/postprocess.py @@ -0,0 +1,88 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Post-process utilities""" + +import jax +from jax import numpy as jnp +import dataclasses +from jax.sharding import NamedSharding +from inference.nn import AttentionMetadata +from inference.utils import register_flat_dataclass_as_pytree + + +@register_flat_dataclass_as_pytree +@dataclasses.dataclass +class ModelOutput: + prefill_token: jax.Array | NamedSharding + prefill_done: jax.Array | NamedSharding + prefill_next_pos: jax.Array | NamedSharding + generate_tokens: jax.Array | NamedSharding + generate_done: jax.Array | NamedSharding + generate_next_pos: jax.Array | NamedSharding + + +def postprocess( + tokens: jax.Array, done: jax.Array, attn_meta: AttentionMetadata +) -> ModelOutput: + dummy_scalar = jnp.asarray(-1, dtype=jnp.int32) + dummy_vec = jnp.asarray([-1], dtype=jnp.int32) + output = ModelOutput( + prefill_token=dummy_scalar, + prefill_done=dummy_scalar, + prefill_next_pos=dummy_scalar, + generate_tokens=dummy_vec, + generate_done=dummy_vec, + generate_next_pos=dummy_vec, + ) + + has_prefill = False + has_generate = False + if len(attn_meta.prefill_pos.shape) != 0: + has_prefill = True + if len(attn_meta.generate_pos.shape) != 0: + has_generate = True + + if has_prefill and not has_generate: + output.prefill_token = tokens[0] + output.prefill_done = done[0] + output.prefill_next_pos = attn_meta.prefill_length + + if not has_prefill and has_generate: + output.generate_tokens = tokens + output.generate_done = done + output.generate_next_pos = jnp.where( + output.generate_done, -1, attn_meta.generate_pos + 1 + ) + output.generate_next_pos = jnp.where( + output.generate_next_pos, output.generate_next_pos, -1 + ) + + if has_prefill and has_generate: + output.prefill_token = tokens[0] + output.prefill_done = done[0] + output.prefill_next_pos = attn_meta.prefill_length + + output.generate_tokens = tokens[1:] + output.generate_done = done[1:] + output.generate_next_pos = jnp.where( + output.generate_done, -1, attn_meta.generate_pos + 1 + ) + output.generate_next_pos = jnp.where( + output.generate_next_pos, output.generate_next_pos, -1 + ) + + return output diff --git a/experimental/jax/inference/model/sampling/sampler.py b/experimental/jax/inference/model/sampling/sampler.py new file mode 100644 index 00000000..ccb75ddc --- /dev/null +++ b/experimental/jax/inference/model/sampling/sampler.py @@ -0,0 +1,82 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import dataclasses +import jax +from jax import numpy as jnp +from inference.nn import AttentionMetadata +from inference.utils import register_flat_dataclass_as_pytree + + +@register_flat_dataclass_as_pytree +@dataclasses.dataclass +class SamplingParams: + temperature: jax.Array + top_k: jax.Array + rng: jax.Array + + +class Sampler: + + def __init__( + self, + eos: int | None = None, + max_length: int | None = None, + ) -> None: + self.eos = eos + self.max_length = max_length + + def sample( + self, + logits: jax.Array, # [num_tokens, vocab_size] + positions: jax.Array, + attn_metadata: AttentionMetadata, + sampling_params: SamplingParams, + ) -> jax.Array: + sampling_rng = sampling_params.rng + probabilities = jax.nn.softmax( + logits / sampling_params.temperature, axis=-1 + ) + top_k_prob, top_k_indices = jax.lax.top_k(probabilities, 1) + selected_index = jax.random.categorical(sampling_rng, top_k_prob) + + tokens = top_k_indices[ + jnp.arange(0, top_k_indices.shape[0]), selected_index + ] + done = jnp.equal(tokens, self.eos) + done = jnp.logical_or( + done, jnp.greater_equal(positions, self.max_length - 1) + ) + + if len(attn_metadata.prefill_pos.shape) == 0: + padded_prefill_len = 0 + else: + padded_prefill_len = attn_metadata.prefill_pos.shape[0] + + if len(attn_metadata.generate_pos.shape) != 0 and padded_prefill_len != 0: + prefill_token = tokens.at[attn_metadata.prefill_length - 1].get()[None] + generate_tokens = tokens.at[padded_prefill_len:].get() + + prefill_done = done.at[attn_metadata.prefill_length - 1].get()[None] + generate_done = done.at[padded_prefill_len:].get() + + tokens = jnp.concatenate((prefill_token, generate_tokens)) + done = jnp.concatenate((prefill_done, generate_done)) + elif padded_prefill_len != 0: + tokens = tokens.at[attn_metadata.prefill_length - 1].get()[None] + done = done.at[attn_metadata.prefill_length - 1].get()[None] + + return tokens, done diff --git a/experimental/jax/inference/nn/__init__.py b/experimental/jax/inference/nn/__init__.py new file mode 100644 index 00000000..ba70eda7 --- /dev/null +++ b/experimental/jax/inference/nn/__init__.py @@ -0,0 +1,22 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from .module import Module, ModuleList, Model, CausalLM +from .parameter import Parameter +from .linear import Linear +from .embedding import Embedding, apply_rope_embedding +from .norm import RMSNorm +from .attention import AttentionOps, AttentionMetadata, KVCache diff --git a/experimental/jax/inference/nn/attention.py b/experimental/jax/inference/nn/attention.py new file mode 100644 index 00000000..af28f5a3 --- /dev/null +++ b/experimental/jax/inference/nn/attention.py @@ -0,0 +1,253 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""AttentionOps Module""" + +import dataclasses +import jax +from jax import numpy as jnp +import jax.experimental +from jax.sharding import NamedSharding +from inference import kernel +from inference.nn.module import Module +from inference.utils import * + + +@register_flat_dataclass_as_pytree +@dataclasses.dataclass +class KVCache: + k: jax.Array | NamedSharding + v: jax.Array | NamedSharding + + +@register_flat_dataclass_as_pytree +@dataclasses.dataclass +class AttentionMetadata: + prefill_length: ( + jax.Array | NamedSharding + ) # shape: []; Prefill True length without padding + prefill_pos: jax.Array | NamedSharding # shape: [padded length] + prefill_page_table: jax.Array | NamedSharding # shape: [max_len // page_size] + + generate_pos: jax.Array | NamedSharding # shape: [generate_batch_size] + generate_page_table: ( + jax.Array | NamedSharding + ) # shape: [generate_batch, max_len // page_size] + + +class AttentionOps(Module): + + def __init__(self, num_attn_heads, num_kv_heads, head_dim): + super().__init__() + self.num_attn_heads = num_attn_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + def _write_prefill_kv_to_kv_cache( + self, k, v, kv_cache: KVCache, unpadded_len, page_table + ): + padded_prefill_len = k.shape[0] + num_kv_heads_per_device = kv_cache.k.shape[0] + page_size = kv_cache.k.shape[2] + num_pages = padded_prefill_len // page_size + num_pages = jnp.where(num_pages < 1, 1, num_pages) + num_active_pages, reminder = jnp.divmod(unpadded_len, page_size) + num_active_pages += jnp.where(reminder > 0, 1, 0) + + k = k.transpose((1, 0, 2)) + v = v.transpose((1, 0, 2)) + # kv shape after the change: (num_kv_heads, num_pages, page_size, head_dim) + k = k.reshape( + (num_kv_heads_per_device, -1, page_size, self.head_dim) + ).astype(kv_cache.k.dtype) + v = v.reshape( + (num_kv_heads_per_device, -1, page_size, self.head_dim) + ).astype(kv_cache.v.dtype) + + def update_cond(carry): + _, idx = carry + return idx < num_active_pages + + def per_page_update(carry): + kv_cache, idx = carry + page_k = k[:, idx, :, :][:, None, :, :] + page_v = v[:, idx, :, :][:, None, :, :] + mapped_idx = page_table[idx] + kv_cache.k = jax.lax.dynamic_update_slice_in_dim( + kv_cache.k, + page_k, + mapped_idx, + axis=1, + ) + kv_cache.v = jax.lax.dynamic_update_slice_in_dim( + kv_cache.v, + page_v, + mapped_idx, + axis=1, + ) + idx += 1 + return kv_cache, idx + + idx = 0 + kv_cache, idx = jax.lax.while_loop( + update_cond, per_page_update, (kv_cache, idx) + ) + + return kv_cache + + def _write_generate_kv_to_kv_cache(self, k, v, kv_cache, pos, page_table): + k = k.transpose((1, 0, 2)) + v = v.transpose((1, 0, 2)) + + k = k.astype(kv_cache.k.dtype) + v = v.astype(kv_cache.v.dtype) + + num_tokens = k.shape[1] + num_kv_heads_per_device, num_pages, page_size, head_dim = kv_cache.k.shape + page_idx, offset = jnp.divmod(pos, page_size) + page_to_update = page_table[jnp.arange(0, num_tokens), page_idx] + + mapped_page_to_update = page_to_update * page_size + offset + mapped_page_to_update = jnp.tile( + mapped_page_to_update, num_kv_heads_per_device + ) + + kv_heads_axis_stride = ( + jnp.repeat(jnp.arange(0, num_kv_heads_per_device), num_tokens) + * num_pages + * page_size + ) + mapped_page_to_update = kv_heads_axis_stride + mapped_page_to_update + + k = k.reshape(-1, head_dim) + v = v.reshape(-1, head_dim) + + kv_cache.k = kv_cache.k.reshape(-1, head_dim) + kv_cache.v = kv_cache.v.reshape(-1, head_dim) + + kv_cache.k = kv_cache.k.at[mapped_page_to_update, :].set(k) + kv_cache.v = kv_cache.v.at[mapped_page_to_update, :].set(v) + + kv_cache.k = kv_cache.k.reshape( + num_kv_heads_per_device, num_pages, page_size, head_dim + ) + kv_cache.v = kv_cache.v.reshape( + num_kv_heads_per_device, num_pages, page_size, head_dim + ) + + return kv_cache + + def _prefill( + self, q, k, v, kv_cache: KVCache, attn_metadata: AttentionMetadata + ): + kv_cache = self._write_prefill_kv_to_kv_cache( + k, + v, + kv_cache, + attn_metadata.prefill_length, + attn_metadata.prefill_page_table, + ) + output = kernel.chunked_prefill_attention( + q, + kv_cache.k, + kv_cache.v, + attn_metadata.prefill_length, + attn_metadata.prefill_page_table, + ) + # if self.num_attn_heads == self.num_kv_heads: + # output = kernel.vanilla_prefill_mha(q, k, v, attn_metadata.prefill_length) + # elif self.num_kv_heads == 1: + # output = kernel.vanilla_prefill_mqa(q, k, v, attn_metadata.prefill_length) + # else: + # output = kernel.vanilla_prefill_gqa(q, k, v, attn_metadata.prefill_length) + return output, kv_cache + + def _generate( + self, q, k, v, kv_cache: KVCache, attn_metadata: AttentionMetadata + ): + kv_cache = self._write_generate_kv_to_kv_cache( + k, + v, + kv_cache, + attn_metadata.generate_pos, + attn_metadata.generate_page_table, + ) + + batch = q.shape[0] + + output = kernel.decode_attention( + q, + kv_cache.k, + kv_cache.v, + attn_metadata.generate_pos, + attn_metadata.generate_page_table, + ) + + output = output.reshape((batch, -1)) + + return output, kv_cache + + def _mixed_prefill_generate( + self, q, k, v, kv_cache: KVCache, attn_metadata: AttentionMetadata + ): + total_len, num_attn_heads_per_device, head_dim = q.shape + output = jnp.zeros( + shape=(total_len, num_attn_heads_per_device * head_dim), + dtype=q.dtype, + ) + padded_prompt_length = attn_metadata.prefill_pos.shape[0] + prefill_output, kv_cache = self._prefill( + q[:padded_prompt_length, :, :], + k[:padded_prompt_length, :, :], + v[:padded_prompt_length, :, :], + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + generate_output, kv_cache = self._generate( + q[padded_prompt_length:, :, :], + k[padded_prompt_length:, :, :], + v[padded_prompt_length:, :, :], + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + output = jax.lax.dynamic_update_slice_in_dim( + output, + prefill_output, + start_index=0, + axis=0, + ) + + output = jax.lax.dynamic_update_slice_in_dim( + output, + generate_output, + start_index=padded_prompt_length, + axis=0, + ) + + return output, kv_cache + + def __call__( + self, q, k, v, kv_cache: KVCache, attn_metadata: AttentionMetadata + ): + # q, k, v has shape as (tokens, num_heads/devices, head_dim) + if len(attn_metadata.generate_pos.shape) == 0: + return self._prefill(q, k, v, kv_cache, attn_metadata) + elif len(attn_metadata.prefill_pos.shape) == 0: + return self._generate(q, k, v, kv_cache, attn_metadata) + else: + return self._mixed_prefill_generate(q, k, v, kv_cache, attn_metadata) diff --git a/experimental/jax/inference/nn/embedding.py b/experimental/jax/inference/nn/embedding.py new file mode 100644 index 00000000..431a6c3a --- /dev/null +++ b/experimental/jax/inference/nn/embedding.py @@ -0,0 +1,66 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Embedding Module""" + +from jax import numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +from inference.nn import Module, Parameter +from inference.parallel import EmbeddingParallelType +from inference import parallel + + +class Embedding(Module): + + def __init__( + self, + vocab_size, + embedding_dim, + parallel_config: parallel.EmbeddingParallelConfig, + ): + super().__init__() + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.parallel_config = parallel_config + + self.weight = Parameter( + jnp.zeros((vocab_size, embedding_dim), dtype=jnp.bfloat16) + ) + + mesh = parallel_config.mesh + if parallel_config.parallel_type == EmbeddingParallelType.COLUMN: + weight_pspec = P(None, parallel.tp_axis_names()) + else: + weight_pspec = P(None, None) + + self.weight.sharding = NamedSharding(mesh, weight_pspec) + + def __call__(self, input): + return self.weight.value[input] + + +def apply_rope_embedding(input, position, theta=10000): + emb_dim = input.shape[-1] + fraction = jnp.arange(0, emb_dim, 2) / emb_dim + timescale = theta**fraction + position = position[:, None, None] + sinusoid_inp = position / timescale + sin = jnp.sin(sinusoid_inp).astype(input.dtype) + cos = jnp.cos(sinusoid_inp).astype(input.dtype) + first_half, second_half = jnp.split(input, 2, axis=-1) + first_part = first_half * cos - second_half * sin + second_part = second_half * cos + first_half * sin + return jnp.concatenate((first_part, second_part), axis=-1) diff --git a/experimental/jax/inference/nn/linear.py b/experimental/jax/inference/nn/linear.py new file mode 100644 index 00000000..b0e87756 --- /dev/null +++ b/experimental/jax/inference/nn/linear.py @@ -0,0 +1,143 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless reuired by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Linear Module + +Consider break down the Linear layer module by the sharding strategy for +better readability. +""" + +import logging +from typing import Sequence +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +from inference.nn import Module, Parameter +from inference import kernel +from inference import parallel +from inference.parallel import LinearParallelConfig, LinearParallelType + + +class Linear(Module): + + def __init__( + self, + in_features: int, + out_features: int | Sequence[int], + parallel_config: LinearParallelConfig, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.parallel_config = parallel_config + + self._num_merged = ( + len(out_features) if isinstance(out_features, Sequence) else 1 + ) + out_features_sum = ( + sum(out_features) if self._num_merged > 1 else out_features + ) + + self.weight = Parameter( + value=jnp.zeros((in_features, out_features_sum), dtype=jnp.bfloat16) + ) + + axis_names = parallel.tp_axis_names() + if parallel_config.parallel_type == LinearParallelType.COLUMN: + weight_pspec = P(None, axis_names) + elif parallel_config.parallel_type == LinearParallelType.ROW: + weight_pspec = P(axis_names, None) + else: + weight_pspec = P(None, None) + + self.weight.sharding = NamedSharding(parallel_config.mesh, weight_pspec) + + collective_matmul_type = parallel_config.collective_matmul_type + if collective_matmul_type: + self._collective_matmul = kernel.build_collective_matmul( + collective_matmul_type, + axis_names, + ) + + def __call__(self, input): + if ( + self.parallel_config.collective_matmul_type != None + and self._collective_matmul + ): + output = self._collective_matmul(input, self.weight.value) + if self._num_merged > 1: + return jnp.split(output, self._num_merged, 1) + return output + + preferred_type = input.dtype + + output = jnp.matmul( + input, self.weight.value, preferred_element_type=preferred_type + ) + + output = output.astype(input.dtype) + + parallel_config = self.parallel_config + axis_names = parallel.tp_axis_names() + if parallel_config.reduce_output: + output = parallel.ops.all_reduce(output, axis_names) + elif parallel_config.reduce_scatter_output: + output = jax.lax.psum_scatter( + output, axis_names, scatter_dimension=1, tiled=True + ) + + if self._num_merged > 1: + return jnp.split(output, self._num_merged, 1) + + return output + + def load_weights_dict(self, weights_dict): + res = {} + for k, v in weights_dict.items(): + attr = getattr(self, k) + if isinstance(attr, Parameter): + param = self._parameters[k] + if v.shape != param.shape: + logging.warning( + f"Not matched shape" + + f": defined {param.shape}," + + f"loaded {v.shape} for module {self.__class__}" + ) + param.value = v + if isinstance(param.sharding, NamedSharding): + param.to_device() + + if ( + self.parallel_config.collective_matmul_type + == parallel.CollectiveMatmulType.ALL_GATHER + ): + param.value = kernel.prepare_rhs_for_all_gather_collective_matmul( + param.value, self.parallel_config.mesh + ) + elif ( + self.parallel_config.collective_matmul_type + == parallel.CollectiveMatmulType.REDUCE_SCATTER + ): + param.value = kernel.prepare_rhs_for_collective_matmul_reduce_scatter( + param.value, self.parallel_config.mesh + ) + res[k] = param.value + else: + logging.warning( + f"Unknown checkpoint key {k} for module {self.__class__}" + ) + + return res diff --git a/experimental/jax/inference/nn/module.py b/experimental/jax/inference/nn/module.py new file mode 100644 index 00000000..cb0b89f3 --- /dev/null +++ b/experimental/jax/inference/nn/module.py @@ -0,0 +1,223 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Simple NN Module. + +TODO: migrate to Flax NNX. +""" + +import logging +from typing import Any +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding +from inference.nn.parameter import Parameter + + +class Module: + """Simple NN module.""" + + def __init__(self): + self._parameters: dict[str, Parameter] = {} + self._submodules: dict[str, Module] = {} + + def __setattr__(self, name: str, value: Any): + if isinstance(value, Parameter): + self._parameters[name] = value + elif isinstance(value, Module | ModuleList): + self._submodules[name] = value + else: + self.__dict__[name] = value + + def __getattr__(self, name: str): + if name in self._parameters: + return self._parameters[name] + elif name in self._submodules: + return self._submodules[name] + elif name in self.__dict__: + return self.__dict__[name] + return None + + def init_weights(self): + res = {} + rng = jax.random.key(0) + for k, param in self._parameters.items(): + param.value = jax.random.uniform(rng, param.shape, dtype=jnp.bfloat16) + param.to_device() + res[k] = param.value + + for k, module in self._submodules.items(): + res[k] = module.init_weights() + return res + + def load_weights_dict(self, weights_dict): + res = {} + for k, v in weights_dict.items(): + attr = getattr(self, k) + if isinstance(attr, Parameter): + param = self._parameters[k] + if v.shape != param.shape: + logging.warning( + f"Not matched shape" + + f": defined {param.shape}," + + f"loaded {v.shape} for module {self.__class__}" + ) + param.value = v + if isinstance(param.sharding, NamedSharding): + param.to_device() + res[k] = param.value + elif isinstance(attr, Module) or isinstance(attr, ModuleList): + sub_weights_dict = self._submodules[k].load_weights_dict(v) + res[k] = sub_weights_dict + else: + logging.warning( + f"Unknown checkpoint key {k} for module {self.__class__}" + ) + return res + + def _weights_assignment_in_jit(self, weights_dict): + for key in self._parameters: + param = self._parameters[key] + if isinstance(param, Parameter): + param.value = weights_dict[key] + for key, module in self._submodules.items(): + if key in weights_dict: + module._weights_assignment_in_jit(weights_dict[key]) + + def jittable_call(self, weights_dict, *args): + self._weights_assignment_in_jit(weights_dict) + return self(*args) + + # Following methods "__repr__", "_repr_with_indent" and "_spec" are + # for debugging purpose which provide a clean string representation + # for the model. + def _spec(self) -> str: + return "" + + def _repr_with_indent(self, indent) -> str: + indent = indent + " " + if len(self._parameters) == 0 and len(self._submodules) == 0: + return "{}" + res = "{" + for k, v in self._parameters.items(): + res += "\n" + (indent) + f"'{k}': {v}" + + for k, v in self._submodules.items(): + res += ( + "\n" + + (indent) + + f"'{k}': <{v.__class__.__name__}{v._spec()}> {v._repr_with_indent(indent)}" + ) + + res += "\n" + indent[:-2] + "}" + return res + + def __repr__(self) -> str: + return "\n" + self._repr_with_indent("") + + +class ModuleList: + + def __init__(self, modules: list[Module]) -> None: + self._modules: dict[int, Module] = {} + for i, m in enumerate(modules): + self._modules[i] = m + + def __getitem__(self, key): + return self._modules[key] + + def __setitem__(self, key, value): + self._modules[key] = value + + def _spec(self) -> str: + return "" + + def _repr_with_indent(self, indent) -> str: + indent = indent + " " + if len(self._modules) == 0: + return "{}" + res = "{" + + for k, v in self._modules.items(): + res += ( + "\n" + + (indent) + + f"'{k}': <{v.__class__.__name__}{v._spec()}> {v._repr_with_indent(indent)}" + ) + + res += "\n" + indent[:-2] + "}" + return res + + def __repr__(self) -> str: + return "\n" + self._repr_with_indent("") + + def init_weights(self): + res = {} + for k, module in self._modules.items(): + if isinstance(k, int): + res[k] = module.init_weights() + else: + logging.warning(f"Unknown checkpoint key {k} for module list") + return res + + def load_weights_dict(self, weights_dict): + res = {} + for k, v in weights_dict.items(): + if isinstance(k, int): + ws = self._modules[k].load_weights_dict(v) + res[k] = ws + else: + logging.warning(f"Unknown checkpoint key {k} for module list") + return res + + def _weights_assignment_in_jit(self, weights_dict): + for layer_num, module in self._modules.items(): + module._weights_assignment_in_jit(weights_dict[layer_num]) + + +class Model(Module): + + def jittable_call( + self, + weights_dict, + input_ids, + positions, + kv_caches, + attn_metadata, + ) -> tuple[jax.Array, list[Any]]: + self._weights_assignment_in_jit(weights_dict) + return self(input_ids, positions, kv_caches, attn_metadata) + + +class CausalLM(Module): + + def jittable_call( + self, + weights_dict, + input_ids: jax.Array, + positions: jax.Array, + kv_caches: Any, + attn_metadata: Any, + sampling_params: Any, + ) -> tuple[jax.Array, list[Any]]: + self._weights_assignment_in_jit(weights_dict) + return self( + input_ids, + positions, + kv_caches, + attn_metadata, + sampling_params, + ) diff --git a/experimental/jax/inference/nn/norm.py b/experimental/jax/inference/nn/norm.py new file mode 100644 index 00000000..184d082b --- /dev/null +++ b/experimental/jax/inference/nn/norm.py @@ -0,0 +1,55 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Norm Module""" + +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +from inference.nn import Module, Parameter +from inference import parallel + + +class RMSNorm(Module): + + def __init__(self, dim, eps, parallel_config: parallel.RMSNormParallelConfig): + super().__init__() + self.dim = dim + self.eps = eps + self.parallel_config = parallel_config + + self.weight = Parameter(jnp.zeros((dim,))) + self.variance_epsilon = eps + + mesh = parallel_config.mesh + if parallel_config.activation_shared: + self.weight.sharding = NamedSharding(mesh, P(parallel.tp_axis_names())) + else: + self.weight.sharding = NamedSharding(mesh, P(None)) + + def __call__(self, input): + input_dtype = input.dtype + input = input.astype(jnp.float32) + variance = jnp.mean(jax.lax.square(input), axis=-1, keepdims=True) + + if self.parallel_config.activation_shared: + axis_names = parallel.tp_axis_names() + variance = parallel.ops.all_reduce( + variance, axis_names + ) / parallel.get_num_partitions(axis_names) + + input = input * jax.lax.rsqrt(variance + self.variance_epsilon) + return self.weight * input.astype(input_dtype) diff --git a/experimental/jax/inference/nn/parameter.py b/experimental/jax/inference/nn/parameter.py new file mode 100644 index 00000000..bd74d135 --- /dev/null +++ b/experimental/jax/inference/nn/parameter.py @@ -0,0 +1,330 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Simple Parameter. +The code is mainly from the flax.nnx.variables. +TODO: migrate to Flax nnx. +""" + +import jax +from jax import numpy as jnp +import typing as tp +from typing import Any +from jax.sharding import NamedSharding + +A = tp.TypeVar("A") +V = tp.TypeVar("V", bound="Parameter[Any]") + + +class Parameter: + """Parameter class. + + It composes of the jax.Array and routes the computation + function to the jax.Array itself. + The code is mainly from the flax.nnx.variables. + TODO: migrate to Flax nnx. + """ + + def __init__(self, value: jax.Array): + self.value = jnp.ones((0,)) + self._defined_shape = value.shape + self._defined_dtype = value.dtype + self._defined_sharding = value.sharding + + def _shape(self): + return self._defined_shape + + def _set_shape(self, shape: tuple): + self._defined_shape = shape + + def _sharding(self): + return self._defined_sharding + + def _set_sharding(self, sharding: NamedSharding): + self._defined_sharding = sharding + + def _dtype(self): + return self._defined_dtype + + def _set_dtype(self, dtype): + self._dtype = dtype + + shape = property(_shape, _set_shape) + dtype = property(_dtype, _set_dtype) + sharding = property(_sharding, _set_sharding) + + def to_device(self): + self.value = jax.device_put(self.value, self._defined_sharding) + return self + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(shape: {self._defined_shape}, dtype: {self._defined_dtype}, sharding: {self._defined_sharding})" + + def __setattr__(self, name: str, value: Any) -> None: + object.__setattr__(self, name, value) + + # -------------------------------------------- + # proxy methods + # -------------------------------------------- + # NOTE: we dont override __setattr__ to avoid cases where + # you need to set an attribute on the variable instance + # def __getattr__(self, name: str) -> tp.Any: + # print("zhihaoshan",name) + # if name == "sharding": + # return self.sharding + # return getattr(self.value, name) + + def __getitem__(self, key) -> tp.Any: + return self.value[key] # type: ignore + + def __setitem__(self, key, value) -> None: + self.value[key] = value # type: ignore + + def __call__(self, *args, **kwargs) -> tp.Any: + return self.value(*args, **kwargs) # type: ignore + + def __len__(self) -> int: + return len(self.value) # type: ignore + + def __iter__(self) -> tp.Iterator: + return iter(self.value) # type: ignore + + def __contains__(self, item) -> bool: + return item in self.value # type: ignore + + def __add__(self, other) -> A: + return self.value.__add__(other) # type: ignore + + def __sub__(self, other) -> A: + return self.value.__sub__(other) # type: ignore + + def __mul__(self, other) -> A: + return self.value.__mul__(other) # type: ignore + + def __matmul__(self, other) -> A: + return self.value.__matmul__(other) # type: ignore + + def __truediv__(self, other) -> A: + return self.value.__truediv__(other) # type: ignore + + def __floordiv__(self, other) -> A: + return self.value.__floordiv__(other) # type: ignore + + def __mod__(self, other) -> A: + return self.value.__mod__(other) # type: ignore + + def __divmod__(self, other) -> A: + return self.value.__divmod__(other) # type: ignore + + def __pow__(self, other) -> A: + return self.value.__pow__(other) # type: ignore + + def __lshift__(self, other) -> A: + return self.value.__lshift__(other) # type: ignore + + def __rshift__(self, other) -> A: + return self.value.__rshift__(other) # type: ignore + + def __and__(self, other) -> A: + return self.value.__and__(other) # type: ignore + + def __xor__(self, other) -> A: + return self.value.__xor__(other) # type: ignore + + def __or__(self, other) -> A: + return self.value.__or__(other) # type: ignore + + def __radd__(self, other) -> A: + return self.value.__radd__(other) # type: ignore + + def __rsub__(self, other) -> A: + return self.value.__rsub__(other) # type: ignore + + def __rmul__(self, other) -> A: + return self.value.__rmul__(other) # type: ignore + + def __rmatmul__(self, other) -> A: + return self.value.__rmatmul__(other) # type: ignore + + def __rtruediv__(self, other) -> A: + return self.value.__rtruediv__(other) # type: ignore + + def __rfloordiv__(self, other) -> A: + return self.value.__rfloordiv__(other) # type: ignore + + def __rmod__(self, other) -> A: + return self.value.__rmod__(other) # type: ignore + + def __rdivmod__(self, other) -> A: + return self.value.__rdivmod__(other) # type: ignore + + def __rpow__(self, other) -> A: + return self.value.__rpow__(other) # type: ignore + + def __rlshift__(self, other) -> A: + return self.value.__rlshift__(other) # type: ignore + + def __rrshift__(self, other) -> A: + return self.value.__rrshift__(other) # type: ignore + + def __rand__(self, other) -> A: + return self.value.__rand__(other) # type: ignore + + def __rxor__(self, other) -> A: + return self.value.__rxor__(other) # type: ignore + + def __ror__(self, other) -> A: + return self.value.__ror__(other) # type: ignore + + def __iadd__(self: V, other) -> V: + value = self.value + if hasattr(value, "__iadd__"): + value.__iadd__(other) + else: + self.value = value.__add__(other) + return self + + def __isub__(self: V, other) -> V: + value = self.value + if hasattr(value, "__isub__"): + value.__isub__(other) + else: + self.value = value.__sub__(other) + return self + + def __imul__(self: V, other) -> V: + value = self.value + if hasattr(value, "__imul__"): + value.__imul__(other) + else: + self.value = value.__mul__(other) + return self + + def __imatmul__(self: V, other) -> V: + value = self.value + if hasattr(value, "__imatmul__"): + value.__imatmul__(other) + else: + self.value = value.__matmul__(other) + return self + + def __itruediv__(self: V, other) -> V: + value = self.value + if hasattr(value, "__itruediv__"): + value.__itruediv__(other) + else: + self.value = value.__truediv__(other) + return self + + def __ifloordiv__(self: V, other) -> V: + value = self.value + if hasattr(value, "__ifloordiv__"): + value.__ifloordiv__(other) + else: + self.value = value.__floordiv__(other) + return self + + def __imod__(self: V, other) -> V: + value = self.value + if hasattr(value, "__imod__"): + value.__imod__(other) + else: + self.value = value.__mod__(other) + return self + + def __ipow__(self: V, other) -> V: + value = self.value + if hasattr(value, "__ipow__"): + value.__ipow__(other) + else: + self.value = value.__pow__(other) + return self + + def __ilshift__(self: V, other) -> V: + value = self.value + if hasattr(value, "__ilshift__"): + value.__ilshift__(other) + else: + self.value = value.__lshift__(other) + return self + + def __irshift__(self: V, other) -> V: + value = self.value + if hasattr(value, "__irshift__"): + value.__irshift__(other) + else: + self.value = value.__rshift__(other) + return self + + def __iand__(self: V, other) -> V: + value = self.value + if hasattr(value, "__iand__"): + value.__iand__(other) + else: + self.value = value.__and__(other) + return self + + def __ixor__(self: V, other) -> V: + value = self.value + if hasattr(value, "__ixor__"): + value.__ixor__(other) + else: + self.value = value.__xor__(other) + return self + + def __ior__(self: V, other) -> V: + value = self.value + if hasattr(value, "__ior__"): + value.__ior__(other) + else: + self.value = value.__or__(other) + return self + + def __neg__(self) -> A: + return self.value.__neg__() # type: ignore + + def __pos__(self) -> A: + return self.value.__pos__() # type: ignore + + def __abs__(self) -> A: + return self.value.__abs__() # type: ignore + + def __invert__(self) -> A: + return self.value.__invert__() # type: ignore + + def __complex__(self) -> A: + return self.value.__complex__() # type: ignore + + def __int__(self) -> A: + return self.value.__int__() # type: ignore + + def __float__(self) -> A: + return self.value.__float__() # type: ignore + + def __index__(self) -> A: + return self.value.__index__() # type: ignore + + def __round__(self, ndigits: int) -> A: + return self.value.__round__(ndigits) # type: ignore + + def __trunc__(self) -> A: + return self.value.__trunc__() # type: ignore + + def __floor__(self) -> A: + return self.value.__floor__() # type: ignore + + def __ceil__(self) -> A: + return self.value.__ceil__() # type: ignore diff --git a/experimental/jax/inference/parallel/__init__.py b/experimental/jax/inference/parallel/__init__.py new file mode 100644 index 00000000..bb111c7f --- /dev/null +++ b/experimental/jax/inference/parallel/__init__.py @@ -0,0 +1,21 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from .mesh import create_device_mesh, get_num_partitions, get_partition_index +from . import operations as ops +from .config import * +from .device import * +from .util import get_partition_spec diff --git a/experimental/jax/inference/parallel/config.py b/experimental/jax/inference/parallel/config.py new file mode 100644 index 00000000..552b7590 --- /dev/null +++ b/experimental/jax/inference/parallel/config.py @@ -0,0 +1,135 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import enum +import dataclasses +import jax + + +# TODO: change to enum.StrEnum after Python upgrading in Google Cloud TPU. +@enum.unique +class ParallelAxis(enum.Enum): + # X is used for Tensor Parallelism (as major axis) and Expert Parallelism. + X = enum.auto() + # Y is used for Sequence, Sequence Pipeline, Decode Batch and Tensor Parallelism (as minor axis). + Y = enum.auto() + + +def dp_axis_names(): + return (ParallelAxis.Y.name,) + + +def sp_axis_names(): + return (ParallelAxis.Y.name,) + + +def spp_axis_names(): + return (ParallelAxis.Y.name,) + + +def tp_axis_names(): + return (ParallelAxis.X.name, ParallelAxis.Y.name) + + +def tp_major_axis_names(): + return (ParallelAxis.X.name,) + + +def tp_minor_axis_names(): + return (ParallelAxis.Y.name,) + + +@enum.unique +class ModelParallelStrategy(enum.Enum): + """Overall Transformer Parallel Strategy.""" + + TENSOR_PARALLEL = enum.auto() + + +@dataclasses.dataclass +class ModelParallelConfig: + mesh: jax.sharding.Mesh + parallel_type: ModelParallelStrategy = ModelParallelStrategy.TENSOR_PARALLEL + + +@enum.unique +class FFWParallelStrategy(enum.Enum): + """Overall Transformer FFW Layer Parallel Strategy. + + Please refer to https://arxiv.org/pdf/2211.05102""" + + ONE_D_WEIGHT_STATIONARY = enum.auto() + + +@dataclasses.dataclass +class FeedForwardParallelConfig: + mesh: jax.sharding.Mesh + parallel_type: FFWParallelStrategy | None = None + enable_collective_matmul: bool = False + + +@dataclasses.dataclass +class AttentionParallelConfig: + mesh: jax.sharding.Mesh + gather_input: bool = False + reduce_output: bool = False + + +@enum.unique +class LinearParallelType(enum.Enum): + """Parallel Type for Linear Layer weight.""" + + ROW = enum.auto() + COLUMN = enum.auto() + + +@enum.unique +class CollectiveMatmulType(enum.Enum): + ALL_GATHER = enum.auto() + REDUCE_SCATTER = enum.auto() + + +@dataclasses.dataclass +class LinearParallelConfig: + mesh: jax.sharding.Mesh + parallel_type: LinearParallelType | None = None + reduce_scatter_output: bool = False + reduce_output: bool = False + collective_matmul_type: CollectiveMatmulType | None = None + + +@dataclasses.dataclass +class RMSNormParallelConfig: + mesh: jax.sharding.Mesh + activation_shared: bool = False + + +@enum.unique +class EmbeddingParallelType(enum.Enum): + """Parallel Type for Embedding Layer weight.""" + + COLUMN = enum.auto() + + +@dataclasses.dataclass +class EmbeddingParallelConfig: + mesh: jax.sharding.Mesh + parallel_type: EmbeddingParallelType | None = None + + +@dataclasses.dataclass +class DecoderLayerParallelConfig: + mesh: jax.sharding.Mesh diff --git a/experimental/jax/inference/parallel/device.py b/experimental/jax/inference/parallel/device.py new file mode 100644 index 00000000..a55a02a8 --- /dev/null +++ b/experimental/jax/inference/parallel/device.py @@ -0,0 +1,21 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import jax + + +def platform(): + return jax.lib.xla_bridge.get_backend().platform diff --git a/experimental/jax/inference/parallel/mesh.py b/experimental/jax/inference/parallel/mesh.py new file mode 100644 index 00000000..8159d26e --- /dev/null +++ b/experimental/jax/inference/parallel/mesh.py @@ -0,0 +1,107 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""mesh module""" + +from collections.abc import Iterable +from typing import Sequence +import jax +from jax.sharding import Mesh +from jax.experimental import mesh_utils as jax_mesh_utils +from jaxlib.xla_client import Device +import numpy as np +from inference.parallel.config import ParallelAxis +from inference.parallel.device import platform + + +def create_device_mesh( + devices: Sequence[Device], + shape: int | Sequence[int], +) -> Mesh: + """Create a powerful mesh given the devices and shape. + + For fully connected topology, the devices topology defined in the mesh + usually not affect the performance. + For other cases, the devices topology defined in the mesh + will affect the performance. (it depends on the collective algorithm + implementation.) + """ + if len(devices) == 0: + raise ValueError("no devices is provided for mesh creation") + + axis_names = ( + ParallelAxis.X.name, + ParallelAxis.Y.name, + ) + + if not isinstance(shape, Sequence): + shape = (shape,) + + shape = tuple(list(shape) + [1 for _ in range(len(axis_names) - len(shape))]) + + if len(shape) != len(axis_names): + raise ValueError( + f"The number of mesh dimensions {len(shape)}" + + "doesn't match with number of axis_names {axis_names}" + ) + if platform == "gpu" or platform == "cpu": + devices = jax_mesh_utils.create_device_mesh( + mesh_shape=shape, + devices=devices, + allow_split_physical_axes=True, + ) + return Mesh(devices=devices, axis_names=axis_names) + + # TODO: Figure out a general method. + # Current mesh builder is very limited. + # only support (2,x) underlying topology shape to . + # form a 1D ring for the devices. + devices = devices[::2] + devices[1::2][::-1] + return Mesh(devices=np.reshape(devices, shape), axis_names=axis_names) + + +def get_num_partitions(axis_names): + """Get the total number of partitions across the axis. + Args: + axis_names: the name of the axis where the partition is + relative to. If the number of axis is greater than 1, + the axis names need to follow the "major to minor" order + as defined in the mesh setup for consistency. + """ + if not isinstance(axis_names, Iterable): + return jax.lax.psum(1, axis_name=axis_names) + product = 1 + for axis in axis_names[::-1]: + product = product * jax.lax.psum(1, axis_name=axis) + return product + + +def get_partition_index(axis_names): + """Get the partition index across the axis for the device. + Args: + axis_names: the names of the axis where the partition index is + relative to. If the number of axis is greater than 1, + the axis names need to follow the "major to minor" order + as defined in the mesh setup for consistency. + """ + if not isinstance(axis_names, Iterable): + return jax.lax.axis_index(axis_name=axis_names) + cur_idx = 0 + multiplicand = 1 + for axis in axis_names[::-1]: + cur_idx += jax.lax.axis_index(axis) * multiplicand + multiplicand = multiplicand * jax.lax.psum(1, axis_name=axis) + return cur_idx diff --git a/experimental/jax/inference/parallel/operations.py b/experimental/jax/inference/parallel/operations.py new file mode 100644 index 00000000..f87c34ec --- /dev/null +++ b/experimental/jax/inference/parallel/operations.py @@ -0,0 +1,91 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Basic Collective Operations.""" + +import jax +from jax import numpy as jnp +from .mesh import get_num_partitions, get_partition_index + + +def reduce_scatter(operand, scatter_dimension, axis_names): + """reduce-scatter sum operation via ppermute.""" + idx = get_partition_index(axis_names=axis_names) + num_partitions = get_num_partitions(axis_names=axis_names) + chunk_size = operand.shape[scatter_dimension] // num_partitions + half_chunk_size = chunk_size // 2 + half_accum_shape = ( + operand.shape[:scatter_dimension] + + (half_chunk_size,) + + operand.shape[scatter_dimension + 1 :] + ) + + def step(i, carry): + accum_fwd, accum_bwd, p_fwd_res, p_bwd_res = carry + accum_fwd += p_fwd_res + accum_bwd += p_bwd_res + + fwd_idx = ((idx - i - 1) % num_partitions) * chunk_size + bwd_idx = ((idx + i + 1) % num_partitions) * chunk_size + half_chunk_size + p_fwd_res = jax.lax.dynamic_slice_in_dim( + operand, fwd_idx, half_chunk_size, scatter_dimension + ) + p_bwd_res = jax.lax.dynamic_slice_in_dim( + operand, bwd_idx, half_chunk_size, scatter_dimension + ) + + accum_fwd = jax.lax.ppermute( + accum_fwd, + axis_name=axis_names, + perm=[(j, (j + 1) % num_partitions) for j in range(num_partitions)], + ) + accum_bwd = jax.lax.ppermute( + accum_bwd, + axis_name=axis_names, + perm=[(j, (j - 1) % num_partitions) for j in range(num_partitions)], + ) + return accum_fwd, accum_bwd, p_fwd_res, p_bwd_res + + accum_fwd = jnp.zeros(half_accum_shape, dtype=operand.dtype) + accum_bwd = jnp.zeros(half_accum_shape, dtype=operand.dtype) + initial_fwd_idx = ((idx - 1) % num_partitions) * chunk_size + initial_bwd_idx = ((idx + 1) % num_partitions) * chunk_size + half_chunk_size + p_fwd_res = jax.lax.dynamic_slice_in_dim( + operand, initial_fwd_idx, half_chunk_size, scatter_dimension + ) + p_bwd_res = jax.lax.dynamic_slice_in_dim( + operand, initial_bwd_idx, half_chunk_size, scatter_dimension + ) + + accum_fwd, accum_bwd, p_fwd_res, p_bwd_res = jax.lax.fori_loop( + 1, num_partitions, step, (accum_fwd, accum_bwd, p_fwd_res, p_bwd_res) + ) + + return jnp.concatenate( + (p_fwd_res, p_bwd_res), scatter_dimension + ) + jnp.concatenate((accum_fwd, accum_bwd), scatter_dimension) + + +def all_reduce(operand, axis_names): + """all-reduce sum operation""" + return jax.lax.psum(operand, axis_name=axis_names) + + +def all_gather(operand, axis, axis_names): + """all-gather operation""" + return jax.lax.all_gather( + operand, axis=axis, axis_name=axis_names, tiled=True + ) diff --git a/experimental/jax/inference/parallel/util.py b/experimental/jax/inference/parallel/util.py new file mode 100644 index 00000000..59c300b7 --- /dev/null +++ b/experimental/jax/inference/parallel/util.py @@ -0,0 +1,32 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Utility module.""" + +import jax +from jax.sharding import PartitionSpec as P + + +def get_partition_spec(sharded_pytree): + def pspec(a): + if isinstance(a, jax.Array): + return a.sharding.spec + elif isinstance(a, int) or isinstance(a, float): + return P() + else: + raise ValueError(f"unknown parition spec for {a}") + + return jax.tree_util.tree_map(pspec, sharded_pytree) diff --git a/experimental/jax/inference/runtime/__init__.py b/experimental/jax/inference/runtime/__init__.py new file mode 100644 index 00000000..e7c0b714 --- /dev/null +++ b/experimental/jax/inference/runtime/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/experimental/jax/inference/runtime/batch_scheduler.py b/experimental/jax/inference/runtime/batch_scheduler.py new file mode 100644 index 00000000..16dd662c --- /dev/null +++ b/experimental/jax/inference/runtime/batch_scheduler.py @@ -0,0 +1,208 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""batch scheduler""" + +import enum +import queue +import dataclasses +from inference.runtime.kv_cache import KVCacheManager +from inference.runtime.request_type import PrefillRequest, GenerateRequest, GenerateState + + +@dataclasses.dataclass +class PrefillPagesUpdate: + page_indices: list[int] # length: chunk_size // page_size + + +@dataclasses.dataclass +class GenerateStatePageUpdate: + slot: int + page_idx: int + mapped_idx: int + + +@dataclasses.dataclass +class Schedule: + schedule_prefill: bool + prefill_request: PrefillRequest + prefill_pages_update: PrefillPagesUpdate + schedule_generate: bool + new_generate_requests: list[GenerateRequest] + generate_state_page_updates: list[GenerateStatePageUpdate] + + +@enum.unique +class SchedulePolicy(enum.Enum): + OFFLINE = enum.auto() + ONLINE = enum.auto() + + +class BatchScheduler: + + def __init__( + self, + kv_cache_manager: KVCacheManager, + max_num_seqs: int, + max_seq_len: int, + schedule_policy: SchedulePolicy = SchedulePolicy.OFFLINE, + ): + self.prefill_queue: queue.Queue[PrefillRequest] = queue.Queue() + self.generate_queue: queue.Queue[GenerateRequest] = queue.Queue() + self.kv_manager = kv_cache_manager + self.max_num_seqs = max_num_seqs + self.max_seq_len = max_seq_len + self.schedule_policy = schedule_policy + + def enqueue_prefill_req(self, req: PrefillRequest): + self.prefill_queue.put(req) + + def enqueue_generate_req(self, req: GenerateRequest): + self.generate_queue.put(req) + + def schedule( + self, + active_prefill: PrefillRequest | None, + generate_state: GenerateState, + ) -> Schedule | None: + """Schedule the workload for next iteration. Only host state is + updated in the schedule function. + """ + avail_slot_size = generate_state.available_slots.qsize() + next_prefill_req = active_prefill + prefill_pages_update = None + next_generate_reqs = [] + generate_state_page_updates = [] + + schedule_prefill = False + schedule_generate = False + + # Schedule new prefill req, if no active prefill request. + if not next_prefill_req: + if avail_slot_size > 0: + try: + next_prefill_req = self.prefill_queue.get_nowait() + if not next_prefill_req: + return None + except queue.Empty: + pass + + if next_prefill_req: + cur_prompt_chunk_len = next_prefill_req.chunk_size + total_len = len(next_prefill_req.unpadded_token_ids) + if ( + total_len + <= (next_prefill_req.chunk_idx + 1) * next_prefill_req.chunk_size + ): + cur_prompt_chunk_len = ( + total_len - next_prefill_req.chunk_idx * next_prefill_req.chunk_size + ) + alloced_pages = self.kv_manager.alloc_prefill_hbm_pages( + cur_prompt_chunk_len + ) + if len(alloced_pages) == 0: + # TODO: introduce priority for the request and better + # eviction algorithm. + raise NotImplementedError("Eviction is not supported yet") + else: + start_idx = ( + next_prefill_req.chunk_idx * next_prefill_req.chunk_size + ) // self.kv_manager.page_size + for i, page in enumerate(alloced_pages): + next_prefill_req.page_indices[start_idx + i] = page + prefill_pages_update = PrefillPagesUpdate(alloced_pages) + + # Schedule new generate reqs and allocate memory for all reqs. + with generate_state.map_mutex: + if ( + self.schedule_policy == SchedulePolicy.ONLINE + or not next_prefill_req + or ( + len(generate_state.active_slot_req_map) + + self.generate_queue.qsize() + > 0.95 * self.max_num_seqs + ) + ): + # Add new generate request to the slots. + while ( + generate_state.available_slots.qsize() > 0 + and self.generate_queue.qsize() > 0 + ): + gr = self.generate_queue.get_nowait() + if not gr: + return None + slot = generate_state.available_slots.get_nowait() + gr.slot = slot + generate_state.active_slot_req_map[slot] = gr + next_generate_reqs.append(gr) + + # Check and alloc memory for generate. + alloced_pages = self.kv_manager.alloc_hbm_pages( + len(generate_state.active_slot_req_map) + ) + if ( + len(generate_state.active_slot_req_map) != 0 + and len(alloced_pages) == 0 + ): + raise NotImplementedError( + "Eviction isn't supported yet, please set a lower value for max_num_seqs" + ) + + page_to_use = 0 + for slot, req in generate_state.active_slot_req_map.items(): + idx = req.pos // self.kv_manager.page_size + if req.pos % self.kv_manager.page_size != 0: + continue + if idx >= len(req.page_indices): + continue + + req.page_indices[idx] = alloced_pages[page_to_use] + generate_state_page_updates.append( + GenerateStatePageUpdate( + slot=slot, + page_idx=idx, + mapped_idx=alloced_pages[page_to_use], + ) + ) + page_to_use += 1 + + self.kv_manager.free_hbm_pages(alloced_pages[page_to_use:]) + + if len(generate_state.active_slot_req_map) == 0: + schedule_generate = False + else: + schedule_generate = True + + if next_prefill_req: + schedule_prefill = True + else: + schedule_prefill = False + + if not schedule_prefill and not schedule_generate: + # Nothing got scheduled, busy waiting for either prefill + # or generate queue to have pending request. + while True: + if self.prefill_queue.qsize() > 0 or self.generate_queue.qsize() > 0: + return self.schedule(active_prefill, generate_state) + + return Schedule( + schedule_prefill=schedule_prefill, + prefill_request=next_prefill_req, + prefill_pages_update=prefill_pages_update, + schedule_generate=schedule_generate, + new_generate_requests=next_generate_reqs, + generate_state_page_updates=generate_state_page_updates, + ) diff --git a/experimental/jax/inference/runtime/engine.py b/experimental/jax/inference/runtime/engine.py new file mode 100644 index 00000000..7e902879 --- /dev/null +++ b/experimental/jax/inference/runtime/engine.py @@ -0,0 +1,553 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""engine module""" + +import enum +import dataclasses +import datetime +import queue +import threading +from typing import Any +import uuid +import jax +import jax.profiler +from jax import numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np +from inference.model import ModelSource, ModelRegistry, SamplingParams +from inference import nn +from inference import parallel +from inference.runtime.kv_cache import * +from inference.runtime.request_type import * +from inference.runtime.kv_cache import KVCacheStorage, KVCacheManager +from inference.runtime.batch_scheduler import BatchScheduler, SchedulePolicy +from inference.runtime.model_executor import Executor + + +@dataclasses.dataclass +class ModelLoadParams: + model_id: str + tokenizer_path: str | None = None + weights_path: str | None = None + source: ModelSource = ModelSource.HUGGINGFACE + hf_model_config: Any | None = None + dummy_weights: bool = False + + +@dataclasses.dataclass +class InferenceParams: + max_num_seqs: int = 160 + max_seq_length: int = 2048 + max_input_length: int = 1024 + prefill_chunk_sizes: list[int] = dataclasses.field( + default_factory=lambda: [128, 256, 512, 1024] + ) + # prefill_chunk_sizes: list[int] = dataclasses.field(default_factory=lambda: [256]) + page_size: int = 128 + hbm_utilization: float = 0.8 + + +@enum.unique +class EngineMode(enum.Enum): + OFFLINE = enum.auto() + ONLINE = enum.auto() + + +@dataclasses.dataclass +class OfflineChannel: + req_queue: queue.Queue[OfflineRequest] + res_queue: queue.Queue[Response] + + +@dataclasses.dataclass +class OnlineChannel: + req_queue: asyncio.Queue[OnlineRequest] + aio_loop: asyncio.AbstractEventLoop + + +class Engine: + + def __init__( + self, + mesh: Mesh, + model_load_params: ModelLoadParams, + inference_params: InferenceParams, + mode: EngineMode, + channel: OfflineChannel | OnlineChannel, + debug_mode: bool = False, + ): + """Engine is a wrapper of the model for inference""" + print("Initializing engine") + self.mesh = mesh + self.inference_params = inference_params + model_registry = ModelRegistry() + self.tokenizer = model_registry.load_tokenizer( + model_id=model_load_params.model_id, + path=model_load_params.tokenizer_path, + ) + + print("Loading model config") + model_config = model_registry.load_model_config(model_load_params.model_id) + if debug_mode: + model_config.num_hidden_layers = 1 + + model_cls = model_registry.model_cls(model_load_params.model_id) + self.model: nn.CausalLM = model_cls( + model_config, + parallel.ModelParallelConfig(mesh=self.mesh), + self.tokenizer.eos_token_id, + self.inference_params.max_seq_length, + ) + + if model_load_params.dummy_weights: + print("Initializing random params") + self.weights_dict = self.model.init_weights() + else: + print("Loading model weights") + weights_on_host = model_registry.load_weights_to_host( + model_load_params.model_id, + self.mesh.devices.size, + model_config, + model_load_params.weights_path, + model_load_params.source, + ) + print("Loading model weights to devices") + self.weights_dict = self.model.load_weights_dict(weights_on_host) + + print("Initializing KV Cache storage") + # init kv cache + self.kv_storage = KVCacheStorage( + mesh=self.mesh, + model_config=model_config, + page_size=inference_params.page_size, + hbm_utilization=inference_params.hbm_utilization, + ) + num_hbm_pages = self.kv_storage.num_hbm_pages + self.kv_manager = KVCacheManager( + num_hbm_pages, self.inference_params.page_size + ) + + self.mode = mode + self.channel = channel + + if self.mode == EngineMode.OFFLINE: + mode = SchedulePolicy.OFFLINE + if self.mode == EngineMode.ONLINE: + mode = SchedulePolicy.ONLINE + + self.scheduler = BatchScheduler( + self.kv_manager, + self.inference_params.max_num_seqs, + self.inference_params.max_seq_length, + mode, + ) + + print("Initializing GenerateState") + self.active_prefill_request = None + slots = queue.SimpleQueue() + for i in range(self.inference_params.max_num_seqs): + slots.put(i) + self.generate_state = GenerateState( + token_ids=jnp.zeros( + shape=(self.inference_params.max_num_seqs,), + dtype=jnp.int32, + device=NamedSharding(self.mesh, P(None)), + ), + positions=jnp.full( + shape=(self.inference_params.max_num_seqs,), + fill_value=-1, + dtype=jnp.int32, + device=NamedSharding(self.mesh, P(None)), + ), + page_table=jnp.full( + shape=( + self.inference_params.max_num_seqs, + self.inference_params.max_seq_length + // self.inference_params.page_size, + ), + fill_value=self.kv_manager.dummy_page_idx, + dtype=jnp.int32, + device=NamedSharding(self.mesh, P(None, None)), + ), + available_slots=slots, + active_slot_req_map={}, + ) + self.sample_params = SamplingParams( + jax.device_put( + jnp.asarray((1.0), dtype=jnp.float32), NamedSharding(self.mesh, P()) + ), + jax.device_put( + jnp.asarray((1), dtype=jnp.int32), NamedSharding(self.mesh, P()) + ), + jax.device_put(jax.random.key(0), NamedSharding(self.mesh, P())), + ) + + self.num_pages_per_seq = ( + self.inference_params.max_seq_length // self.inference_params.page_size + ) + self.model_executor = Executor( + self.mesh, + self.weights_dict, + self.model.jittable_call, + self.num_pages_per_seq, + debug_mode=debug_mode, + ) + + self.model_executor.compile( + self.inference_params.prefill_chunk_sizes, + self.inference_params.max_num_seqs, + self.inference_params.max_seq_length, + self.inference_params.max_input_length, + self.kv_storage.hbm_kv_caches, + self.sample_params, + ) + + print("Engine compilation finished") + # running loop + self.requests_dict: dict[str, Request] = {} + + if self.mode == EngineMode.OFFLINE: + self._dequeue_offline_req_thread = threading.Thread( + name="dequeue_offline_request", target=self._dequeue_offline_request + ) + else: + self._dequeue_online_req_thread = threading.Thread( + name="_dequeue_online_request", target=self._dequeue_online_request + ) + + # TODO: Assign the max_device_requests_sem number by the + # device spec and cost model. + self._max_device_requests_sem = threading.Semaphore( + self.inference_params.max_num_seqs * 1.5 + ) + self._preprocess_queue: queue.Queue[Request] = queue.Queue() + # TODO: Seperate the running loop with the static inference model. + self._preprocess_thread = threading.Thread( + name="preprocess", target=self._preprocess + ) + # Add backpressure to prevent that the inference thread never releases + # the GIL and keeps dispatching the device program. + self._postprocess_queue: queue.Queue[PostProcessRequest] = queue.Queue(8) + self._postprocess_thread = threading.Thread( + name="postprocess", target=self._postprocess + ) + + self._inference_thread = threading.Thread( + name="inference", target=self._inference + ) + self.total_reqs = 0 + self.complete_reqs = 0 + + def start(self): + jax.profiler.start_server(9999) + + if self.mode == EngineMode.OFFLINE: + self._dequeue_offline_req_thread.start() + else: + self._dequeue_online_req_thread.start() + + self._preprocess_thread.start() + self._postprocess_thread.start() + self._inference_thread.start() + + print("Engine starts: ", datetime.datetime.now()) + + def stop(self): + jax.profiler.stop_server() + # Stop listen to the queue when item is None. + self.channel.req_queue.put(None) + self._preprocess_queue.put(None) + self.scheduler.enqueue_prefill_req(None) + self.scheduler.enqueue_generate_req(None) + self._postprocess_queue.put(None) + + if self.mode == EngineMode.OFFLINE: + self._dequeue_offline_req_thread.join() + else: + self._dequeue_online_req_thread.join() + + self._preprocess_thread.join() + self._inference_thread.join() + self._postprocess_thread.join() + + print("Engine stops: ", datetime.datetime.now()) + + def _dequeue_online_request(self): + while True: + online_req: OnlineRequest = self.channel.req_queue.get() + if not online_req: + return + + req = Request( + id=uuid.uuid4().hex, + prompt=online_req.prompt, + aio_response_queue=online_req.res_queue, + ) + + self._preprocess_queue.put(req) + self.requests_dict[req.id] = req + + def _dequeue_offline_request(self): + while True: + offline_req: OfflineRequest = self.channel.req_queue.get() + if not offline_req: + return + + req = Request( + id=uuid.uuid4().hex, + prompt=offline_req.prompt, + ) + + self._preprocess_queue.put(req) + self.requests_dict[req.id] = req + + def _preprocess(self) -> jax.Array: + while True: + req: Request | None = self._preprocess_queue.get() + if not req: + return + + token_id_list = self.tokenizer.encode(req.prompt) + req.prompt_token_ids = token_id_list + + # Don't put too many pending requests + # to the HBM. + self._max_device_requests_sem.acquire() + + tokens = np.asarray(token_id_list) + token_len = tokens.size + num_paddings = self.inference_params.max_input_length - token_len + if num_paddings < 0: + padded_tokens = tokens[-self.inference_params.max_input_length :] + req.prompt_token_ids = token_id_list[ + -self.inference_params.max_input_length : + ] + else: + padded_tokens = np.pad( + tokens, (0, self.inference_params.max_input_length - token_len) + ) + padded_tokens = jax.device_put( + padded_tokens, NamedSharding(self.mesh, P(None)) + ) + + positions = jax.device_put( + np.arange(0, padded_tokens.shape[0]), + NamedSharding(self.mesh, P(None)), + ) + + dummy_page_indices = [ + self.kv_manager.dummy_page_idx for _ in range(self.num_pages_per_seq) + ] + + # Select chunk size. + # TODO: move it to a function. + chunk_size_idx = 0 + chunk_sizes = self.inference_params.prefill_chunk_sizes + while ( + chunk_size_idx < len(chunk_sizes) + and token_len > chunk_sizes[chunk_size_idx] + ): + chunk_size_idx += 1 + chunk_size_idx = ( + chunk_size_idx - 1 + if chunk_size_idx == len(chunk_sizes) + else chunk_size_idx + ) + + self.scheduler.enqueue_prefill_req( + PrefillRequest( + id=req.id, + unpadded_token_ids=req.prompt_token_ids, + page_indices=dummy_page_indices, + chunk_idx=0, + chunk_size=self.inference_params.prefill_chunk_sizes[ + chunk_size_idx + ], + device_token_ids=padded_tokens, + device_positions=positions, + ) + ) + + def _inference(self) -> jax.Array: + while True: + schedule = self.scheduler.schedule( + self.active_prefill_request, self.generate_state + ) + if not schedule: + return + + input = self.model_executor.prepare_input_and_update_generate_state( + schedule, + self.generate_state, + self.kv_storage.hbm_kv_caches, + self.sample_params, + self.inference_params.max_num_seqs, + ) + + output, self.kv_storage.hbm_kv_caches = self.model_executor.execute(input) + + # Prepare for next iteration and post-processed request. + post_req = PostProcessRequest( + prefill_request_id=None, + prefill_token_id=output.prefill_token, + prefill_done=output.prefill_done, + generate_active_slots=[], + generate_active_request_ids=[], + generate_token_ids=output.generate_tokens, + generate_done=output.generate_done, + ) + + if schedule.schedule_prefill: + prefill_req = schedule.prefill_request + prefill_req.chunk_idx += 1 + start_idx = prefill_req.chunk_idx * prefill_req.chunk_size + prefill_length = len(prefill_req.unpadded_token_ids) + + if start_idx < prefill_length: + self.active_prefill_request = prefill_req + else: + self.active_prefill_request = None + post_req.prefill_request_id = schedule.prefill_request.id + + generate_req = GenerateRequest( + id=prefill_req.id, + slot=-1, + pos=prefill_length, + page_indices=prefill_req.page_indices, + device_prefill_token_id=output.prefill_token, + ) + self.scheduler.enqueue_generate_req(generate_req) + + if schedule.schedule_generate: + self.generate_state.token_ids = output.generate_tokens + self.generate_state.positions = output.generate_next_pos + + with self.generate_state.map_mutex: + for ( + slot, + processed_gr, + ) in self.generate_state.active_slot_req_map.items(): + processed_gr.pos += 1 + post_req.generate_active_slots.append(slot) + post_req.generate_active_request_ids.append(processed_gr.id) + + self._postprocess_queue.put(post_req) + + def _postprocess(self) -> str: + while True: + p_req = self._postprocess_queue.get() + if not p_req: + return + + p_req.prefill_token_id = np.asarray(p_req.prefill_token_id).item() + p_req.prefill_done = np.asarray(p_req.prefill_done).item() + p_req.generate_token_ids = np.asarray(p_req.generate_token_ids).tolist() + p_req.generate_done = np.asarray(p_req.generate_done).tolist() + + # Free finished slot. + if len(p_req.generate_active_request_ids) > 0: + with self.generate_state.map_mutex: + slot_to_del = [] + for slot in self.generate_state.active_slot_req_map.keys(): + if p_req.generate_done[slot]: + self.generate_state.available_slots.put(slot) + pages_to_free = self.generate_state.active_slot_req_map[ + slot + ].page_indices + self.kv_manager.free_hbm_pages(pages_to_free) + slot_to_del.append(slot) + for slot in slot_to_del: + del self.generate_state.active_slot_req_map[slot] + + # Return generated tokens to the client. + if p_req.prefill_request_id: + req = self.requests_dict[p_req.prefill_request_id] + req.generated_token_ids.append(p_req.prefill_token_id) + generated_text = self.tokenizer._convert_id_to_token( + p_req.prefill_token_id + ).replace("▁", " ") + req.generated_text += generated_text + + if self.mode == EngineMode.ONLINE: + self.channel.aio_loop.call_soon_threadsafe( + req.aio_response_queue.put_nowait, + Response(generated_text, p_req.prefill_token_id), + ) + + if p_req.prefill_done: + req.completed = True + if self.mode == EngineMode.ONLINE: + self.channel.aio_loop.call_soon_threadsafe( + req.aio_response_queue.put_nowait, + Response(generated_text, p_req.prefill_token_id), + ) + + else: + self.channel.res_queue.put_nowait( + Response( + generated_text=req.generated_text, + generated_tokens=req.generated_token_ids, + input_tokens=req.prompt_token_ids, + ) + ) + + del self.requests_dict[p_req.prefill_request_id] + self._max_device_requests_sem.release() + + for slot, req_id in zip( + p_req.generate_active_slots, p_req.generate_active_request_ids + ): + if req_id not in self.requests_dict: + continue + req = self.requests_dict[req_id] + + req.generated_token_ids.append(p_req.generate_token_ids[slot]) + generated_text = self.tokenizer._convert_id_to_token( + p_req.generate_token_ids[slot] + ).replace("▁", " ") + req.generated_text += generated_text + + if self.mode == EngineMode.ONLINE: + self.channel.aio_loop.call_soon_threadsafe( + req.aio_response_queue.put_nowait, + Response( + generated_text=generated_text, + generated_tokens=p_req.generate_token_ids[slot], + ), + ) + + if p_req.generate_done[slot]: + req.completed = True + if self.mode == EngineMode.ONLINE: + self.channel.aio_loop.call_soon_threadsafe( + req.aio_response_queue.put_nowait, + None, + ) + else: + self.channel.res_queue.put_nowait( + Response( + generated_text=req.generated_text, + generated_tokens=req.generated_token_ids, + input_tokens=req.prompt_token_ids, + ) + ) + + self._max_device_requests_sem.release() + del self.requests_dict[req_id] + + def handle_request(self, request: OfflineRequest | OnlineRequest): + self.channel.req_queue.put(request) diff --git a/experimental/jax/inference/runtime/kv_cache.py b/experimental/jax/inference/runtime/kv_cache.py new file mode 100644 index 00000000..710e8eb4 --- /dev/null +++ b/experimental/jax/inference/runtime/kv_cache.py @@ -0,0 +1,150 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""kv cache module""" + +import math +import jax +from jax import numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import queue +from inference.nn import KVCache + + +class KVCacheStorage: + + def __init__( + self, + mesh: Mesh, + model_config, + page_size: int = 32, + hbm_utilization: float = 0.8, + ): + self.mesh = mesh + self.num_devices = mesh.devices.size + self.num_layers = model_config.num_hidden_layers + self.num_kv_heads = model_config.num_key_value_heads + self.head_dim = getattr( + model_config, + "head_dim", + model_config.hidden_size // model_config.num_attention_heads, + ) + self.page_size = page_size + + self.hbm_kv_caches: list[KVCache] = self.init_hbm_storage(hbm_utilization) + self.num_hbm_pages = self.hbm_kv_caches[0].k.shape[1] + + def init_hbm_storage( + self, + hbm_utilization: float = 0.8, + cache_dtype: jnp.dtype = jnp.bfloat16, + ): + memory_stats = jax.devices()[0].memory_stats() + if memory_stats: + print("per device memory_stats: ", memory_stats) + available_hbm_bytes = ( + memory_stats["bytes_reservable_limit"] * hbm_utilization + - memory_stats["bytes_in_use"] + ) + item_size = jnp.ones((1), dtype=cache_dtype).itemsize + kv_size_per_page = ( + item_size + * self.num_kv_heads + // self.num_devices + * self.head_dim + * self.page_size + * 2 + ) + num_pages_per_layer = int( + available_hbm_bytes // kv_size_per_page // self.num_layers + ) + else: + print("memory_stats not available, allocate 128 pages.") + num_pages_per_layer = 128 + + # TODO: support 2d sharding. + kv_storage = [ + KVCache( + k=jnp.ones( + shape=( + self.num_kv_heads, + num_pages_per_layer, + self.page_size, + self.head_dim, + ), + device=NamedSharding( + self.mesh, P(self.mesh.axis_names, None, None, None) + ), + dtype=cache_dtype, + ), + v=jnp.ones( + shape=( + self.num_kv_heads, + num_pages_per_layer, + self.page_size, + self.head_dim, + ), + device=NamedSharding( + self.mesh, P(self.mesh.axis_names, None, None, None) + ), + dtype=cache_dtype, + ), + ) + for _ in range(self.num_layers) + ] + + return kv_storage + + +class KVCacheManager: + """Logical KV Cache Manager""" + + def __init__( + self, + num_hbm_pages, + page_size, + ): + self.available_hbm_pages = queue.SimpleQueue() + self.page_size = page_size + self.dummy_page_idx = 0 + + for i in range(1, num_hbm_pages): + self.available_hbm_pages.put_nowait(i) + + @property + def num_available_hbm_pages(self): + return self.available_hbm_pages.qsize() + + def alloc_prefill_hbm_pages(self, prompt_len) -> list[int]: + num_pages = math.ceil(prompt_len / self.page_size) + if num_pages > self.num_available_hbm_pages: + return [] + else: + return self.alloc_hbm_pages(num_pages) + + def alloc_hbm_pages(self, num_pages: int) -> list[int]: + pages_to_use = [] + if num_pages > self.num_available_hbm_pages: + return pages_to_use + + for _ in range(num_pages): + pages_to_use.append(self.available_hbm_pages.get()) + return pages_to_use + + def free_hbm_pages(self, pages: list[int]): + for p in pages: + if p != self.dummy_page_idx: + self.available_hbm_pages.put_nowait(p) diff --git a/experimental/jax/inference/runtime/model_executor.py b/experimental/jax/inference/runtime/model_executor.py new file mode 100644 index 00000000..e1bd3b03 --- /dev/null +++ b/experimental/jax/inference/runtime/model_executor.py @@ -0,0 +1,530 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import datetime +import dataclasses +import jax +from jax import numpy as jnp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from collections.abc import Callable +from inference.nn import AttentionMetadata, KVCache +from inference.model import ModelOutput, SamplingParams +from inference import parallel +from inference.runtime.batch_scheduler import Schedule, PrefillPagesUpdate +from inference.runtime.request_type import GenerateState, PrefillRequest, GenerateRequest + +ModelForwardFunc = Callable[ + [ + dict, + jax.Array, + jax.Array, + list[KVCache], + AttentionMetadata, + SamplingParams, + ], + tuple[ModelOutput, list[KVCache]], +] + +Executable = ModelForwardFunc + + +@dataclasses.dataclass +class ModelForwardInput: + input_ids: jax.Array + positions: jax.Array + kv_caches: list[KVCache] + attn_metadata: AttentionMetadata + sampling_params: SamplingParams + + +class Executor: + + def __init__( + self, + mesh: Mesh, + weights_dict: dict, + model_forward: ModelForwardFunc, + num_pages_per_seq: int, + cache_dir: str | None = "/tmp/jax_cache", + debug_mode: bool = False, + ): + self.mesh = mesh + self.weights_dict = weights_dict + self.executables_dict: dict[str:Executable] = {} + self._model_forward = model_forward + self.num_pages_per_seq = num_pages_per_seq + + # TODO: Understand why the following doesn't work. + # Currently, cache is saved by "export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache". + if cache_dir: + from jax.experimental.compilation_cache import compilation_cache as cc + + cc.set_cache_dir(cache_dir) + + self.dummy_scalar = jax.device_put( + jnp.asarray(1e7, dtype=jnp.int32), NamedSharding(self.mesh, P()) + ) + self.jitted_prepare_model_input_func = jax.jit( + self.prepare_model_input, static_argnames=("chunk_size",) + ) + self.debug_mode = debug_mode + + def compile( + self, + prefill_chunk_sizes: list[int], + max_num_seqs: int, + max_seq_len: int, + max_input_len: int, + kv_caches: list[KVCache], + sampling_params: SamplingParams, + compiler_options: dict[str, jax.stages.CompilerOptions] | None = None, + ): + if self.debug_mode: + print("No compilation under debug mode") + + page_size = kv_caches[0].k.shape[2] + + # Prefill only compile. + for size in prefill_chunk_sizes: + dummy_padded_tensor = jnp.ones( + (max_input_len,), + dtype=jnp.int32, + device=NamedSharding(self.mesh, P(None)), + ) + dummy_page_indices_tensor = jnp.ones( + (max_seq_len // page_size,), + dtype=jnp.int32, + device=NamedSharding(self.mesh, P(None)), + ) + dummy_page_indices = np.asarray(dummy_page_indices_tensor).tolist() + dummy_page_update_indices = [0 for _ in range(size // page_size)] + + dummy_schedule = Schedule( + schedule_prefill=True, + prefill_request=PrefillRequest( + id="0", + unpadded_token_ids=[1], + page_indices=dummy_page_indices, + chunk_idx=0, + chunk_size=size, + device_token_ids=dummy_padded_tensor, + device_positions=dummy_padded_tensor, + ), + prefill_pages_update=PrefillPagesUpdate( + page_indices=dummy_page_update_indices + ), + schedule_generate=False, + new_generate_requests=[], + generate_state_page_updates=[], + ) + + input = self.prepare_input_and_update_generate_state( + schedule=dummy_schedule, + generate_state=None, + kv_caches=kv_caches, + sampling_params=sampling_params, + max_num_seqs=max_num_seqs, + ) + + key = self.executable_key(input.attn_metadata) + options = None + if compiler_options and key in compiler_options: + options = compiler_options[key] + print(f"Compiling for {key}") + self.compile_once(key, input, options) + + # Generate only compile. + dummy_batch_tensor = jnp.ones( + (max_num_seqs), + dtype=jnp.int32, + device=NamedSharding(self.mesh, P(None)), + ) + dummy_page_table_tensor = jnp.ones( + (max_num_seqs, max_seq_len // page_size), + dtype=jnp.int32, + device=NamedSharding(self.mesh, P(None, None)), + ) + dummy_schedule = Schedule( + schedule_prefill=False, + prefill_request=None, + prefill_pages_update=None, + schedule_generate=True, + new_generate_requests=[], + generate_state_page_updates=[], + ) + dummy_generate_state = GenerateState( + token_ids=dummy_batch_tensor, + positions=dummy_batch_tensor, + page_table=dummy_page_table_tensor, + available_slots=0, + active_slot_req_map={}, + ) + input = self.prepare_input_and_update_generate_state( + schedule=dummy_schedule, + generate_state=dummy_generate_state, + kv_caches=kv_caches, + sampling_params=sampling_params, + max_num_seqs=max_num_seqs, + ) + + key = self.executable_key(input.attn_metadata) + options = None + if compiler_options and key in compiler_options: + options = compiler_options[key] + print(f"Compiling for {key}") + self.compile_once(key, input, options) + + # Mixed compile. + for prefill_chunk_size in prefill_chunk_sizes: + dummy_padded_prompt_tensor = jnp.ones( + (max_input_len,), + dtype=jnp.int32, + device=NamedSharding(self.mesh, P(None)), + ) + dummy_prefill_page_indices = np.ones( + (max_seq_len // page_size,), + dtype=np.int32, + ).tolist() + dummy_prefill_page_update_indices = [ + 0 for _ in range(prefill_chunk_size // page_size) + ] + + dummy_schedule = Schedule( + schedule_prefill=True, + prefill_request=PrefillRequest( + id="0", + unpadded_token_ids=[1], + page_indices=dummy_prefill_page_indices, + chunk_idx=0, + chunk_size=prefill_chunk_size, + device_token_ids=dummy_padded_prompt_tensor, + device_positions=dummy_padded_prompt_tensor, + ), + prefill_pages_update=PrefillPagesUpdate( + page_indices=dummy_prefill_page_update_indices + ), + schedule_generate=True, + new_generate_requests=[], + generate_state_page_updates=[], + ) + dummy_generate_state = GenerateState( + token_ids=dummy_batch_tensor, + positions=dummy_batch_tensor, + page_table=dummy_page_table_tensor, + available_slots=0, + active_slot_req_map={}, + ) + input = self.prepare_input_and_update_generate_state( + schedule=dummy_schedule, + generate_state=dummy_generate_state, + kv_caches=kv_caches, + sampling_params=sampling_params, + max_num_seqs=max_num_seqs, + ) + + key = self.executable_key(input.attn_metadata) + options = None + if compiler_options and key in compiler_options: + options = compiler_options[key] + print(f"Compiling for {key}") + self.compile_once(key, input, options) + + def compile_once( + self, + key: str, + input: ModelForwardInput, + options: jax.stages.CompilerOptions, + ): + start_time = datetime.datetime.now() + jitted_func = self.jitted_model_forward_func(input) + + self.executables_dict[key] = jitted_func.lower( + self.weights_dict, + input.input_ids, + input.positions, + input.kv_caches, + input.attn_metadata, + input.sampling_params, + ).compile(options) + + end_time = datetime.datetime.now() + print( + f"Compilation for {key} completed, take {(end_time-start_time).total_seconds()} seconds" + ) + + def execute( + self, input: ModelForwardInput + ) -> tuple[ModelOutput, list[KVCache]]: + key = self.executable_key(input.attn_metadata) + if self.debug_mode: + return self.model_forward(input) + + if key not in self.executables_dict: + print( + ( + "Warning: the cache is missing, " + f"for {jax.tree.map(lambda a: a.shape, input.attn_metadata)}, compile and execute" + ) + ) + self.compile_once(key, input, None) + + executable = self.executables_dict[key] + + return executable( + self.weights_dict, + input.input_ids, + input.positions, + input.kv_caches, + input.attn_metadata, + input.sampling_params, + ) + + def executable_key(self, attn_meta: AttentionMetadata) -> str: + prefill_chunk_size = 0 + generate_batch_size = 0 + if len(attn_meta.prefill_pos.shape) > 0: + prefill_chunk_size = attn_meta.prefill_pos.shape[0] + if len(attn_meta.generate_pos.shape) > 0: + generate_batch_size = attn_meta.generate_pos.shape[0] + return f"prefill_chunk_size={prefill_chunk_size}, generate_batch_size={generate_batch_size}" + + def model_forward(self, input: ModelForwardInput): + return self.shard_mapped_model_forward_func(input)( + self.weights_dict, + input.input_ids, + input.positions, + input.kv_caches, + input.attn_metadata, + input.sampling_params, + ) + + def jitted_model_forward_func(self, input: ModelForwardInput): + return jax.jit( + self.shard_mapped_model_forward_func(input), donate_argnums=(3,) + ) + + def shard_mapped_model_forward_func(self, input: ModelForwardInput): + return shard_map( + f=self._model_forward, + mesh=self.mesh, + in_specs=( + parallel.get_partition_spec(self.weights_dict), + P(None), + P(None), + parallel.get_partition_spec(input.kv_caches), + parallel.get_partition_spec(input.attn_metadata), + parallel.get_partition_spec(input.sampling_params), + ), + out_specs=( + ModelOutput( + prefill_token=P(), + prefill_done=P(), + prefill_next_pos=P(), + generate_tokens=P(None), + generate_done=P(None), + generate_next_pos=P(None), + ), + parallel.get_partition_spec(input.kv_caches), + ), + check_rep=False, + ) + + def prepare_input_and_update_generate_state( + self, + schedule: Schedule, + generate_state: GenerateState, + kv_caches: list[KVCache], + sampling_params: SamplingParams, + max_num_seqs: int, + ) -> ModelForwardInput: + attn_meta = AttentionMetadata( + prefill_length=self.dummy_scalar, + prefill_pos=self.dummy_scalar, + prefill_page_table=self.dummy_scalar, + generate_pos=self.dummy_scalar, + generate_page_table=self.dummy_scalar, + ) + + prefill_tpp = (self.dummy_scalar, self.dummy_scalar, self.dummy_scalar) + prefill_cur_length = self.dummy_scalar + chunk_id = self.dummy_scalar + chunk_size = 512 + + generate_tpp = (self.dummy_scalar, self.dummy_scalar, self.dummy_scalar) + update_generate_tpp = ( + self.dummy_scalar, + self.dummy_scalar, + self.dummy_scalar, + ) + insert_slots = self.dummy_scalar + generate_page_updates = self.dummy_scalar + generate_pt_update_slots = self.dummy_scalar + generate_pt_update_page_idxs = self.dummy_scalar + + if schedule.schedule_prefill: + pr = schedule.prefill_request + prefill_tpp = ( + pr.device_token_ids, + pr.device_positions, + np.array(pr.page_indices), + ) + prefill_cur_length = (pr.chunk_idx + 1) * pr.chunk_size + prefill_total_len = len(pr.unpadded_token_ids) + if prefill_cur_length > prefill_total_len: + prefill_cur_length = prefill_total_len + + chunk_id = pr.chunk_idx + chunk_size = pr.chunk_size + + if schedule.schedule_generate: + generate_tpp = ( + generate_state.token_ids, + generate_state.positions, + generate_state.page_table, + ) + update_token_ids = [] + update_pos = np.full((max_num_seqs,), 1e6, dtype=np.int32) + update_page_indices = np.full( + (max_num_seqs, self.num_pages_per_seq), 1e6, dtype=np.int32 + ) + slots = np.full((max_num_seqs,), 1e6, dtype=np.int32) + + for i, gr in enumerate(schedule.new_generate_requests): + update_token_ids.append(gr.device_prefill_token_id) + update_pos[i] = gr.pos + update_page_indices[i] = np.array(gr.page_indices) + slots[i] = gr.slot + + for i in range(max_num_seqs - len(schedule.new_generate_requests)): + update_token_ids.append(self.dummy_scalar) + + update_generate_tpp = (update_token_ids, update_pos, update_page_indices) + insert_slots = slots + + # Handle page indices update. + page_update_slots = np.full((max_num_seqs,), 1e6, dtype=np.int32) + page_update_page_idxs = np.full((max_num_seqs,), 1e6, dtype=np.int32) + page_update_mapped_idxs = np.full((max_num_seqs,), 1e6, dtype=np.int32) + + for i, update in enumerate(schedule.generate_state_page_updates): + page_update_slots[i] = update.slot + page_update_page_idxs[i] = update.page_idx + page_update_mapped_idxs[i] = update.mapped_idx + + generate_page_updates = page_update_mapped_idxs + generate_pt_update_slots = page_update_slots + generate_pt_update_page_idxs = page_update_page_idxs + + input_ids, generate_tokens, positions, attn_meta = ( + self.jitted_prepare_model_input_func( + attn_meta, + prefill_tpp, + prefill_cur_length, + chunk_id, + chunk_size, + generate_tpp, + update_generate_tpp, + insert_slots, + generate_page_updates, + generate_pt_update_slots, + generate_pt_update_page_idxs, + ) + ) + _, new_key = jax.random.split(sampling_params.rng) + sampling_params.rng = new_key + + if schedule.schedule_generate: + generate_state.token_ids = generate_tokens + generate_state.positions = attn_meta.generate_pos + generate_state.page_table = attn_meta.generate_page_table + + return ModelForwardInput( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_meta, + sampling_params=sampling_params, + ) + + def prepare_model_input( + self, + attn_meta: AttentionMetadata, + prefill_tpp: tuple[jax.Array, jax.Array, jax.Array], + prefill_length: jax.Array, + chunk_id: jax.Array, + chunk_size: jax.Array, + generate_tpp: tuple[list[jax.Array], jax.Array, jax.Array], + update_generate_tpp: tuple[jax.Array, jax.Array, jax.Array], + insert_slots: jax.Array, + generate_page_updates: jax.Array, + generate_pt_update_slots: jax.Array, + generate_pt_update_page_idxs: jax.Array, + ): + p_tokens, p_positions, p_page_indices = prefill_tpp + if len(p_tokens.shape) > 0: + idx = chunk_id * chunk_size + p_tokens = jax.lax.dynamic_slice_in_dim(p_tokens, idx, chunk_size) + p_positions = jax.lax.dynamic_slice_in_dim(p_positions, idx, chunk_size) + + g_tokens, g_positions, g_page_table = generate_tpp + if len(g_tokens.shape) > 0: + update_g_tokens, update_g_positions, update_g_page_table = ( + update_generate_tpp + ) + update_g_tokens = jnp.asarray(update_g_tokens) + + g_tokens = g_tokens.at[insert_slots].set(update_g_tokens) + g_positions = g_positions.at[insert_slots].set(update_g_positions) + # Insert new request to the slot. + g_page_table = g_page_table.at[insert_slots, :].set(update_g_page_table) + # Add the new page for the existing slot. + g_page_table = g_page_table.at[ + generate_pt_update_slots, generate_pt_update_page_idxs + ].set(generate_page_updates) + + if len(p_tokens.shape) > 0 and len(g_tokens.shape) > 0: + input_ids = jnp.concatenate((p_tokens, g_tokens)) + positions = jnp.concatenate((p_positions, g_positions)) + + attn_meta.prefill_length = prefill_length + attn_meta.prefill_pos = p_positions + attn_meta.prefill_page_table = p_page_indices + attn_meta.generate_pos = g_positions + attn_meta.generate_page_table = g_page_table + + elif len(p_tokens.shape) > 0: + input_ids = p_tokens + positions = p_positions + + attn_meta.prefill_length = prefill_length + attn_meta.prefill_pos = p_positions + attn_meta.prefill_page_table = p_page_indices + + elif len(g_tokens.shape) > 0: + input_ids = g_tokens + positions = g_positions + + attn_meta.generate_pos = g_positions + attn_meta.generate_page_table = g_page_table + + else: + raise ValueError( + "Failed to build the input as no prefill or generate gets scheduled" + ) + + return input_ids, g_tokens, positions, attn_meta diff --git a/experimental/jax/inference/runtime/offline_inference.py b/experimental/jax/inference/runtime/offline_inference.py new file mode 100644 index 00000000..fd0fc322 --- /dev/null +++ b/experimental/jax/inference/runtime/offline_inference.py @@ -0,0 +1,152 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""offline driver module""" + + +import datetime +import multiprocessing as mp +import threading +import queue +import math +from typing import Sequence +import jax +from inference import parallel +from inference.runtime.engine import Engine, ModelLoadParams, InferenceParams, EngineMode, OfflineChannel +from inference.runtime.request_type import * + + +class OfflineInference: + + def __init__( + self, + model_id: str = "meta-llama/Llama-2-7b-chat-hf", + num_engines: int = 1, + enable_multiprocessing: bool = False, + ): + self.num_engines = num_engines + self.req_queues: list[mp.Queue | queue.Queue] = [] + self.res_queues: list[mp.Queue | queue.Queue] = [] + self._next_pick_engine_index = 0 + self._running_pool: list[mp.Process | threading.Thread] = [] + self._engine_started_events: list = [] + self._completion_events: list = [] + for i in range(num_engines): + if enable_multiprocessing: + self.req_queue.append(mp.Queue()) + self.res_queue.append(mp.Queue()) + self._engine_started_events.append(mp.Event()) + self._completion_events.append(mp.Event()) + execution = mp.Process( + target=self.launch_engine, + args=( + self.req_queue[i], + self.res_queue[i], + ModelLoadParams(model_id=model_id), + self._engine_started_events[i], + self._completion_events[i], + ), + ) + else: + self.req_queues.append(queue.Queue()) + self.res_queues.append(queue.Queue()) + self._engine_started_events.append(threading.Event()) + self._completion_events.append(threading.Event()) + + execution = threading.Thread( + target=self.launch_engine, + args=( + self.req_queues[i], + self.res_queues[i], + ModelLoadParams(model_id=model_id), + self._engine_started_events[i], + self._completion_events[i], + ), + ) + self._running_pool.append(execution) + + for e in self._running_pool: + e.start() + + for started_event in self._engine_started_events: + while not started_event.is_set(): + started_event.wait() + + def launch_engine( + self, + req_queue, + res_queue, + model_load_params, + started_event, + completion_event, + ): + devices = jax.devices() + mesh = parallel.create_device_mesh( + devices, + (len(devices),), + ) + engine = Engine( + mesh=mesh, + model_load_params=model_load_params, + inference_params=InferenceParams(), + mode=EngineMode.OFFLINE, + channel=OfflineChannel( + req_queue=req_queue, + res_queue=res_queue, + ), + ) + engine.start() + started_event.set() + while not completion_event.is_set(): + completion_event.wait() + engine.stop() + return + + def __call__(self, prompts: Sequence[str]) -> tuple[list[Response], float]: + for i in range(self.num_engines): + while not self._engine_started_events[i].is_set(): + self._engine_started_events[i].wait() + + print( + f"All the engines started: {datetime.datetime.now()}, processing requests..." + ) + + num_reqs_per_engine = math.ceil(len(prompts) / self.num_engines) + res = [] + + for i in range(self.num_engines): + idx = i * num_reqs_per_engine + prompts_slice = prompts[idx : idx + num_reqs_per_engine] + for p in prompts_slice: + self.req_queues[i].put(OfflineRequest(prompt=p)) + + for i in range(self.num_engines): + if i != self.num_engines - 1: + num_reqs = num_reqs_per_engine + else: + num_reqs = len(prompts) - (i * num_reqs_per_engine) + for _ in range(num_reqs): + res.append(self.res_queues[i].get()) + + print("Offline inference ends:", datetime.datetime.now()) + + for i in range(self.num_engines): + self._completion_events[i].set() + + for e in self._running_pool: + e.join() + + return res diff --git a/experimental/jax/inference/runtime/request_type.py b/experimental/jax/inference/runtime/request_type.py new file mode 100644 index 00000000..a0d075c4 --- /dev/null +++ b/experimental/jax/inference/runtime/request_type.py @@ -0,0 +1,108 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Request and response classes""" + +# TODO: Clean up and merge the request to improve the readability. + +import asyncio +from dataclasses import dataclass, field +import numpy as np +import jax +import threading +import queue + + +@dataclass +class Response: + generated_text: str + generated_tokens: list[int] | int + input_tokens: list[int] | None = None + + +@dataclass +class OnlineRequest: + prompt: str + res_queue: asyncio.Queue[Response] + + +@dataclass +class OfflineRequest: + prompt: str + + +@dataclass +class Request: + """Request for holding the input and output information""" + + id: str + prompt: str + prompt_token_ids: list[int] = field(default_factory=lambda: []) + generated_text: str = "" + generated_token_ids: list[int] = field(default_factory=lambda: []) + aio_response_queue: asyncio.Queue[Response] | None = None + completed: bool = False + + +@dataclass +class PrefillRequest: + """class for new request need to be processed in the prefill phase""" + + id: str + unpadded_token_ids: list[int] + page_indices: list[int] + chunk_idx: int + chunk_size: int + + device_token_ids: jax.Array + device_positions: jax.Array + + +@dataclass +class GenerateRequest: + """class for new request need to be processed in the generate phase""" + + id: str + slot: int + pos: int + page_indices: list[int] + device_prefill_token_id: jax.Array + + +@dataclass +class GenerateState: + """generate phase state""" + + token_ids: jax.Array # num_max_seq + positions: jax.Array # num_max_seq + page_table: jax.Array # num_max_seq, num_pages_per_seq + available_slots: queue.SimpleQueue + active_slot_req_map: dict[int, GenerateRequest] + map_mutex: threading.Lock = threading.Lock() + + +@dataclass +class PostProcessRequest: + """Post process request""" + + prefill_request_id: str | None + prefill_token_id: jax.Array | np.ndarray + prefill_done: jax.Array | np.ndarray + + generate_active_slots: list[int] + generate_active_request_ids: list[str] + generate_token_ids: jax.Array | np.ndarray + generate_done: jax.Array | np.ndarray diff --git a/experimental/jax/inference/server/__init__.py b/experimental/jax/inference/server/__init__.py new file mode 100644 index 00000000..e7c0b714 --- /dev/null +++ b/experimental/jax/inference/server/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/experimental/jax/inference/server/simple_server.py b/experimental/jax/inference/server/simple_server.py new file mode 100644 index 00000000..0538ce03 --- /dev/null +++ b/experimental/jax/inference/server/simple_server.py @@ -0,0 +1,87 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Server module""" + +from contextlib import asynccontextmanager +import queue +import uvicorn +from inference import parallel +from pydantic import BaseModel +from fastapi import FastAPI +from fastapi.responses import StreamingResponse +from inference.runtime.engine import * +from inference.runtime.request_type import * + + +@asynccontextmanager +async def lifespan(app: FastAPI): + devices = jax.devices() + mesh = parallel.create_device_mesh( + devices, + (len(devices),), + ) + loop = asyncio.get_running_loop() + app.state.req_queue = queue.Queue() + print("starting engine") + engine = Engine( + mesh=mesh, + model_load_params=ModelLoadParams( + model_id="meta-llama/Llama-2-7b-chat-hf" + ), + inference_params=InferenceParams(), + mode=EngineMode.ONLINE, + channel=OnlineChannel( + req_queue=app.state.req_queue, + aio_loop=loop, + ), + ) + engine.start() + yield + # Clean up the ML models and release the resources + engine.stop() + + +app = FastAPI(lifespan=lifespan) + + +async def streaming_tokens(res_queue): + while True: + res: Response | None = await res_queue.get() + if not res: + return + yield res.generated_text + + +class GenerateRequest(BaseModel): + prompt: str + + +@app.post("/generate") +async def generate(req: GenerateRequest): + res_queue = asyncio.Queue() + app.state.req_queue.put_nowait( + OnlineRequest( + prompt=req.prompt, + res_queue=res_queue, + ) + ) + return StreamingResponse(streaming_tokens(res_queue)) + + +if __name__ == "__main__": + print("start") + uvicorn.run("simple_server:app", host="0.0.0.0", port=8000) diff --git a/experimental/jax/inference/utils/__init__.py b/experimental/jax/inference/utils/__init__.py new file mode 100644 index 00000000..020512ad --- /dev/null +++ b/experimental/jax/inference/utils/__init__.py @@ -0,0 +1,17 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from .pytree_utils import * diff --git a/experimental/jax/inference/utils/pytree_utils.py b/experimental/jax/inference/utils/pytree_utils.py new file mode 100644 index 00000000..64157056 --- /dev/null +++ b/experimental/jax/inference/utils/pytree_utils.py @@ -0,0 +1,33 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from jax.tree_util import register_pytree_node + + +def register_flat_dataclass_as_pytree(cls): + def flatten(v): + fields = cls.__dataclass_fields__ + children = () + aux_data = None + for i in fields: + children += (getattr(v, i),) + return (children, aux_data) + + def unflatten(aux_data, children): + return cls(*children) + + register_pytree_node(cls, flatten, unflatten) + return cls diff --git a/experimental/jax/requirements.txt b/experimental/jax/requirements.txt new file mode 100644 index 00000000..72ba566c --- /dev/null +++ b/experimental/jax/requirements.txt @@ -0,0 +1,12 @@ +--find-links https://download.pytorch.org/whl/torch_stable.html +--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + +absl-py +torch==2.3.0+cpu +torchvision==0.18.0+cpu +jax[tpu]==0.4.33 +huggingface_hub[cli] +transformers +pandas +fastapi +uvicorn diff --git a/experimental/jax/tests/kernel/attention/tpu/test_chunked_prefill_attention.py b/experimental/jax/tests/kernel/attention/tpu/test_chunked_prefill_attention.py new file mode 100644 index 00000000..8b53576a --- /dev/null +++ b/experimental/jax/tests/kernel/attention/tpu/test_chunked_prefill_attention.py @@ -0,0 +1,126 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from absl.testing import absltest +import numpy as np +import math +import jax +from jax import numpy as jnp +from inference import nn +from inference.kernel.attention_ops import vanilla_prefill_gqa +from inference.kernel.attention.tpu.chunked_prefill_attention import chunked_prefill_attention + + +class ChunkedPrefillTest(absltest.TestCase): + + def test(self): + num_attn_heads = 4 + num_kv_heads = 2 + head_dim = 128 + total_page_num = 16 + page_size = 16 + + rng = jax.random.key(0) + + attn_layer = nn.AttentionOps(num_attn_heads, num_kv_heads, head_dim) + k_hbm, v_hbm = ( + jnp.zeros((num_kv_heads, total_page_num, page_size, head_dim)), + jnp.zeros((num_kv_heads, total_page_num, page_size, head_dim)), + ) + + kv_cache = nn.KVCache(k=k_hbm, v=v_hbm) + + prefill_len = 6 * page_size + prefill_non_padding_len = 4 * page_size + 3 + + q = jax.random.uniform(rng, (prefill_len, num_attn_heads, head_dim)) + + rng_1, rng_2 = jax.random.split(rng) + k_to_save, v_to_save = ( + jax.random.uniform(rng_1, (prefill_len, num_kv_heads, head_dim)), + jax.random.uniform(rng_2, (prefill_len, num_kv_heads, head_dim)), + ) + + num_pages_with_padding = math.ceil(prefill_len / page_size) + page_table = jnp.array(([i for i in range(num_pages_with_padding)])) + kv_cache = attn_layer._write_prefill_kv_to_kv_cache( + k_to_save, + v_to_save, + kv_cache, + prefill_non_padding_len, + page_table, + ) + + chunk_size = 2 * page_size + + expected_output = vanilla_prefill_gqa(q, k_to_save, v_to_save) + + num_active_pages_per_prefill = math.ceil( + prefill_non_padding_len / page_size + ) + compute_times = math.ceil(prefill_non_padding_len / chunk_size) + + for i in range(compute_times): + idx = i * chunk_size + length = idx + chunk_size + length = min(length, prefill_non_padding_len) + chunk_output = chunked_prefill_attention( + q[idx : idx + chunk_size], kv_cache.k, kv_cache.v, length, page_table + ) + + if i < compute_times - 1: + np.testing.assert_allclose( + np.array(chunk_output), + np.array(expected_output[idx : idx + chunk_size]), + rtol=4e-03, + atol=1e-03, + ) + if i == compute_times - 1: + offset = prefill_non_padding_len % chunk_size + np.testing.assert_allclose( + np.array(chunk_output[:offset]), + np.array(expected_output[idx : idx + offset]), + rtol=4e-03, + atol=1e-03, + ) + + num_pages_per_prefill = prefill_len // page_size + for i in range(num_active_pages_per_prefill): + idx = i * page_size + np.testing.assert_equal( + np.array(kv_cache.k[:, page_table[i], :, :][:, :, :]), + np.array(k_to_save[idx : idx + page_size, :, :].transpose(1, 0, 2)), + ) + np.testing.assert_equal( + np.array(kv_cache.v[:, page_table[i], :, :][:, :, :]), + np.array(v_to_save[idx : idx + page_size, :, :].transpose(1, 0, 2)), + ) + + for i in range(num_active_pages_per_prefill, num_pages_per_prefill): + idx = i * page_size + np.testing.assert_equal( + np.array(kv_cache.v[:, page_table[i], :, :][:, None, :, :]), + np.array(jnp.zeros((num_kv_heads, 1, page_size, head_dim))), + ) + + np.testing.assert_equal( + np.array(kv_cache.v[:, page_table[i], :, :][:, None, :, :]), + np.array(jnp.zeros((num_kv_heads, 1, page_size, head_dim))), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/experimental/jax/tests/kernel/linear/tpu/test_collective_matmul_impl.py b/experimental/jax/tests/kernel/linear/tpu/test_collective_matmul_impl.py new file mode 100644 index 00000000..f59581ab --- /dev/null +++ b/experimental/jax/tests/kernel/linear/tpu/test_collective_matmul_impl.py @@ -0,0 +1,87 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from absl.testing import absltest +from functools import partial +import numpy as np +import jax.experimental +import jax.experimental.mesh_utils +import jax +from jax import numpy as jnp +from jax.experimental import shard_map +from jax.sharding import NamedSharding, PartitionSpec as P +from inference import kernel +from inference import parallel + + +class CollectiveMatmulTest(absltest.TestCase): + + def test_all_gather_collective_matmul(self): + key1, key2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1) + lhs = jax.random.normal(key1, shape=(1, 32), dtype=jnp.float32) + rhs = jax.random.normal(key2, shape=(32, 16), dtype=jnp.float32) + expect = lhs @ rhs + + mesh = parallel.create_device_mesh(jax.devices(), (2, 4)) + axis_names = mesh.axis_names + rhs = jax.device_put(rhs, NamedSharding(mesh, P(None, axis_names))) + rhs = kernel.prepare_rhs_for_all_gather_collective_matmul(rhs, mesh) + + def agcm(lhs, rhs, type, axis_names): + return kernel.build_collective_matmul(type, axis_names)(lhs, rhs) + + got = shard_map.shard_map( + f=partial( + agcm, + type=parallel.CollectiveMatmulType.ALL_GATHER, + axis_names=axis_names, + ), + mesh=mesh, + in_specs=(P(None, axis_names), P(None, axis_names)), + out_specs=P(None, axis_names), + )(lhs, rhs) + np.testing.assert_allclose(got, expect, rtol=1e-6) + + def test_collective_matmul_reduce_scatter(self): + key1, key2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1) + lhs = jax.random.uniform(key1, shape=(8, 64), dtype=jnp.float32) + rhs = jax.random.uniform(key2, shape=(64, 64), dtype=jnp.float32) + expect = lhs @ rhs + + mesh = parallel.create_device_mesh(jax.devices(), (2, 4)) + axis_names = mesh.axis_names + rhs = jax.device_put(rhs, NamedSharding(mesh, P(axis_names, None))) + + rhs = kernel.prepare_rhs_for_collective_matmul_reduce_scatter(rhs, mesh) + + def cmrc(lhs, rhs, type, axis_names): + return kernel.build_collective_matmul(type, axis_names)(lhs, rhs) + + got = shard_map.shard_map( + f=partial( + cmrc, + type=parallel.CollectiveMatmulType.REDUCE_SCATTER, + axis_names=axis_names, + ), + mesh=mesh, + in_specs=(P(None, axis_names), P(axis_names, None)), + out_specs=P(None, axis_names), + )(lhs, rhs) + np.testing.assert_allclose(got, expect, rtol=1e-6) + + +if __name__ == "__main__": + absltest.main() diff --git a/experimental/jax/tests/model/test_llama.py b/experimental/jax/tests/model/test_llama.py new file mode 100644 index 00000000..abc763f9 --- /dev/null +++ b/experimental/jax/tests/model/test_llama.py @@ -0,0 +1,161 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from absl.testing import absltest +import numpy as np +import jax +from jax.experimental.shard_map import shard_map +from jax import numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +import torch +from transformers import AutoModelForCausalLM +from inference.model import LlamaModel, ModelRegistry +from inference import parallel +from inference import nn + + +class LlamaModelTest(absltest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_llama(self): + # TODO: make it as an accuracy test. + model_id = "meta-llama/Llama-2-7b-chat-hf" + mesh = parallel.create_device_mesh(jax.devices(), len(jax.devices())) + model_registry = ModelRegistry() + + config, tokenizer = model_registry.load_model_config( + model_id + ), model_registry.load_tokenizer(model_id) + config.num_hidden_layers = 1 + num_prefill_tokens = 16 + input_ids = tokenizer.encode("I have a dog that is", return_tensors="pt") + prompt_len = input_ids.shape[1] + hg_model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float32, config=config + ) + + outputs = hg_model(input_ids) + expected_logits = outputs.logits.detach().numpy()[0] + + tokens = jnp.asarray(input_ids)[0] + tokens = jax.lax.dynamic_update_index_in_dim( + jnp.zeros((num_prefill_tokens), dtype=jnp.int32), tokens, 0, 0 + ) + pos = jnp.arange(0, num_prefill_tokens) + kv_caches = [ + nn.KVCache( + k=jnp.zeros((config.num_key_value_heads, 32, 16, config.head_dim)), + v=jnp.zeros((config.num_key_value_heads, 32, 16, config.head_dim)), + ) + for _ in range(config.num_hidden_layers) + ] + kv_caches_sharding = [ + nn.KVCache( + k=NamedSharding( + mesh, P(parallel.tp_axis_names(), None, None, None) + ), + v=NamedSharding( + mesh, P(parallel.tp_axis_names(), None, None, None) + ), + ) + for _ in range(config.num_hidden_layers) + ] + kv_caches = jax.device_put(kv_caches, kv_caches_sharding) + attn_metadata = nn.AttentionMetadata( + prefill_length=prompt_len, + prefill_pos=pos, + prefill_page_table=jnp.asarray([0, 1, 2, 3]), + generate_pos=jnp.asarray(0), + generate_page_table=jnp.asarray(0), + ) + + attention_metadata_sharding = nn.AttentionMetadata( + prefill_length=NamedSharding(mesh, P()), + prefill_pos=NamedSharding(mesh, P(None)), + prefill_page_table=NamedSharding(mesh, P(None)), + generate_pos=NamedSharding(mesh, P()), + generate_page_table=NamedSharding(mesh, P()), + ) + attn_metadata = jax.device_put(attn_metadata, attention_metadata_sharding) + + casual_lm_weight_cpu = model_registry.load_weights_to_host( + model_id, + num_devices=np.prod(mesh.devices.shape), + dtype=jnp.float32, + model_config=config, + ) + model = LlamaModel(config, parallel.ModelParallelConfig(mesh=mesh)) + + weight_dict = model.load_weights_dict(casual_lm_weight_cpu["model"]) + weight_dict_pspec = jax.tree_util.tree_map( + lambda a: a.sharding.spec, weight_dict + ) + kv_caches_pspec = jax.tree_util.tree_map( + lambda a: a.sharding.spec, kv_caches + ) + attn_meta_pspec = jax.tree_util.tree_map( + lambda a: a.spec, attention_metadata_sharding + ) + + del casual_lm_weight_cpu + + infer_func = shard_map( + f=model.jittable_call, + mesh=mesh, + in_specs=( + weight_dict_pspec, + P(None), + P(None), + kv_caches_pspec, + attn_meta_pspec, + ), + out_specs=( + P(None, None), + kv_caches_pspec, + ), + check_rep=False, + ) + + executable = ( + jax.jit(infer_func, donate_argnums=(3,)) + .lower( + weight_dict, + jnp.asarray(tokens), + pos, + kv_caches, + attn_metadata, + ) + .compile() + ) + got_logits, _ = executable( + weight_dict, + jnp.asarray(tokens), + pos, + kv_caches, + attn_metadata, + ) + + got_logits = got_logits[:prompt_len] + np.testing.assert_allclose( + got_logits, expected_logits, atol=3e-02, rtol=1e-02 + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/experimental/jax/tests/nn/test_attention.py b/experimental/jax/tests/nn/test_attention.py new file mode 100644 index 00000000..a5311e3e --- /dev/null +++ b/experimental/jax/tests/nn/test_attention.py @@ -0,0 +1,230 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# from absl.testing import absltest +# import numpy as np +# import jax.experimental +# import jax.experimental.mesh_utils +# import jax +# from jax import numpy as jnp +# from jax.sharding import NamedSharding, PartitionSpec as P +# from inference import nn +# from inference import parallel + +# """TODO: Enable the attention test""" + + +# class AttentionTest(absltest.TestCase): +# def test_kv_cache_sharding(self): +# axis = parallel.tp_axis_names() +# k_hbm = jnp.zeros((8, 8, 8, 128)) +# v_hbm = jnp.copy(k_hbm) +# num_layer = 3 +# mesh = parallel.create_device_mesh(jax.devices(), shape=len(jax.devices())) +# kv_cache = [nn.KVCache( +# k=k_hbm, +# v=v_hbm, +# ) for _ in range(num_layer)] +# kv_cache_sharding = [nn.KVCache( +# k=NamedSharding(mesh, P(axis, None, None, None)), +# v=NamedSharding(mesh, P(axis, None, None, None)), +# ) for _ in range(num_layer)] + +# def sharding(a, b): +# res = jax.device_put(a, b) +# return res + +# kv_cache = jax.tree.map(sharding, kv_cache, kv_cache_sharding) +# for i in range(num_layer): +# self.assertIsInstance(kv_cache[i].k.sharding, NamedSharding) +# self.assertIsInstance(kv_cache[i].v.sharding, NamedSharding) + +# def test_prefill(self): +# """Only available on TPU.""" +# num_attn_heads = 16 +# num_kv_heads = 8 +# head_dim = 128 +# total_page_num = 16 +# page_size = 8 +# attn_layer = nn.AttentionOps(num_attn_heads, num_kv_heads, head_dim) +# k_hbm, v_hbm = ( +# jnp.zeros((num_kv_heads, total_page_num, page_size, head_dim)), +# jnp.zeros((num_kv_heads, total_page_num, page_size, head_dim)) +# ) + +# kv_cache = nn.KVCache(k=k_hbm ,v=v_hbm) + +# prefill_len = 16 +# prefill_non_padding_len = 12 +# q = jnp.ones(((prefill_len, num_attn_heads, head_dim))) +# k_to_save, v_to_save = ( +# jnp.ones((prefill_len, num_kv_heads, head_dim)), +# jnp.ones((prefill_len, num_kv_heads, head_dim)) +# ) +# num_page_to_use = prefill_len // page_size +# # using the second and third page to save the kv cache. +# page_table = jnp.array([1, 3, 0, 0, 0, 0]) +# output, kv_cache = attn_layer._prefill( +# q, +# k_to_save, +# v_to_save, +# kv_cache, +# nn.AttentionMetadata( +# prefill_length=prefill_non_padding_len, +# prefill_pos=jnp.arange(0, 8), +# prefill_page_table=page_table, +# generate_pos=0, +# generate_page_table=0 +# ), +# ) +# kv_cache = attn_layer._write_prefill_kv_to_kv_cache( +# k_to_save, +# v_to_save, +# kv_cache, +# prefill_non_padding_len, +# page_table +# ) +# np.testing.assert_allclose( +# kv_cache.k[:, page_table[:2], :, :], +# jnp.ones((num_kv_heads, num_page_to_use , page_size, head_dim)) +# ) +# np.testing.assert_allclose( +# kv_cache.v[:, page_table[:2], :, :], +# jnp.ones((num_kv_heads, num_page_to_use, page_size, head_dim)) +# ) +# zero_index = [i for i in range(page_size)] +# zero_index = ( +# zero_index[0:page_table[0]] + +# zero_index[page_table[0]+1:page_table[1]] + +# zero_index[page_table[1]+1:] +# ) + +# np.testing.assert_allclose( +# kv_cache.k[:, zero_index, :, :], +# jnp.zeros((num_kv_heads, page_size-num_page_to_use , page_size, head_dim)) +# ) +# np.testing.assert_allclose( +# kv_cache.v[:, zero_index ,:,:], +# jnp.zeros((num_kv_heads, page_size-num_page_to_use , page_size, head_dim)) +# ) +# np.testing.assert_equal(output.shape, (prefill_len, num_attn_heads * head_dim)) + +# def test_generate_cache_update(self): +# num_attn_heads = 16 +# num_kv_heads = 8 +# head_dim = 128 +# total_page_num = 128 +# page_size = 8 +# max_len = 32 +# attn_layer = nn.AttentionOps(num_attn_heads, num_kv_heads, head_dim) +# k_hbm, v_hbm = ( +# jnp.zeros((num_kv_heads, total_page_num, page_size, head_dim)), +# jnp.zeros((num_kv_heads, total_page_num, page_size, head_dim)), +# ) + +# kv_cache = nn.KVCache(k=k_hbm ,v=v_hbm) + +# generate_len = 4 +# k_to_save, v_to_save = ( +# jnp.ones((generate_len, num_kv_heads, head_dim)), +# jnp.ones((generate_len, num_kv_heads, head_dim)) +# ) +# prng = jax.random.PRNGKey(99) +# page_table = jnp.asarray(np.random.choice(total_page_num, (generate_len, max_len//page_size), replace=False)) +# page_pos = jax.random.randint(prng, shape=(generate_len,), +# minval=0, maxval=total_page_num * page_size) +# kv_cache = attn_layer._write_generate_kv_to_kv_cache( +# k_to_save, +# v_to_save, +# kv_cache, +# page_pos, +# page_table +# ) + +# page_idx, offset = jnp.divmod(page_pos, page_size) +# page_to_update = page_table[jnp.arange(0, generate_len), page_idx] + +# np.testing.assert_allclose( +# kv_cache.k[:, page_to_update, offset,:], +# jnp.ones_like(kv_cache.k[:, page_to_update , offset,:]), +# ) + +# np.testing.assert_allclose( +# kv_cache.v[:, page_to_update, offset,:], +# jnp.ones_like(kv_cache.v[:, page_to_update , offset,:]), +# ) + +# np.testing.assert_allclose(jnp.sum(kv_cache.k), generate_len * num_kv_heads * head_dim) + +# def test_generate(self): +# """Only available on TPU.""" +# num_attn_heads = 16 +# num_kv_heads = 8 +# head_dim = 128 +# total_page_num = 128 +# page_size = 8 +# max_len = 64 +# attn_layer = nn.AttentionOps(num_attn_heads, num_kv_heads, head_dim) +# k_hbm, v_hbm = ( +# jnp.zeros((num_kv_heads, total_page_num, page_size, head_dim)), +# jnp.zeros((num_kv_heads, total_page_num, page_size, head_dim)), +# ) + +# kv_cache = nn.KVCache(k=k_hbm ,v=v_hbm) + +# num_generate_tokens = 4 +# q = jnp.ones((num_generate_tokens, num_attn_heads, head_dim)) +# k_to_save, v_to_save = ( +# jnp.ones((num_generate_tokens, num_kv_heads, head_dim)), +# jnp.ones((num_generate_tokens, num_kv_heads, head_dim)), +# ) +# prng = jax.random.PRNGKey(99) +# page_table = jnp.asarray(np.random.choice(total_page_num, (num_generate_tokens, max_len//page_size), replace=False)) +# page_pos = jax.random.randint(prng, shape=(num_generate_tokens,), +# minval=0, maxval=total_page_num * page_size) + +# output, kv_cache = attn_layer._generate( +# q, +# k_to_save, +# v_to_save, +# kv_cache, +# nn.AttentionMetadata( +# prefill_length=0, +# prefill_pos=0, +# prefill_page_table=0, +# generate_pos=page_pos, +# generate_page_table=page_table, +# ), +# ) + +# page_idx, offset = jnp.divmod(page_pos, page_size) +# page_to_update = page_table[jnp.arange(0, num_generate_tokens), page_idx] + +# np.testing.assert_allclose( +# kv_cache.k[:, page_to_update, offset,:], +# jnp.ones_like(kv_cache.k[:, page_to_update , offset,:]), +# ) + +# np.testing.assert_allclose( +# kv_cache.v[:, page_to_update, offset,:], +# jnp.ones_like(kv_cache.v[:, page_to_update , offset,:]), +# ) + +# np.testing.assert_allclose(jnp.sum(kv_cache.k), num_generate_tokens * num_kv_heads * head_dim) +# np.testing.assert_equal(output.shape, (num_generate_tokens, num_attn_heads * head_dim)) + +# if __name__ == "__main__": +# absltest.main() diff --git a/experimental/jax/tests/nn/test_embedding.py b/experimental/jax/tests/nn/test_embedding.py new file mode 100644 index 00000000..e6665b83 --- /dev/null +++ b/experimental/jax/tests/nn/test_embedding.py @@ -0,0 +1,54 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from absl.testing import absltest +import os +import numpy as np +import jax.experimental.mesh_utils +import jax +from inference import parallel +from inference import nn + + +class EmbeddingTest(absltest.TestCase): + + def test_embedding(self): + mesh = parallel.create_device_mesh( + devices=jax.devices(), + shape=len(jax.devices()), + ) + vocal_size = 2048 + emb_dim = 8192 + embedding_layer = nn.Embedding( + vocal_size, + emb_dim, + parallel_config=parallel.EmbeddingParallelConfig( + mesh=mesh, + parallel_type=parallel.EmbeddingParallelType.COLUMN, + ), + ) + key = jax.random.key(0) + emb_table = jax.random.uniform(key, (vocal_size, emb_dim)) + input = jax.random.randint(key, (96,), 0, 2048) + expect = emb_table[input] + + embedding_layer.load_weights_dict({"weight": emb_table}) + got = embedding_layer(input) + np.testing.assert_allclose(got, expect) + + +if __name__ == "__main__": + absltest.main() diff --git a/experimental/jax/tests/nn/test_module.py b/experimental/jax/tests/nn/test_module.py new file mode 100644 index 00000000..0c6f7bd1 --- /dev/null +++ b/experimental/jax/tests/nn/test_module.py @@ -0,0 +1,126 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from absl.testing import absltest +import numpy as np +from jax import numpy as jnp +from inference.nn import Module, Parameter + + +class ModuleTest(absltest.TestCase): + + def test_random_code_initialize(self): + w0, w1, w2, w3 = ( + jnp.ones((1,)), + jnp.ones((2,)), + jnp.ones((3,)), + jnp.ones((4,)), + ) + parent_module = Module() + parent_module.w0 = Parameter(w0) + + h1_child_0_module = Module() + h1_child_0_module.w1 = Parameter(w1) + + h1_child_1_module = Module() + h1_child_1_module.w2 = Parameter(w2) + + h2_child_0_module = Module() + h2_child_0_module.w3 = Parameter(w3) + + parent_module.child0 = h1_child_0_module + parent_module.child1 = h1_child_1_module + h1_child_0_module.child0 = h2_child_0_module + + parent_module.init_weights() + + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + w0, + parent_module.w0.value, + ) + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + w1, + h1_child_0_module.w1.value, + ) + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + w2, + h1_child_1_module.w2.value, + ) + np.testing.assert_raises( + AssertionError, + np.testing.assert_array_equal, + w3, + h2_child_0_module.w3.value, + ) + + def test_load_weights_dict(self): + w0, w1, w2, w3 = ( + jnp.ones((1,)), + jnp.ones((2,)), + jnp.ones((3,)), + jnp.ones((4,)), + ) + parent_module = Module() + parent_module.w0 = Parameter(w0) + + h1_child_0_module = Module() + h1_child_0_module.w1 = Parameter(w1) + + h1_child_1_module = Module() + h1_child_1_module.w2 = Parameter(w2) + + h2_child_0_module = Module() + h2_child_0_module.w3 = Parameter(w3) + + parent_module.child0 = h1_child_0_module + parent_module.child1 = h1_child_1_module + h1_child_0_module.child0 = h2_child_0_module + print(parent_module) + + partial_parent_weight_dict = { + "w0": jnp.zeros((1,)), + "child0": { + "w1": jnp.zeros((2,)), + "child0": { + "w3": jnp.zeros((4,)), + }, + }, + } + + child1_weight_dict = { + "w2": jnp.zeros((2,)), + "wrong_weight_not_load": jnp.zeros((2,)), + } + + parent_module.load_weights_dict(partial_parent_weight_dict) + h1_child_1_module.load_weights_dict(child1_weight_dict) + + np.testing.assert_array_equal(parent_module.w0, 0) + np.testing.assert_array_equal(h1_child_0_module.w1, 0) + np.testing.assert_array_equal(h1_child_1_module.w2, 0) + np.testing.assert_array_equal(h2_child_0_module.w3, 0) + + assert not h1_child_1_module.wrong_weight_not_load + + +if __name__ == "__main__": + absltest.main() diff --git a/experimental/jax/tests/nn/test_norm.py b/experimental/jax/tests/nn/test_norm.py new file mode 100644 index 00000000..2cd0946a --- /dev/null +++ b/experimental/jax/tests/nn/test_norm.py @@ -0,0 +1,90 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from absl.testing import absltest +import os +import jax.experimental.shard_map +import numpy as np +import jax.experimental +import jax.experimental.mesh_utils +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +from inference import parallel +from inference import nn + + +class NormTest(absltest.TestCase): + + def test_rmsnorm_per_device_forward(self): + mesh = parallel.create_device_mesh( + devices=jax.devices(), + shape=len(jax.devices()), + ) + hidden_state_size = 128 + eps = 1e-6 + rmsnorm_layer = nn.RMSNorm( + dim=hidden_state_size, + eps=eps, + parallel_config=parallel.RMSNormParallelConfig( + mesh=mesh, + activation_shared=False, + ), + ) + distributed_rmsnorm_layer = nn.RMSNorm( + dim=hidden_state_size, + eps=eps, + parallel_config=parallel.RMSNormParallelConfig( + mesh=mesh, + activation_shared=True, + ), + ) + + key = jax.random.PRNGKey(0) + input = jax.random.uniform(key, (96, hidden_state_size)) + weight = jax.random.uniform(key, (hidden_state_size,)) + sharded_weight = jnp.copy(weight) + + rmsnorm_layer.load_weights_dict({"weight": weight}) + dis_weight = distributed_rmsnorm_layer.load_weights_dict( + {"weight": sharded_weight} + ) + expect = rmsnorm_layer(input) + + sharded_input = jax.device_put( + input, NamedSharding(mesh, P(None, parallel.tp_axis_names())) + ) + + def distributed_rms_ag(weight, input): + output = distributed_rmsnorm_layer.jittable_call(weight, input) + return parallel.ops.all_gather(output, 1, parallel.tp_axis_names()) + + got = jax.experimental.shard_map.shard_map( + distributed_rms_ag, + mesh, + in_specs=( + P(parallel.tp_axis_names()), + P(None, parallel.tp_axis_names()), + ), + out_specs=P(None, None), + check_rep=False, + )(dis_weight, sharded_input) + + np.testing.assert_allclose(got, expect, atol=1e-6, rtol=1e-7) + + +if __name__ == "__main__": + absltest.main() diff --git a/experimental/jax/tests/parallel/test_operations.py b/experimental/jax/tests/parallel/test_operations.py new file mode 100644 index 00000000..2ca10ccb --- /dev/null +++ b/experimental/jax/tests/parallel/test_operations.py @@ -0,0 +1,70 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from absl.testing import absltest +from functools import partial +import os +import numpy as np +import jax.experimental +import jax.experimental.mesh_utils +import jax +from jax import numpy as jnp +from jax.experimental import shard_map +from jax.sharding import PartitionSpec as P +from inference import parallel + + +class CollectiveOperationsTest(absltest.TestCase): + + def _build_mesh(self): + axis = "x" + device_mesh = jax.experimental.mesh_utils.create_device_mesh( + len(jax.devices()), jax.devices() + ) + mesh = jax.sharding.Mesh(device_mesh, ("x")) + return mesh, axis + + def test_reduce_scatter(self): + key = jax.random.key(99) + operand = jax.random.uniform(key, shape=(16 * 32, 1024), dtype=jnp.float32) + mesh, axis = self._build_mesh() + + expect = shard_map.shard_map( + f=partial( + jax.lax.psum_scatter, + axis_name=axis, + scatter_dimension=1, + tiled=True, + ), + mesh=mesh, + in_specs=P(axis, None), + out_specs=P(None, axis), + )(operand) + + got = shard_map.shard_map( + f=partial( + parallel.ops.reduce_scatter, axis_names=axis, scatter_dimension=1 + ), + mesh=mesh, + in_specs=P(axis, None), + out_specs=P(None, axis), + )(operand) + + np.testing.assert_allclose(got, expect, rtol=1e-6) + + +if __name__ == "__main__": + absltest.main() diff --git a/google3/third_party/py/jetstream/core/metrics/prometheus.py b/google3/third_party/py/jetstream/core/metrics/prometheus.py deleted file mode 100644 index dc8a00e9..00000000 --- a/google3/third_party/py/jetstream/core/metrics/prometheus.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Contains common functions for configuring Jetstream server metrics""" - -import os -import shortuuid -from prometheus_client import Counter, Gauge, Histogram -from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS - - -class JetstreamMetricsCollector: - """Wrapper class should be used to assure all metrics have proper tags""" - - _id: str = os.getenv("HOSTNAME", shortuuid.uuid()) - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(JetstreamMetricsCollector, cls).__new__(cls) - return cls.instance - - # Metric definitions - _prefill_backlog = Gauge( - name="jetstream_prefill_backlog_size", - documentation="Size of prefill queue", - labelnames=["id"], - ) - - _transfer_backlog = Gauge( - name="jetstream_transfer_backlog_size", - documentation="Size of transfer queue", - labelnames=["id", "idx"], - ) - - _generate_backlog = Gauge( - name="jetstream_generate_backlog_size", - documentation="Size of generate queue", - labelnames=["id", "idx"], - ) - - _queue_duration = Histogram( - name="jetstream_queue_duration", - documentation="The total time each request spends enqueued in seconds", - labelnames=["id"], - buckets=[ - 0.01, - 0.02, - 0.05, - 0.1, - 0.2, - 0.5, - 1.0, - 2.0, - 5.0, - 10.0, - 20.0, - 50.0, - 100.0, - ], - ) - - _slots_used_percentage = Gauge( - name="jetstream_slots_used_percentage", - documentation="The percentage of decode slots currently being used", - labelnames=["id", "idx"], - ) - - _server_startup_latency = Gauge( - name="jetstream_server_startup_latency", - documentation="Total time taken to start the Jetstream server", - labelnames=["id"], - ) - _request_input_length = Histogram( - name="jetstream_request_input_length", - documentation="Number of input tokens per request", - labelnames=["id"], - buckets=DEFAULT_PREFILL_BUCKETS, - ) - _request_output_length = Histogram( - name="jetstream_request_output_length", - documentation="Number of output tokens per request", - labelnames=["id"], - buckets=[ - 1, - 2, - 5, - 10, - 20, - 50, - 100, - 200, - 500, - 1000, - 2000, - 5000, - 10000, - 20000, - 50000, - 100000, - 200000, - 500000, - 1000000, - 2000000, - ], - ) - _request_success_count = Counter( - name="jetstream_request_success_count", - documentation="Number of requests successfully completed", - labelnames=["id"], - ) - - _time_to_first_token = Histogram( - name="jetstream_time_to_first_token", - documentation="Time to first token per request in seconds", - labelnames=["id"], - buckets=[ - 0.001, - 0.005, - 0.01, - 0.02, - 0.04, - 0.06, - 0.08, - 0.1, - 0.25, - 0.5, - 0.75, - 1.0, - 2.5, - 5.0, - 7.5, - 10.0, - ], - ) - - _time_per_output_token = Histogram( - name="jetstream_time_per_output_token", - documentation="Average time per output token per request in seconds", - labelnames=["id"], - buckets=[ - 0.01, - 0.025, - 0.05, - 0.075, - 0.1, - 0.15, - 0.2, - 0.3, - 0.4, - 0.5, - 0.75, - 1.0, - 2.5, - ], - ) - - _time_per_prefill_token = Histogram( - name="jetstream_time_per_prefill_token", - documentation="Prefill time per token per request in seconds", - labelnames=["id"], - buckets=[ - 0.00001, - 0.00002, - 0.00005, - 0.0001, - 0.0002, - 0.0005, - 0.001, - 0.002, - 0.005, - 0.01, - 0.02, - 0.05, - 0.1, - ], - ) - - _time_per_request = Histogram( - name="jetstream_time_per_request", - documentation="End to end request latency in seconds", - labelnames=["id"], - buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0], - ) - - _wait_time_per_request = Histogram( - name="jetstream_wait_time_per_request", - documentation="Time each request is not being prefilled or decoded", - labelnames=["id"], - buckets=[ - 0.01, - 0.02, - 0.05, - 0.1, - 0.2, - 0.5, - 1.0, - 2.0, - 5.0, - 10.0, - 20.0, - 50.0, - 100.0, - ], - ) - - def get_prefill_backlog_metric(self): - return self._prefill_backlog.labels(id=self._id) - - def get_transfer_backlog_metric(self, idx: int): - return self._transfer_backlog.labels(id=self._id, idx=idx) - - def get_generate_backlog_metric(self, idx: int): - return self._generate_backlog.labels(id=self._id, idx=idx) - - def get_queue_duration(self): - return self._queue_duration.labels(id=self._id) - - def get_slots_used_percentage_metric(self, idx: int): - return self._slots_used_percentage.labels(id=self._id, idx=idx) - - def get_server_startup_latency_metric(self): - return self._server_startup_latency.labels(id=self._id) - - def get_time_to_first_token(self): - return self._time_to_first_token.labels(id=self._id) - - def get_time_per_output_token(self): - return self._time_per_output_token.labels(id=self._id) - - def get_time_per_prefill_token(self): - return self._time_per_prefill_token.labels(id=self._id) - - def get_time_per_request(self): - return self._time_per_request.labels(id=self._id) - - def get_wait_time_per_request(self): - return self._wait_time_per_request.labels(id=self._id) - - def get_request_input_length(self): - return self._request_input_length.labels(id=self._id) - - def get_request_output_length(self): - return self._request_output_length.labels(id=self._id) - - def get_request_success_count_metric(self): - return self._request_success_count.labels(id=self._id) diff --git a/google3/third_party/py/jetstream/__init__.py b/jetstream/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/__init__.py rename to jetstream/__init__.py diff --git a/google3/third_party/py/jetstream/core/README.md b/jetstream/core/README.md similarity index 100% rename from google3/third_party/py/jetstream/core/README.md rename to jetstream/core/README.md diff --git a/google3/third_party/py/jetstream/core/__init__.py b/jetstream/core/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/__init__.py rename to jetstream/core/__init__.py diff --git a/google3/third_party/py/jetstream/core/config_lib.py b/jetstream/core/config_lib.py similarity index 92% rename from google3/third_party/py/jetstream/core/config_lib.py rename to jetstream/core/config_lib.py index f3022d01..83059706 100644 --- a/google3/third_party/py/jetstream/core/config_lib.py +++ b/jetstream/core/config_lib.py @@ -16,7 +16,7 @@ import dataclasses import functools -from typing import Any, Callable, List, Tuple, Type +from typing import Any, Callable, List, Optional, Tuple, Type from numpy import uint16 from jetstream.engine import engine_api @@ -39,6 +39,11 @@ class ServerConfig: generate_engine_create_fns: Tuple[CreateEngineFn, ...] = () interleaved_engine_create_fns: Tuple[CreateEngineFn, ...] = () is_ray_backend: bool = False + # Parameters for customized gc config, increase the numbers here will + # potentially increase memory usage. + gc_gen0_allocs: int = 60000 # default is 700, too frequent sometimes. + gc_gen1_multipler: int = 2 # Make gen1 gc runs less frequent + gc_gen2_multipler: int = 3 # Make gen2 gc runs less frequent @dataclasses.dataclass @@ -51,6 +56,7 @@ class InstantiatedEngines: @dataclasses.dataclass class MetricsServerConfig: port: uint16 + model_name: Optional[str] = None # ▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼# diff --git a/google3/third_party/py/jetstream/core/implementations/__init__.py b/jetstream/core/implementations/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/implementations/__init__.py rename to jetstream/core/implementations/__init__.py diff --git a/google3/third_party/py/jetstream/core/implementations/mock/README.md b/jetstream/core/implementations/mock/README.md similarity index 100% rename from google3/third_party/py/jetstream/core/implementations/mock/README.md rename to jetstream/core/implementations/mock/README.md diff --git a/google3/third_party/py/jetstream/core/implementations/mock/__init__.py b/jetstream/core/implementations/mock/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/implementations/mock/__init__.py rename to jetstream/core/implementations/mock/__init__.py diff --git a/google3/third_party/py/jetstream/core/implementations/mock/config.py b/jetstream/core/implementations/mock/config.py similarity index 100% rename from google3/third_party/py/jetstream/core/implementations/mock/config.py rename to jetstream/core/implementations/mock/config.py diff --git a/google3/third_party/py/jetstream/core/implementations/mock/server.py b/jetstream/core/implementations/mock/server.py similarity index 100% rename from google3/third_party/py/jetstream/core/implementations/mock/server.py rename to jetstream/core/implementations/mock/server.py diff --git a/google3/third_party/py/jetstream/core/metrics/__init__.py b/jetstream/core/metrics/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/metrics/__init__.py rename to jetstream/core/metrics/__init__.py diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py new file mode 100644 index 00000000..34475e23 --- /dev/null +++ b/jetstream/core/metrics/prometheus.py @@ -0,0 +1,291 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains common functions for configuring Jetstream server metrics""" + +import os +import re +from typing import Optional +import shortuuid +from prometheus_client import Counter, Gauge, Histogram +from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS + + +class JetstreamMetricsCollector: + """Wrapper class should be used to assure all metrics have proper tags""" + + _initialized: bool = False + _model_name: str + universal_labels = {"id": os.getenv("HOSTNAME", shortuuid.uuid())} + + def __new__(cls, model_name: Optional[str] = None): + if not hasattr(cls, "instance"): + cls.instance = super(JetstreamMetricsCollector, cls).__new__(cls) + return cls.instance + + def __init__(self, model_name: Optional[str] = None): + if hasattr(self, "_initialized") and self._initialized: + return + self._initialized = True + + # '-'s are common in model names but invalid in prometheus labels + # these are replaced with '_'s + if model_name is not None: + sanitized_model_name = model_name.replace("-", "_") + if sanitized_model_name == "": + print("No model name provided, omitting from metrics labels") + elif not bool( + re.match(r"^[a-zA-Z_:][a-zA-Z0-9_:]*$", sanitized_model_name) + ): + print( + "Provided model name cannot be used to label prometheus metrics", + "(does not match ^[a-zA-Z_:][a-zA-Z0-9_:]*$)", + "omitting from metrics labels", + ) + else: + self.universal_labels["model_name"] = sanitized_model_name + universal_label_names = list(self.universal_labels.keys()) + + # Metric definitions + self._prefill_backlog = Gauge( + name="jetstream_prefill_backlog_size", + documentation="Size of prefill queue", + labelnames=universal_label_names, + ) + + self._transfer_backlog = Gauge( + name="jetstream_transfer_backlog_size", + documentation="Size of transfer queue", + labelnames=universal_label_names + ["idx"], + ) + + self._generate_backlog = Gauge( + name="jetstream_generate_backlog_size", + documentation="Size of generate queue", + labelnames=universal_label_names + ["idx"], + ) + + self._queue_duration = Histogram( + name="jetstream_queue_duration", + documentation="The total time each request spends enqueued in seconds", + labelnames=universal_label_names, + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ) + + self._slots_used_percentage = Gauge( + name="jetstream_slots_used_percentage", + documentation="The percentage of decode slots currently being used", + labelnames=universal_label_names + ["idx"], + ) + self._model_load_time = Gauge( + name="jetstream_model_load_time", + documentation="Total time taken to load the model", + labelnames=universal_label_names, + ) + self._server_startup_latency = Gauge( + name="jetstream_server_startup_latency", + documentation="Total time taken to start the Jetstream server", + labelnames=universal_label_names, + ) + self._request_input_length = Histogram( + name="jetstream_request_input_length", + documentation="Number of input tokens per request", + labelnames=universal_label_names, + buckets=DEFAULT_PREFILL_BUCKETS, + ) + self._request_output_length = Histogram( + name="jetstream_request_output_length", + documentation="Number of output tokens per request", + labelnames=universal_label_names, + buckets=[ + 1, + 2, + 5, + 10, + 20, + 50, + 100, + 200, + 500, + 1000, + 2000, + 5000, + 10000, + 20000, + 50000, + 100000, + 200000, + 500000, + 1000000, + 2000000, + ], + ) + self._request_success_count = Counter( + name="jetstream_request_success_count", + documentation="Number of requests successfully completed", + labelnames=universal_label_names, + ) + + self._time_to_first_token = Histogram( + name="jetstream_time_to_first_token", + documentation="Time to first token per request in seconds", + labelnames=universal_label_names, + buckets=[ + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + ], + ) + + self._time_per_output_token = Histogram( + name="jetstream_time_per_output_token", + documentation="Average time per output token per request in seconds", + labelnames=universal_label_names, + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + ], + ) + + self._time_per_prefill_token = Histogram( + name="jetstream_time_per_prefill_token", + documentation="Prefill time per token per request in seconds", + labelnames=universal_label_names, + buckets=[ + 0.00001, + 0.00002, + 0.00005, + 0.0001, + 0.0002, + 0.0005, + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + 0.1, + ], + ) + + self._time_per_request = Histogram( + name="jetstream_time_per_request", + documentation="End to end request latency in seconds", + labelnames=universal_label_names, + buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0], + ) + + self._wait_time_per_request = Histogram( + name="jetstream_wait_time_per_request", + documentation="Time each request is not being prefilled or decoded", + labelnames=universal_label_names, + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ) + + def get_prefill_backlog_metric(self): + return self._prefill_backlog.labels(**self.universal_labels) + + def get_transfer_backlog_metric(self, idx: int): + return self._transfer_backlog.labels(**self.universal_labels, idx=idx) + + def get_generate_backlog_metric(self, idx: int): + return self._generate_backlog.labels(**self.universal_labels, idx=idx) + + def get_queue_duration(self): + return self._queue_duration.labels(**self.universal_labels) + + def get_slots_used_percentage_metric(self, idx: int): + return self._slots_used_percentage.labels(**self.universal_labels, idx=idx) + + def get_server_startup_latency_metric(self): + return self._server_startup_latency.labels(**self.universal_labels) + + def get_model_load_time_metric(self): + return self._model_load_time.labels(**self.universal_labels) + + def get_time_to_first_token(self): + return self._time_to_first_token.labels(**self.universal_labels) + + def get_time_per_output_token(self): + return self._time_per_output_token.labels(**self.universal_labels) + + def get_time_per_prefill_token(self): + return self._time_per_prefill_token.labels(**self.universal_labels) + + def get_time_per_request(self): + return self._time_per_request.labels(**self.universal_labels) + + def get_wait_time_per_request(self): + return self._wait_time_per_request.labels(**self.universal_labels) + + def get_request_input_length(self): + return self._request_input_length.labels(**self.universal_labels) + + def get_request_output_length(self): + return self._request_output_length.labels(**self.universal_labels) + + def get_request_success_count_metric(self): + return self._request_success_count.labels(**self.universal_labels) diff --git a/google3/third_party/py/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py similarity index 96% rename from google3/third_party/py/jetstream/core/orchestrator.py rename to jetstream/core/orchestrator.py index 15fc36dd..1975de84 100644 --- a/google3/third_party/py/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -74,6 +74,7 @@ to debug hangs due to bugs in threads (it is easier to debug with live logs). """ +from datetime import datetime import dataclasses import functools import itertools @@ -85,7 +86,7 @@ import threading import time import traceback -from typing import Any, AsyncIterator, Optional, Tuple, cast +from typing import Any, AsyncIterator, Optional, Tuple, cast, List import grpc import jax @@ -98,10 +99,10 @@ import numpy as np root = logging.getLogger() -root.setLevel(logging.INFO) +root.setLevel(logging.WARNING) handler = logging.StreamHandler(sys.stdout) -handler.setLevel(logging.INFO) +handler.setLevel(logging.WARNING) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) @@ -113,18 +114,25 @@ class ActiveRequestMetadata: """Inference request metadata.""" - start_time: Optional[float] = None + start_time: float = 0.0 - prefill_enqueue_time: Optional[float] = None - prefill_dequeue_time: Optional[float] = None + prefill_enqueue_time: float = 0.0 + prefill_dequeue_time: float = 0.0 - transfer_enqueue_time: Optional[float] = None - transfer_dequeue_time: Optional[float] = None + transfer_enqueue_time: float = 0.0 + transfer_dequeue_time: float = 0.0 - generate_enqueue_time: Optional[float] = None - generate_dequeue_time: Optional[float] = None + generate_enqueue_time: float = 0.0 + generate_dequeue_time: float = 0.0 - complete_time: Optional[float] = None + complete_time: float = 0.0 + + def stats(self) -> str: + return ( + f"{self.prefill_enqueue_time - self.start_time:.2f};" + f"{self.prefill_dequeue_time - self.prefill_enqueue_time:.2f};" + f"{time.perf_counter() - self.prefill_dequeue_time:.2f}" + ) @dataclasses.dataclass @@ -245,7 +253,7 @@ def __init__( if generate_params is None: generate_params = [] - logging.info( + logging.warning( "Initialising driver with %d prefill engines and %d generate engines.", len(prefill_engines), len(generate_engines), @@ -396,7 +404,7 @@ def __init__( ), name=f"prefill_detokenize-{idx}", ) - for idx in range(len(self._generate_engines)) + for idx in range(len(self._prefill_engines)) ] self.generate_detokenize_threads = [ JetThread( @@ -476,6 +484,9 @@ def get_total_concurrent_requests(self) -> int: ) return total_max_concurrent_decodes + def prefill_backlog_size(self): + return self._prefill_backlog.qsize() + def place_request_on_prefill_queue(self, request: ActiveRequest): """Used to place new requests for prefilling and generation.""" # Don't block so we can fail and shed load when the queue is full. @@ -920,7 +931,9 @@ def _get_prefill_content( True, ) - def process_client_side_tokenization_response(self, response: Any): + def _process_client_side_tokenization_response( + self, response: list[ReturnSample] + ): samples = [] for sample in response: samples.append( @@ -934,15 +947,15 @@ def process_client_side_tokenization_response(self, response: Any): ) ) - def should_buffer_response(self, response: Any) -> bool: + def should_buffer_response(self, response: List[ReturnSample]) -> bool: for item in response: if item.text and token_utils.is_byte_token(item.text[-1]): # If any sample ends in bytes, this means we might still need to # decode more bytes to compose the string. return True - def process_server_side_tokenization_response( - self, response: Any, buffered_response_list + def _process_server_side_tokenization_response( + self, response: list[ReturnSample], buffered_response_list ): # Flush the buffered responses to each sample of current response. current_response_with_flushed_buffer = list( @@ -980,6 +993,8 @@ async def Decode( # pylint: disable=invalid-overridden-method context: Optional[grpc.aio.ServicerContext] = None, ) -> AsyncIterator[jetstream_pb2.DecodeResponse]: """Decode.""" + request_start_time = time.perf_counter() + ttft = 0 if context is None: logging.warning( "LLM orchestrator is being used in offline test mode, and will not" @@ -1031,11 +1046,20 @@ async def Decode( # pylint: disable=invalid-overridden-method buffered_response_list = [] async for response in active_request.return_channel: response = cast(list[ReturnSample], response) + if ttft == 0: + ttft = time.perf_counter() - request_start_time + if ttft > 2.0: + print( + datetime.now(), + f"Slow TTFT: {ttft:.2f}s," + f" stats={active_request.metadata.stats()}," + f" prefill_qsize={self._driver.prefill_backlog_size()}", + ) if is_client_side_tokenization: # If is_client_side_tokenization, the client should request with token # ids, and the JetStream server will return token ids as response. # The client should take care of tokenization and detokenization. - yield self.process_client_side_tokenization_response(response) + yield self._process_client_side_tokenization_response(response) else: # Buffer response mechanism is used to handle streaming # detokenization with special character (For some edge cases with @@ -1044,7 +1068,7 @@ async def Decode( # pylint: disable=invalid-overridden-method if self.should_buffer_response(response): buffered_response_list.append(response) continue - yield self.process_server_side_tokenization_response( + yield self._process_server_side_tokenization_response( response, buffered_response_list ) # Reset buffer after flushed. diff --git a/google3/third_party/py/jetstream/core/proto/__init__.py b/jetstream/core/proto/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/proto/__init__.py rename to jetstream/core/proto/__init__.py diff --git a/google3/third_party/py/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto similarity index 100% rename from google3/third_party/py/jetstream/core/proto/jetstream.proto rename to jetstream/core/proto/jetstream.proto diff --git a/google3/third_party/py/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py similarity index 100% rename from google3/third_party/py/jetstream/core/proto/jetstream_pb2.py rename to jetstream/core/proto/jetstream_pb2.py diff --git a/google3/third_party/py/jetstream/core/proto/jetstream_pb2_grpc.py b/jetstream/core/proto/jetstream_pb2_grpc.py similarity index 100% rename from google3/third_party/py/jetstream/core/proto/jetstream_pb2_grpc.py rename to jetstream/core/proto/jetstream_pb2_grpc.py diff --git a/google3/third_party/py/jetstream/core/server_lib.py b/jetstream/core/server_lib.py similarity index 92% rename from google3/third_party/py/jetstream/core/server_lib.py rename to jetstream/core/server_lib.py index b323286a..d860ccab 100644 --- a/google3/third_party/py/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -19,6 +19,7 @@ import asyncio from concurrent import futures +import gc import logging import os import signal @@ -113,10 +114,15 @@ def create_driver( An orchestrator driver. """ engines = config_lib.get_engines(config, devices=devices) + model_load_start_time = time.time() prefill_params = [pe.load_params() for pe in engines.prefill_engines] generate_params = [ge.load_params() for ge in engines.generate_engines] shared_params = [ie.load_params() for ie in engines.interleaved_engines] logging.info("Loaded all weights.") + if metrics_collector: + metrics_collector.get_model_load_time_metric().set( + time.time() - model_load_start_time + ) interleaved_mode = ( len(config.prefill_slices) + len(config.generate_slices) == 0 ) @@ -205,7 +211,9 @@ def run( "Starting Prometheus server on port %d", metrics_server_config.port ) start_http_server(metrics_server_config.port) - metrics_collector = JetstreamMetricsCollector() + metrics_collector = JetstreamMetricsCollector( + model_name=metrics_server_config.model_name + ) else: logging.info( "Not starting Prometheus server: --prometheus_port flag not set" @@ -218,8 +226,20 @@ def run( # to make sure we can fully saturate the model. Set default minimum to 64. threads = threads or max(driver.get_total_concurrent_requests(), 64) jetstream_server = JetStreamServer(driver, threads, port, credentials) - logging.info("Starting server on port %d with %d threads", port, threads) + # Tweak gc config. + # Force a gen 2 collection here. + gc.collect(generation=2) + # Freeze objects currently tracked and ignore them in future gc runs. + gc.freeze() + allocs, gen1, gen2 = gc.get_threshold() + allocs = config.gc_gen0_allocs + gen1 = gen1 * config.gc_gen1_multipler + gen2 = gen2 * config.gc_gen2_multipler + gc.set_threshold(allocs, gen1, gen2) + print("GC tweaked (allocs, gen1, gen2): ", allocs, gen1, gen2) + + logging.info("Starting server on port %d with %d threads", port, threads) jetstream_server.start() if metrics_collector: diff --git a/google3/third_party/py/jetstream/core/utils/__init__.py b/jetstream/core/utils/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/core/utils/__init__.py rename to jetstream/core/utils/__init__.py diff --git a/google3/third_party/py/jetstream/core/utils/async_multifuture.py b/jetstream/core/utils/async_multifuture.py similarity index 100% rename from google3/third_party/py/jetstream/core/utils/async_multifuture.py rename to jetstream/core/utils/async_multifuture.py diff --git a/google3/third_party/py/jetstream/core/utils/proxy_util.py b/jetstream/core/utils/proxy_util.py similarity index 100% rename from google3/third_party/py/jetstream/core/utils/proxy_util.py rename to jetstream/core/utils/proxy_util.py diff --git a/google3/third_party/py/jetstream/core/utils/return_sample.py b/jetstream/core/utils/return_sample.py similarity index 100% rename from google3/third_party/py/jetstream/core/utils/return_sample.py rename to jetstream/core/utils/return_sample.py diff --git a/google3/third_party/py/jetstream/engine/README.md b/jetstream/engine/README.md similarity index 100% rename from google3/third_party/py/jetstream/engine/README.md rename to jetstream/engine/README.md diff --git a/google3/third_party/py/jetstream/engine/__init__.py b/jetstream/engine/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/engine/__init__.py rename to jetstream/engine/__init__.py diff --git a/google3/third_party/py/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py similarity index 100% rename from google3/third_party/py/jetstream/engine/engine_api.py rename to jetstream/engine/engine_api.py diff --git a/google3/third_party/py/jetstream/engine/mock_engine.py b/jetstream/engine/mock_engine.py similarity index 100% rename from google3/third_party/py/jetstream/engine/mock_engine.py rename to jetstream/engine/mock_engine.py diff --git a/google3/third_party/py/jetstream/engine/mock_utils.py b/jetstream/engine/mock_utils.py similarity index 100% rename from google3/third_party/py/jetstream/engine/mock_utils.py rename to jetstream/engine/mock_utils.py diff --git a/google3/third_party/py/jetstream/engine/sampling_utils.py b/jetstream/engine/sampling_utils.py similarity index 100% rename from google3/third_party/py/jetstream/engine/sampling_utils.py rename to jetstream/engine/sampling_utils.py diff --git a/google3/third_party/py/jetstream/engine/token_utils.py b/jetstream/engine/token_utils.py similarity index 100% rename from google3/third_party/py/jetstream/engine/token_utils.py rename to jetstream/engine/token_utils.py diff --git a/google3/third_party/py/jetstream/engine/tokenizer.proto b/jetstream/engine/tokenizer.proto similarity index 100% rename from google3/third_party/py/jetstream/engine/tokenizer.proto rename to jetstream/engine/tokenizer.proto diff --git a/google3/third_party/py/jetstream/engine/tokenizer_api.py b/jetstream/engine/tokenizer_api.py similarity index 100% rename from google3/third_party/py/jetstream/engine/tokenizer_api.py rename to jetstream/engine/tokenizer_api.py diff --git a/google3/third_party/py/jetstream/engine/tokenizer_pb2.py b/jetstream/engine/tokenizer_pb2.py similarity index 100% rename from google3/third_party/py/jetstream/engine/tokenizer_pb2.py rename to jetstream/engine/tokenizer_pb2.py diff --git a/google3/third_party/py/jetstream/engine/tokenizer_pb2_grpc.py b/jetstream/engine/tokenizer_pb2_grpc.py similarity index 100% rename from google3/third_party/py/jetstream/engine/tokenizer_pb2_grpc.py rename to jetstream/engine/tokenizer_pb2_grpc.py diff --git a/google3/third_party/py/jetstream/engine/warmup_utils.py b/jetstream/engine/warmup_utils.py similarity index 100% rename from google3/third_party/py/jetstream/engine/warmup_utils.py rename to jetstream/engine/warmup_utils.py diff --git a/google3/third_party/py/jetstream/entrypoints/__init__.py b/jetstream/entrypoints/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/__init__.py rename to jetstream/entrypoints/__init__.py diff --git a/google3/third_party/py/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/config.py rename to jetstream/entrypoints/config.py diff --git a/google3/third_party/py/jetstream/entrypoints/http/__init__.py b/jetstream/entrypoints/http/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/http/__init__.py rename to jetstream/entrypoints/http/__init__.py diff --git a/google3/third_party/py/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py similarity index 97% rename from google3/third_party/py/jetstream/entrypoints/http/api_server.py rename to jetstream/entrypoints/http/api_server.py index aaced235..5bbf2411 100644 --- a/google3/third_party/py/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -111,7 +111,9 @@ def server(argv: Sequence[str]): "Starting Prometheus server on port %d", metrics_server_config.port ) start_http_server(metrics_server_config.port) - metrics_collector = JetstreamMetricsCollector() + metrics_collector = JetstreamMetricsCollector( + model_name=metrics_server_config.model_name + ) else: logging.info( "Not starting Prometheus server: --prometheus_port flag not set" diff --git a/google3/third_party/py/jetstream/entrypoints/http/protocol.py b/jetstream/entrypoints/http/protocol.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/http/protocol.py rename to jetstream/entrypoints/http/protocol.py diff --git a/google3/third_party/py/jetstream/entrypoints/http/utils.py b/jetstream/entrypoints/http/utils.py similarity index 100% rename from google3/third_party/py/jetstream/entrypoints/http/utils.py rename to jetstream/entrypoints/http/utils.py diff --git a/google3/third_party/py/jetstream/external_tokenizers/__init__.py b/jetstream/external_tokenizers/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/external_tokenizers/__init__.py rename to jetstream/external_tokenizers/__init__.py diff --git a/jetstream/external_tokenizers/llama3/__init__.py b/jetstream/external_tokenizers/llama3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/google3/third_party/py/jetstream/external_tokenizers/llama3/llama3_tokenizer.py b/jetstream/external_tokenizers/llama3/llama3_tokenizer.py similarity index 100% rename from google3/third_party/py/jetstream/external_tokenizers/llama3/llama3_tokenizer.py rename to jetstream/external_tokenizers/llama3/llama3_tokenizer.py diff --git a/google3/third_party/py/jetstream/tests/__init__.py b/jetstream/tests/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/tests/__init__.py rename to jetstream/tests/__init__.py diff --git a/google3/third_party/py/jetstream/tests/core/__init__.py b/jetstream/tests/core/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/tests/core/__init__.py rename to jetstream/tests/core/__init__.py diff --git a/google3/third_party/py/jetstream/tests/core/test_config_lib.py b/jetstream/tests/core/test_config_lib.py similarity index 100% rename from google3/third_party/py/jetstream/tests/core/test_config_lib.py rename to jetstream/tests/core/test_config_lib.py diff --git a/google3/third_party/py/jetstream/tests/core/test_orchestrator.py b/jetstream/tests/core/test_orchestrator.py similarity index 98% rename from google3/third_party/py/jetstream/tests/core/test_orchestrator.py rename to jetstream/tests/core/test_orchestrator.py index 00e2e1c1..0861014c 100644 --- a/google3/third_party/py/jetstream/tests/core/test_orchestrator.py +++ b/jetstream/tests/core/test_orchestrator.py @@ -14,7 +14,7 @@ """Integration test of the orchestrator. -This test tests the multi-htreaded orchestrator, where a prefill request is +This test tests the multithreaded orchestrator, where a prefill request is popped onto a prefill queue, prefilled, sent to a generation queue and run for a number of decoding steps. diff --git a/google3/third_party/py/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py similarity index 97% rename from google3/third_party/py/jetstream/tests/core/test_server.py rename to jetstream/tests/core/test_server.py index 2fdddce9..9cab05cd 100644 --- a/google3/third_party/py/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -84,7 +84,9 @@ async def test_server( config=config, devices=devices, credentials=credentials, - metrics_server_config=config_lib.MetricsServerConfig(port=metrics_port) + metrics_server_config=config_lib.MetricsServerConfig( + port=metrics_port, model_name="some_model_name" + ) if metrics_enabled is True else None, ) diff --git a/google3/third_party/py/jetstream/tests/engine/__init__.py b/jetstream/tests/engine/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/__init__.py rename to jetstream/tests/engine/__init__.py diff --git a/google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model b/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model rename to jetstream/tests/engine/external_tokenizers/llama2/tokenizer.model diff --git a/google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model b/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model rename to jetstream/tests/engine/external_tokenizers/llama3/tokenizer.model diff --git a/google3/third_party/py/jetstream/tests/engine/test_mock_engine.py b/jetstream/tests/engine/test_mock_engine.py similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/test_mock_engine.py rename to jetstream/tests/engine/test_mock_engine.py diff --git a/google3/third_party/py/jetstream/tests/engine/test_sampling_utils.py b/jetstream/tests/engine/test_sampling_utils.py similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/test_sampling_utils.py rename to jetstream/tests/engine/test_sampling_utils.py diff --git a/google3/third_party/py/jetstream/tests/engine/test_token_utils.py b/jetstream/tests/engine/test_token_utils.py similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/test_token_utils.py rename to jetstream/tests/engine/test_token_utils.py diff --git a/google3/third_party/py/jetstream/tests/engine/test_utils.py b/jetstream/tests/engine/test_utils.py similarity index 100% rename from google3/third_party/py/jetstream/tests/engine/test_utils.py rename to jetstream/tests/engine/test_utils.py diff --git a/google3/third_party/py/jetstream/tests/entrypoints/__init__.py b/jetstream/tests/entrypoints/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/tests/entrypoints/__init__.py rename to jetstream/tests/entrypoints/__init__.py diff --git a/google3/third_party/py/jetstream/tests/entrypoints/http/__init__.py b/jetstream/tests/entrypoints/http/__init__.py similarity index 100% rename from google3/third_party/py/jetstream/tests/entrypoints/http/__init__.py rename to jetstream/tests/entrypoints/http/__init__.py diff --git a/google3/third_party/py/jetstream/tests/entrypoints/http/test_api_server.py b/jetstream/tests/entrypoints/http/test_api_server.py similarity index 100% rename from google3/third_party/py/jetstream/tests/entrypoints/http/test_api_server.py rename to jetstream/tests/entrypoints/http/test_api_server.py diff --git a/google3/third_party/py/jetstream/tools/load_tester.py b/jetstream/tools/load_tester.py similarity index 100% rename from google3/third_party/py/jetstream/tools/load_tester.py rename to jetstream/tools/load_tester.py diff --git a/google3/third_party/py/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh similarity index 100% rename from google3/third_party/py/jetstream/tools/maxtext/model_ckpt_conversion.sh rename to jetstream/tools/maxtext/model_ckpt_conversion.sh diff --git a/google3/third_party/py/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh b/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh similarity index 100% rename from google3/third_party/py/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh rename to jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh diff --git a/google3/third_party/py/jetstream/tools/proxy_dev/base.Dockerfile b/jetstream/tools/proxy_dev/base.Dockerfile similarity index 100% rename from google3/third_party/py/jetstream/tools/proxy_dev/base.Dockerfile rename to jetstream/tools/proxy_dev/base.Dockerfile diff --git a/google3/third_party/py/jetstream/tools/proxy_dev/dev.Dockerfile b/jetstream/tools/proxy_dev/dev.Dockerfile similarity index 100% rename from google3/third_party/py/jetstream/tools/proxy_dev/dev.Dockerfile rename to jetstream/tools/proxy_dev/dev.Dockerfile diff --git a/google3/third_party/py/jetstream/tools/requester.py b/jetstream/tools/requester.py similarity index 100% rename from google3/third_party/py/jetstream/tools/requester.py rename to jetstream/tools/requester.py diff --git a/google3/third_party/py/jetstream/license_preamble.txt b/license_preamble.txt similarity index 100% rename from google3/third_party/py/jetstream/license_preamble.txt rename to license_preamble.txt diff --git a/google3/third_party/py/jetstream/pylintrc b/pylintrc similarity index 100% rename from google3/third_party/py/jetstream/pylintrc rename to pylintrc diff --git a/google3/third_party/py/jetstream/requirements.txt b/requirements.txt similarity index 100% rename from google3/third_party/py/jetstream/requirements.txt rename to requirements.txt diff --git a/google3/third_party/py/jetstream/setup.py b/setup.py similarity index 98% rename from google3/third_party/py/jetstream/setup.py rename to setup.py index 55b91c3b..658885cd 100644 --- a/google3/third_party/py/jetstream/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ def parse_requirements(filename): setup( name="google-jetstream", - version="0.2.2", + version="0.2.3", description=( "JetStream is a throughput and memory optimized engine for LLM inference on XLA devices, starting with TPUs (and GPUs in future -- PRs welcome)." ),