Skip to content

Commit

Permalink
Resolve the todo about using model_family
Browse files Browse the repository at this point in the history
The model_family param is used to "force" a particular family,
overriding what might be guessed from the model filename.

Since the utils.models is copied and pasted from instructlab,
let's isolate the use of utils.models to the generate_data()
function so if we move the generate_data() code to instructlab
we can get rid of the copy here.

In its place add MODEL_FAMILY_MIXTRAL/MERLINITE constants to
the API.

Signed-off-by: Mark McLoughlin <[email protected]>
  • Loading branch information
markmc authored and russellb committed Jun 28, 2024
1 parent e02d33d commit 04f4fd5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 31 deletions.
58 changes: 31 additions & 27 deletions src/instructlab/sdg/default_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,33 @@
import operator
import os

# First Party
from instructlab.sdg.utils import models

# Local
from .filterblock import FilterByValueBlock
from .iterblock import IterBlock
from .llmblock import LLMBlock

MODEL_PROMPT_MIXTRAL = "<s> [INST] {prompt} [/INST]"
MODEL_PROMPT_MERLINITE = "'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n'"
MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"

_MODEL_PROMPT_MIXTRAL = "<s> [INST] {prompt} [/INST]"
_MODEL_PROMPT_MERLINITE = "'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n'"

_MODEL_PROMPTS = {
MODEL_FAMILY_MIXTRAL: _MODEL_PROMPT_MIXTRAL,
MODEL_FAMILY_MERLINITE: _MODEL_PROMPT_MERLINITE,
}


def _get_model_prompt(model_id):
return (
MODEL_PROMPT_MIXTRAL
if models.get_model_family(None, model_id) == "mixtral"
else MODEL_PROMPT_MERLINITE
)
def _get_model_prompt(model_family):
if model_family not in _MODEL_PROMPTS:
raise ValueError(f"Unknown model family: {model_family}")
return _MODEL_PROMPTS[model_family]


class Flow(ABC):
def __init__(self, client, model_id, batched=True) -> None:
def __init__(self, client, model_family, model_id, batched=True) -> None:
self.client = client
self.model_family = model_family
self.model_id = model_id
self.batched = batched

Expand All @@ -49,7 +53,7 @@ def get_flow(self) -> list:
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["output"],
"batch_kwargs": {
"num_procs": 8,
Expand Down Expand Up @@ -77,7 +81,7 @@ def get_flow(self) -> list:
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["mmlubench_question", "mmlubench_answer"],
"batch_kwargs": {
"num_procs": 8,
Expand Down Expand Up @@ -106,7 +110,7 @@ def get_flow(self) -> list:
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["question", "response"],
"batch_kwargs": {
"num_procs": 8,
Expand All @@ -132,7 +136,7 @@ def get_flow(self) -> list:
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["explanation", "judgment"],
"batch_kwargs": {
"num_procs": 8,
Expand Down Expand Up @@ -165,7 +169,7 @@ def get_flow(self) -> list:
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["feedback", "score"],
"batch_kwargs": {
"num_procs": 8,
Expand Down Expand Up @@ -198,7 +202,7 @@ def get_flow(self) -> list:
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["explanation", "rating"],
"batch_kwargs": {
"num_procs": 8,
Expand Down Expand Up @@ -235,7 +239,7 @@ def get_flow(self) -> list:
"config_path": "src/instructlab/sdg/configs/skills/freeform_questions.yaml",
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["question"],
"batch_kwargs": {
"num_procs": 8,
Expand All @@ -252,7 +256,7 @@ def get_flow(self) -> list:
"config_path": "src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml",
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
Expand Down Expand Up @@ -281,7 +285,7 @@ def get_flow(self) -> list:
"config_path": "src/instructlab/sdg/configs/skills/freeform_responses.yaml",
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["answer"],
"batch_kwargs": {
"num_procs": 8,
Expand All @@ -296,7 +300,7 @@ def get_flow(self) -> list:
"config_path": "src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml",
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
Expand Down Expand Up @@ -335,7 +339,7 @@ def get_flow(self) -> list:
"config_path": "src/instructlab/sdg/configs/skills/contexts.yaml",
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["context"],
"batch_kwargs": {
"num_procs": 8,
Expand All @@ -355,7 +359,7 @@ def get_flow(self) -> list:
"config_path": "src/instructlab/sdg/configs/skills/grounded_questions.yaml",
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["question"],
"batch_kwargs": {
"num_procs": 8,
Expand All @@ -371,7 +375,7 @@ def get_flow(self) -> list:
"config_path": "src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml",
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
Expand Down Expand Up @@ -400,7 +404,7 @@ def get_flow(self) -> list:
"config_path": "src/instructlab/sdg/configs/skills/grounded_responses.yaml",
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["answer"],
"batch_kwargs": {
"num_procs": 8,
Expand All @@ -415,7 +419,7 @@ def get_flow(self) -> list:
"config_path": "src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml",
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_id),
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
Expand Down
12 changes: 8 additions & 4 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
# pylint: disable=ungrouped-imports
from instructlab.sdg import SDG, utils
from instructlab.sdg.default_flows import (
MODEL_FAMILY_MERLINITE,
MODEL_FAMILY_MIXTRAL,
MMLUBenchFlow,
SimpleKnowledgeFlow,
SynthKnowledgeFlow,
)
from instructlab.sdg.pipeline import Pipeline
from instructlab.sdg.utils import chunking
from instructlab.sdg.utils import chunking, models
from instructlab.sdg.utils.taxonomy import (
leaf_node_to_samples,
read_taxonomy_leaf_nodes,
Expand Down Expand Up @@ -88,8 +90,6 @@ def generate_data(
logger,
api_base,
tls_insecure,
# TODO - not yet used. Right now the lib will guess based on the model name
# but we should pass this along if specified
model_family: str,
yaml_rules: Optional[str] = None,
output_dir: Optional[str] = None,
Expand Down Expand Up @@ -157,6 +157,10 @@ def generate_data(
http_client=httpx.Client(cert=cert, verify=verify),
)

model_family = MODEL_FAMILY_MERLINITE
if models.get_model_family(model_family, model_name) == "mixtral":
model_family = MODEL_FAMILY_MIXTRAL

# TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI
# about whether we can turn this on (whether vllm is used or not)
batched = False
Expand All @@ -172,7 +176,7 @@ def generate_data(

sdg = SDG(
[
Pipeline(flow_type(client, model_name, batched).get_flow())
Pipeline(flow_type(client, model_family, model_name, batched).get_flow())
for flow_type in flow_types
]
)
Expand Down

0 comments on commit 04f4fd5

Please sign in to comment.