Skip to content

Commit

Permalink
Add add_num_samples to LLMBlock config
Browse files Browse the repository at this point in the history
Two pipelines include an LLMBlock which use `{num_samples}` in their
instructions to the teacher model. There needs to be some way to
configure the LLMBlock so that `num_samples` will be included, but
as per #82 (commit a01b04e) the value of `num_samples` should be
based on the `num_instructions_to_generate` parameter.

Signed-off-by: Mark McLoughlin <[email protected]>
  • Loading branch information
markmc committed Jul 11, 2024
1 parent 8cb673b commit b956643
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
8 changes: 2 additions & 6 deletions src/instructlab/sdg/default_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ def get_flow(self) -> list:
"block_name": "gen_questions",
"config_path": "configs/skills/freeform_questions.yaml",
"output_cols": ["question"],
"batch_kwargs": {
"num_samples": self.ctx.num_instructions_to_generate,
},
"add_num_samples": True,
},
"drop_duplicates": ["question"],
},
Expand Down Expand Up @@ -262,9 +260,7 @@ def get_flow(self) -> list:
"block_name": "gen_grounded_questions",
"config_path": "configs/skills/grounded_questions.yaml",
"output_cols": ["question"],
"batch_kwargs": {
"num_samples": 3,
},
"add_num_samples": True,
},
"drop_duplicates": ["question"],
},
Expand Down
11 changes: 8 additions & 3 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
block_name,
config_path,
output_cols,
add_num_samples=False,
parser_kwargs={},
**batch_kwargs,
) -> None:
Expand All @@ -69,6 +70,7 @@ def __init__(
)
self.prompt_template = self.prompt_struct.format(**self.block_config)
self.model_prompt = _get_model_prompt(self.ctx.model_family)
self.add_num_samples = add_num_samples
self.output_cols = output_cols
self.batch_params = batch_kwargs.get("batch_kwargs", {})
self.parser_name = parser_kwargs.get("parser_name", None)
Expand Down Expand Up @@ -156,11 +158,12 @@ def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
:return: The parsed output after generation.
"""
num_samples = self.batch_params.get("num_samples", None)
logger.debug("Generating outputs for {} samples".format(len(samples)))

if (num_samples is not None) and ("num_samples" not in samples.column_names):
samples = samples.add_column("num_samples", [num_samples] * len(samples))
if self.add_num_samples and ("num_samples" not in samples.column_names):
samples = samples.add_column(
"num_samples", [self.ctx.num_instructions_to_generate] * len(samples)
)

# validate each sample
# Log errors and remove invalid samples
Expand Down Expand Up @@ -211,6 +214,7 @@ def __init__(
config_paths,
output_cols,
selector_column_name,
add_num_samples=False,
parser_kwargs={},
**batch_kwargs,
) -> None:
Expand All @@ -219,6 +223,7 @@ def __init__(
block_name,
config_paths[0][0],
output_cols,
add_num_samples=add_num_samples,
parser_kwargs=parser_kwargs,
**batch_kwargs,
)
Expand Down

0 comments on commit b956643

Please sign in to comment.