Skip to content

Commit

Permalink
Merge branch 'batch-tests' into main
Browse files Browse the repository at this point in the history
Signed-off-by: Oindrilla Chatterjee <[email protected]>
  • Loading branch information
oindrillac authored Jul 8, 2024
2 parents 53c14f9 + c2e19a4 commit e0e5e8d
Showing 1 changed file with 75 additions and 13 deletions.
88 changes: 75 additions & 13 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,32 +77,75 @@ def _get_response(logger, synth_example):
return parts[1].strip() if len(parts) == 2 else parts[0].strip()


def _gen_train_data(logger, machine_instruction_data, output_file_train):
def _convert_to_messages(sample):
"""
Convert a sample dictionary to contain 'messages' and 'metadata' columns required for training.
"""
# Create user query message
user_query = sample["inputs"]
# TODO: we can remove the combinecolumnsblock and combine them here for simplicity
# if "context" in sample:
# user_query = f"{sample['context']}\n\n{sample['inputs']}"

sample["messages"] = [
{"content": user_query, "role": "user"},
{"content": sample["targets"], "role": "assistant"},
]
metadata = {
key: value
for key, value in sample.items()
if key not in ["messages", "inputs", "targets"]
}
sample["metadata"] = json.dumps(metadata)

# keeping required keys for messages training format
sample = {"messages": sample["messages"], "metadata": sample["metadata"]}

return sample


def _gen_train_data(
logger, machine_instruction_data, output_file_train, output_file_messages
):
train_data = []
messages_data = []
for synth_example in machine_instruction_data:
logger.debug(synth_example)
user = _get_question(logger, synth_example)
if len(synth_example.get("context", "")) > 0:
user += "\n" + synth_example["context"]
train_data.append(
{
"system": get_sysprompt(),
"user": _unescape(user),
"assistant": _unescape(_get_response(logger, synth_example)),
}
)
assistant = _unescape(_get_response(logger, synth_example))
train_entry = {
"system": get_sysprompt(),
"user": _unescape(user),
"assistant": assistant,
}
train_data.append(train_entry)
sample = {
"inputs": _unescape(user),
"targets": assistant,
"system": get_sysprompt(),
}
messages_data.append(_convert_to_messages(sample))

with open(output_file_train, "w", encoding="utf-8") as outfile:
for entry in train_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")

with open(output_file_messages, "w", encoding="utf-8") as outfile:
for entry in messages_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")


def _gen_test_data(
leaf_nodes,
output_file_test,
output_file_messages,
):
test_data = []
messages_data = []
for _, leaf_node in leaf_nodes.items():
for seed_example in leaf_node:
user = seed_example["instruction"] # question
Expand All @@ -117,12 +160,23 @@ def _gen_test_data(
"assistant": _unescape(seed_example["output"]), # answer
}
)
sample = {
"inputs": _unescape(user),
"targets": _unescape(seed_example["output"]),
"system": get_sysprompt(),
}
messages_data.append(_convert_to_messages(sample))

with open(output_file_test, "w", encoding="utf-8") as outfile:
for entry in test_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")

with open(output_file_messages, "w", encoding="utf-8") as outfile:
for entry in messages_data:
json.dump(entry, outfile, ensure_ascii=False)
outfile.write("\n")


def _sdg_init(pipeline, client, model_family, model_name, num_instructions_to_generate):
knowledge_flow_types = []
Expand Down Expand Up @@ -220,12 +274,14 @@ def generate_data(
output_file_generated = f"generated_{name}_{date_suffix}.json"
output_file_test = f"test_{name}_{date_suffix}.jsonl"
output_file_train = f"train_{name}_{date_suffix}.jsonl"
output_file_messages_train = f"train_messages_{name}_{date_suffix}.jsonl"
output_file_messages_test = f"test_messages_{name}_{date_suffix}.jsonl"

_gen_test_data(
leaf_nodes,
os.path.join(output_dir, output_file_test),
os.path.join(output_dir, output_file_messages_test),
)

logger.debug(f"Generating to: {os.path.join(output_dir, output_file_generated)}")

orig_cert = (tls_client_cert, tls_client_key, tls_client_passwd)
Expand All @@ -242,6 +298,10 @@ def generate_data(
else:
model_family = MODEL_FAMILY_MERLINITE

# TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI
# about whether we can turn this on (whether vllm is used or not)

batched = False

Check warning on line 304 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / lint

W0612: Unused variable 'batched' (unused-variable)
sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(
pipeline,
client,
Expand Down Expand Up @@ -284,16 +344,18 @@ def generate_data(
if generated_data is None:
generated_data = []

_gen_train_data(logger, generated_data, os.path.join(output_dir, output_file_train))
_gen_train_data(
logger,
generated_data,
os.path.join(output_dir, output_file_train),
os.path.join(output_dir, output_file_messages_train),
)

# TODO
# This is for backwards compatibility. The file existing previously, so we'll keep it for now.
# I believe the github bot assumes it is present for presenting generated data to a taxonomy
# reviewer or contributor. Otherwise, I don't see a consumer of it in this repo or the
# `ilab` CLI.
_gen_train_data(
logger, generated_data, os.path.join(output_dir, output_file_generated)
)

generate_duration = time.time() - generate_start
logger.info(f"Generation took {generate_duration:.2f}s")

0 comments on commit e0e5e8d

Please sign in to comment.