Skip to content

Commit

Permalink
Merge pull request #73 from tiran/public-api
Browse files Browse the repository at this point in the history
Export public APIs in top-level package
  • Loading branch information
russellb authored Jul 16, 2024
2 parents c083f98 + 40b45ad commit b0b8096
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 11 deletions.
39 changes: 39 additions & 0 deletions src/instructlab/sdg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
# SPDX-License-Identifier: Apache-2.0

# NOTE: This package imports Torch and other heavy packages.
__all__ = (
"Block",
"CombineColumnsBlock",
"ConditionalLLMBlock",
"EmptyDatasetError",
"FilterByValueBlock",
"FilterByValueBlockError",
"GenerateException",
"ImportBlock",
"LLMBlock",
"Pipeline",
"PipelineConfigParserError",
"PipelineContext",
"SamplePopulatorBlock",
"SelectorBlock",
"SDG",
"SIMPLE_PIPELINES_PACKAGE",
"FULL_PIPELINES_PACKAGE",
"generate_data",
)

# Local
from .block import Block
from .filterblock import FilterByValueBlock, FilterByValueBlockError
from .generate_data import generate_data
from .importblock import ImportBlock
from .llmblock import ConditionalLLMBlock, LLMBlock
from .pipeline import (
FULL_PIPELINES_PACKAGE,
SIMPLE_PIPELINES_PACKAGE,
EmptyDatasetError,
Pipeline,
PipelineConfigParserError,
PipelineContext,
)
from .sdg import SDG
from .utilblocks import CombineColumnsBlock, SamplePopulatorBlock, SelectorBlock
from .utils import GenerateException
from .utils.taxonomy import TaxonomyReadingException
1 change: 1 addition & 0 deletions src/instructlab/sdg/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
logger = setup_logger(__name__)


# This is part of the public API.
class Block(ABC):
def __init__(self, ctx, pipe, block_name: str) -> None:
self.ctx = ctx
Expand Down
2 changes: 2 additions & 0 deletions src/instructlab/sdg/filterblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
logger = setup_logger(__name__)


# This is part of the public API.
class FilterByValueBlockError(Exception):
"""An exception raised by the FilterByValue block."""

Expand Down Expand Up @@ -83,6 +84,7 @@ def convert_column(sample):
return samples.map(convert_column, num_proc=num_proc)


