Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix data checkpointing #24

Merged
merged 2 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/instructlab/sdg/sdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import uuid

# Third Party
from datasets import Dataset, concatenate_datasets, load_dataset
from datasets import Dataset, load_dataset
from datasets.data_files import EmptyDatasetError
from tqdm import tqdm

# Local
from .logger_config import setup_logger
from .pipeline import Pipeline
from .utils.datautils import safe_concatenate_datasets


logger = setup_logger(__name__)
Expand Down Expand Up @@ -74,13 +75,14 @@
for pipeline in pipelines:
input_split = pipeline.generate(input_split)
return input_split
except Exception as e:

Check warning on line 78 in src/instructlab/sdg/sdg.py

View workflow job for this annotation

GitHub Actions / lint

W0718: Catching too general exception Exception (broad-exception-caught)

Check warning on line 78 in src/instructlab/sdg/sdg.py

View workflow job for this annotation

GitHub Actions / lint

W0718: Catching too general exception Exception (broad-exception-caught)

Check warning on line 78 in src/instructlab/sdg/sdg.py

View workflow job for this annotation

GitHub Actions / lint

W0718: Catching too general exception Exception (broad-exception-caught)
logger.error(f"Error processing split {i}: {e}")
traceback.print_exc()
return None

def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset:
# check if checkpoint_dir exists
pre_generated_data = []
if checkpoint_dir is not None:
try:
# check if there are any existing checkpoints
Expand All @@ -91,10 +93,13 @@
f"Loading existing checkpoints from {checkpoint_dir}, with {pre_generated_data.num_rows} rows"
)
seed_data = self._get_missing_data(dataset, pre_generated_data)
logger.info(
f"Found {seed_data.num_rows} missing rows in the dataset"
)

if seed_data.num_rows == 0:
logger.info(
f"All seed data has been generated, no missing rows found, returning data from {checkpoint_dir}"
)
return pre_generated_data
logger.info(f"Found {seed_data.num_rows} missing rows in the dataset")

except EmptyDatasetError:
logger.info(
f"No existing checkpoints found in {checkpoint_dir}, generating from scratch"
Expand Down Expand Up @@ -122,9 +127,9 @@
f"Generating dataset with {len(input_splits)} splits, batch size {self.batch_size}, and {self.num_workers} workers"
)

generated_data = []
generated_data = [pre_generated_data] if pre_generated_data else []
last_saved_split_index = 0 # To track the last saved split

with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
futures = [
executor.submit(self._generate_data, self.pipelines, input_split, i)
Expand All @@ -140,13 +145,13 @@
if self.save_freq and (i + 1) % self.save_freq == 0:
# Save only the new splits since the last checkpoint
new_splits = generated_data[last_saved_split_index : i + 1]
checkpoint_dataset = concatenate_datasets(new_splits)
checkpoint_dataset = safe_concatenate_datasets(new_splits)
self._save_intermediate_checkpoint(
checkpoint_dataset, checkpoint_dir
)

last_saved_split_index = i + 1

generated_dataset = concatenate_datasets(generated_data)
generated_dataset = safe_concatenate_datasets(generated_data)

return generated_dataset
6 changes: 4 additions & 2 deletions src/instructlab/sdg/utils/datamixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import os

# Third Party
from datasets import Dataset, concatenate_datasets, load_dataset
from datasets import Dataset, load_dataset
import yaml

# First Party
from instructlab.sdg.logger_config import setup_logger
from .datautils import safe_concatenate_datasets


LOGGER = setup_logger(__name__)
ALLOWED_COLS = ["id", "messages", "metadata"]
Expand Down Expand Up @@ -107,7 +109,7 @@ def save_mixed_dataset(self, output_path):
for dataset in self.recipe["datasets"]
]

mixed_ds = concatenate_datasets(mixed_ds)
mixed_ds = safe_concatenate_datasets(mixed_ds)
mixed_ds = mixed_ds.map(
add_system_message, fn_kwargs={"sys_prompt": self.sys_prompt}, num_proc=8
)
Expand Down
14 changes: 14 additions & 0 deletions src/instructlab/sdg/utils/datautils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Third Party
from datasets import concatenate_datasets


def safe_concatenate_datasets(datasets: list):
"""
Concatenate datasets safely, ignoring any datasets that are None or empty.
"""
filtered_datasets = [ds for ds in datasets if ds is not None and ds.num_rows > 0]

if not filtered_datasets:
return None

return concatenate_datasets(filtered_datasets)
11 changes: 6 additions & 5 deletions src/instructlab/sdg/utils/parse_and_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import uuid

# Third Party
from datasets import Dataset, concatenate_datasets
from datasets import Dataset
import yaml

# First Party
# pylint: disable=ungrouped-imports
from instructlab.sdg import utils
from instructlab.sdg.logger_config import setup_logger
from .datautils import safe_concatenate_datasets

logger = setup_logger(__name__)

Expand Down Expand Up @@ -211,7 +212,7 @@ def __create_qa_row(rec):
def build_raft_dataset(ds: Dataset, p, num_doc_in_context=4):
all_context = list(set(ds["context"]))

def __pick_documents(rec, p):
def _pick_documents(rec, p):
answer_document = [rec["context"]]
selected_docs = [e for e in all_context if e != answer_document]
if len(selected_docs) > 0:
Expand Down Expand Up @@ -254,7 +255,7 @@ def __pick_documents(rec, p):

return rec

ds = ds.map(__pick_documents, fn_kwargs={"p": p}, remove_columns=["context"])
ds = ds.map(_pick_documents, fn_kwargs={"p": p}, remove_columns=["context"])
return ds


Expand All @@ -277,7 +278,7 @@ def create_knowledge_regular_ds(generated_dataset: Dataset):

auxiliary_dataset = create_auxiliary_dataset(generated_dataset)
if auxiliary_dataset is not None:
transformed_data = concatenate_datasets([knowledge_ds, auxiliary_dataset])
transformed_data = safe_concatenate_datasets([knowledge_ds, auxiliary_dataset])
else:
transformed_data = knowledge_ds
return transformed_data
Expand All @@ -293,7 +294,7 @@ def create_knowledge_pretraining_ds(generated_dataset: Dataset):
auxiliary_dataset = create_auxiliary_dataset(generated_dataset)
if auxiliary_dataset is not None:
auxiliary_dataset = auxiliary_dataset.map(_conv_pretrain)
transformed_data = concatenate_datasets([knowledge_ds, auxiliary_dataset])
transformed_data = safe_concatenate_datasets([knowledge_ds, auxiliary_dataset])
else:
transformed_data = knowledge_ds
return transformed_data
Expand Down
Loading