-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #409 from bbrowning/research-sync
Reconcile core data generation features with latest research advances
- Loading branch information
Showing
52 changed files
with
1,060 additions
and
432 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Standard | ||
from abc import ABC | ||
from typing import Any, Dict, Union | ||
import logging | ||
import os.path | ||
|
||
# Third Party | ||
from jinja2 import Template, UndefinedError | ||
import yaml | ||
|
||
# Local | ||
from ..registry import BlockRegistry | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# This is part of the public API. | ||
@BlockRegistry.register("Block") | ||
class Block(ABC): | ||
def __init__(self, ctx, pipe, block_name: str) -> None: | ||
self.ctx = ctx | ||
self.pipe = pipe | ||
self.block_name = block_name | ||
|
||
def _validate(self, prompt_template: Template, input_dict: Dict[str, Any]) -> bool: | ||
""" | ||
Validate the input data for this block. This method validates whether all required | ||
variables in the Jinja template are provided in the input_dict. | ||
Args: | ||
prompt_template (Template): The Jinja2 template object. | ||
input_dict (Dict[str, Any]): A dictionary of input values to check against | ||
the template. | ||
Returns: | ||
True if the input data is valid (i.e., no missing variables), False otherwise. | ||
""" | ||
|
||
try: | ||
# Try rendering the template with the input_dict | ||
prompt_template.render(input_dict) | ||
return True | ||
except UndefinedError as e: | ||
# Jinja throws an UndefinedError for any undefnined template variables, | ||
# assuming the prompt_template was created using StrictUndefined. This | ||
# is the case for anything using PromptRegistry.template_from_string. | ||
logger.error(f"Missing key: {e}") | ||
return False | ||
|
||
def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]: | ||
""" | ||
Load the configuration file for this block. | ||
If the supplied configuration file is a relative path, it is assumed | ||
to be part of this Python package. | ||
Args: | ||
config_path (str): The path to the configuration file. | ||
Returns: | ||
The loaded configuration. | ||
""" | ||
if not os.path.isabs(config_path): | ||
config_path = os.path.join( | ||
os.path.dirname(self.pipe.config_path), config_path | ||
) | ||
with open(config_path, "r", encoding="utf-8") as config_file: | ||
return yaml.safe_load(config_file) | ||
|
||
|
||
# This is part of the public API. | ||
class BlockConfigParserError(Exception): | ||
"""An exception raised while parsing a block's configuration.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# Standard | ||
import logging | ||
|
||
# Third Party | ||
from datasets import Dataset | ||
|
||
# Local | ||
from ..pipeline import _lookup_block_type | ||
from ..registry import BlockRegistry | ||
from .block import Block | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# This is part of the public API. | ||
@BlockRegistry.register("IterBlock") | ||
class IterBlock(Block): | ||
""" | ||
Call another block multiple times for a single set of input | ||
samples, concatening the results of each iteration's call to that | ||
other block in the final returned output. | ||
Args: | ||
num_iters: The number of times to iterate over the block | ||
block_type: The type of the other block to call (ie LLMBlock) | ||
block_config: Any necessary configuration that will get passed to the | ||
other block to properly configure it. | ||
Returns: | ||
A Dataset containing all output samples from each iteration | ||
""" | ||
|
||
def __init__( | ||
self, | ||
ctx, | ||
pipe, | ||
block_name, | ||
num_iters, | ||
block_type, | ||
**block_config, | ||
) -> None: | ||
super().__init__(ctx, pipe, block_name) | ||
self.num_iters = num_iters | ||
block_type = _lookup_block_type(block_type) | ||
self.block = block_type(ctx, pipe, block_name, **block_config) | ||
|
||
def generate(self, samples: Dataset) -> Dataset: | ||
generated_samples = [] | ||
num_iters = self.num_iters | ||
|
||
for _ in range(num_iters): | ||
batch_generated = self.block.generate(samples) | ||
generated_samples.extend(batch_generated) | ||
|
||
return Dataset.from_list(generated_samples) |
Oops, something went wrong.