Skip to content

Commit

Permalink
Set gen_kwargs['n'] dynamically in the simple pipelines
Browse files Browse the repository at this point in the history
We need a way to allow `--num-instructions`, or in the future
`--sdg-scale-factor`, to influence how many instructions we generate
using the simple pipelines. The way to do this seems to be to set `n`
to this value. Since this is a runtime parameter, and we only want to
set it for `n` in certain cases, add a new value for gen_kwargs['n']
called `scaled` which is a hint to use the runtime parameter here.

Closes instructlab#130

Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb committed Jul 16, 2024
1 parent fa3603e commit f7486f4
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 5 deletions.
12 changes: 9 additions & 3 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,15 @@ def _gen_kwargs(self, **gen_kwargs):
gen_kwargs["max_tokens"] = int(gen_kwargs["max_tokens"])
if "temperature" in gen_kwargs:
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
gen_kwargs["n"] = self._get_n(gen_kwargs)
return gen_kwargs

def _get_n(self, gen_kwargs):
n = gen_kwargs.get("n", 1)
if isinstance(n, str) and n == "scaled":
n = self.ctx.num_instructions_to_generate
return n

def _generate(self, samples, **gen_kwargs) -> list:
prompts = [
self.model_prompt.format(prompt=self._format_prompt(sample))
Expand All @@ -148,10 +155,9 @@ def _generate(self, samples, **gen_kwargs) -> list:
)
return [choice.text.strip() for choice in response.choices]

n = gen_kwargs.get("n", 1)
results = []
for prompt in prompts:
for _ in range(n):
for _ in range(generate_args["n"]):
response = self.ctx.client.completions.create(
prompt=prompt, **generate_args
)
Expand Down Expand Up @@ -193,7 +199,7 @@ def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
outputs = self._generate(samples, **gen_kwargs)
logger.debug("Generated outputs: %s", outputs)

num_parallel_samples = gen_kwargs.get("n", 1)
num_parallel_samples = self._get_n(gen_kwargs)
extended_samples = []

# Duplicate each input sample n times, where n is the number
Expand Down
10 changes: 9 additions & 1 deletion src/instructlab/sdg/pipelines/schema/v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,15 @@
"type": "number"
},
"n": {
"type": "number"
"oneOf": [
{
"type": "number"
},
{
"type": "string",
"enum": ["scaled"]
}
]
},
"seed": {
"type": "number"
Expand Down
1 change: 1 addition & 0 deletions src/instructlab/sdg/pipelines/simple/freeform_skills.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ blocks:
gen_kwargs:
max_tokens: 2048
temperature: 0.7
n: scaled
drop_duplicates:
- output
2 changes: 1 addition & 1 deletion src/instructlab/sdg/pipelines/simple/grounded_skills.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ blocks:
gen_kwargs:
max_tokens: 2048
temperature: 0.7
n: 10
n: scaled
drop_duplicates:
- output
1 change: 1 addition & 0 deletions src/instructlab/sdg/pipelines/simple/knowledge.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ blocks:
gen_kwargs:
max_tokens: 2048
temperature: 0.7
n: scaled
drop_duplicates:
- output

0 comments on commit f7486f4

Please sign in to comment.