From 5388722646feb176b6ac5e162d065974ea648968 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Tue, 24 Sep 2024 10:59:16 -0500 Subject: [PATCH 1/5] Add the general SD api framework --- apps/shark_studio/api/shark_api.py | 200 +++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 apps/shark_studio/api/shark_api.py diff --git a/apps/shark_studio/api/shark_api.py b/apps/shark_studio/api/shark_api.py new file mode 100644 index 0000000000..93cf255e92 --- /dev/null +++ b/apps/shark_studio/api/shark_api.py @@ -0,0 +1,200 @@ + +# Internal API + +# Used for filenames as well as the key for the global cache +def safe_name(): + pass + +def local_path(): + pass + +def generate_sd_vmfb( + model: str, + height: int, + width: int, + steps: int, + strength: float, + guidance_scale: float, + batch_size: int = 1, + base_model_id: str, + precision: str, + controlled: bool, + **kwargs, +): + pass + +def load_sd_vmfb( + model: str, + weight_file: str, + height: int, + width: int, + steps: int, + strength: float, + guidance_scale: float, + batch_size: int = 1, + base_model: str, + precision: str, + controlled: bool, + try_download: bool, + **kwargs, +): + # Check if the file is already loaded and cached + # Check if the file already exists on disk + # Try to download from the web + # Generate the vmfb (generate_sd_vmfb) + # Load the vmfb and weights + # Return wrapper + pass + +# External API +def generate_images( + prompt: str, + negative_prompt: str, + *, + height: int = 512, + width: int = 512, + steps: int = 20, + strength: float = 0.8, + sd_init_image: list = None, + guidance_scale: float = 7.5, + seed: list = -1, + batch_count: int = 1, + batch_size: int = 1, + scheduler: str = "EulerDiscrete", + base_model: str = "sd2", + custom_weights: str = None, + custom_vae: str = None, + precision: str = "fp16", + device: str = "cpu", + target_triple: str = None, + ondemand: bool = False, + compiled_pipeline: bool = False, + resample_type: str = "Nearest Neighbor", + controlnets: dict = {}, + embeddings: dict = {}, + **kwargs, +): + sd_kwargs = locals() + + # Handle img2img + if not isinstance(sd_init_image, list): + sd_init_image = [sd_init_image] + is_img2img = True if sd_init_image[0] is not None else False + + # Generate seed if < 0 + # TODO + + # Sanity checks + # Scheduler + # Base model + # Custom weights + # Custom VAE + # Precision + # Device + # Target triple + # Resample type + # TODO + + adapters = {} + is_controlled = False + control_mode = None + hints = [] + num_loras = 0 + import_ir = True + + # Populate model map + if model == "sd1.5": + submodels = { + "clip": None, + "scheduler": None, + "unet": None, + "vae_decode": None, + } + elif model == "sd2": + submodels = { + "clip": None, + "scheduler": None, + "unet": None, + "vae_decode": None, + } + elif model == "sdxl": + submodels = { + "prompt_encoder": None, + "scheduled_unet": None, + "vae_decode": None, + "pipeline": None, + "full_pipeline": None, + } + elif model == "sd3": + pass + + # TODO: generate and load submodel vmfbs + for submodel in submodels: + submodels[submodel] = load_sd_vmfb( + submodel, + custom_weights, + height, + width, + steps, + strength, + guidance_scale, + batch_size, + model, + precision, + not controlnets.keys(), + True, + ) + + generated_imgs = [] + for current_batch in range(batch_count): + + # TODO: Batch size > 1 + + # TODO: random sample (or img2img input) + sample = None + + # TODO: encode input + prompt_embeds, negative_prompt_embeds = encode(prompt, negative_prompt) + + start_time = time.time() + for t in range(steps): + + # Prepare latents + + # Scale model input + latent_model_input = submodels["scheduler"].scale_model_input( + sample, + t + ) + + # Run unet + latents = submodels["unet"]( + latent_model_input, + t, + (negative_prompt_embeds, prompt_embeds), + guidance_scale, + ) + + # Step scheduler + sample = submodels["scheduler"].step( + latents, + t, + sample + ) + + # VAE decode + out_img = submodels["vae_decode"]( + sample + ) + + # Processing time + total_time = time.time() - start_time + # text_output = f"Total image(s) generation time: {total_time:.4f}sec" + # print(f"\n[LOG] {text_output}") + + # TODO: Add to output list + generated_imgs.append(out_img) + + # TODO: Allow the user to halt the process + + return generated_imgs \ No newline at end of file From 38b6248dd4b6707692ce85c22825c8752a722d80 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Thu, 26 Sep 2024 11:07:57 -0500 Subject: [PATCH 2/5] Address comment --- apps/shark_studio/api/shark_api.py | 216 ++++++++++++----------------- 1 file changed, 87 insertions(+), 129 deletions(-) diff --git a/apps/shark_studio/api/shark_api.py b/apps/shark_studio/api/shark_api.py index 93cf255e92..f772169df6 100644 --- a/apps/shark_studio/api/shark_api.py +++ b/apps/shark_studio/api/shark_api.py @@ -1,51 +1,26 @@ - # Internal API +pipelines = { + "sd1.5": ("", None), + "sd2": ("", None), + "sdxl": ("", None), + "sd3": ("", None), +} -# Used for filenames as well as the key for the global cache -def safe_name(): - pass - -def local_path(): - pass -def generate_sd_vmfb( - model: str, +# Used for filenames as well as the key for the global cache +def safe_name( + model_name: str, height: int, width: int, - steps: int, - strength: float, - guidance_scale: float, - batch_size: int = 1, - base_model_id: str, - precision: str, - controlled: bool, - **kwargs, + batch_size: int, ): pass -def load_sd_vmfb( - model: str, - weight_file: str, - height: int, - width: int, - steps: int, - strength: float, - guidance_scale: float, - batch_size: int = 1, - base_model: str, - precision: str, - controlled: bool, - try_download: bool, - **kwargs, -): - # Check if the file is already loaded and cached - # Check if the file already exists on disk - # Try to download from the web - # Generate the vmfb (generate_sd_vmfb) - # Load the vmfb and weights - # Return wrapper + +def local_path(): pass + # External API def generate_images( prompt: str, @@ -78,123 +53,106 @@ def generate_images( # Handle img2img if not isinstance(sd_init_image, list): - sd_init_image = [sd_init_image] + sd_init_image = [sd_init_image] * batch_count is_img2img = True if sd_init_image[0] is not None else False # Generate seed if < 0 # TODO + # Cache dir + # TODO + pipeline_dir = None + # Sanity checks - # Scheduler - # Base model + assert scheduler in ["EulerDiscrete"] + assert base_model in ["sd1.5", "sd2", "sdxl", "sd3"] + assert precision in ["fp16", "fp32"] + assert device in [ + "cpu", + "vulkan", + "rocm", + "hip", + "cuda", + ] # and (IREE check if the device exists) + assert resample_type in ["Nearest Neighbor"] + # Custom weights + # TODO # Custom VAE - # Precision - # Device + # TODO # Target triple - # Resample type # TODO - adapters = {} - is_controlled = False - control_mode = None - hints = [] - num_loras = 0 - import_ir = True - - # Populate model map - if model == "sd1.5": - submodels = { - "clip": None, - "scheduler": None, - "unet": None, - "vae_decode": None, - } - elif model == "sd2": - submodels = { - "clip": None, - "scheduler": None, - "unet": None, - "vae_decode": None, - } - elif model == "sdxl": - submodels = { - "prompt_encoder": None, - "scheduled_unet": None, - "vae_decode": None, - "pipeline": None, - "full_pipeline": None, - } - elif model == "sd3": + # (Re)initialize pipeline + pipeline_args = { + "height": height, + "width": width, + "batch_size": batch_size, + "precision": precision, + "device": device, + "target_triple": target_triple, + } + (existing_args, pipeline) = pipelines[base_model] + if not existing_args or not pipeline or not pipeline_args == existing_args: + # TODO: Initialize new pipeline + if base_model == "sd1.5": + pass + elif base_model == "sd2": + new_pipeline = SharkSDPipeline( + hf_model_name="stabilityai/stable-diffusion-2-1", + scheduler_id=scheduler, + height=height, + width=width, + precision=precision, + max_length=64, + batch_size=batch_size, + num_inference_steps=steps, + device=device, # TODO: Get the IREE device ID? + iree_target_triple=target_triple, + ireec_flags={}, + attn_spec=None, # TODO: Find a better way to figure this out than hardcoding + decomp_attn=True, # TODO: Ditto + pipeline_dir=pipeline_dir, + external_weights_dir=weights, # TODO: Are both necessary still? + external_weights=weights, + custom_vae=custom_vae, + ) + elif base_model == "sdxl": + pass + elif base_model == "sd3": + pass + # existing_args = pipeline_args pass - # TODO: generate and load submodel vmfbs - for submodel in submodels: - submodels[submodel] = load_sd_vmfb( - submodel, - custom_weights, - height, - width, - steps, - strength, - guidance_scale, - batch_size, - model, - precision, - not controlnets.keys(), - True, - ) - - generated_imgs = [] + generated_images = [] for current_batch in range(batch_count): - # TODO: Batch size > 1 - - # TODO: random sample (or img2img input) - sample = None - - # TODO: encode input - prompt_embeds, negative_prompt_embeds = encode(prompt, negative_prompt) - start_time = time.time() for t in range(steps): - - # Prepare latents - - # Scale model input - latent_model_input = submodels["scheduler"].scale_model_input( - sample, - t - ) - # Run unet - latents = submodels["unet"]( - latent_model_input, - t, - (negative_prompt_embeds, prompt_embeds), - guidance_scale, + out_images = pipeline.generate_images( + prompt=prompt, + negative_prompt=negative_prompt, + image=sd_init_image[current_batch], + strength=strength, + guidance_scale=guidance_scale, + seed=seed, + ondemand=ondemand, + resample_type=resample_type, + control_mode=control_mode, + hints=hints, ) - # Step scheduler - sample = submodels["scheduler"].step( - latents, - t, - sample - ) - - # VAE decode - out_img = submodels["vae_decode"]( - sample - ) - # Processing time total_time = time.time() - start_time # text_output = f"Total image(s) generation time: {total_time:.4f}sec" # print(f"\n[LOG] {text_output}") # TODO: Add to output list - generated_imgs.append(out_img) + if not isinstance(out_images, list): + out_images = [out_images] + generated_images.extend(out_images) # TODO: Allow the user to halt the process - return generated_imgs \ No newline at end of file + return generated_images From 077815c4d8071d33c762de0ee14e515a193d98d6 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Thu, 26 Sep 2024 11:09:12 -0500 Subject: [PATCH 3/5] Small fix --- apps/shark_studio/api/shark_api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/apps/shark_studio/api/shark_api.py b/apps/shark_studio/api/shark_api.py index f772169df6..b171bb3d3a 100644 --- a/apps/shark_studio/api/shark_api.py +++ b/apps/shark_studio/api/shark_api.py @@ -121,8 +121,9 @@ def generate_images( pass elif base_model == "sd3": pass - # existing_args = pipeline_args - pass + existing_args = pipeline_args + pipeline = new_pipeline + pipelines[base_model] = (existing_args, pipeline) generated_images = [] for current_batch in range(batch_count): From acfa5617d747a50cc4c81be9b8140345bd70a24f Mon Sep 17 00:00:00 2001 From: George Petterson Date: Thu, 26 Sep 2024 11:12:44 -0500 Subject: [PATCH 4/5] 1.5 support --- apps/shark_studio/api/shark_api.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/apps/shark_studio/api/shark_api.py b/apps/shark_studio/api/shark_api.py index b171bb3d3a..ed13b4f697 100644 --- a/apps/shark_studio/api/shark_api.py +++ b/apps/shark_studio/api/shark_api.py @@ -95,11 +95,9 @@ def generate_images( (existing_args, pipeline) = pipelines[base_model] if not existing_args or not pipeline or not pipeline_args == existing_args: # TODO: Initialize new pipeline - if base_model == "sd1.5": - pass - elif base_model == "sd2": + if base_model in ["sd1.5", "sd2"]: new_pipeline = SharkSDPipeline( - hf_model_name="stabilityai/stable-diffusion-2-1", + hf_model_name=("stabilityai/stable-diffusion-2-1" if base_model == "sd2" else "stabilityai/stable-diffusion-1-5"), scheduler_id=scheduler, height=height, width=width, From f806b31cbce5e952f6c556b136a322d417917ed0 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Mon, 7 Oct 2024 10:42:27 -0500 Subject: [PATCH 5/5] Add Llama stuff (needs testing) --- apps/shark_studio/api/shark_api.py | 375 ++++++++++++++++++++++++++++- requirements.txt | 5 +- 2 files changed, 373 insertions(+), 7 deletions(-) diff --git a/apps/shark_studio/api/shark_api.py b/apps/shark_studio/api/shark_api.py index ed13b4f697..0392a1c239 100644 --- a/apps/shark_studio/api/shark_api.py +++ b/apps/shark_studio/api/shark_api.py @@ -1,10 +1,36 @@ +# from turbine_models.custom_models import stateless_llama +from turbine_models.model_runner import vmfbRunner + +# from turbine_models.gen_external_params.gen_external_params import gen_external_params +from shark.iree_utils.compile_utils import compile_module_to_flatbuffer +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +from sharktank.layers import * +from sharktank.types import * +from shark_turbine.aot import * + +# from sharktank.models.mixtral.mixtral import * +from sharktank.models.llama.llama import * +from sharktank.utils.debugging import trace_tensor +from sharktank.utils.tokenizer import InferenceTokenizer, load_tokenizer + +from shark_turbine.aot import * + +# from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 +import sharktank + # Internal API -pipelines = { +sd_pipelines = { "sd1.5": ("", None), "sd2": ("", None), "sdxl": ("", None), "sd3": ("", None), } +language_models = {} +system_prompt = """[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n +""" # Used for filenames as well as the key for the global cache @@ -92,12 +118,16 @@ def generate_images( "device": device, "target_triple": target_triple, } - (existing_args, pipeline) = pipelines[base_model] + (existing_args, pipeline) = sd_pipelines[base_model] if not existing_args or not pipeline or not pipeline_args == existing_args: # TODO: Initialize new pipeline if base_model in ["sd1.5", "sd2"]: new_pipeline = SharkSDPipeline( - hf_model_name=("stabilityai/stable-diffusion-2-1" if base_model == "sd2" else "stabilityai/stable-diffusion-1-5"), + hf_model_name=( + "stabilityai/stable-diffusion-2-1" + if base_model == "sd2" + else "stabilityai/stable-diffusion-1-5" + ), scheduler_id=scheduler, height=height, width=width, @@ -121,7 +151,7 @@ def generate_images( pass existing_args = pipeline_args pipeline = new_pipeline - pipelines[base_model] = (existing_args, pipeline) + sd_pipelines[base_model] = (existing_args, pipeline) generated_images = [] for current_batch in range(batch_count): @@ -155,3 +185,340 @@ def generate_images( # TODO: Allow the user to halt the process return generated_images + + +def chat( + prompt, + model_name, + history: list = [], + hf_auth_token: str = None, + device=None, + target_triple=None, + max_tokens=4096, + quantization="int4", + precision="f16", + external_weights=None, + use_system_prompt=True, + streaming_llm=False, + batch_sizes=[4], +): + # Compile model if necessary + if not model_name in language_models or language_models[model_name] is None: + language_models[model_name] = None + # gen_external_params( + # hf_model_name=model_name, + # quantization=quantization, + # weight_path="llama.safetensors", + # hf_auth_token=hf_auth_token, + # precision=precision, + # ) + # torch_ir, _ = stateless_llama.export_transformer_model( + # model_name, + # hf_auth_token, + # compile_to="torch", + # external_weights=None, #external_weights="llama.safetensors", + # precision=precision, + # quantization=quantization, + # streaming_llm=streaming_llm, + # decomp_attn=True, + # ) + + import pdb + + pdb.set_trace() + dataset = sharktank.types.Dataset.load("llama.gguf", file_type="gguf") + hp = sharktank.layers.configs.LlamaHParams.from_gguf_props(dataset.properties) + llama_config = sharktank.models.llama.llama.LlamaModelConfig(hp) + llama_config.kv_cache_type = "paged" + model = PagedLlamaModelV1(dataset.root_theta, llama_config) + + def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): + return { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": hp.context_length, + "attn_head_count": hp.attention_head_count, + "attn_head_dim": hp.attn_head_dim, + "prefill_batch_sizes": prefill_bs, + "decode_batch_sizes": decode_bs, + "transformer_block_count": hp.block_count, + "block_seq_stride": llama_config.block_seq_stride, + } + + import torch._dynamo.config as dynamo_config + + fxb = FxProgramsBuilder(model) + + def generate_batch_prefill(bs: int): + tokens = torch.empty(bs, 64, dtype=torch.int64) + seq_lens = torch.empty(bs, dtype=torch.int64) + seq_block_ids = torch.empty(bs, 4, dtype=torch.int64) + block_dim = torch.export.Dim( + "block", max=(hp.context_length - 1) // llama_config.block_seq_stride + ) + import pdb + + pdb.set_trace() + sl_dim = llama_config.block_seq_stride * block_dim + + if model.config.kv_cache_type == "paged": + cache_state = model.cache.allocate( + page_count=hp.context_length // llama_config.block_seq_stride + ) + page_dim = torch.export.Dim("page") + cache_state_dynamic_shapes = [{0: page_dim}] + elif model.config.kv_cache_type == "direct": + cache_state = model.cache.allocate(bs=1) + # Direct cache dimensions: + # 2 * transformer_block_count of... + # [bs, seq_length, attn_head_count, attn_head_dim] + cache_state_dynamic_shapes = (2 * hp.block_count) * [{}] + else: + raise NotImplementedError( + f"Unsupported KV cache type: {type(model.cache)}" + ) + + dynamic_shapes = { + "tokens": {1: sl_dim}, + "seq_lens": {}, + "seq_block_ids": {1: block_dim}, + "cache_state": cache_state_dynamic_shapes, + } + + print(f"Exporting prefill_bs{bs}") + + @fxb.export_program( + name=f"prefill_bs{bs}", + args=(tokens, seq_lens, seq_block_ids, cache_state), + dynamic_shapes=dynamic_shapes, + strict=True, + ) + def _(model, tokens, seq_lens, seq_block_ids, cache_state): + sl = tokens.shape[1] + input_mask = model.input_mask(seq_lens, sl) + attention_mask = model.attention_mask(input_mask) + logits = model.prefill( + tokens, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + return logits + + def generate_batch_decode(bs: int): + tokens = torch.ones(bs, 1, dtype=torch.int64) + seq_lens = torch.ones(bs, dtype=torch.int64) + start_positions = torch.ones(bs, dtype=torch.int64) + seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64) + block_dim = torch.export.Dim( + "block", max=(hp.context_length - 1) // llama_config.block_seq_stride + ) + + if model.config.kv_cache_type == "paged": + cache_state = model.cache.allocate( + page_count=hp.context_length // llama_config.block_seq_stride + ) + page_dim = torch.export.Dim("page") + cache_state_dynamic_shapes = [{0: page_dim}] + elif model.config.kv_cache_type == "direct": + cache_state = model.cache.allocate(bs=1) + # Direct cache dimensions: + # 2 * transformer_block_count of... + # [bs, seq_length, attn_head_count, attn_head_dim] + cache_state_dynamic_shapes = (2 * hp.block_count) * [{}] + else: + raise NotImplementedError( + f"Unsupported KV cache type: {type(model.cache)}" + ) + + dynamic_shapes = { + "tokens": {}, + "seq_lens": {}, + "start_positions": {}, + "seq_block_ids": {1: block_dim}, + "cache_state": cache_state_dynamic_shapes, + } + + print(f"Exporting decode_bs{bs}") + + @fxb.export_program( + name=f"decode_bs{bs}", + args=( + tokens, + seq_lens, + start_positions, + seq_block_ids, + cache_state, + ), + dynamic_shapes=dynamic_shapes, + strict=True, + ) + def _( + model, + tokens, + seq_lens, + start_positions, + seq_block_ids, + cache_state, + ): + input_mask = model.input_mask( + seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride + ) + attention_mask = model.decode_attention_mask(input_mask) + logits = model.decode( + tokens, + attention_mask=attention_mask, + start_positions=start_positions, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + return logits + + bsizes = [] + for batch_size in batch_sizes: + generate_batch_prefill(bs) + generate_batch_decode(bs) + bsizes.append(batch_size) + config = generate_params_json(hp, bsizes, bsizes) + + torch_ir = export(fxb) + torch_ir.save_mlir("llama.mlir") + + # with open("llama.mlir", "w+") as f: + # f.write(torch_ir) + del torch_ir + flags = [] + if "cpu" in device: + flags.extend( + [ + "--iree-global-opt-enable-quantized-matmul-reassociation", + ] + ) + elif device == "vulkan": + flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"]) + elif device == "rocm": + flags.extend( + [ + "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-opt-outer-dim-concat=true", + "--iree-flow-enable-aggressive-fusion", + ] + ) + # if "gfx9" in target_triple: + # flags.extend( + # [ + # f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(target_triple, get_checkpoints_path())}", + # "--iree-codegen-llvmgpu-use-vector-distribution=true", + # ] + # ) + flags.extend( + [ + "--iree-opt-const-expr-hoisting=False", + f"--iree-rocm-target-chip={target_triple}", + ] + ) + flatbuffer_blob = compile_module_to_flatbuffer( + "llama.mlir", + device=device, + frontend="auto", + model_config_path=None, + extra_args=flags, + write_to="llama.vmfb", + ) + model = language_models[model_name] + runner = vmfbRunner( + device=device, + vmfb_path="llama.vmfb", # safe_name + external_weight_path="llama.safetensors", # self.external_weight_file, + ) + + # Sanitize prompt + if isinstance(prompt, list): + prompt = list(chain.from_iterable(prompt)) + prompt = " ".join([x for x in prompt if isinstance(x, str)]) + prompt = prompt.replace("\n", " ") + prompt = prompt.replace("\t", " ") + prompt = prompt.replace("\r", " ") + if use_system_prompt and not history: + prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt) + else: + prompt = f"[INST] {prompt} [/INST]" + + # Parse input + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, + use_fast=False, + use_auth_token=hf_auth_token, + ) + input_tensor = tokenizer(prompt, return_tensors="pt").input_ids + + def format_out(results): + return torch.tensor(results.to_host()[0][0]) + + for iter in range(max_tokens): + if streaming_llm: + # token_slice = max(self.prev_token_len - 1, 0) + token_slice = max(len(history) - 1, 0) + input_tensor = input_tensor[:, token_slice:] + if streaming_llm and model["get_seq_step"]() > 600: + print("Evicting cache space!") + model["evict_kvcache_space"]() + token_len = input_tensor.shape[-1] + device_inputs = [ireert.asdevicearray(runner.config.device, input_tensor)] + # if self.first_input or not streaming_llm: + if not history or not streaming_llm: + st_time = time.time() + token = model["run_initialize"](*device_inputs) + total_time = time.time() - st_time + token_len += 1 + # self.first_input = False + else: + st_time = time.time() + token = model["run_cached_initialize"](*device_inputs) + total_time = time.time() - st_time + token_len += 1 + + history.append(format_out(token)) + while ( + format_out(token) != llm_model_map[model_name]["stop_token"] + and len(history) < max_tokens + ): + dec_time = time.time() + if streaming_llm and model["get_seq_step"]() > 600: + print("Evicting cache space!") + model["evict_kvcache_space"]() + token = model["run_forward"](token) + history.append(format_out(token)) + total_time = time.time() - dec_time + yield tokenizer.decode(history), total_time + + # self.prev_token_len = token_len + len(history) + history.append(token) + + if format_out(token) == llm_model_map[model_name]["stop_token"]: + break + + for i in range(len(history)): + if type(history[i]) != int: + history[i] = int(history[i]) + result_output = tokenizer.decode(history) + # self.global_iter += 1 + return result_output, history, total_time + + +if __name__ == "__main__": + output, history, time = chat( + "Hello.", + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + history=[], + hf_auth_token=None, + device="rocm", + target_triple="gfx942", + max_tokens=4096, + quantization="int4", + precision="f16", + external_weights=None, + use_system_prompt=True, + streaming_llm=True, + ) diff --git a/requirements.txt b/requirements.txt index 404c1db9b1..ad710b97f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,9 @@ setuptools wheel - -torch==2.3.0 +torch>2.3.0.dev1 shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main -turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#subdirectory=models +turbine-models @ git+https://github.com/nod-ai/SHARK-ModelDev.git@main#subdirectory=models diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b