From 6e02cee711650e4a5febb88487f6e7589adc7ebe Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 9 Jul 2024 19:04:57 +0530 Subject: [PATCH 1/8] Update visualize trace --- neuralmagic/tools/profiler/visualize_trace.py | 457 ++++++++++++------ 1 file changed, 315 insertions(+), 142 deletions(-) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index fd5659161b046..5579a439f843a 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -1,11 +1,48 @@ import argparse import json +import math +import copy +from pathlib import Path +from typing import Optional, List, Tuple, Any import matplotlib.pyplot as plt import pandas as pd +## JSON parsing utils #### + +def largest_dist_from_leaf(node: dict, depth:int=0): + if len(node["children"]) == 0: + return depth + return max([ + largest_dist_from_leaf(child, depth=depth + 1) + for child in node["children"] + ]) + +def get_entries_at_depth(depth: int, + entries_and_traces: List[Tuple[Any, Any]], + node: dict, + ignore_sampler: bool, + curr_depth:int =0, + trace=()): + if ignore_sampler and node["entry"]["name"] == "Sampler": + return + + if (depth >= 0 and depth == curr_depth) or ( + depth < 0 + and largest_dist_from_leaf(node) == (abs(depth) - 1)): + entries_and_traces.append((node["entry"], trace)) + trace = (node["entry"]["name"], ) + trace + for child in node["children"]: + get_entries_at_depth(depth, + entries_and_traces, + child, + ignore_sampler, + curr_depth=curr_depth + 1, + trace=trace) + +## Operation name cleanup utils #### -def trim_string_back(string: str, width: int): +def trim_string_back(string: str, width: int) -> str: if len(string) > width: offset = len(string) - width + 3 string = string[:-offset] @@ -13,8 +50,13 @@ def trim_string_back(string: str, width: int): string = string + "..." return string +def shorten_plot_legend_strings(legend, max_char_len: int): + for t in legend.get_texts(): + t.set_text( + trim_string_back(abbreviate_known_names(t.get_text()), + max_char_len)) -def abbreviate_known_names(name: str): +def abbreviate_known_names(name: str) -> str: abbreviations = { "MergedColumnParallelLinear": "MCPLinear", "QKVParallelLinear": "QKVPLinear", @@ -27,13 +69,266 @@ def abbreviate_known_names(name: str): name = name.replace(key, value) return name +def attempt_to_make_names_unique(entries_and_traces): + names, non_unique_names = (set(), set()) + + def all_the_same(items) -> bool: + return all(i == items[0] for i in items) -def shorten_plot_legend_strings(legend, max_char_len: int): - for t in legend.get_texts(): - t.set_text( - trim_string_back(abbreviate_known_names(t.get_text()), - max_char_len)) + for entry, _ in entries_and_traces: + if entry["name"] in names: + non_unique_names.add(entry["name"]) + else: + names.add(entry["name"]) + + for name in non_unique_names: + entries_and_traces_with_name = [ + (entry, trace) for entry, trace in entries_and_traces + if entry["name"] == name + ] + + zipped_traces = list( + zip(*[trace for _, trace in entries_and_traces_with_name])) + first_trace_difference = next( + (i for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles)), None) + + if first_trace_difference is None: + # can't create a unique name, leave them names as the + # are they will get aggregated by the pivot_table call + continue + + for entry, trace in entries_and_traces_with_name: + entry["name"] = " <- ".join((entry["name"], ) + + trace[:first_trace_difference + 1]) + +## Operation grouping utils #### + +''' + Group operations in the given dataframe by some high-level ops like, + - gemms + - attention + - rms_norm + etc. +''' +def group_trace_by_operations(trace_df : pd.DataFrame) -> pd.DataFrame: + + def is_rms_norm(op_name: str): + if "rms_norm_kernel" in op_name: + return True + + def is_attention_block(op_name:str): + if "flash_fwd" in op_name or \ + "reshape_and_cache_flash_kernel" in op_name: + return True + + def is_quant(op_name: str): + if "scaled_fp8_quant" in op_name or \ + "scaled_int8_quant" in op_name: + return True + + def is_gemm_op(op_name : str): + if is_quant(op_name): + return False + if "xmma_gemm" in op_name or \ + "gemv2T_kernel" in op_name or \ + "splitKreduce" in op_name or \ + "void cutlass::Kernel" in op_name or \ + "void cutlass::device_kernel" in op_name or \ + "s16816gemm" in op_name: + return True + + def is_elementwise_op(op_name : str): + return "elementwise_kernel" in op_name + + def is_mem_op(op_name: str): + return "memcpy" in op_name.lower() or \ + "memset" in op_name.lower() + + def is_vocab_embedding_op(op_name: str): + return "vocabparallelembed" in op_name.lower() + + headers = list(trace_df) + ops = copy.deepcopy(headers) + + attention_ops = list(filter(lambda x: is_attention_block(x), ops)) + ops = list(filter(lambda x: x not in attention_ops, ops)) + + quant_ops = list(filter(lambda x: is_quant(x), ops)) + ops = list(filter(lambda x: x not in quant_ops, ops)) + + gemm_ops = list(filter(lambda x: is_gemm_op(x), ops)) + ops = list(filter(lambda x: x not in gemm_ops, ops)) + + rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops)) + ops = list(filter(lambda x: x not in rms_norm_ops, ops)) + + vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops)) + ops = list(filter(lambda x: x not in vocab_embed_ops, ops)) + + mem_ops = list(filter(lambda x: is_mem_op(x), ops)) + ops = list(filter(lambda x: x not in mem_ops, ops)) + + elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops)) + ops = list(filter(lambda x: x not in elementwise_ops, ops)) + + remaining_ops = ops + + if len(attention_ops): + trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) + if len(quant_ops): + trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) + if len(gemm_ops): + trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) + if len(rms_norm_ops): + trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1) + if len(vocab_embed_ops): + trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum", axis=1) + if len(mem_ops): + trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1) + if len(elementwise_ops): + trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum", axis=1) + + trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops + mem_ops + elementwise_ops, + axis=1, inplace=True) + return trace_df + +## Data plotting utils #### + +def plot_trace_df(traces_df: pd.DataFrame, + plot_title: str, + output: Optional[Path] = None): + + traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000 + traces_df = traces_df.fillna(0) + + phases = traces_df['phase'].unique() + traces_df = traces_df.pivot_table(index="phase", + columns="name", + values="cuda_time_ms", + aggfunc="sum") + + traces_df = group_trace_by_operations(traces_df) + + # Make the figure + fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True) + + # Draw the stacked bars + ops = list(traces_df) + bottom = [0] * len(phases) + for op in ops: + values = [traces_df[op][phase] for phase in phases] + values = list(map(lambda x: 0.0 if math.isnan(x) else x, values)) + ax.bar(phases, values, label=op, bottom=bottom) + bottom = [bottom[j] + values[j] for j in range(len(phases))] + + # Write the values as text on the bars + for bar in ax.patches: + if bar.get_height() != 0: + v = round(bar.get_height(), 2) + ax.text(bar.get_x() + bar.get_width() / 2, + bar.get_height() / 2 + bar.get_y(), + f"{round(bar.get_height(), 2)}", + ha = 'center', color = 'w', weight = 'bold', + size = 5) + + # Setup legend + handles, labels = plt.gca().get_legend_handles_labels() + legend = fig.legend(handles, + labels, + loc='center left', + bbox_to_anchor=(1, 1)) + shorten_plot_legend_strings(legend, 50) + + # Setup labels and title + plt.setp(ax.get_xticklabels(), rotation=90) + ax.set_ylabel("time ms") + plt.suptitle(plot_title) + + plt.savefig(output, bbox_inches='tight') + print("Created: ", output) + +def main(json_trace: Path, + output_directory: Path, + depth: int, # Fetch/Plot operations at this depth of the Json tree + make_names_unique: bool, + top_k: int, + ignore_sampler: bool): + + def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame: + + def get_entries_and_traces(key: str): + entries_and_traces : List[Tuple[Any, Any]] = [] + for root in profile_json[key]["summary_stats"]: + get_entries_at_depth(depth, entries_and_traces, root, ignore_sampler) + return entries_and_traces + + def keep_only_top_entries(df : pd.DataFrame, metric: str, top_k:int =9) -> pd.DataFrame: + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, + ["name"]] = "others" + return df + + # Get data for each key + traces = list(map(lambda x: get_entries_and_traces(x), step_keys)) + + # Attempt some cleanup + if make_names_unique: + traces = list(map(lambda x: attempt_to_make_names_unique(x), traces)) + + # To pandas dataframe + trace_dfs = list(map(lambda t: pd.DataFrame( + [entry for entry, _ in t]).fillna(0), + traces)) + + # Respect top_k + if top_k: + trace_dfs = list(map(lambda trace_df: keep_only_top_entries( + trace_df, "cuda_time_us", top_k), trace_dfs)) + + # Fill in information about the step-keys + for trace_df, step_key in zip(trace_dfs, step_keys): + trace_df['phase'] = step_key + + # Combine all data frames so they can be put in a single plot + traces_df = pd.concat(trace_dfs) + return traces_df + + def make_plot_title_suffix(profile_json: dict) -> str: + context = profile_json["context"] + sparsity = context.get('sparsity', None) + return (f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"OutputLen={context['output_len']}," + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}") + + profile_json = None + with open(json_trace, "r") as f: + profile_json = json.load(f) + assert profile_json is not None + # Get all `llm.generate.step()` profile + step_traces = list(profile_json.keys()) + assert (step_traces[0] == 'context') + step_traces = step_traces[1:] # have only prefill and decodes + prefills = list(filter(lambda x: "prefill" in x, step_traces)) + all_decodes = list(filter(lambda x: "decode" in x, step_traces)) + assert len(prefills) + len(all_decodes) == len(step_traces) + assert len(prefills) == 1 + + decodes = all_decodes[::args.step_plot_interval] + if decodes[-1] != all_decodes[-1]: + # Always have the last decode + decodes.append(all_decodes[-1]) + + prefill_traces = prepare_data(profile_json, prefills) + decode_traces = prepare_data(profile_json, decodes) + + plot_title_suffix = make_plot_title_suffix(profile_json) + + plot_trace_df(prefill_traces, "prefill " + plot_title_suffix, output_directory / Path("prefill.png")) + plot_trace_df(decode_traces, "decodes " + plot_title_suffix, output_directory / Path("decode_steps.png")) if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -44,11 +339,10 @@ def shorten_plot_legend_strings(legend, max_char_len: int): required=True, help="json trace file output by examples/offline_profile.py") parser.add_argument( - "--output", + "--output-directory", type=str, required=False, - help="Output figure file, should be a image file such as pdf, " - "jpeg, png, etc., defaults to .pdf") + help="Directory to output plots") parser.add_argument("--level", type=str, default="module", @@ -60,13 +354,15 @@ def shorten_plot_legend_strings(legend, max_char_len: int): parser.add_argument("--ignore_sampler", action='store_true', help="Ignore everything under the \"Sampler\" module") + parser.add_argument("--step-plot-interval", + type=int, + default=4, + help="For every `step_plot_interval` steps, plot 1 step") args = parser.parse_args() - ignore_sampler = args.ignore_sampler + # Prepare/Extract relevant args make_names_unique = False - top_k = args.top_k - if args.level == "module": depth = -2 make_names_unique = True @@ -75,135 +371,12 @@ def shorten_plot_legend_strings(legend, max_char_len: int): else: raise Exception(f"Unexpected level value ({args.level})") - if ignore_sampler: + output_directory = args.output_directory if args.output_directory else Path(args.json_trace).parent + + if args.ignore_sampler: print("WARNING: ignoring Sampler time so the pct_cuda_time will not " "add up to 100%") - json_trace = args.json_trace - output = args.output if args.output else json_trace.strip(".json") + ".pdf" - - with open(json_trace, "r") as f: - profile_data = json.load(f) - - prefill_entries_and_traces = [] - decode_entries_and_traces = [] - - def largest_dist_from_leaf(node, depth=0): - if len(node["children"]) == 0: - return depth - return max([ - largest_dist_from_leaf(child, depth=depth + 1) - for child in node["children"] - ]) - - def get_entries_at_depth(depth, - entries_and_traces, - node, - curr_depth=0, - trace=()): - if ignore_sampler and node["entry"]["name"] == "Sampler": - return - - if (depth >= 0 and depth == curr_depth) or ( - depth < 0 - and largest_dist_from_leaf(node) == (abs(depth) - 1)): - entries_and_traces.append((node["entry"], trace)) - trace = (node["entry"]["name"], ) + trace - for child in node["children"]: - get_entries_at_depth(depth, - entries_and_traces, - child, - curr_depth=curr_depth + 1, - trace=trace) - - for root in profile_data["prefill"]["summary_stats"]: - get_entries_at_depth(depth, prefill_entries_and_traces, root) - for root in profile_data["decode_1"]["summary_stats"]: - get_entries_at_depth(depth, decode_entries_and_traces, root) - - def attempt_to_make_names_unique(entries_and_traces): - names, non_unique_names = (set(), set()) - - def all_the_same(items) -> bool: - return all(i == items[0] for i in items) - - for entry, _ in entries_and_traces: - if entry["name"] in names: - non_unique_names.add(entry["name"]) - else: - names.add(entry["name"]) - - for name in non_unique_names: - entries_and_traces_with_name = [ - (entry, trace) for entry, trace in entries_and_traces - if entry["name"] == name - ] - - zipped_traces = list( - zip(*[trace for _, trace in entries_and_traces_with_name])) - first_trace_difference = next( - (i for i, trace_eles in enumerate(zipped_traces) - if not all_the_same(trace_eles)), None) - - if first_trace_difference is None: - # can't create a unique name, leave them names as the - # are they will get aggregated by the pivot_table call - continue - - for entry, trace in entries_and_traces_with_name: - entry["name"] = " <- ".join((entry["name"], ) + - trace[:first_trace_difference + 1]) - - if make_names_unique: - attempt_to_make_names_unique(prefill_entries_and_traces) - attempt_to_make_names_unique(decode_entries_and_traces) - - def keep_only_top_entries(df, metric, top_k=9): - df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, - ["name"]] = "others" - - prefill_df = pd.DataFrame( - [entry for entry, _ in prefill_entries_and_traces]) - prefill_df["phase"] = "prefill" - decode_df = pd.DataFrame([entry for entry, _ in decode_entries_and_traces]) - decode_df["phase"] = "decode" - - if top_k: - keep_only_top_entries(prefill_df, "cuda_time_us", top_k) - keep_only_top_entries(decode_df, "cuda_time_us", top_k) - - df = pd.concat([prefill_df, decode_df]) - df["cuda_time_ms"] = df["cuda_time_us"] / 1000 - - fig, axes = plt.subplots(2, figsize=(5, 8), sharex=True) - - def plot_metric(metric: str, ax, add_totals=False): - pivoted_df = df.pivot_table(index="phase", - columns="name", - values=metric, - aggfunc="sum") - pivoted_df.plot.bar(stacked=True, legend=False, ax=ax) - ax.set_ylabel(metric) - - if add_totals: - ax.bar_label(ax.containers[-1]) - - plot_metric("cuda_time_ms", ax=axes[0], add_totals=True) - plot_metric("pct_cuda_time", ax=axes[1]) - - handles, labels = plt.gca().get_legend_handles_labels() - legend = fig.legend(handles, - labels, - loc='center left', - bbox_to_anchor=(0.93, 0.5)) - shorten_plot_legend_strings(legend, 50) - - context = profile_data["context"] - sparsity = context.get('sparsity', None) - plt.suptitle(f"{context['model']}\n" - f"Batch={context['batch_size']}, " - f"PromptLen={context['prompt_len']}, " - f"NumGpus={context['tensor_parallel_size']}" - f"{', Sparsity ' + sparsity if sparsity else ''}") - plt.savefig(output, bbox_inches='tight') - print("Created: ", output) + main (Path(args.json_trace), output_directory, + depth, make_names_unique, + args.top_k, args.ignore_sampler) From 1fa7a826c5f76131fb3ac15c44d6c7c0e0cf051c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 9 Jul 2024 14:18:34 +0000 Subject: [PATCH 2/8] format.sh --- neuralmagic/tools/profiler/visualize_trace.py | 159 ++++++++++-------- 1 file changed, 91 insertions(+), 68 deletions(-) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index 5579a439f843a..8def1ade1598b 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -1,16 +1,17 @@ import argparse +import copy import json import math -import copy from pathlib import Path -from typing import Optional, List, Tuple, Any +from typing import Any, List, Optional, Tuple import matplotlib.pyplot as plt import pandas as pd ## JSON parsing utils #### -def largest_dist_from_leaf(node: dict, depth:int=0): + +def largest_dist_from_leaf(node: dict, depth: int = 0): if len(node["children"]) == 0: return depth return max([ @@ -18,18 +19,18 @@ def largest_dist_from_leaf(node: dict, depth:int=0): for child in node["children"] ]) + def get_entries_at_depth(depth: int, entries_and_traces: List[Tuple[Any, Any]], node: dict, ignore_sampler: bool, - curr_depth:int =0, + curr_depth: int = 0, trace=()): if ignore_sampler and node["entry"]["name"] == "Sampler": return if (depth >= 0 and depth == curr_depth) or ( - depth < 0 - and largest_dist_from_leaf(node) == (abs(depth) - 1)): + depth < 0 and largest_dist_from_leaf(node) == (abs(depth) - 1)): entries_and_traces.append((node["entry"], trace)) trace = (node["entry"]["name"], ) + trace for child in node["children"]: @@ -40,8 +41,10 @@ def get_entries_at_depth(depth: int, curr_depth=curr_depth + 1, trace=trace) + ## Operation name cleanup utils #### + def trim_string_back(string: str, width: int) -> str: if len(string) > width: offset = len(string) - width + 3 @@ -50,12 +53,14 @@ def trim_string_back(string: str, width: int) -> str: string = string + "..." return string + def shorten_plot_legend_strings(legend, max_char_len: int): for t in legend.get_texts(): t.set_text( trim_string_back(abbreviate_known_names(t.get_text()), max_char_len)) + def abbreviate_known_names(name: str) -> str: abbreviations = { "MergedColumnParallelLinear": "MCPLinear", @@ -69,6 +74,7 @@ def abbreviate_known_names(name: str) -> str: name = name.replace(key, value) return name + def attempt_to_make_names_unique(entries_and_traces): names, non_unique_names = (set(), set()) @@ -82,10 +88,9 @@ def all_the_same(items) -> bool: names.add(entry["name"]) for name in non_unique_names: - entries_and_traces_with_name = [ - (entry, trace) for entry, trace in entries_and_traces - if entry["name"] == name - ] + entries_and_traces_with_name = [(entry, trace) + for entry, trace in entries_and_traces + if entry["name"] == name] zipped_traces = list( zip(*[trace for _, trace in entries_and_traces_with_name])) @@ -102,8 +107,8 @@ def all_the_same(items) -> bool: entry["name"] = " <- ".join((entry["name"], ) + trace[:first_trace_difference + 1]) -## Operation grouping utils #### +## Operation grouping utils #### ''' Group operations in the given dataframe by some high-level ops like, - gemms @@ -111,13 +116,15 @@ def all_the_same(items) -> bool: - rms_norm etc. ''' -def group_trace_by_operations(trace_df : pd.DataFrame) -> pd.DataFrame: + + +def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: def is_rms_norm(op_name: str): if "rms_norm_kernel" in op_name: return True - def is_attention_block(op_name:str): + def is_attention_block(op_name: str): if "flash_fwd" in op_name or \ "reshape_and_cache_flash_kernel" in op_name: return True @@ -127,7 +134,7 @@ def is_quant(op_name: str): "scaled_int8_quant" in op_name: return True - def is_gemm_op(op_name : str): + def is_gemm_op(op_name: str): if is_quant(op_name): return False if "xmma_gemm" in op_name or \ @@ -138,7 +145,7 @@ def is_gemm_op(op_name : str): "s16816gemm" in op_name: return True - def is_elementwise_op(op_name : str): + def is_elementwise_op(op_name: str): return "elementwise_kernel" in op_name def is_mem_op(op_name: str): @@ -148,7 +155,7 @@ def is_mem_op(op_name: str): def is_vocab_embedding_op(op_name: str): return "vocabparallelembed" in op_name.lower() - headers = list(trace_df) + headers = list(trace_df) ops = copy.deepcopy(headers) attention_ops = list(filter(lambda x: is_attention_block(x), ops)) @@ -172,8 +179,6 @@ def is_vocab_embedding_op(op_name: str): elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops)) ops = list(filter(lambda x: x not in elementwise_ops, ops)) - remaining_ops = ops - if len(attention_ops): trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) if len(quant_ops): @@ -183,18 +188,24 @@ def is_vocab_embedding_op(op_name: str): if len(rms_norm_ops): trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1) if len(vocab_embed_ops): - trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum", axis=1) + trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum", + axis=1) if len(mem_ops): trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1) if len(elementwise_ops): - trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum", axis=1) + trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum", + axis=1) - trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops + mem_ops + elementwise_ops, - axis=1, inplace=True) + trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops + + vocab_embed_ops + mem_ops + elementwise_ops, + axis=1, + inplace=True) return trace_df + ## Data plotting utils #### + def plot_trace_df(traces_df: pd.DataFrame, plot_title: str, output: Optional[Path] = None): @@ -204,9 +215,9 @@ def plot_trace_df(traces_df: pd.DataFrame, phases = traces_df['phase'].unique() traces_df = traces_df.pivot_table(index="phase", - columns="name", - values="cuda_time_ms", - aggfunc="sum") + columns="name", + values="cuda_time_ms", + aggfunc="sum") traces_df = group_trace_by_operations(traces_df) @@ -225,12 +236,13 @@ def plot_trace_df(traces_df: pd.DataFrame, # Write the values as text on the bars for bar in ax.patches: if bar.get_height() != 0: - v = round(bar.get_height(), 2) ax.text(bar.get_x() + bar.get_width() / 2, - bar.get_height() / 2 + bar.get_y(), - f"{round(bar.get_height(), 2)}", - ha = 'center', color = 'w', weight = 'bold', - size = 5) + bar.get_height() / 2 + bar.get_y(), + f"{round(bar.get_height(), 2)}", + ha='center', + color='w', + weight='bold', + size=5) # Setup legend handles, labels = plt.gca().get_legend_handles_labels() @@ -248,22 +260,27 @@ def plot_trace_df(traces_df: pd.DataFrame, plt.savefig(output, bbox_inches='tight') print("Created: ", output) -def main(json_trace: Path, - output_directory: Path, - depth: int, # Fetch/Plot operations at this depth of the Json tree - make_names_unique: bool, - top_k: int, - ignore_sampler: bool): + +def main( + json_trace: Path, + output_directory: Path, + depth: int, # Fetch/Plot operations at this depth of the Json tree + make_names_unique: bool, + top_k: int, + ignore_sampler: bool): def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame: def get_entries_and_traces(key: str): - entries_and_traces : List[Tuple[Any, Any]] = [] + entries_and_traces: List[Tuple[Any, Any]] = [] for root in profile_json[key]["summary_stats"]: - get_entries_at_depth(depth, entries_and_traces, root, ignore_sampler) + get_entries_at_depth(depth, entries_and_traces, root, + ignore_sampler) return entries_and_traces - def keep_only_top_entries(df : pd.DataFrame, metric: str, top_k:int =9) -> pd.DataFrame: + def keep_only_top_entries(df: pd.DataFrame, + metric: str, + top_k: int = 9) -> pd.DataFrame: df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others" return df @@ -273,17 +290,20 @@ def keep_only_top_entries(df : pd.DataFrame, metric: str, top_k:int =9) -> pd.Da # Attempt some cleanup if make_names_unique: - traces = list(map(lambda x: attempt_to_make_names_unique(x), traces)) + traces = list( + map(lambda x: attempt_to_make_names_unique(x), traces)) # To pandas dataframe - trace_dfs = list(map(lambda t: pd.DataFrame( - [entry for entry, _ in t]).fillna(0), - traces)) + trace_dfs = list( + map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), + traces)) # Respect top_k if top_k: - trace_dfs = list(map(lambda trace_df: keep_only_top_entries( - trace_df, "cuda_time_us", top_k), trace_dfs)) + trace_dfs = list( + map( + lambda trace_df: keep_only_top_entries( + trace_df, "cuda_time_us", top_k), trace_dfs)) # Fill in information about the step-keys for trace_df, step_key in zip(trace_dfs, step_keys): @@ -296,12 +316,12 @@ def keep_only_top_entries(df : pd.DataFrame, metric: str, top_k:int =9) -> pd.Da def make_plot_title_suffix(profile_json: dict) -> str: context = profile_json["context"] sparsity = context.get('sparsity', None) - return (f"{context['model']}\n" - f"Batch={context['batch_size']}, " - f"PromptLen={context['prompt_len']}, " - f"OutputLen={context['output_len']}," - f"NumGpus={context['tensor_parallel_size']}" - f"{', Sparsity ' + sparsity if sparsity else ''}") + return (f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"OutputLen={context['output_len']}," + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}") profile_json = None with open(json_trace, "r") as f: @@ -311,7 +331,7 @@ def make_plot_title_suffix(profile_json: dict) -> str: # Get all `llm.generate.step()` profile step_traces = list(profile_json.keys()) assert (step_traces[0] == 'context') - step_traces = step_traces[1:] # have only prefill and decodes + step_traces = step_traces[1:] # have only prefill and decodes prefills = list(filter(lambda x: "prefill" in x, step_traces)) all_decodes = list(filter(lambda x: "decode" in x, step_traces)) assert len(prefills) + len(all_decodes) == len(step_traces) @@ -327,8 +347,11 @@ def make_plot_title_suffix(profile_json: dict) -> str: plot_title_suffix = make_plot_title_suffix(profile_json) - plot_trace_df(prefill_traces, "prefill " + plot_title_suffix, output_directory / Path("prefill.png")) - plot_trace_df(decode_traces, "decodes " + plot_title_suffix, output_directory / Path("decode_steps.png")) + plot_trace_df(prefill_traces, "prefill " + plot_title_suffix, + output_directory / Path("prefill.png")) + plot_trace_df(decode_traces, "decodes " + plot_title_suffix, + output_directory / Path("decode_steps.png")) + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -338,26 +361,26 @@ def make_plot_title_suffix(profile_json: dict) -> str: type=str, required=True, help="json trace file output by examples/offline_profile.py") - parser.add_argument( - "--output-directory", - type=str, - required=False, - help="Directory to output plots") + parser.add_argument("--output-directory", + type=str, + required=False, + help="Directory to output plots") parser.add_argument("--level", type=str, default="module", choices=["module", "kernel"]) - parser.add_argument("--top_k", + parser.add_argument("--top-k", type=int, default=9, help="Only graph the top `top_k` entries by time.") parser.add_argument("--ignore_sampler", action='store_true', help="Ignore everything under the \"Sampler\" module") - parser.add_argument("--step-plot-interval", - type=int, - default=4, - help="For every `step_plot_interval` steps, plot 1 step") + parser.add_argument( + "--step-plot-interval", + type=int, + default=4, + help="For every `step_plot_interval` steps, plot 1 step") args = parser.parse_args() @@ -371,12 +394,12 @@ def make_plot_title_suffix(profile_json: dict) -> str: else: raise Exception(f"Unexpected level value ({args.level})") - output_directory = args.output_directory if args.output_directory else Path(args.json_trace).parent + output_directory = args.output_directory if args.output_directory else Path( + args.json_trace).parent if args.ignore_sampler: print("WARNING: ignoring Sampler time so the pct_cuda_time will not " "add up to 100%") - main (Path(args.json_trace), output_directory, - depth, make_names_unique, - args.top_k, args.ignore_sampler) + main(Path(args.json_trace), output_directory, depth, make_names_unique, + args.top_k, args.ignore_sampler) From 6c7070bb83e8b6a7603d4de217ba9951c36c0d9a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 9 Jul 2024 17:18:35 +0000 Subject: [PATCH 3/8] add fold nodes --- neuralmagic/tools/profiler/visualize_trace.py | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index 8def1ade1598b..e4e678c92d629 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -23,24 +23,38 @@ def largest_dist_from_leaf(node: dict, depth: int = 0): def get_entries_at_depth(depth: int, entries_and_traces: List[Tuple[Any, Any]], node: dict, - ignore_sampler: bool, curr_depth: int = 0, trace=()): - if ignore_sampler and node["entry"]["name"] == "Sampler": - return + # assert that the query is at kernel or module level + assert depth == -1 or depth == -2 - if (depth >= 0 and depth == curr_depth) or ( - depth < 0 and largest_dist_from_leaf(node) == (abs(depth) - 1)): + if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1): + # The tree is not tall enough! + entries_and_traces.append((node["entry"] , trace)) + return + + if largest_dist_from_leaf(node) == (abs(depth) - 1): entries_and_traces.append((node["entry"], trace)) + trace = (node["entry"]["name"], ) + trace for child in node["children"]: get_entries_at_depth(depth, entries_and_traces, child, - ignore_sampler, curr_depth=curr_depth + 1, trace=trace) +def fold_nodes(root: dict, nodes_to_fold : List[str]): + + stack : List[dict] = [root] + while len(stack) != 0: + node = stack.pop() + if node['entry']['name'] in nodes_to_fold: + node["children"] = [] + continue + for child in node["children"]: + stack.append(child) + return root ## Operation name cleanup utils #### @@ -260,22 +274,23 @@ def plot_trace_df(traces_df: pd.DataFrame, plt.savefig(output, bbox_inches='tight') print("Created: ", output) - def main( json_trace: Path, output_directory: Path, depth: int, # Fetch/Plot operations at this depth of the Json tree make_names_unique: bool, top_k: int, - ignore_sampler: bool): + json_nodes_to_fold: List[str]): def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame: def get_entries_and_traces(key: str): entries_and_traces: List[Tuple[Any, Any]] = [] for root in profile_json[key]["summary_stats"]: - get_entries_at_depth(depth, entries_and_traces, root, - ignore_sampler) + # Fold nodes in the traces as per user request. i.e. simply + # make the requested nodes leaf-nodes. + root = fold_nodes(root, json_nodes_to_fold) + get_entries_at_depth(depth, entries_and_traces, root) return entries_and_traces def keep_only_top_entries(df: pd.DataFrame, @@ -290,8 +305,8 @@ def keep_only_top_entries(df: pd.DataFrame, # Attempt some cleanup if make_names_unique: - traces = list( - map(lambda x: attempt_to_make_names_unique(x), traces)) + for trace in traces: + attempt_to_make_names_unique(trace) # To pandas dataframe trace_dfs = list( @@ -371,11 +386,11 @@ def make_plot_title_suffix(profile_json: dict) -> str: choices=["module", "kernel"]) parser.add_argument("--top-k", type=int, - default=9, + default=12, help="Only graph the top `top_k` entries by time.") - parser.add_argument("--ignore_sampler", - action='store_true', - help="Ignore everything under the \"Sampler\" module") + parser.add_argument("--fold-json-node", + nargs='+', + default=['Sampler', 'LogitsProcessor']) parser.add_argument( "--step-plot-interval", type=int, @@ -397,9 +412,5 @@ def make_plot_title_suffix(profile_json: dict) -> str: output_directory = args.output_directory if args.output_directory else Path( args.json_trace).parent - if args.ignore_sampler: - print("WARNING: ignoring Sampler time so the pct_cuda_time will not " - "add up to 100%") - main(Path(args.json_trace), output_directory, depth, make_names_unique, - args.top_k, args.ignore_sampler) + args.top_k, args.fold_json_node) From 7728378afbbbb466ea19fde1a29700bdb52e683a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 9 Jul 2024 17:18:58 +0000 Subject: [PATCH 4/8] format --- neuralmagic/tools/profiler/visualize_trace.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index e4e678c92d629..b8ccdf612aca7 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -30,8 +30,8 @@ def get_entries_at_depth(depth: int, if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1): # The tree is not tall enough! - entries_and_traces.append((node["entry"] , trace)) - return + entries_and_traces.append((node["entry"], trace)) + return if largest_dist_from_leaf(node) == (abs(depth) - 1): entries_and_traces.append((node["entry"], trace)) @@ -44,9 +44,10 @@ def get_entries_at_depth(depth: int, curr_depth=curr_depth + 1, trace=trace) -def fold_nodes(root: dict, nodes_to_fold : List[str]): - stack : List[dict] = [root] +def fold_nodes(root: dict, nodes_to_fold: List[str]): + + stack: List[dict] = [root] while len(stack) != 0: node = stack.pop() if node['entry']['name'] in nodes_to_fold: @@ -56,6 +57,7 @@ def fold_nodes(root: dict, nodes_to_fold : List[str]): stack.append(child) return root + ## Operation name cleanup utils #### @@ -274,6 +276,7 @@ def plot_trace_df(traces_df: pd.DataFrame, plt.savefig(output, bbox_inches='tight') print("Created: ", output) + def main( json_trace: Path, output_directory: Path, From 2b9a9f314d77202b265c2d9ce18d74699ab3818d Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 9 Jul 2024 23:18:32 +0530 Subject: [PATCH 5/8] Add plot-metric arc --- neuralmagic/tools/profiler/visualize_trace.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index b8ccdf612aca7..844a53c4ac5a3 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -223,16 +223,14 @@ def is_vocab_embedding_op(op_name: str): def plot_trace_df(traces_df: pd.DataFrame, + plot_metric: str, plot_title: str, output: Optional[Path] = None): - traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000 - traces_df = traces_df.fillna(0) - phases = traces_df['phase'].unique() traces_df = traces_df.pivot_table(index="phase", columns="name", - values="cuda_time_ms", + values=plot_metric, aggfunc="sum") traces_df = group_trace_by_operations(traces_df) @@ -281,6 +279,7 @@ def main( json_trace: Path, output_directory: Path, depth: int, # Fetch/Plot operations at this depth of the Json tree + plot_metric: str, make_names_unique: bool, top_k: int, json_nodes_to_fold: List[str]): @@ -329,6 +328,11 @@ def keep_only_top_entries(df: pd.DataFrame, # Combine all data frames so they can be put in a single plot traces_df = pd.concat(trace_dfs) + + # Add a derived metric `cuda_time_ms` + traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000 + traces_df = traces_df.fillna(0) + return traces_df def make_plot_title_suffix(profile_json: dict) -> str: @@ -365,9 +369,9 @@ def make_plot_title_suffix(profile_json: dict) -> str: plot_title_suffix = make_plot_title_suffix(profile_json) - plot_trace_df(prefill_traces, "prefill " + plot_title_suffix, + plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix, output_directory / Path("prefill.png")) - plot_trace_df(decode_traces, "decodes " + plot_title_suffix, + plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix, output_directory / Path("decode_steps.png")) @@ -394,6 +398,10 @@ def make_plot_title_suffix(profile_json: dict) -> str: parser.add_argument("--fold-json-node", nargs='+', default=['Sampler', 'LogitsProcessor']) + parser.add_argument("--plot-metric", + type=str, + default="cuda_time_ms", + help='Metric to plot. some options are cuda_time_us, cuda_time_ms, pct_cuda_time') parser.add_argument( "--step-plot-interval", type=int, @@ -415,5 +423,5 @@ def make_plot_title_suffix(profile_json: dict) -> str: output_directory = args.output_directory if args.output_directory else Path( args.json_trace).parent - main(Path(args.json_trace), output_directory, depth, make_names_unique, - args.top_k, args.fold_json_node) + main(Path(args.json_trace), output_directory, depth, + args.plot_metric, make_names_unique, args.top_k, args.fold_json_node) From 68f649bcd0b44db9e81e3eb2c88e7aed62240011 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 9 Jul 2024 18:51:15 +0000 Subject: [PATCH 6/8] ruff --- neuralmagic/tools/profiler/visualize_trace.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index 844a53c4ac5a3..b45d0f83fdbf4 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -401,7 +401,8 @@ def make_plot_title_suffix(profile_json: dict) -> str: parser.add_argument("--plot-metric", type=str, default="cuda_time_ms", - help='Metric to plot. some options are cuda_time_us, cuda_time_ms, pct_cuda_time') + help='Metric to plot. some options are cuda_time_ms, \ + pct_cuda_time') parser.add_argument( "--step-plot-interval", type=int, From cd81bcc9787660757e6a27b89a0d58590c02b791 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 9 Jul 2024 18:53:11 +0000 Subject: [PATCH 7/8] fix plot y label --- neuralmagic/tools/profiler/visualize_trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index b45d0f83fdbf4..8d9ff24916b8a 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -268,7 +268,7 @@ def plot_trace_df(traces_df: pd.DataFrame, # Setup labels and title plt.setp(ax.get_xticklabels(), rotation=90) - ax.set_ylabel("time ms") + ax.set_ylabel(plot_metric) plt.suptitle(plot_title) plt.savefig(output, bbox_inches='tight') From 1ebec11f21a0d0de4b8ea1a055636fa43218ada8 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 5 Aug 2024 13:19:21 +0000 Subject: [PATCH 8/8] add help desc --- neuralmagic/tools/profiler/visualize_trace.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index 8d9ff24916b8a..f4b1449281f69 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -397,7 +397,10 @@ def make_plot_title_suffix(profile_json: dict) -> str: help="Only graph the top `top_k` entries by time.") parser.add_argument("--fold-json-node", nargs='+', - default=['Sampler', 'LogitsProcessor']) + default=['Sampler', 'LogitsProcessor'], + help='Do not plot the children of these nodes. Let, \ + the node represent the aggregate of all its \ + children') parser.add_argument("--plot-metric", type=str, default="cuda_time_ms",