Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export public APIs in top-level package #73

Merged
merged 3 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/instructlab/sdg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
# SPDX-License-Identifier: Apache-2.0

# NOTE: This package imports Torch and other heavy packages.
__all__ = (
"Block",
"CombineColumnsBlock",
"ConditionalLLMBlock",
"EmptyDatasetError",
"FilterByValueBlock",
"FilterByValueBlockError",
"GenerateException",
"ImportBlock",
"LLMBlock",
"Pipeline",
"PipelineConfigParserError",
"PipelineContext",
"SamplePopulatorBlock",
"SelectorBlock",
"SDG",
"SIMPLE_PIPELINES_PACKAGE",
"FULL_PIPELINES_PACKAGE",
"generate_data",
markmc marked this conversation as resolved.
Show resolved Hide resolved
)

# Local
from .block import Block
from .filterblock import FilterByValueBlock, FilterByValueBlockError
from .generate_data import generate_data
from .importblock import ImportBlock
from .llmblock import ConditionalLLMBlock, LLMBlock
from .pipeline import (
FULL_PIPELINES_PACKAGE,
SIMPLE_PIPELINES_PACKAGE,
EmptyDatasetError,
Pipeline,
PipelineConfigParserError,
PipelineContext,
)
from .sdg import SDG
from .utilblocks import CombineColumnsBlock, SamplePopulatorBlock, SelectorBlock
from .utils import GenerateException
from .utils.taxonomy import TaxonomyReadingException
1 change: 1 addition & 0 deletions src/instructlab/sdg/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
logger = setup_logger(__name__)


# This is part of the public API.
class Block(ABC):
def __init__(self, ctx, pipe, block_name: str) -> None:
self.ctx = ctx
Expand Down
2 changes: 2 additions & 0 deletions src/instructlab/sdg/filterblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
logger = setup_logger(__name__)


# This is part of the public API.
class FilterByValueBlockError(Exception):
"""An exception raised by the FilterByValue block."""

Expand Down Expand Up @@ -73,6 +74,7 @@ def convert_column(sample):
return samples.map(convert_column, num_proc=num_proc)


