Skip to content

Commit

Permalink
Split the question and answer before creating training data
Browse files Browse the repository at this point in the history
The training data needs the generated question and answer separately.
The full pipeline handles this. The simple one at the moment generates
it as a single text blob. This hack splits them apart.

This is all because the small model used by default wasn't good enough
to follow the strict formatting instructions used in the full
pipeline.

We could try asking it to generate the question and answer separately,
but presumably that doubles the inference API calls, which doesn't
sound great for the resource constrained environments the simple
config is aimed at.

Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb committed Jun 28, 2024
1 parent 04f4fd5 commit 4515b4f
Showing 1 changed file with 40 additions and 4 deletions.
44 changes: 40 additions & 4 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,54 @@ def _unescape(s):
return bytes(s, "utf-8").decode("utf-8")


# This is a hack because the simple workflow returns a q/a pair as a single output.
# We could possibly try to ask for them separately, but it would cost twice the inference
# API calls. All of this is because the smallest models we use on small environments
# for testing and demos weren't good enough to follow the strict formatting instructions used
# in the full pipeline.
def _get_question(logger, synth_example):
if "question" in synth_example:
return synth_example["question"]

if "output" not in synth_example:
raise utils.GenerateException(
f"Error: output not found in synth_example: {synth_example}"
)

parts = synth_example["output"].split("?", 1)
if len(parts) != 2:
logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
return parts[0].strip() + "?" if len(parts) == 2 else ""


# This is also a hack. See the comment above _get_question.
def _get_response(logger, synth_example):
if "response" in synth_example:
return synth_example["response"]

if "output" not in synth_example:
raise utils.GenerateException(
f"Error: output not found in synth_example: {synth_example}"
)

parts = synth_example["output"].split("?", 1)
if len(parts) != 2:
logger.warning(f"Failed to split generated q&a: {synth_example['output']}")
return parts[1].strip() if len(parts) == 2 else parts[0].strip()


def _gen_train_data(logger, machine_instruction_data, output_file_train):
train_data = []
for synth_example in machine_instruction_data:
logger.debug(synth_example)
user = synth_example.get("instruction", "")
if len(synth_example.get("input", "")) > 0:
user += "\n" + synth_example["input"]
user = _get_question(logger, synth_example)
if len(synth_example.get("context", "")) > 0:
user += "\n" + synth_example["context"]
train_data.append(
{
"system": get_sysprompt(),
"user": _unescape(user),
"assistant": _unescape(synth_example["output"]),
"assistant": _unescape(_get_response(logger, synth_example)),
}
)

Expand Down

0 comments on commit 4515b4f

Please sign in to comment.