diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 3bfb8be9..7a926a5a 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -165,18 +165,31 @@ def _gen_test_data( def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_generate): + pipeline_pkg = None if pipeline == "full": pipeline_pkg = FULL_PIPELINES_PACKAGE elif pipeline == "simple": pipeline_pkg = SIMPLE_PIPELINES_PACKAGE else: - raise utils.GenerateException(f"Error: pipeline ({pipeline}) is not supported.") + # Validate that pipeline is a valid directory and that it contains the required files + if not os.path.exists(pipeline): + raise utils.GenerateException( + f"Error: pipeline directory ({pipeline}) does not exist." + ) + for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]: + if not os.path.exists(os.path.join(pipeline, file)): + raise utils.GenerateException( + f"Error: pipeline directory ({pipeline}) does not contain {file}." + ) ctx = PipelineContext(client, model_family, model_id, num_instructions_to_generate) def load_pipeline(yaml_basename): - with resources.path(pipeline_pkg, yaml_basename) as yaml_path: - return Pipeline.from_file(ctx, yaml_path) + if pipeline_pkg: + with resources.path(pipeline_pkg, yaml_basename) as yaml_path: + return Pipeline.from_file(ctx, yaml_path) + else: + return Pipeline.from_file(ctx, os.path.join(pipeline, yaml_basename)) return ( SDG([load_pipeline("knowledge.yaml")]), @@ -212,9 +225,21 @@ def generate_data( tls_client_cert: Optional[str] = None, tls_client_key: Optional[str] = None, tls_client_passwd: Optional[str] = None, - # TODO need to update the CLI to specify which pipeline to use (simple or full at the moment) pipeline: Optional[str] = "simple", ): + """Generate data for training and testing a model. + + This currently serves as the primary interface from the `ilab` CLI to the `sdg` library. + It is somewhat a transitionary measure, as this function existed back when all of the + functionality was embedded in the CLI. At some stage, we expect to evolve the CLI to + use the SDG library constructs directly, and this function will likely be removed. + + Args: + pipeline: This argument may be either an alias defined by the sdg library ("simple", "full"), + or an absolute path to a directory containing the pipeline YAML files. + We expect three files to be present in this directory: "knowledge.yaml", + "freeform_skills.yaml", and "grounded_skills.yaml". + """ generate_start = time.time() if not os.path.exists(output_dir):