From 04f4fd548ce1f222f3f3ee4a93005366b4dc6776 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Thu, 27 Jun 2024 13:04:02 -0400 Subject: [PATCH] Resolve the todo about using model_family The model_family param is used to "force" a particular family, overriding what might be guessed from the model filename. Since the utils.models is copied and pasted from instructlab, let's isolate the use of utils.models to the generate_data() function so if we move the generate_data() code to instructlab we can get rid of the copy here. In its place add MODEL_FAMILY_MIXTRAL/MERLINITE constants to the API. Signed-off-by: Mark McLoughlin --- src/instructlab/sdg/default_flows.py | 58 +++++++++++++++------------- src/instructlab/sdg/generate_data.py | 12 ++++-- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/src/instructlab/sdg/default_flows.py b/src/instructlab/sdg/default_flows.py index cff5f52c..d12ce4ff 100644 --- a/src/instructlab/sdg/default_flows.py +++ b/src/instructlab/sdg/default_flows.py @@ -5,29 +5,33 @@ import operator import os -# First Party -from instructlab.sdg.utils import models - # Local from .filterblock import FilterByValueBlock from .iterblock import IterBlock from .llmblock import LLMBlock -MODEL_PROMPT_MIXTRAL = " [INST] {prompt} [/INST]" -MODEL_PROMPT_MERLINITE = "'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n'" +MODEL_FAMILY_MIXTRAL = "mixtral" +MODEL_FAMILY_MERLINITE = "merlinite" + +_MODEL_PROMPT_MIXTRAL = " [INST] {prompt} [/INST]" +_MODEL_PROMPT_MERLINITE = "'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n'" + +_MODEL_PROMPTS = { + MODEL_FAMILY_MIXTRAL: _MODEL_PROMPT_MIXTRAL, + MODEL_FAMILY_MERLINITE: _MODEL_PROMPT_MERLINITE, +} -def _get_model_prompt(model_id): - return ( - MODEL_PROMPT_MIXTRAL - if models.get_model_family(None, model_id) == "mixtral" - else MODEL_PROMPT_MERLINITE - ) +def _get_model_prompt(model_family): + if model_family not in _MODEL_PROMPTS: + raise ValueError(f"Unknown model family: {model_family}") + return _MODEL_PROMPTS[model_family] class Flow(ABC): - def __init__(self, client, model_id, batched=True) -> None: + def __init__(self, client, model_family, model_id, batched=True) -> None: self.client = client + self.model_family = model_family self.model_id = model_id self.batched = batched @@ -49,7 +53,7 @@ def get_flow(self) -> list: ), "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["output"], "batch_kwargs": { "num_procs": 8, @@ -77,7 +81,7 @@ def get_flow(self) -> list: ), "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["mmlubench_question", "mmlubench_answer"], "batch_kwargs": { "num_procs": 8, @@ -106,7 +110,7 @@ def get_flow(self) -> list: ), "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["question", "response"], "batch_kwargs": { "num_procs": 8, @@ -132,7 +136,7 @@ def get_flow(self) -> list: ), "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["explanation", "judgment"], "batch_kwargs": { "num_procs": 8, @@ -165,7 +169,7 @@ def get_flow(self) -> list: ), "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["feedback", "score"], "batch_kwargs": { "num_procs": 8, @@ -198,7 +202,7 @@ def get_flow(self) -> list: ), "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["explanation", "rating"], "batch_kwargs": { "num_procs": 8, @@ -235,7 +239,7 @@ def get_flow(self) -> list: "config_path": "src/instructlab/sdg/configs/skills/freeform_questions.yaml", "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["question"], "batch_kwargs": { "num_procs": 8, @@ -252,7 +256,7 @@ def get_flow(self) -> list: "config_path": "src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml", "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["evaluation", "score"], "batch_kwargs": { "num_procs": 8, @@ -281,7 +285,7 @@ def get_flow(self) -> list: "config_path": "src/instructlab/sdg/configs/skills/freeform_responses.yaml", "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["answer"], "batch_kwargs": { "num_procs": 8, @@ -296,7 +300,7 @@ def get_flow(self) -> list: "config_path": "src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml", "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["evaluation", "score"], "batch_kwargs": { "num_procs": 8, @@ -335,7 +339,7 @@ def get_flow(self) -> list: "config_path": "src/instructlab/sdg/configs/skills/contexts.yaml", "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["context"], "batch_kwargs": { "num_procs": 8, @@ -355,7 +359,7 @@ def get_flow(self) -> list: "config_path": "src/instructlab/sdg/configs/skills/grounded_questions.yaml", "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["question"], "batch_kwargs": { "num_procs": 8, @@ -371,7 +375,7 @@ def get_flow(self) -> list: "config_path": "src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml", "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["evaluation", "score"], "batch_kwargs": { "num_procs": 8, @@ -400,7 +404,7 @@ def get_flow(self) -> list: "config_path": "src/instructlab/sdg/configs/skills/grounded_responses.yaml", "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["answer"], "batch_kwargs": { "num_procs": 8, @@ -415,7 +419,7 @@ def get_flow(self) -> list: "config_path": "src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml", "client": self.client, "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_id), + "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["evaluation", "score"], "batch_kwargs": { "num_procs": 8, diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 29586ff5..6ffb1b64 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -19,12 +19,14 @@ # pylint: disable=ungrouped-imports from instructlab.sdg import SDG, utils from instructlab.sdg.default_flows import ( + MODEL_FAMILY_MERLINITE, + MODEL_FAMILY_MIXTRAL, MMLUBenchFlow, SimpleKnowledgeFlow, SynthKnowledgeFlow, ) from instructlab.sdg.pipeline import Pipeline -from instructlab.sdg.utils import chunking +from instructlab.sdg.utils import chunking, models from instructlab.sdg.utils.taxonomy import ( leaf_node_to_samples, read_taxonomy_leaf_nodes, @@ -88,8 +90,6 @@ def generate_data( logger, api_base, tls_insecure, - # TODO - not yet used. Right now the lib will guess based on the model name - # but we should pass this along if specified model_family: str, yaml_rules: Optional[str] = None, output_dir: Optional[str] = None, @@ -157,6 +157,10 @@ def generate_data( http_client=httpx.Client(cert=cert, verify=verify), ) + model_family = MODEL_FAMILY_MERLINITE + if models.get_model_family(model_family, model_name) == "mixtral": + model_family = MODEL_FAMILY_MIXTRAL + # TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI # about whether we can turn this on (whether vllm is used or not) batched = False @@ -172,7 +176,7 @@ def generate_data( sdg = SDG( [ - Pipeline(flow_type(client, model_name, batched).get_flow()) + Pipeline(flow_type(client, model_family, model_name, batched).get_flow()) for flow_type in flow_types ] )