diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index aac007b7..208819ea 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -286,11 +286,8 @@ def _mixer_init( # This is part of the public API, and used by instructlab. -# TODO - parameter removal needs to be done in sync with a CLI change. -# to be removed: logger def generate_data( client: openai.OpenAI, - logger: logging.Logger = logger, # pylint: disable=redefined-outer-name system_prompt: Optional[str] = None, use_legacy_pretraining_format: Optional[bool] = True, model_family: Optional[str] = None, diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index 0d04a80f..8a358d9e 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -20,7 +20,12 @@ import yaml # First Party -from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data +from instructlab.sdg.generate_data import ( + _context_init, + _sdg_init, + generate_data, + logger, +) from instructlab.sdg.llmblock import LLMBlock from instructlab.sdg.pipeline import PipelineContext @@ -309,19 +314,17 @@ def setUp(self): ) def test_generate(self): - with patch("logging.Logger.info") as mocked_logger: - generate_data( - client=MagicMock(), - logger=mocked_logger, - model_family="merlinite", - model_name="models/merlinite-7b-lab-Q4_K_M.gguf", - num_instructions_to_generate=10, - taxonomy=self.test_taxonomy.root, - taxonomy_base=TEST_TAXONOMY_BASE, - output_dir=self.tmp_path, - pipeline="simple", - system_prompt=TEST_SYS_PROMPT, - ) + generate_data( + client=MagicMock(), + model_family="merlinite", + model_name="models/merlinite-7b-lab-Q4_K_M.gguf", + num_instructions_to_generate=10, + taxonomy=self.test_taxonomy.root, + taxonomy_base=TEST_TAXONOMY_BASE, + output_dir=self.tmp_path, + pipeline="simple", + system_prompt=TEST_SYS_PROMPT, + ) for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]: matches = glob.glob(os.path.join(self.tmp_path, name)) @@ -386,21 +389,19 @@ def setUp(self): self.expected_train_samples = generate_train_samples(test_valid_knowledge_skill) def test_generate(self): - with patch("logging.Logger.info") as mocked_logger: - generate_data( - client=MagicMock(), - logger=mocked_logger, - model_family="merlinite", - model_name="models/merlinite-7b-lab-Q4_K_M.gguf", - num_instructions_to_generate=10, - taxonomy=self.test_taxonomy.root, - taxonomy_base=TEST_TAXONOMY_BASE, - output_dir=self.tmp_path, - chunk_word_count=1000, - server_ctx_size=4096, - pipeline="simple", - system_prompt=TEST_SYS_PROMPT, - ) + generate_data( + client=MagicMock(), + model_family="merlinite", + model_name="models/merlinite-7b-lab-Q4_K_M.gguf", + num_instructions_to_generate=10, + taxonomy=self.test_taxonomy.root, + taxonomy_base=TEST_TAXONOMY_BASE, + output_dir=self.tmp_path, + chunk_word_count=1000, + server_ctx_size=4096, + pipeline="simple", + system_prompt=TEST_SYS_PROMPT, + ) for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]: matches = glob.glob(os.path.join(self.tmp_path, name)) @@ -484,10 +485,9 @@ def setUp(self): ) def test_generate(self): - with patch("logging.Logger.info") as mocked_logger: + with patch("instructlab.sdg.generate_data.logger") as mocked_logger: generate_data( client=MagicMock(), - logger=mocked_logger, model_family="merlinite", model_name="models/merlinite-7b-lab-Q4_K_M.gguf", num_instructions_to_generate=10,