Skip to content

Commit

Permalink
Fix bug in checkpoint resumption and seed data handling
Browse files Browse the repository at this point in the history
- Corrected logic to ensure the generated data is returned when all seed examples have corresponding synthetic data in the checkpoint.
- Enhanced checkpoint resumption to properly load pre-generated data into memory, ensuring it's included in the final dataset returned.

Signed-off-by: shiv <[email protected]>
  • Loading branch information
shivchander committed Jul 25, 2024
1 parent 4dd6fe6 commit a8aadf5
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/instructlab/sdg/sdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import uuid

# Third Party
from datasets import Dataset, concatenate_datasets, load_dataset
from datasets import Dataset, load_dataset
from datasets.data_files import EmptyDatasetError
from tqdm import tqdm

# Local
from .logger_config import setup_logger
from .pipeline import Pipeline
from .utils.datautils import safe_concatenate_datasets


logger = setup_logger(__name__)
Expand Down Expand Up @@ -81,6 +82,7 @@ def _generate_data(pipelines, input_split, i=None):

def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset:
# check if checkpoint_dir exists
pre_generated_data = []
if checkpoint_dir is not None:
try:
# check if there are any existing checkpoints
Expand All @@ -91,10 +93,13 @@ def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset:
f"Loading existing checkpoints from {checkpoint_dir}, with {pre_generated_data.num_rows} rows"
)
seed_data = self._get_missing_data(dataset, pre_generated_data)
logger.info(
f"Found {seed_data.num_rows} missing rows in the dataset"
)

if seed_data.num_rows == 0:
logger.info(
f"All seed data has been generated, no missing rows found, returning data from {checkpoint_dir}"
)
return pre_generated_data
logger.info(f"Found {seed_data.num_rows} missing rows in the dataset")

except EmptyDatasetError:
logger.info(
f"No existing checkpoints found in {checkpoint_dir}, generating from scratch"
Expand Down Expand Up @@ -122,9 +127,9 @@ def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset:
f"Generating dataset with {len(input_splits)} splits, batch size {self.batch_size}, and {self.num_workers} workers"
)

generated_data = []
generated_data = [pre_generated_data] if pre_generated_data else []
last_saved_split_index = 0 # To track the last saved split

with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
futures = [
executor.submit(self._generate_data, self.pipelines, input_split, i)
Expand All @@ -140,13 +145,13 @@ def generate(self, dataset: Dataset, checkpoint_dir=None) -> Dataset:
if self.save_freq and (i + 1) % self.save_freq == 0:
# Save only the new splits since the last checkpoint
new_splits = generated_data[last_saved_split_index : i + 1]
checkpoint_dataset = concatenate_datasets(new_splits)
checkpoint_dataset = safe_concatenate_datasets(new_splits)
self._save_intermediate_checkpoint(
checkpoint_dataset, checkpoint_dir
)

last_saved_split_index = i + 1

generated_dataset = concatenate_datasets(generated_data)
generated_dataset = safe_concatenate_datasets(generated_data)

return generated_dataset

0 comments on commit a8aadf5

Please sign in to comment.