diff --git a/src/instructlab/sdg/sdg.py b/src/instructlab/sdg/sdg.py index ef0b9ef6..30b57ccd 100644 --- a/src/instructlab/sdg/sdg.py +++ b/src/instructlab/sdg/sdg.py @@ -6,13 +6,14 @@ import uuid # Third Party -from datasets import Dataset, concatenate_datasets, load_dataset +from datasets import Dataset, load_dataset from datasets.data_files import EmptyDatasetError from tqdm import tqdm # Local from .logger_config import setup_logger from .pipeline import Pipeline +from .utils.datautils import safe_concatenate_datasets logger = setup_logger(__name__) @@ -81,6 +82,7 @@ def _generate_data(pipelines, input_split, i=None): def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset: # check if checkpoint_dir exists + pre_generated_data = [] if checkpoint_dir is not None: try: # check if there are any existing checkpoints @@ -91,10 +93,13 @@ def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset: f"Loading existing checkpoints from {checkpoint_dir}, with {pre_generated_data.num_rows} rows" ) seed_data = self._get_missing_data(dataset, pre_generated_data) - logger.info( - f"Found {seed_data.num_rows} missing rows in the dataset" - ) - + if seed_data.num_rows == 0: + logger.info( + f"All seed data has been generated, no missing rows found, returning data from {checkpoint_dir}" + ) + return pre_generated_data + logger.info(f"Found {seed_data.num_rows} missing rows in the dataset") + except EmptyDatasetError: logger.info( f"No existing checkpoints found in {checkpoint_dir}, generating from scratch" @@ -122,9 +127,9 @@ def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset: f"Generating dataset with {len(input_splits)} splits, batch size {self.batch_size}, and {self.num_workers} workers" ) - generated_data = [] + generated_data = [pre_generated_data] if pre_generated_data else [] last_saved_split_index = 0 # To track the last saved split - + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: futures = [ executor.submit(self._generate_data, self.pipelines, input_split, i) @@ -140,13 +145,13 @@ def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset: if self.save_freq and (i + 1) % self.save_freq == 0: # Save only the new splits since the last checkpoint new_splits = generated_data[last_saved_split_index : i + 1] - checkpoint_dataset = concatenate_datasets(new_splits) + checkpoint_dataset = safe_concatenate_datasets(new_splits) self._save_intermediate_checkpoint( checkpoint_dataset, checkpoint_dir ) last_saved_split_index = i + 1 - generated_dataset = concatenate_datasets(generated_data) + generated_dataset = safe_concatenate_datasets(generated_data) return generated_dataset