From 686123c23a31b20e95c512e150b14ff194b114de Mon Sep 17 00:00:00 2001 From: abhi1092 Date: Fri, 19 Jul 2024 20:29:22 +0000 Subject: [PATCH] Fixed duplicate context issue by taking set of all context, using sampling without replacement, and comparing text directly instead of row_idx Signed-off-by: abhi1092 --- src/instructlab/sdg/utils/parse_and_convert.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/instructlab/sdg/utils/parse_and_convert.py b/src/instructlab/sdg/utils/parse_and_convert.py index ea3cac9a..7ec1b8e1 100644 --- a/src/instructlab/sdg/utils/parse_and_convert.py +++ b/src/instructlab/sdg/utils/parse_and_convert.py @@ -7,6 +7,7 @@ import random import uuid + # Third Party from datasets import Dataset, concatenate_datasets import yaml @@ -165,19 +166,19 @@ def __create_qa_row(rec): def build_raft_dataset(ds: Dataset, p, num_doc_in_context=4): - all_context = ds["context"] - all_context = [" ".join(e.split(" ")[:random.randint(100, 500)]) for e in all_context] - ds = ds.add_column("row_idx", range(ds.num_rows)) + all_context = list(set(ds["context"])) + def __pick_documents(rec, p): while True: - selected_docs = random.choices(range(ds.num_rows), k=num_doc_in_context) - if rec["row_idx"] not in selected_docs: + selected_idx = random.sample(range(len(all_context)), num_doc_in_context) + selected_docs = [all_context[idx] for idx in selected_idx] + if rec['context'] not in selected_docs: break if random.uniform(0, 1) < p: - docs = [all_context[idx] for idx in selected_docs[:num_doc_in_context-1]] + [rec["context"]] + docs = [selected_doc_[:random.randint(100, 500)] for selected_doc_ in selected_docs[:num_doc_in_context-1]] + [rec["context"]] # rec['indicator'] ='golden' else: - docs = [all_context[idx] for idx in selected_docs] + docs = [selected_doc_[:random.randint(100, 500)] for selected_doc_ in selected_docs] # rec['indicator'] = 'distractor' random.shuffle(docs) docs = "\n".join(([f"Document:\n{e}\n\n" for idx, e in enumerate(docs)]))