diff --git a/src/instructlab/sdg/default_flows.py b/src/instructlab/sdg/default_flows.py index 31edd3d6..f83ed60b 100644 --- a/src/instructlab/sdg/default_flows.py +++ b/src/instructlab/sdg/default_flows.py @@ -36,6 +36,7 @@ def __init__(self, client, model_family, model_id, num_iters, batched=True) -> N self.model_id = model_id self.num_iters = num_iters self.batched = batched + self.sdg_base = resources.files(__package__) @abstractmethod def get_flow(self) -> list: @@ -76,9 +77,8 @@ def get_flow(self) -> list: class SimpleKnowledgeFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - sdg_base = resources.files(__package__) flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join( - sdg_base, "configs/knowledge/simple_generate_qa.yaml" + self.sdg_base, "configs/knowledge/simple_generate_qa.yaml" ) flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_knowledge" flow[0]["block_config"]["block_name"] = "gen_knowledge" @@ -88,9 +88,8 @@ def get_flow(self) -> list: class SimpleFreeformSkillFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - sdg_base = resources.files(__package__) flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join( - sdg_base, "configs/skills/simple_generate_qa_freeform.yaml" + self.sdg_base, "configs/skills/simple_generate_qa_freeform.yaml" ) flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_freeform" flow[0]["block_config"]["block_name"] = "gen_skill_freeform" @@ -100,9 +99,8 @@ def get_flow(self) -> list: class SimpleGroundedSkillFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - sdg_base = resources.files(__package__) flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join( - sdg_base, "configs/skills/simple_generate_qa_grounded.yaml" + self.sdg_base, "configs/skills/simple_generate_qa_grounded.yaml" ) flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_grounded" flow[0]["block_config"]["block_name"] = "gen_skill_grounded" @@ -111,14 +109,14 @@ def get_flow(self) -> list: class MMLUBenchFlow(Flow): def get_flow(self) -> list: - sdg_base = resources.files(__package__) + self.sdg_base = resources.files(__package__) return [ { "block_type": LLMBlock, "block_config": { "block_name": "gen_mmlu_knowledge", "config_path": os.path.join( - sdg_base, "configs/knowledge/mcq_generation.yaml" + self.sdg_base, "configs/knowledge/mcq_generation.yaml" ), "client": self.client, "model_id": self.model_id, @@ -140,14 +138,14 @@ def get_flow(self) -> list: class SynthKnowledgeFlow(Flow): def get_flow(self) -> list: - sdg_base = resources.files(__package__) return [ { "block_type": LLMBlock, "block_config": { "block_name": "gen_knowledge", "config_path": os.path.join( - sdg_base, "configs/knowledge/generate_questions_responses.yaml" + self.sdg_base, + "configs/knowledge/generate_questions_responses.yaml", ), "client": self.client, "model_id": self.model_id, @@ -173,7 +171,7 @@ def get_flow(self) -> list: "block_config": { "block_name": "eval_faithfulness_qa_pair", "config_path": os.path.join( - sdg_base, "configs/knowledge/evaluate_faithfulness.yaml" + self.sdg_base, "configs/knowledge/evaluate_faithfulness.yaml" ), "client": self.client, "model_id": self.model_id, @@ -206,7 +204,7 @@ def get_flow(self) -> list: "block_config": { "block_name": "eval_relevancy_qa_pair", "config_path": os.path.join( - sdg_base, "configs/knowledge/evaluate_relevancy.yaml" + self.sdg_base, "configs/knowledge/evaluate_relevancy.yaml" ), "client": self.client, "model_id": self.model_id, @@ -240,7 +238,7 @@ def get_flow(self) -> list: "block_config": { "block_name": "eval_verify_question", "config_path": os.path.join( - sdg_base, "configs/knowledge/evaluate_question.yaml" + self.sdg_base, "configs/knowledge/evaluate_question.yaml" ), "client": self.client, "model_id": self.model_id, @@ -279,7 +277,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_questions", - "config_path": "src/instructlab/sdg/configs/skills/freeform_questions.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/freeform_questions.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -296,7 +297,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "eval_questions", - "config_path": "src/instructlab/sdg/configs/skills/evaluate_freeform_questions.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/evaluate_freeform_questions.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -325,7 +329,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_responses", - "config_path": "src/instructlab/sdg/configs/skills/freeform_responses.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/freeform_responses.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -340,7 +347,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "evaluate_qa_pair", - "config_path": "src/instructlab/sdg/configs/skills/evaluate_freeform_pair.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/evaluate_freeform_pair.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -379,7 +389,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_kwargs": { "block_name": "gen_contexts", - "config_path": "src/instructlab/sdg/configs/skills/contexts.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/contexts.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -399,7 +412,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_grounded_questions", - "config_path": "src/instructlab/sdg/configs/skills/grounded_questions.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/grounded_questions.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -415,7 +431,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "eval_grounded_questions", - "config_path": "src/instructlab/sdg/configs/skills/evaluate_grounded_questions.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/evaluate_grounded_questions.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -445,7 +464,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "gen_grounded_responses", - "config_path": "src/instructlab/sdg/configs/skills/grounded_responses.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/grounded_responses.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), @@ -460,7 +482,10 @@ def get_flow(self) -> list: "block_type": LLMBlock, "block_config": { "block_name": "evaluate_grounded_qa_pair", - "config_path": "src/instructlab/sdg/configs/skills/evaluate_grounded_pair.yaml", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/evaluate_grounded_pair.yaml", + ), "client": self.client, "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family),