# This is part of the public API.
class FilterByValueBlock(Block):
def __init__(
self,
Expand Down
21 changes: 11 additions & 10 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

# First Party
# pylint: disable=ungrouped-imports
from instructlab.sdg import SDG, utils
from instructlab.sdg.llmblock import MODEL_FAMILY_MERLINITE, MODEL_FAMILY_MIXTRAL
from instructlab.sdg.pipeline import (
FULL_PIPELINES_PACKAGE,
SIMPLE_PIPELINES_PACKAGE,
Pipeline,
PipelineContext,
)
from instructlab.sdg.utils import models
from instructlab.sdg.sdg import SDG
markmc marked this conversation as resolved.
Show resolved Hide resolved
from instructlab.sdg.utils import GenerateException, models
markmc marked this conversation as resolved.
Show resolved Hide resolved
from instructlab.sdg.utils.taxonomy import (
leaf_node_to_samples,
read_taxonomy_leaf_nodes,
Expand All @@ -48,7 +48,7 @@ def _get_question(logger, synth_example):
return synth_example["question"]

if not synth_example.get("output"):
raise utils.GenerateException(
raise GenerateException(
f"Error: output not found in synth_example: {synth_example}"
)

Expand All @@ -64,7 +64,7 @@ def _get_response(logger, synth_example):
return synth_example["response"]

if "output" not in synth_example:
raise utils.GenerateException(
raise GenerateException(
f"Error: output not found in synth_example: {synth_example}"
)

Expand Down Expand Up @@ -173,12 +173,12 @@ def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_gene
else:
# Validate that pipeline is a valid directory and that it contains the required files
if not os.path.exists(pipeline):
raise utils.GenerateException(
raise GenerateException(
f"Error: pipeline directory ({pipeline}) does not exist."
)
for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]:
if not os.path.exists(os.path.join(pipeline, file)):
raise utils.GenerateException(
raise GenerateException(
f"Error: pipeline directory ({pipeline}) does not contain {file}."
)

Expand All @@ -198,6 +198,7 @@ def load_pipeline(yaml_basename):
)


# This is part of the public API, and used by instructlab.
# TODO - parameter removal needs to be done in sync with a CLI change.
# pylint: disable=unused-argument
def generate_data(
Expand Down Expand Up @@ -226,7 +227,7 @@ def generate_data(
tls_client_key: Optional[str] = None,
tls_client_passwd: Optional[str] = None,
pipeline: Optional[str] = "simple",
):
) -> None:
"""Generate data for training and testing a model.
This currently serves as the primary interface from the `ilab` CLI to the `sdg` library.
Expand All @@ -246,11 +247,11 @@ def generate_data(
os.mkdir(output_dir)

if not (taxonomy and os.path.exists(taxonomy)):
raise utils.GenerateException(f"Error: taxonomy ({taxonomy}) does not exist.")
raise GenerateException(f"Error: taxonomy ({taxonomy}) does not exist.")

leaf_nodes = read_taxonomy_leaf_nodes(taxonomy, taxonomy_base, yaml_rules)
if not leaf_nodes:
raise utils.GenerateException("Error: No new leaf nodes found in the taxonomy.")
raise GenerateException("Error: No new leaf nodes found in the taxonomy.")

name = Path(model_name).stem # Just in case it is a file path
date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_")
Expand Down Expand Up @@ -301,7 +302,7 @@ def generate_data(
samples = leaf_node_to_samples(leaf_node, server_ctx_size, chunk_word_count)

if not samples:
raise utils.GenerateException("Error: No samples found in leaf node.")
raise GenerateException("Error: No samples found in leaf node.")

if samples[0].get("document"):
sdg = sdg_knowledge
Expand Down
1 change: 1 addition & 0 deletions src/instructlab/sdg/importblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
logger = setup_logger(__name__)


# This is part of the public API.
class ImportBlock(Block):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def server_supports_batched(client, model_id: str) -> bool:
return supported


# This is part of the public API.
# pylint: disable=dangerous-default-value
class LLMBlock(Block):
# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -212,6 +213,7 @@ def generate(self, samples: Dataset, **gen_kwargs) -> Dataset:
return Dataset.from_list(new_data)


# This is part of the public API.
class ConditionalLLMBlock(LLMBlock):
def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
logger = setup_logger(__name__)


# This is part of the public API.
class EmptyDatasetError(Exception):
pass


# This is part of the public API.
class PipelineContext:
def __init__(
self, client, model_family, model_id, num_instructions_to_generate
Expand All @@ -30,6 +32,7 @@ def __init__(
self.num_procs = 8


# This is part of the public API.
class Pipeline:
def __init__(self, ctx, config_path, chained_blocks: list) -> None:
"""
Expand Down Expand Up @@ -113,6 +116,7 @@ def _lookup_block_type(block_type):
_PIPELINE_CONFIG_PARSER_MINOR = 0


# This is part of the public API.
class PipelineConfigParserError(Exception):
"""An exception raised while parsing a pipline config file."""

Expand Down Expand Up @@ -141,5 +145,6 @@ def _parse_pipeline_config_file(pipeline_yaml):
return content["blocks"]


# This is part of the public API.
SIMPLE_PIPELINES_PACKAGE = "instructlab.sdg.pipelines.simple"
FULL_PIPELINES_PACKAGE = "instructlab.sdg.pipelines.full"
1 change: 1 addition & 0 deletions src/instructlab/sdg/sdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .pipeline import Pipeline


# This is part of the public API.
class SDG:
def __init__(self, pipelines: list[Pipeline]) -> None:
self.pipelines = pipelines
Expand Down
3 changes: 3 additions & 0 deletions src/instructlab/sdg/utilblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
logger = setup_logger(__name__)


# This is part of the public API.
class SamplePopulatorBlock(Block):
def __init__(
self, ctx, pipe, block_name, config_paths, column_name, post_fix=""
Expand Down Expand Up @@ -38,6 +39,7 @@ def generate(self, samples) -> Dataset:
)


# This is part of the public API.
class SelectorBlock(Block):
def __init__(
self, ctx, pipe, block_name, choice_map, choice_col, output_col
Expand Down Expand Up @@ -66,6 +68,7 @@ def generate(self, samples: Dataset) -> Dataset:
)


# This is part of the public API.
class CombineColumnsBlock(Block):
def __init__(
self, ctx, pipe, block_name, columns, output_col, separator="\n\n"
Expand Down
2 changes: 1 addition & 1 deletion src/instructlab/sdg/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

# This is part of the public API, and used by instructlab
# This is part of the public API, and used by instructlab.
class GenerateException(Exception):
"""An exception raised during generate step."""
1 change: 1 addition & 0 deletions src/instructlab/sdg/utils/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""


# This is part of the public API.
class TaxonomyReadingException(Exception):
"""An exception raised during reading of the taxonomy."""

Expand Down