# This is part of the public API.
class FilterByValueBlock(Block):
def __init__(
self,
Expand Down
21 changes: 11 additions & 10 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

# First Party
# pylint: disable=ungrouped-imports
from instructlab.sdg import SDG, utils
from instructlab.sdg.llmblock import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL
from instructlab.sdg.pipeline import (
FULL_PIPELINES_PACKAGE,
SIMPLE_PIPELINES_PACKAGE,
Pipeline,
PipelineContext,
)
from instructlab.sdg.utils import models
from instructlab.sdg.sdg import SDG
from instructlab.sdg.utils import GenerateException, models
from instructlab.sdg.utils.taxonomy import (
leaf_node_to_samples,
read_taxonomy_leaf_nodes,
Expand All @@ -48,7 +48,7 @@ def _get_question(logger, synth_example):
return synth_example["question"]

if not synth_example.get("output"):
raise utils.GenerateException(
raise GenerateException(
f"Error: output not found in synth_example: {synth_example}"
)

Expand All @@ -64,7 +64,7 @@ def _get_response(logger, synth_example):
return synth_example["response"]

if "output" not in synth_example:
raise utils.GenerateException(
raise GenerateException(
f"Error: output not found in synth_example: {synth_example}"
)

Expand Down Expand Up @@ -173,12 +173,12 @@ def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_gene
else:
# Validate that pipeline is a valid directory and that it contains the required files
if not os.path.exists(pipeline):
raise utils.GenerateException(
raise GenerateException(
f"Error: pipeline directory ({pipeline}) does not exist."
)
for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]:
if not os.path.exists(os.path.join(pipeline, file)):
raise utils.GenerateException(
raise GenerateException(
f"Error: pipeline directory ({pipeline}) does not contain {file}."
)

Expand All @@ -198,6 +198,7 @@ def load_pipeline(yaml_basename):
)


# This is part of the public API, and used by instructlab.
# TODO - parameter removal needs to be done in sync with a CLI change.
# pylint: disable=unused-argument
def generate_data(
Expand Down Expand Up @@ -226,7 +227,7 @@ def generate_data(
tls_client_key: Optional[str] = None,
tls_client_passwd: Optional[str] = None,
pipeline: Optional[str] = "simple",
):
) -> None:
"""Generate data for training and testing a model.
This currently serves as the primary interface from the `ilab` CLI to the `sdg` library.
Expand All @@ -246,11 +247,11 @@ def generate_data(
os.mkdir(output_dir)

if not (taxonomy and os.path.exists(taxonomy)):
raise utils.GenerateException(f"Error: taxonomy ({taxonomy}) does not exist.")
raise GenerateException(f"Error: taxonomy ({taxonomy}) does not exist.")

leaf_nodes = read_taxonomy_leaf_nodes(taxonomy, taxonomy_base, yaml_rules)
if not leaf_nodes:
raise utils.GenerateException("Error: No new leaf nodes found in the taxonomy.")
raise GenerateException("Error: No new leaf nodes found in the taxonomy.")

name = Path(model_name).stem # Just in case it is a file path
date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_")
Expand Down Expand Up @@ -301,7 +302,7 @@ def generate_data(
samples = leaf_node_to_samples(leaf_node, server_ctx_size, chunk_word_count)

if not samples:
raise utils.GenerateException("Error: No samples found in leaf node.")
raise GenerateException("Error: No samples found in leaf node.")

if samples[0].get("document"):
sdg = sdg_knowledge
Expand Down
1 change: 1 addition & 0 deletions src/instructlab/sdg/importblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
logger = setup_logger(__name__)


# This is part of the public API.
class ImportBlock(Block):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def server_supports_batched(client, model_id: str) -> bool:
return supported


# This is part of the public API.
# pylint: disable=dangerous-default-value
class LLMBlock(Block):
# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -228,6 +229,7 @@ def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
return Dataset.from_list(new_data)


# This is part of the public API.
class ConditionalLLMBlock(LLMBlock):
def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
logger = setup_logger(__name__)


# This is part of the public API.
class EmptyDatasetError(Exception):
pass


# This is part of the public API.
class PipelineContext:
def __init__(
self, client, model_family, model_id, num_instructions_to_generate
Expand All @@ -30,6 +32,7 @@ def __init__(
self.num_procs = 8


# This is part of the public API.
class Pipeline:
def __init__(self, ctx, config_path, chained_blocks: list) -> None:
"""
Expand Down Expand Up @@ -113,6 +116,7 @@ def _lookup_block_type(block_type):
_PIPELINE_CONFIG_PARSER_MINOR = 0


# This is part of the public API.
class PipelineConfigParserError(Exception):
"""An exception raised while parsing a pipline config file."""

Expand Down Expand Up @@ -141,5 +145,6 @@ def _parse_pipeline_config_file(pipeline_yaml):
return content["blocks"]


# This is part of the public API.
SIMPLE_PIPELINES_PACKAGE = "instructlab.sdg.pipelines.simple"
FULL_PIPELINES_PACKAGE = "instructlab.sdg.pipelines.full"
1 change: 1 addition & 0 deletions src/instructlab/sdg/sdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .pipeline import Pipeline


# This is part of the public API.
class SDG:
def __init__(self, pipelines: list[Pipeline]) -> None:
self.pipelines = pipelines
Expand Down
3 changes: 3 additions & 0 deletions src/instructlab/sdg/utilblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
logger = setup_logger(__name__)


# This is part of the public API.
class SamplePopulatorBlock(Block):
def __init__(
self, ctx, pipe, block_name, config_paths, column_name, post_fix=""
Expand Down Expand Up @@ -38,6 +39,7 @@ def generate(self, samples) -> Dataset:
)


# This is part of the public API.
class SelectorBlock(Block):
def __init__(
self, ctx, pipe, block_name, choice_map, choice_col, output_col
Expand Down Expand Up @@ -66,6 +68,7 @@ def generate(self, samples: Dataset) -> Dataset:
)


# This is part of the public API.
class CombineColumnsBlock(Block):
def __init__(
self, ctx, pipe, block_name, columns, output_col, separator="\n\n"
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

# This is part of the public API, and used by instructlab
# This is part of the public API, and used by instructlab.
class GenerateException(Exception):
"""An exception raised during generate step."""
1 change: 1 addition & 0 deletions src/instructlab/sdg/utils/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""


# This is part of the public API.
class TaxonomyReadingException(Exception):
"""An exception raised during reading of the taxonomy."""

Expand Down

0 comments on commit b0b8096

Please sign in to comment.