diff --git a/src/instructlab/sdg/default_flows.py b/src/instructlab/sdg/default_flows.py index 56933712..818c4972 100644 --- a/src/instructlab/sdg/default_flows.py +++ b/src/instructlab/sdg/default_flows.py @@ -7,7 +7,6 @@ # Local from .filterblock import FilterByValueBlock -from .iterblock import IterBlock from .llmblock import LLMBlock from .utilblocks import CombineColumnsBlock @@ -30,12 +29,13 @@ def _get_model_prompt(model_family): class Flow(ABC): - def __init__(self, client, model_family, model_id, num_iters, batched=True) -> None: + def __init__( + self, client, model_family, model_id, num_instructions_to_generate + ) -> None: self.client = client self.model_family = model_family self.model_id = model_id - self.num_iters = num_iters - self.batched = batched + self.num_instructions_to_generate = num_instructions_to_generate self.sdg_base = resources.files(__package__) @abstractmethod @@ -47,29 +47,21 @@ class _SimpleFlow(Flow): def get_flow(self) -> list: return [ { - "block_type": IterBlock, + "block_type": LLMBlock, "block_config": { "block_name": "", # must be set by subclass - "num_iters": self.num_iters, - "block_type": LLMBlock, - "block_kwargs": { - "block_name": "", # must be set by subclass - "config_path": "", # must be set by subclass - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), - "output_cols": ["output"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, - }, - "gen_kwargs": { - "max_tokens": 2048, - "temperature": 0.7, - }, - "drop_duplicates": ["output"], + "config_path": "", # must be set by subclass + "client": self.client, + "model_id": self.model_id, + "model_prompt": _get_model_prompt(self.model_family), + "output_cols": ["output"], }, + "gen_kwargs": { + "max_tokens": 2048, + "temperature": 0.7, + "n": self.num_instructions_to_generate, + }, + "drop_duplicates": ["output"], } ] @@ -77,10 +69,9 @@ def get_flow(self) -> list: class SimpleKnowledgeFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join( + flow[0]["block_config"]["config_path"] = os.path.join( self.sdg_base, "configs/knowledge/simple_generate_qa.yaml" ) - flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_knowledge" flow[0]["block_config"]["block_name"] = "gen_knowledge" return flow @@ -88,10 +79,10 @@ def get_flow(self) -> list: class SimpleFreeformSkillFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join( + flow[0]["block_config"]["config_path"] = os.path.join( self.sdg_base, "configs/skills/simple_generate_qa_freeform.yaml" ) - flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_freeform" + flow[0]["block_config"]["block_name"] = "gen_skill_freeform" flow[0]["block_config"]["block_name"] = "gen_skill_freeform" return flow @@ -99,10 +90,9 @@ def get_flow(self) -> list: class SimpleGroundedSkillFlow(_SimpleFlow): def get_flow(self) -> list: flow = super().get_flow() - flow[0]["block_config"]["block_kwargs"]["config_path"] = os.path.join( + flow[0]["block_config"]["config_path"] = os.path.join( self.sdg_base, "configs/skills/simple_generate_qa_grounded.yaml" ) - flow[0]["block_config"]["block_kwargs"]["block_name"] = "gen_skill_grounded" flow[0]["block_config"]["block_name"] = "gen_skill_grounded" return flow @@ -122,10 +112,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["mmlubench_question", "mmlubench_answer"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, }, "gen_kwargs": { "temperature": 0, @@ -151,10 +137,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["question", "response"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, "parser_kwargs": { "parser_name": "custom", "parsing_pattern": r"\[(?:Question|QUESTION)\]\s*(.*?)\s*\[(?:Answer|ANSWER)\]\s*(.*?)\s*(?=\[(?:Question|QUESTION)\]|$)", @@ -177,10 +159,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["explanation", "judgment"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, }, "gen_kwargs": { "max_tokens": 2048, @@ -210,10 +188,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["feedback", "score"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, }, "gen_kwargs": { "max_tokens": 2048, @@ -244,10 +218,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["explanation", "rating"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, }, "gen_kwargs": { "max_tokens": 2048, @@ -286,9 +256,7 @@ def get_flow(self) -> list: "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["question"], "batch_kwargs": { - "num_procs": 8, - "num_samples": self.num_iters, - "batched": self.batched, + "num_samples": self.num_instructions_to_generate, }, }, "drop_duplicates": ["question"], @@ -305,10 +273,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["evaluation", "score"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, }, }, { @@ -337,10 +301,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["response"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, }, }, { @@ -355,10 +315,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["evaluation", "score"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, }, }, { @@ -382,31 +338,24 @@ class SynthGroundedSkillsFlow(Flow): def get_flow(self) -> list: return [ { - "block_type": IterBlock, + "block_type": LLMBlock, "block_config": { - "block_name": "context_iter", - "num_iters": 10, - "block_type": LLMBlock, - "block_kwargs": { - "block_name": "gen_contexts", - "config_path": os.path.join( - self.sdg_base, - "configs/skills/contexts.yaml", - ), - "client": self.client, - "model_id": self.model_id, - "model_prompt": _get_model_prompt(self.model_family), - "output_cols": ["context"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, - }, - "gen_kwargs": { - "temperature": 0.7, - "max_tokens": 2048, - }, + "block_name": "gen_contexts", + "config_path": os.path.join( + self.sdg_base, + "configs/skills/contexts.yaml", + ), + "client": self.client, + "model_id": self.model_id, + "model_prompt": _get_model_prompt(self.model_family), + "output_cols": ["context"], + }, + "gen_kwargs": { + "temperature": 0.7, + "max_tokens": 2048, + "n": self.num_instructions_to_generate, }, + "drop_duplicates": ["context"], }, { "block_type": LLMBlock, @@ -421,8 +370,7 @@ def get_flow(self) -> list: "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["question"], "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, + "num_samples": 3, }, }, "drop_duplicates": ["question"], @@ -439,11 +387,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["evaluation", "score"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - "num_samples": 10, - }, }, }, { @@ -472,10 +415,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["response"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, }, }, { @@ -490,10 +429,6 @@ def get_flow(self) -> list: "model_id": self.model_id, "model_prompt": _get_model_prompt(self.model_family), "output_cols": ["evaluation", "score"], - "batch_kwargs": { - "num_procs": 8, - "batched": self.batched, - }, }, }, { diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index b32acfbd..ad5a6864 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -126,7 +126,7 @@ def _gen_test_data( outfile.write("\n") -def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched): +def _sdg_init(pipeline, client, model_family, model_name, num_instructions_to_generate): knowledge_flow_types = [] freeform_skill_flow_types = [] grounded_skill_flow_types = [] @@ -146,7 +146,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched): [ Pipeline( flow_type( - client, model_family, model_name, num_iters, batched + client, model_family, model_name, num_instructions_to_generate ).get_flow() ) for flow_type in knowledge_flow_types @@ -156,7 +156,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched): [ Pipeline( flow_type( - client, model_family, model_name, num_iters, batched + client, model_family, model_name, num_instructions_to_generate ).get_flow() ) for flow_type in freeform_skill_flow_types @@ -166,7 +166,7 @@ def _sdg_init(pipeline, client, model_family, model_name, num_iters, batched): [ Pipeline( flow_type( - client, model_family, model_name, num_iters, batched + client, model_family, model_name, num_instructions_to_generate ).get_flow() ) for flow_type in grounded_skill_flow_types @@ -246,7 +246,6 @@ def generate_data( # TODO -- llama-cpp doesn't support batching, we need to get a hint from the CLI # about whether we can turn this on (whether vllm is used or not) - batched = False sdg_knowledge, sdg_freeform_skill, sdg_grounded_skill = _sdg_init( pipeline, @@ -254,7 +253,6 @@ def generate_data( model_family, model_name, num_instructions_to_generate, - batched, ) if console_output: @@ -269,7 +267,6 @@ def generate_data( if not samples: raise utils.GenerateException("Error: No samples found in leaf node.") - sdg = None if samples[0].get("document"): sdg = sdg_knowledge elif samples[0].get("seed_context"): diff --git a/src/instructlab/sdg/iterblock.py b/src/instructlab/sdg/iterblock.py deleted file mode 100644 index 31726b12..00000000 --- a/src/instructlab/sdg/iterblock.py +++ /dev/null @@ -1,28 +0,0 @@ -# Third Party -from datasets import Dataset - -# Local -from .block import Block -from .logger_config import setup_logger - -logger = setup_logger(__name__) - - -class IterBlock(Block): - def __init__(self, block_name, num_iters, block_type, block_kwargs, **kwargs): - super().__init__(block_name) - self.num_iters = num_iters - self.block = block_type(**block_kwargs) - self.gen_kwargs = kwargs.get("gen_kwargs", {}) - - def generate(self, samples, **gen_kwargs) -> Dataset: - generated_samples = [] - num_iters = self.num_iters - - for _ in range(num_iters): - batch_generated = self.block.generate( - samples, **{**self.gen_kwargs, **gen_kwargs} - ) - generated_samples.extend(batch_generated) - - return Dataset.from_list(generated_samples) diff --git a/src/instructlab/sdg/llmblock.py b/src/instructlab/sdg/llmblock.py index 7952609a..4153a191 100644 --- a/src/instructlab/sdg/llmblock.py +++ b/src/instructlab/sdg/llmblock.py @@ -5,6 +5,7 @@ # Third Party from datasets import Dataset +import openai # Local from .block import Block @@ -13,6 +14,25 @@ logger = setup_logger(__name__) +def server_supports_batched(client, model_id: str) -> bool: + supported = getattr(client, "server_supports_batched", None) + if supported is not None: + return supported + try: + # Make a test call to the server to determine whether it supports + # multiple input prompts per request and also the n parameter + response = client.completions.create( + model=model_id, prompt=["test1", "test2"], max_tokens=1, n=3 + ) + # Number outputs should be 2 * 3 = 6 + supported = len(response.choices) == 6 + except openai.InternalServerError: + supported = False + client.server_supports_batched = supported + logger.info(f"LLM server supports batched inputs: {client.server_supports_batched}") + return supported + + # pylint: disable=dangerous-default-value class LLMBlock(Block): # pylint: disable=too-many-instance-attributes @@ -47,6 +67,10 @@ def __init__( "max_tokens": 12000, } + # Whether the LLM server supports a list of input prompts + # and supports the n parameter to generate n outputs per input + self.server_supports_batched = server_supports_batched(client, model_id) + def _parse(self, generated_string) -> dict: matches = {} @@ -84,19 +108,31 @@ def _parse(self, generated_string) -> dict: return matches + def _format_prompt(self, sample: Dict) -> str: + return self.prompt_template.format(**sample).strip() + def _generate(self, samples, **gen_kwargs) -> list: prompts = [ - self.model_prompt.format( - prompt=self.prompt_template.format(**sample).strip() - ) + self.model_prompt.format(prompt=self._format_prompt(sample)) for sample in samples ] - response = self.client.completions.create( - prompt=prompts, **{**self.defaults, **gen_kwargs} - ) - return [choice.text.strip() for choice in response.choices] + generate_args = {**self.defaults, **gen_kwargs} + + if self.server_supports_batched: + response = self.client.completions.create(prompt=prompts, **generate_args) + return [choice.text.strip() for choice in response.choices] + + n = gen_kwargs.get("n", 1) + results = [] + for prompt in prompts: + for _ in range(n): + response = self.client.completions.create( + prompt=prompt, **generate_args + ) + results.append(response.choices[0].text.strip()) + return results - def generate(self, samples, **gen_kwargs) -> Dataset: + def generate(self, samples: Dataset, **gen_kwargs) -> Dataset: """ Generate the output from the block. This method should first validate the input data, then generate the output, and finally parse the generated output before returning it. @@ -104,30 +140,46 @@ def generate(self, samples, **gen_kwargs) -> Dataset: :return: The parsed output after generation. """ num_samples = self.batch_params.get("num_samples", None) - batched = self.batch_params.get("batched", False) logger.debug("Generating outputs for {} samples".format(len(samples))) if (num_samples is not None) and ("num_samples" not in samples.column_names): samples = samples.add_column("num_samples", [num_samples] * len(samples)) # validate each sample + # Log errors and remove invalid samples + valid_samples = [] + for sample in samples: - if not self._validate(self.prompt_template, sample): - return None + if self._validate(self.prompt_template, sample): + valid_samples.append(sample) + else: + logger.warning( + f"Sample failed validation: {sample}" + ) # Log details of the failed sample + + samples = valid_samples + + if len(samples) == 0: + return Dataset.from_list([]) # generate the output - outputs = [] - if batched: - outputs = self._generate(samples, **gen_kwargs) - else: - outputs = [self._generate([sample], **gen_kwargs)[0] for sample in samples] - logger.debug("Generated outputs: {}".format(outputs)) + + outputs = self._generate(samples, **gen_kwargs) + logger.debug("Generated outputs: %s", outputs) + + num_parallel_samples = gen_kwargs.get("n", 1) + extended_samples = [] + + # Duplicate each input sample n times, where n is the number + # of output sequences generated per input, so that we can + # pair up the inputs and outputs. + for item in samples: + extended_samples.extend([item] * num_parallel_samples) new_data = [] - for sample, output in zip(samples, outputs): + for sample, output in zip(extended_samples, outputs): parsed_outputs = self._parse(output) - # pylint: disable=consider-using-generator - max_length = max([len(value) for value in parsed_outputs.values()]) + max_length = max(len(value) for value in parsed_outputs.values()) for values in zip(*(lst[:max_length] for lst in parsed_outputs.values())): new_data.append({**sample, **dict(zip(parsed_outputs.keys(), values))}) @@ -167,27 +219,15 @@ def __init__( **self._load_config(config) ) - def _generate(self, samples, **gen_kwargs) -> str: + def _format_prompt(self, sample: Dict) -> str: if isinstance(self.prompt_template, dict): - prompts = [ - self.model_prompt.format( - prompt=self.prompt_template[sample[self.selector_column_name]] - .format(**sample) - .strip() - ) - for sample in samples - ] - else: - prompts = [ - self.model_prompt.format( - prompt=self.prompt_template.format(**sample).strip() - ) - for sample in samples - ] - response = self.client.completions.create( - prompt=prompts, **{**self.defaults, **gen_kwargs} - ) - return [choice.text.strip() for choice in response.choices] + return ( + self.prompt_template[sample[self.selector_column_name]] + .format(**sample) + .strip() + ) + + return self.prompt_template.format(**sample).strip() def validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool: if isinstance(prompt_template, dict): diff --git a/src/instructlab/sdg/pipeline.py b/src/instructlab/sdg/pipeline.py index fc93f78d..bc570a83 100644 --- a/src/instructlab/sdg/pipeline.py +++ b/src/instructlab/sdg/pipeline.py @@ -3,7 +3,6 @@ from datasets import Dataset # Local -from .iterblock import IterBlock from .logger_config import setup_logger logger = setup_logger(__name__) @@ -39,12 +38,6 @@ def generate(self, dataset) -> Dataset: drop_duplicates_cols = block_prop.get("drop_duplicates", False) block = block_type(**block_config) - if block_type == IterBlock: - block_kwargs = block_config.pop("block_kwargs") - block = block_type(**block_config, block_kwargs=block_kwargs) - else: - block = block_type(**block_config) - logger.info("Running block: %s", block_config["block_name"]) logger.info(dataset)