Skip to content

Commit

Permalink
added number of iterations to generate
Browse files Browse the repository at this point in the history
Signed-off-by: Oindrilla Chatterjee <[email protected]>
  • Loading branch information
oindrillac authored and russellb committed Jun 28, 2024
1 parent 9ee7f70 commit 54b065a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 27 deletions.
47 changes: 28 additions & 19 deletions src/instructlab/sdg/default_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def _get_model_prompt(model_family):


class Flow(ABC):
def __init__(self, client, model_family, model_id, batched=True) -> None:
def __init__(self, client, model_family, model_id, num_iters, batched=True) -> None:
self.client = client
self.model_family = model_family
self.model_id = model_id
self.num_iters = num_iters
self.batched = batched

@abstractmethod
Expand All @@ -42,37 +43,43 @@ def get_flow(self) -> list:

class _SimpleFlow(Flow):
def get_flow(self) -> list:
sdg_base = resources.files(__package__)
return [
{
"block_type": LLMBlock,
"block_type": IterBlock,
"block_config": {
"block_name": "", # must be set by subclass
"config_path": "", # must be set by subclass
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["output"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
"num_iters": self.num_iters,
"block_type": LLMBlock,
"block_kwargs": {
"block_name": "", # must be set by subclass
"config_path": "", # must be set by subclass
"client": self.client,
"model_id": self.model_id,
"model_prompt": _get_model_prompt(self.model_family),
"output_cols": ["output"],
"batch_kwargs": {
"num_procs": 8,
"batched": self.batched,
},
},
"gen_kwargs": {
"max_tokens": 2048,
"temperature": 0.7,
},
"drop_duplicates": ["output"],
},
"gen_kwargs": {
"max_tokens": 2048,
},
"drop_duplicates": ["output"],
},
}
]


class SimpleKnowledgeFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
sdg_base = resources.files(__package__)
flow[0]["block_config"]["config_path"] = os.path.join(
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
sdg_base, "configs/knowledge/simple_generate_qa.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_knowledge"
flow[0]["block_config"]["block_name"] = "gen_knowledge"
return flow

Expand All @@ -81,9 +88,10 @@ class SimpleFreeformSkillFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
sdg_base = resources.files(__package__)
flow[0]["block_config"]["config_path"] = os.path.join(
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
sdg_base, "configs/skills/simple_generate_qa_freeform.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_freeform"
flow[0]["block_config"]["block_name"] = "gen_skill_freeform"
return flow

Expand All @@ -92,9 +100,10 @@ class SimpleGroundedSkillFlow(_SimpleFlow):
def get_flow(self) -> list:
flow = super().get_flow()
sdg_base = resources.files(__package__)
flow[0]["block_config"]["config_path"] = os.path.join(
flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join(
sdg_base, "configs/skills/simple_generate_qa_grounded.yaml"
)
flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_grounded"
flow[0]["block_config"]["block_name"] = "gen_skill_grounded"
return flow

Expand Down
27 changes: 19 additions & 8 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _gen_test_data(
outfile.write("\n")


def _sdg_init(profile, client, model_family, model_name, batched):
def _sdg_init(profile, client, model_family, model_name, num_iters, batched):
knowledge_flow_types = []
freeform_skill_flow_types = []
grounded_skill_flow_types = []
Expand All @@ -138,19 +138,31 @@ def _sdg_init(profile, client, model_family, model_name, batched):

sdg_knowledge = SDG(
[
Pipeline(flow_type(client, model_family, model_name, batched).get_flow())
Pipeline(
flow_type(
client, model_family, model_name, num_iters, batched
).get_flow()
)
for flow_type in knowledge_flow_types
]
)
sdg_freeform_skill = SDG(
[
Pipeline(flow_type(client, model_family, model_name, batched).get_flow())
Pipeline(
flow_type(
client, model_family, model_name, num_iters, batched
).get_flow()
)
for flow_type in freeform_skill_flow_types
]
)
sdg_grounded_skill = SDG(
[
Pipeline(flow_type(client, model_family, model_name, batched).get_flow())
Pipeline(
flow_type(
client, model_family, model_name, num_iters, batched
).get_flow()
)
for flow_type in grounded_skill_flow_types
]
)
Expand All @@ -174,14 +186,13 @@ def generate_data(
# TODO - not used -- when batching is enabled, this is relevant.
# Right now the code hard codes 8 cpus for batching
num_cpus: Optional[int] = None,
# TODO - not yet used, but should be presumably
num_instructions_to_generate: Optional[int] = None,
num_instructions_to_generate: Optional[int] = 30,
# TODO - not used, can probably be removed
num_prompt_instructions=2,
# TODO - determine if this is relevant
request_batch_size=5,
# TODO - probably should be removed
temperature=1.0,
temperature=1.0, # temperature per step is provided in the config file
# TODO - probably should be removed
top_p=1.0,
# TODO - probably should be removed
Expand Down Expand Up @@ -240,7 +251,7 @@ def generate_data(
batched = False

sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init(
profile, client, model_family, model_name, batched
profile, client, model_family, model_name, num_instructions_to_generate, batched
)

if console_output:
Expand Down

0 comments on commit 54b065a

Please sign in to comment.