Skip to content

Commit

Permalink
Add skills variants for the full pipeline
Browse files Browse the repository at this point in the history
Add the last little bit needed to choose the right "full" pipeline for
skills.

I also renamed "profile" to "pipeline" to better reflect what is being
selected here. The term "profile" is a bit overloaded from lots of
past CLI UX discussion, so it's better not to use that here.

Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb committed Jun 28, 2024
1 parent 54b065a commit 31ecfda
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
SimpleFreeformSkillFlow,
SimpleGroundedSkillFlow,
SimpleKnowledgeFlow,
SynthGroundedSkillsFlow,
SynthKnowledgeFlow,
SynthSkillsFlow,
)
from instructlab.sdg.pipeline import Pipeline
from instructlab.sdg.utils import chunking, models
Expand Down Expand Up @@ -122,19 +124,21 @@ def _gen_test_data(
outfile.write("\n")


def _sdg_init(profile, client, model_family, model_name, num_iters, batched):
def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched):
knowledge_flow_types = []
freeform_skill_flow_types = []
grounded_skill_flow_types = []
if profile == "full":
if pipeline == "full":
knowledge_flow_types.append(MMLUBenchFlow)
knowledge_flow_types.append(SynthKnowledgeFlow)
elif profile == "simple":
freeform_skill_flow_types.append(SynthSkillsFlow)
grounded_skill_flow_types.append(SynthGroundedSkillsFlow)
elif pipeline == "simple":
knowledge_flow_types.append(SimpleKnowledgeFlow)
freeform_skill_flow_types.append(SimpleFreeformSkillFlow)
grounded_skill_flow_types.append(SimpleGroundedSkillFlow)
else:
raise utils.GenerateException(f"Error: profile ({profile}) is not supported.")
raise utils.GenerateException(f"Error: pipeline ({pipeline}) is not supported.")

sdg_knowledge = SDG(
[
Expand Down Expand Up @@ -204,8 +208,8 @@ def generate_data(
tls_client_cert: Optional[str] = None,
tls_client_key: Optional[str] = None,
tls_client_passwd: Optional[str] = None,
# TODO need to update the CLI to specify which profile to use (simple or full at the moment)
profile: Optional[str] = "simple",
# TODO need to update the CLI to specify which pipeline to use (simple or full at the moment)
pipeline: Optional[str] = "simple",
):
generate_start = time.time()

Expand Down Expand Up @@ -251,7 +255,12 @@ def generate_data(
batched = False

sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(
profile, client, model_family, model_name, num_instructions_to_generate, batched
pipeline,
client,
model_family,
model_name,
num_instructions_to_generate,
batched,
)

if console_output:
Expand Down

0 comments on commit 31ecfda

Please sign in to comment.