Skip to content

Commit

Permalink
Merge pull request #105 from aakankshaduggal/batch-support-vllm
Browse files Browse the repository at this point in the history
Batch support with vllm
  • Loading branch information
shivchander authored Jul 9, 2024
2 parents 72a7be4 + 1182c80 commit 7ef628f
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 185 deletions.
141 changes: 38 additions & 103 deletions src/instructlab/sdg/default_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# Local
from .filterblock import FilterByValueBlock
from .iterblock import IterBlock
from .llmblock import LLMBlock
from .utilblocks import CombineColumnsBlock

Expand All @@ -30,12 +29,13 @@ def _get_model_prompt(model_family):


class Flow(ABC):
def __init__(self, client, model_family, model_id, num_iters, batched=True) -> None:
def __init__(
self, client, model_family, model_id, num_instructions_to_generate
) -> None:
self.client = client
self.model_family = model_family
self.model_id = model_id
self.num_iters = num_iters
self.batched = batched
self.num_instructions_to_generate = num_instructions_to_generate
self.sdg_base = resources.files(__package__)

@abstractmethod
Expand All @@ -47,62 +47,52 @@ class _SimpleFlow(Flow):
def get_flow(self) -> list:
return [
{
"block_type": IterBlock,
"block_type": LLMBlock,
"block_config": {
"block_name": "", # must be set by subclass
"num_iters": self.num_iters,
"block_type": LLMBlock,
"block_kwargs": {
"block_name": "", # must be set by subclass
"config_path": "", # must be set by subclass
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["output"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"max_tokens": 2048,
"temperature": 0.7,
},
"drop_duplicates": ["output"],
"config_path": "", # must be set by subclass
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["output"],
},
"gen_kwargs": {
"max_tokens": 2048,
"temperature": 0.7,
"n": self.num_instructions_to_generate,
},
"drop_duplicates": ["output"],
}
]


class SimpleKnowledgeFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
flow[0]["block_config"]["config_path"] = os.path.join(
self.sdg_base, "configs/knowledge/simple_generate_qa.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_knowledge"
flow[0]["block_config"]["block_name"] = "gen_knowledge"
return flow


class SimpleFreeformSkillFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
flow[0]["block_config"]["config_path"] = os.path.join(
self.sdg_base, "configs/skills/simple_generate_qa_freeform.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_freeform"
flow[0]["block_config"]["block_name"] = "gen_skill_freeform"
flow[0]["block_config"]["block_name"] = "gen_skill_freeform"
return flow


class SimpleGroundedSkillFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
flow[0]["block_config"]["config_path"] = os.path.join(
self.sdg_base, "configs/skills/simple_generate_qa_grounded.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_grounded"
flow[0]["block_config"]["block_name"] = "gen_skill_grounded"
return flow

Expand All @@ -122,10 +112,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["mmlubench_question", "mmlubench_answer"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"temperature": 0,
Expand All @@ -151,10 +137,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["question", "response"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
"parser_kwargs": {
"parser_name": "custom",
"parsing_pattern": r"\[(?:Question|QUESTION)\]\s*(.*?)\s*\[(?:Answer|ANSWER)\]\s*(.*?)\s*(?=\[(?:Question|QUESTION)\]|$)",
Expand All @@ -177,10 +159,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["explanation", "judgment"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"max_tokens": 2048,
Expand Down Expand Up @@ -210,10 +188,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["feedback", "score"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"max_tokens": 2048,
Expand Down Expand Up @@ -244,10 +218,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["explanation", "rating"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"max_tokens": 2048,
Expand Down Expand Up @@ -286,9 +256,7 @@ def get_flow(self) -> list:
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["question"],
"batch_kwargs": {
"num_procs": 8,
"num_samples": self.num_iters,
"batched": self.batched,
"num_samples": self.num_instructions_to_generate,
},
},
"drop_duplicates": ["question"],
Expand All @@ -305,10 +273,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
},
{
Expand Down Expand Up @@ -337,10 +301,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["response"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
},
{
Expand All @@ -355,10 +315,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
},
{
Expand All @@ -382,31 +338,24 @@ class SynthGroundedSkillsFlow(Flow):
def get_flow(self) -> list:
return [
{
"block_type": IterBlock,
"block_type": LLMBlock,
"block_config": {
"block_name": "context_iter",
"num_iters": 10,
"block_type": LLMBlock,
"block_kwargs": {
"block_name": "gen_contexts",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/contexts.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["context"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"temperature": 0.7,
"max_tokens": 2048,
},
"block_name": "gen_contexts",
"config_path": os.path.join(
self.sdg_base,
"configs/skills/contexts.yaml",
),
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["context"],
},
"gen_kwargs": {
"temperature": 0.7,
"max_tokens": 2048,
"n": self.num_instructions_to_generate,
},
"drop_duplicates": ["context"],
},
{
"block_type": LLMBlock,
Expand All @@ -421,8 +370,7 @@ def get_flow(self) -> list:
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["question"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
"num_samples": 3,
},
},
"drop_duplicates": ["question"],
Expand All @@ -439,11 +387,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
"num_samples": 10,
},
},
},
{
Expand Down Expand Up @@ -472,10 +415,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["response"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
},
{
Expand All @@ -490,10 +429,6 @@ def get_flow(self) -> list:
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["evaluation", "score"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
},
{
Expand Down
11 changes: 4 additions & 7 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _gen_test_data(
outfile.write("\n")


def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
def _sdg_init(pipeline, client, model_family, model_name, num_instructions_to_generate):
knowledge_flow_types = []
freeform_skill_flow_types = []
grounded_skill_flow_types = []
Expand All @@ -146,7 +146,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
[
Pipeline(
flow_type(
client, model_family, model_name, num_iters, batched
client, model_family, model_name, num_instructions_to_generate
).get_flow()
)
for flow_type in knowledge_flow_types
Expand All @@ -156,7 +156,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
[
Pipeline(
flow_type(
client, model_family, model_name, num_iters, batched
client, model_family, model_name, num_instructions_to_generate
).get_flow()
)
for flow_type in freeform_skill_flow_types
Expand All @@ -166,7 +166,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
[
Pipeline(
flow_type(
client, model_family, model_name, num_iters, batched
client, model_family, model_name, num_instructions_to_generate
).get_flow()
)
for flow_type in grounded_skill_flow_types
Expand Down Expand Up @@ -246,15 +246,13 @@ def generate_data(

# TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI
# about whether we can turn this on (whether vllm is used or not)
batched = False

sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(
pipeline,
client,
model_family,
model_name,
num_instructions_to_generate,
batched,
)

if console_output:
Expand All @@ -269,7 +267,6 @@ def generate_data(
if not samples:
raise utils.GenerateException("Error: No samples found in leaf node.")

sdg = None
if samples[0].get("document"):
sdg = sdg_knowledge
elif samples[0].get("seed_context"):
Expand Down
28 changes: 0 additions & 28 deletions src/instructlab/sdg/iterblock.py

This file was deleted.

Loading

0 comments on commit 7ef628f

Please sign in to comment.