From 4515b4ff99630f781feabf411c30ed069ae44fa6 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 27 Jun 2024 18:43:48 -0400 Subject: [PATCH] Split the question and answer before creating training data The training data needs the generated question and answer separately. The full pipeline handles this. The simple one at the moment generates it as a single text blob. This hack splits them apart. This is all because the small model used by default wasn't good enough to follow the strict formatting instructions used in the full pipeline. We could try asking it to generate the question and answer separately, but presumably that doubles the inference API calls, which doesn't sound great for the resource constrained environments the simple config is aimed at. Signed-off-by: Russell Bryant --- src/instructlab/sdg/generate_data.py | 44 +++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 6ffb1b64..92b25598 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -37,18 +37,54 @@ def _unescape(s): return bytes(s, "utf-8").decode("utf-8") +# This is a hack because the simple workflow returns a q/a pair as a single output. +# We could possibly try to ask for them separately, but it would cost twice the inference +# API calls. All of this is because the smallest models we use on small environments +# for testing and demos weren't good enough to follow the strict formatting instructions used +# in the full pipeline. +def _get_question(logger, synth_example): + if "question" in synth_example: + return synth_example["question"] + + if "output" not in synth_example: + raise utils.GenerateException( + f"Error: output not found in synth_example: {synth_example}" + ) + + parts = synth_example["output"].split("?", 1) + if len(parts) != 2: + logger.warning(f"Failed to split generated q&a: {synth_example['output']}") + return parts[0].strip() + "?" if len(parts) == 2 else "" + + +# This is also a hack. See the comment above _get_question. +def _get_response(logger, synth_example): + if "response" in synth_example: + return synth_example["response"] + + if "output" not in synth_example: + raise utils.GenerateException( + f"Error: output not found in synth_example: {synth_example}" + ) + + parts = synth_example["output"].split("?", 1) + if len(parts) != 2: + logger.warning(f"Failed to split generated q&a: {synth_example['output']}") + return parts[1].strip() if len(parts) == 2 else parts[0].strip() + + def _gen_train_data(logger, machine_instruction_data, output_file_train): train_data = [] for synth_example in machine_instruction_data: logger.debug(synth_example) - user = synth_example.get("instruction", "") - if len(synth_example.get("input", "")) > 0: - user += "\n" + synth_example["input"] + user = _get_question(logger, synth_example) + if len(synth_example.get("context", "")) > 0: + user += "\n" + synth_example["context"] train_data.append( { "system": get_sysprompt(), "user": _unescape(user), - "assistant": _unescape(synth_example["output"]), + "assistant": _unescape(_get_response(logger, synth_example)), } )