diff --git a/src/tanuki/__init__.py b/src/tanuki/__init__.py index 534f9b2..305225d 100644 --- a/src/tanuki/__init__.py +++ b/src/tanuki/__init__.py @@ -294,10 +294,13 @@ def wrapper(*args, **kwargs) -> Union[Embedding, Any]: # Configure the function modeler using incoming parameters function_modeler.environment_id = environment_id if ignore_finetuning: + logging.info(f"The flag for ignoring finetuning has been set True for {test_func.__name__}. No model distillation will be performed.") function_modeler.execute_finetune_blacklist.append(func_hash) if ignore_finetune_fetching: + logging.info(f"The flag for ignoring searching for finetuned models has been set True for {test_func.__name__}. No already finetuned models will be looked for.") function_modeler.check_finetune_blacklist.append(func_hash) if ignore_data_storage: + logging.info(f"The flag for ignoring data storage has been set True for {test_func.__name__}. No data will be read or saved and model distillation will not be performed.") function_modeler.store_data_blacklist.append(func_hash) task_type = function_description.type if len(teacher_models) > 0: diff --git a/src/tanuki/constants.py b/src/tanuki/constants.py index 72c0343..e16682d 100644 --- a/src/tanuki/constants.py +++ b/src/tanuki/constants.py @@ -30,7 +30,7 @@ # default models DEFAULT_TEACHER_MODEL_NAMES = ["gpt-4", "gpt-4-32k", ] -DEFAULT_DISTILLED_MODEL_NAME = "gpt-3.5-finetune" +DEFAULT_DISTILLED_MODEL_NAME = "gpt-3.5-turbo-1106" DEFAULT_EMBEDDING_MODEL_NAME = "ada-002" # provider names diff --git a/src/tanuki/function_modeler.py b/src/tanuki/function_modeler.py index 8760b8e..00256ce 100644 --- a/src/tanuki/function_modeler.py +++ b/src/tanuki/function_modeler.py @@ -4,12 +4,12 @@ import json from typing import List, Tuple, Dict, Union -import openai +import logging from tanuki.constants import EXAMPLE_ELEMENT_LIMIT, PATCHES, SYMBOLIC_ALIGNMENTS, POSITIVE_EMBEDDABLE_ALIGNMENTS, \ NEGATIVE_EMBEDDABLE_ALIGNMENTS, OPENAI_PROVIDER from tanuki.models.function_type import FunctionType -from tanuki.language_models.llm_configs import DEFAULT_GENERATIVE_MODELS, DEFAULT_EMBEDDING_MODELS +from tanuki.language_models.llm_configs import DEFAULT_TEACHER_MODELS, DEFAULT_EMBEDDING_MODELS from tanuki.language_models.llm_configs.abc_base_config import BaseModelConfig from tanuki.language_models.llm_finetune_api_abc import LLM_Finetune_API from tanuki.models.finetune_job import FinetuneJob @@ -42,6 +42,7 @@ def __init__(self, data_worker: DatasetWorker, self.store_data_blacklist = [] self.api_provider = api_provider self.teacher_models_override = {} + self.startup_logging_checker = {} def _get_dataset_info(self, dataset_type, func_hash, type="length"): """ @@ -65,7 +66,7 @@ def _configure_teacher_models(self, if task_type == FunctionType.EMBEDDABLE: preconfigured_models = DEFAULT_EMBEDDING_MODELS elif task_type == FunctionType.SYMBOLIC: - preconfigured_models = DEFAULT_GENERATIVE_MODELS + preconfigured_models = DEFAULT_TEACHER_MODELS for model in teacher_models: if isinstance(model, str): if model not in preconfigured_models: @@ -318,7 +319,7 @@ def load_function_config(self, func_hash, function_description): def _check_for_finetunes(self, function_description: FunctionDescription, finetune_provider : str) -> Tuple[bool, Dict]: # hash the function_hash into 16 characters (to embed it into the name of OpenAI finetunes, for later retrieval) - + logging.info(f"Checking for finetunes for {function_description.name} using {finetune_provider}") finetune_hash = function_description.__hash__(purpose="finetune") + encode_int(self.environment_id) # List 10 fine-tuning jobs finetunes: List[FinetuneJob] = self.api_provider[finetune_provider].list_finetuned(limit=1000) @@ -333,10 +334,12 @@ def _check_for_finetunes(self, function_description: FunctionDescription, finetu config = self._construct_config_from_finetune(finetune_hash, finetune) # save the config self.data_worker.update_function_config(function_description.__hash__(), config) + logging.info(f"Found finetuned model for {function_description.name} [{config.distilled_model.model_name}]") return True, config except: + logging.info(f"Found finetuned model for {function_description.name} [{finetune.fine_tuned_model.model_name}] but could not load it") return False, {} - + logging.info(f"No finetuned model found for {function_description.name}") return False, {} def _construct_config_from_finetune(self, finetune_hash: str, finetune: FinetuneJob): @@ -426,16 +429,16 @@ def check_for_finetuning(self, function_description, func_hash): # check if already finetuning if "job_id" in self.function_configs[func_hash].current_training_run: # check for job status - self._check_finetuning_status(func_hash) + self._check_finetuning_status(func_hash, function_description) else: # check for finetuning condition - if self._check_finetuning_condition(func_hash): + if self._check_finetuning_condition(func_hash, function_description): self._execute_finetuning(function_description, func_hash) except Exception as e: print(e) print("Error checking for finetuning") - def _check_finetuning_condition(self, func_hash): + def _check_finetuning_condition(self, func_hash, function_description): """ Check if the finetuning condition is met Currently finetuning condition is dependent on the number of symbolic datapoints since last finetuning @@ -453,6 +456,11 @@ def _check_finetuning_condition(self, func_hash): # if havent read in the patch dataset size, read it in patch_dataset_size = self._get_dataset_info(PATCHES, func_hash, type="length") self.dataset_sizes[PATCHES][func_hash] = patch_dataset_size + if func_hash not in self.startup_logging_checker: + logging.info(f"Function {function_description.name} [{align_dataset_size} aligns | {patch_dataset_size} runs] will be finetuned from"\ + f" {self.function_configs[func_hash].teacher_models[0].model_name} using {self.function_configs[func_hash].distilled_model.provider} in "\ + f"{training_threshold-(patch_dataset_size + align_dataset_size)} runs") + self.startup_logging_checker[func_hash] = True return (patch_dataset_size + align_dataset_size) > training_threshold @@ -529,8 +537,10 @@ def _execute_finetuning(self, function_description, func_hash): # Use the stream as a file try: finetune_provider = self.function_configs[func_hash].distilled_model.provider + logging.info(f"Starting finetuning for {function_description.name} using {finetune_provider}") finetuning_response: FinetuneJob = self.api_provider[finetune_provider].finetune(file=temp_file, suffix=finetune_hash) except Exception as e: + logging.info(f"Could not start finetuning for {function_description.name} using {finetune_provider}. Error: {e}") return self.function_configs[func_hash].current_training_run = {"job_id": finetuning_response.id, @@ -544,7 +554,7 @@ def _execute_finetuning(self, function_description, func_hash): print(e) print("Could not update config file to register a finetuning run") - def _check_finetuning_status(self, func_hash): + def _check_finetuning_status(self, func_hash, function_description): """ Check the status of the current finetuning job If the job is finished, update the config file to reflect the new model @@ -560,18 +570,18 @@ def _check_finetuning_status(self, func_hash): self.function_configs[func_hash].current_training_run["last_checked"] = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S") if response.status == "succeeded" or response.status == "failed": - self._update_finetune_config(response, func_hash) + self._update_finetune_config(response, func_hash, function_description) else: self._update_config_file(func_hash) - def _update_finetune_config(self, response: FinetuneJob, func_hash): + def _update_finetune_config(self, response: FinetuneJob, func_hash, function_description): """ Update the config file to reflect the new model and switch the current model to the finetuned model """ self.function_configs[func_hash].update_with_finetuned_response(response) + logging.info(f"Finetuning for {function_description.name} using {self.function_configs[func_hash].distilled_model.provider} finished with status: {response.status}") try: self._update_config_file(func_hash) except Exception as e: - print(e) - print("Could not update config file after a successful finetuning run") + logging.info(f"Could not update the function configuration file with the finetuned model for {function_description.name}. Error: {e}") pass diff --git a/src/tanuki/language_models/embedding_model_manager.py b/src/tanuki/language_models/embedding_model_manager.py index 4750e62..8ee7cf7 100644 --- a/src/tanuki/language_models/embedding_model_manager.py +++ b/src/tanuki/language_models/embedding_model_manager.py @@ -12,7 +12,7 @@ class EmbeddingModelManager(object): def __init__(self, function_modeler, api_provider: APIManager): self.function_modeler = function_modeler self.api_provider = api_provider - self.current_generators = {} + self.initialized_functions = {} def get_embedding_case(self, args, function_description: FunctionDescription, kwargs, examples=None): # example_input = f"Examples:{examples}\n" if examples else "" @@ -25,12 +25,12 @@ def get_embedding_case(self, args, function_description: FunctionDescription, kw # loggings - if function_hash not in self.current_generators: - logging.info(f"Generating function embeddings with {model.model_name}") - self.current_generators[function_hash] = model.model_name - elif self.current_generators[function_hash] != model.model_name: - logging.info(f"Switching embeddings generation from {self.current_generators[function_hash]} to {model.model_name}") - self.current_generators[function_hash] = model.model_name + if function_hash not in self.initialized_functions: + logging.info(f"Generating function embeddings for {function_description.name} with {model.model_name}") + self.initialized_functions[function_hash] = model.model_name + elif self.initialized_functions[function_hash] != model.model_name: + logging.info(f"Switching embeddings generation for {function_description.name} from {self.initialized_functions[function_hash]} to {model.model_name}") + self.initialized_functions[function_hash] = model.model_name return content, model diff --git a/src/tanuki/language_models/language_model_manager.py b/src/tanuki/language_models/language_model_manager.py index d7c1852..d3cde05 100644 --- a/src/tanuki/language_models/language_model_manager.py +++ b/src/tanuki/language_models/language_model_manager.py @@ -28,7 +28,7 @@ def __init__(self, self.api_provider = api_provider self.function_modeler = function_modeler self.default_generation_length = generation_token_limit - self.current_generators = {} + self.initialized_functions = {} self.token_counts = {} def __call__(self, @@ -83,17 +83,24 @@ def generate(self, args, kwargs, function_description, llm_parameters={}): The main generation function, given the args, kwargs, function description and model type, generate a response and check if the datapoint can be saved to the finetune dataset """ + func_hash = function_description.__hash__() prompt, model, save_to_finetune, is_distilled_model = self.get_generation_case(args, kwargs, function_description, - llm_parameters) - func_hash = function_description.__hash__() + llm_parameters, + func_hash) # loggings - if func_hash not in self.current_generators: - logging.info(f"Generating function outputs with {model.model_name}") - self.current_generators[func_hash] = model.model_name - elif self.current_generators[func_hash] != model.model_name: - logging.info(f"Switching output generation from {self.current_generators[func_hash]} to {model.model_name}") - self.current_generators[func_hash] = model.model_name + current_function_setup = self.initialized_functions.get(func_hash, None) # getting the current function setup - model and align statements + if current_function_setup: + generator_model = current_function_setup["model"] + if is_distilled_model: + logging.info(f"Generating function outputs for {function_description.name} with a finetuned model: {model.model_name}.") + self.initialized_functions[func_hash]["model"] = model.model_name + elif generator_model == "": + logging.info(f"Found {len(current_function_setup['examples'])} align statements for {function_description.name}. Generating function outputs with {model.model_name}.") + self.initialized_functions[func_hash]["model"] = model.model_name + elif generator_model != model.model_name: + logging.info(f"Switching output generation from {generator_model} to {model.model_name} for function {function_description.name}.") + self.initialized_functions[func_hash]["model"] = model.model_name choice = self._synthesise_answer(prompt, model, llm_parameters) output = LanguageModelOutput(choice, save_to_finetune, is_distilled_model) @@ -114,7 +121,7 @@ def _synthesise_answer(self, prompt, model, llm_parameters): return self.api_provider[model.provider].generate(model, system_message, prompt, **llm_parameters) - def get_generation_case(self, args, kwargs, function_description, llm_parameters): + def get_generation_case(self, args, kwargs, function_description, llm_parameters, func_hash): """ Get the generation case with the correct prompt and model First get the current model, then if distilled model, do zero-shot prompt and return False as suitable_for_finetune @@ -126,6 +133,9 @@ def get_generation_case(self, args, kwargs, function_description, llm_parameters is_distilled_model = distilled_model.model_name != "" suitable_for_distillation, input_prompt_token_count = self.suitable_for_finetuning_token_check(args, kwargs, f, distilled_model) + if func_hash not in self.initialized_functions: + # initialise the initialized_functions dict + self.initialized_functions[func_hash] = {"model": "", "examples": []} # no examples needed, using a finetuned model. Dont save to finetune dataset if is_distilled_model and suitable_for_distillation: prompt = self.construct_prompt(f, args, kwargs, [], distilled_model) @@ -136,13 +146,18 @@ def get_generation_case(self, args, kwargs, function_description, llm_parameters examples = [f"Inputs:\nArgs: {align['args']}\nKwargs: {align['kwargs']}\nOutput: {align['output']}" for align in aligns] + # update the examples in the initialized_functions dict + self.initialized_functions[func_hash]["examples"] = examples + examples_token_count = sum([approximate_token_count(example) for example in examples]) generation_tokens = llm_parameters.get("max_new_tokens", self.default_generation_length) model = self.choose_model_from_tokens(teacher_models, examples_token_count + input_prompt_token_count + generation_tokens, len(examples)) if model: - prompt = self.construct_prompt(f, args, kwargs, examples, model) + examples_with_parsing_tokens = [f"Inputs:\nArgs: {align['args']}\nKwargs: {align['kwargs']}\nOutput:{model.parsing_helper_tokens['start_token']}{align['output']}{model.parsing_helper_tokens['end_token']}" for align in + aligns] + prompt = self.construct_prompt(f, args, kwargs, examples_with_parsing_tokens, model) return prompt, model, suitable_for_distillation, False else: raise ValueError( @@ -179,14 +194,14 @@ def construct_prompt(self, f, args, kwargs, examples, model): """ if examples: final_examples = "\n".join( - [f"{model.parsing_helper_tokens['start_token']}{align}{model.parsing_helper_tokens['end_token']}" for align in + [f"{align}" for align in examples]) example_input = f"Examples:{final_examples}\n" else: example_input = "" instruction_prompt = model.instructions - content = f"{instruction_prompt}\nFunction: {f}\n{example_input}---\n{model.parsing_helper_tokens['start_token']}Inputs:\nArgs: {args}\nKwargs: {kwargs}\nOutput:" + content = f"{instruction_prompt}\nFunction: {f}\n{example_input}---\nInputs:\nArgs: {args}\nKwargs: {kwargs}\nOutput:{model.parsing_helper_tokens['start_token']}" return content def repair_generate(self, args, kwargs, f, failed_outputs_list, aligns, models, llm_parameters): diff --git a/src/tanuki/language_models/llama_bedrock_api.py b/src/tanuki/language_models/llama_bedrock_api.py index 929a895..4d883fb 100644 --- a/src/tanuki/language_models/llama_bedrock_api.py +++ b/src/tanuki/language_models/llama_bedrock_api.py @@ -55,5 +55,9 @@ def generate(self, model: BaseModelConfig, system_message: str, prompt: str, **k if model.parsing_helper_tokens["end_token"]: # remove the end token from the choice choice = choice.split(model.parsing_helper_tokens["end_token"])[0] + # check if starting token is in choice + if model.parsing_helper_tokens["start_token"] in choice: + # remove the starting token from the choice + choice = choice.split(model.parsing_helper_tokens["start_token"])[-1] - return choice + return choice.strip() diff --git a/src/tanuki/language_models/llm_configs/__init__.py b/src/tanuki/language_models/llm_configs/__init__.py index ab26e67..96fb45c 100644 --- a/src/tanuki/language_models/llm_configs/__init__.py +++ b/src/tanuki/language_models/llm_configs/__init__.py @@ -2,16 +2,26 @@ from tanuki.language_models.llm_configs.claude_config import ClaudeConfig from tanuki.language_models.llm_configs.llama_config import LlamaBedrockConfig from tanuki.language_models.llm_configs.titan_config import TitanBedrockConfig -DEFAULT_GENERATIVE_MODELS = { +DEFAULT_TEACHER_MODELS = { "gpt-4-1106-preview": OpenAIConfig(model_name = "gpt-4-1106-preview", context_length = 128000), "gpt-4": OpenAIConfig(model_name = "gpt-4", context_length = 8192), "gpt-4-32k": OpenAIConfig(model_name = "gpt-4-32k", context_length = 32768), - "gpt-3.5-finetune": OpenAIConfig(model_name = "", context_length = 3000), + "gpt-4-turbo": OpenAIConfig(model_name = "gpt-4-1106-preview", + context_length = 128000, + instructions="You are given below a function description and input data. The function description of what the function must carry out can be found in the Function section, with input and output type hints. The input data can be found in Input section. Using the function description, apply the function to the Input and return a valid output type, that is acceptable by the output_class_definition and output_class_hint.\nINCREDIBLY IMPORTANT: Only output a JSON-compatible string in the correct response format. Use the [END] tokens to specify when the output ends.", + parsing_helper_tokens={"start_token": "[START]", "end_token": "[END]"}), + "gpt-4-turbo-0125": OpenAIConfig(model_name = "gpt-4-0125-preview", + context_length = 128000, + instructions="You are given below a function description and input data. The function description of what the function must carry out can be found in the Function section, with input and output type hints. The input data can be found in Input section. Using the function description, apply the function to the Input and return a valid output type, that is acceptable by the output_class_definition and output_class_hint.\nINCREDIBLY IMPORTANT: Only output a JSON-compatible string in the correct response format. Use the [END] tokens to specify when the output ends.", + parsing_helper_tokens={"start_token": "[START]", "end_token": "[END]"}), "anthropic.claude-v2:1": ClaudeConfig(model_name = "anthropic.claude-v2:1", context_length = 200000), "llama_70b_chat_aws": LlamaBedrockConfig(model_name = "meta.llama2-70b-chat-v1", context_length = 4096), "llama_13b_chat_aws": LlamaBedrockConfig(model_name = "meta.llama2-13b-chat-v1", context_length = 4096), } +DEFAULT_STUDENT_MODELS = { + "gpt-3.5-turbo-1106": OpenAIConfig(model_name = "", context_length = 14000), + } DEFAULT_EMBEDDING_MODELS = { "ada-002": OpenAIConfig(model_name="text-embedding-ada-002", context_length=8191), diff --git a/src/tanuki/language_models/llm_configs/model_config_factory.py b/src/tanuki/language_models/llm_configs/model_config_factory.py index b763a6e..3b8c176 100644 --- a/src/tanuki/language_models/llm_configs/model_config_factory.py +++ b/src/tanuki/language_models/llm_configs/model_config_factory.py @@ -3,7 +3,7 @@ from tanuki.language_models.llm_configs.llama_config import LlamaBedrockConfig from tanuki.language_models.llm_configs.titan_config import TitanBedrockConfig from typing import Union -from tanuki.language_models.llm_configs import DEFAULT_GENERATIVE_MODELS +from tanuki.language_models.llm_configs import DEFAULT_TEACHER_MODELS, DEFAULT_STUDENT_MODELS from tanuki.constants import DEFAULT_DISTILLED_MODEL_NAME, OPENAI_PROVIDER, LLAMA_BEDROCK_PROVIDER, \ DISTILLED_MODEL, TEACHER_MODEL, TITAN_BEDROCK_PROVIDER @@ -23,13 +23,13 @@ def create_config(input_config: Union[str, dict, BaseModelConfig], type: str) -> if isinstance(input_config, str): # This is purely for backwards compatibility as we used to save the model as a string if type == DISTILLED_MODEL: - config = DEFAULT_GENERATIVE_MODELS[DEFAULT_DISTILLED_MODEL_NAME] + config = DEFAULT_STUDENT_MODELS[DEFAULT_DISTILLED_MODEL_NAME] config.model_name = input_config return config elif type == TEACHER_MODEL: - if input_config not in DEFAULT_GENERATIVE_MODELS: + if input_config not in DEFAULT_TEACHER_MODELS: raise Exception("Error loading the teacher model, saved config model was saved a string but is not a default model") - model = DEFAULT_GENERATIVE_MODELS[input_config] + model = DEFAULT_TEACHER_MODELS[input_config] return model else: if input_config["provider"] == OPENAI_PROVIDER: diff --git a/src/tanuki/language_models/openai_api.py b/src/tanuki/language_models/openai_api.py index 0ed6da8..0763862 100644 --- a/src/tanuki/language_models/openai_api.py +++ b/src/tanuki/language_models/openai_api.py @@ -12,7 +12,7 @@ from tanuki.language_models.embedding_api_abc import Embedding_API from tanuki.language_models.llm_api_abc import LLM_API import os -from tanuki.language_models.llm_configs import DEFAULT_GENERATIVE_MODELS +from tanuki.language_models.llm_configs import DEFAULT_STUDENT_MODELS from tanuki.constants import DEFAULT_DISTILLED_MODEL_NAME from tanuki.language_models.llm_configs.openai_config import OpenAIConfig from tanuki.models.finetune_job import FinetuneJob @@ -130,7 +130,14 @@ def generate(self, model, system_message, prompt, **kwargs): if not choice: raise Exception("OpenAI API failed to generate a response") - + + if model.parsing_helper_tokens["end_token"]: + # remove the end token from the choice + choice = choice.split(model.parsing_helper_tokens["end_token"])[0] + # check if starting token is in choice + if model.parsing_helper_tokens["start_token"] in choice: + # remove the starting token from the choice + choice = choice.split(model.parsing_helper_tokens["start_token"])[-1] return choice def list_finetuned(self, limit=100, **kwargs) -> List[FinetuneJob]: @@ -152,24 +159,18 @@ def get_finetuned(self, job_id) -> FinetuneJob: def finetune(self, file, suffix, **kwargs) -> FinetuneJob: self.check_api_key() # Use the stream as a file - try: - response = self.client.files.create(file=file, purpose='fine-tune') - except Exception as e: - return + response = self.client.files.create(file=file, purpose='fine-tune') training_file_id = response.id # submit the finetuning job - try: - finetuning_response: FineTuningJob = self.client.fine_tuning.jobs.create(training_file=training_file_id, - model="gpt-3.5-turbo", + finetuning_response: FineTuningJob = self.client.fine_tuning.jobs.create(training_file=training_file_id, + model=DEFAULT_DISTILLED_MODEL_NAME, suffix=suffix) - except Exception as e: - return finetune_job = self.create_finetune_job(finetuning_response) return finetune_job def create_finetune_job(self, response: FineTuningJob) -> FinetuneJob: - finetuned_model_config = copy.deepcopy(DEFAULT_GENERATIVE_MODELS[DEFAULT_DISTILLED_MODEL_NAME]) + finetuned_model_config = copy.deepcopy(DEFAULT_STUDENT_MODELS[DEFAULT_DISTILLED_MODEL_NAME]) finetuned_model_config.model_name = response.fine_tuned_model finetune_job = FinetuneJob(response.id, response.status, finetuned_model_config) return finetune_job diff --git a/src/tanuki/models/function_config.py b/src/tanuki/models/function_config.py index 202a12b..ab41a0b 100644 --- a/src/tanuki/models/function_config.py +++ b/src/tanuki/models/function_config.py @@ -1,7 +1,7 @@ from pydantic import BaseModel from typing import Dict, List from tanuki.language_models.llm_configs.abc_base_config import BaseModelConfig -from tanuki.language_models.llm_configs import DEFAULT_GENERATIVE_MODELS +from tanuki.language_models.llm_configs import DEFAULT_TEACHER_MODELS, DEFAULT_STUDENT_MODELS from tanuki.constants import DEFAULT_TEACHER_MODEL_NAMES, DEFAULT_DISTILLED_MODEL_NAME, \ DISTILLED_MODEL, TEACHER_MODEL from tanuki.language_models.llm_configs.model_config_factory import ModelConfigFactory @@ -22,13 +22,13 @@ class FunctionConfig(BaseModel): nr_of_training_runs : int -- the number of training runs """ - distilled_model: BaseModelConfig = DEFAULT_GENERATIVE_MODELS[DEFAULT_DISTILLED_MODEL_NAME] + distilled_model: BaseModelConfig = DEFAULT_STUDENT_MODELS[DEFAULT_DISTILLED_MODEL_NAME] current_model_stats : Dict = { "trained_on_datapoints": 0, "running_faults": []} last_training_run : Dict = {"trained_on_datapoints": 0} current_training_run : Dict = {} - teacher_models : List[BaseModelConfig] = [DEFAULT_GENERATIVE_MODELS[teacher_model_name] for teacher_model_name in DEFAULT_TEACHER_MODEL_NAMES] + teacher_models : List[BaseModelConfig] = [DEFAULT_TEACHER_MODELS[teacher_model_name] for teacher_model_name in DEFAULT_TEACHER_MODEL_NAMES] nr_of_training_runs : int = 0 def load_from_dict(self, json_dict): diff --git a/src/tanuki/register.py b/src/tanuki/register.py index a7eea4c..6941a0a 100644 --- a/src/tanuki/register.py +++ b/src/tanuki/register.py @@ -1,6 +1,6 @@ import inspect -from typing import get_type_hints, Literal, get_origin, Tuple, Callable, Optional, Dict - +from typing import get_type_hints, Literal, get_origin, Tuple, Callable, Optional, Dict, Union +import json from tanuki.models.embedding import Embedding from tanuki.models.function_description import FunctionDescription from tanuki.models.function_type import FunctionType @@ -147,14 +147,39 @@ def get_class_definition(class_type): # output_class_definition = get_class_definition(output_type_hint) output_class_definition = None function_type = FunctionType.SYMBOLIC - if inspect.isclass(output_type_hint): - # Check if the base class of the output type hint is Embedding - base_class = get_origin(output_type_hint) or output_type_hint - if issubclass(base_class, Embedding): - output_class_definition = None - function_type = FunctionType.EMBEDDABLE - else: - output_class_definition = get_class_definition(output_type_hint) + # check if the output type hint is a class or a subclass of a Union + if inspect.isclass(output_type_hint) or (hasattr(output_type_hint, "__origin__") and + output_type_hint.__origin__ == Union): + if (hasattr(output_type_hint, "__origin__") and output_type_hint.__origin__ == Union): # it's a union + # get all the types in the union + union_types = output_type_hint.__args__ + output_type_descriptions = {} + for output_type in union_types: + # check if it is a class Nonetype + if output_type is type(None): + output_type_descriptions["NoneType"] = "None" + elif inspect.isclass(output_type): + # Check if the base class of the output type hint is Embedding + base_class = get_origin(output_type) or output_type + if issubclass(base_class, Embedding): + output_class_definition = None + function_type = FunctionType.EMBEDDABLE + break + else: + class_type_description = get_class_definition(output_type) + if isinstance(class_type_description,str): + class_type_description = class_type_description.replace('"', "'") # less horrible prompt formatting when dump to json + output_type_descriptions[output_type.__name__] = class_type_description + output_class_definition = f"Union of following classes {json.dumps(output_type_descriptions)}" + + else: # it's a class + # Check if the base class of the output type hint is Embedding + base_class = get_origin(output_type_hint) or output_type_hint + if issubclass(base_class, Embedding): + output_class_definition = None + function_type = FunctionType.EMBEDDABLE + else: + output_class_definition = get_class_definition(output_type_hint) return FunctionDescription( name=func_object.__name__, diff --git a/src/tanuki/validator.py b/src/tanuki/validator.py index 1da2313..90c73aa 100644 --- a/src/tanuki/validator.py +++ b/src/tanuki/validator.py @@ -117,9 +117,10 @@ def check_type(self, value: Any, type_definition: Any) -> bool: # Handle tuples if origin == tuple: - if not isinstance(value, tuple) or (args and len(value) != len(args)): + if not isinstance(value, tuple): return False - return all(self.check_type(v, t) for v, t in zip(value, args)) + item_type = args[0] if args else Any + return all(self.check_type(v, item_type) for v in value) # Handle lists if origin == list: @@ -175,6 +176,8 @@ def check_type(self, value: Any, type_definition: Any) -> bool: if self.is_pydantic_model(origin): try: #temp_model = create_model('TempModel', **value) + if isinstance(value, origin): + return True #return isinstance(temp_model, origin) # check if value is dict if not isinstance(value, dict): @@ -485,6 +488,8 @@ def instantiate(self, data: Any, target_type: Type) -> Any: if self._is_subclass_of_generic(target_type, list) and not self._is_generic(target_type): return target_type(instantiated_items) + return instantiated_items + # Handle tuples if self._is_tuple_like(target_type) or (isinstance(origin, type) and issubclass(origin, tuple)): base, item_types = self._find_generic_base_and_args(target_type) diff --git a/tests/test_register/test_output_type_definitions.py b/tests/test_register/test_output_type_definitions.py new file mode 100644 index 0000000..1bfdadf --- /dev/null +++ b/tests/test_register/test_output_type_definitions.py @@ -0,0 +1,95 @@ + +from tanuki.register import Register +from tanuki.models.function_description import FunctionDescription +from pydantic import BaseModel +from typing import List, Dict, Union, Optional +import json + +def test_output_base_classes(): + def output_int(input: str) -> int: + """ + Does something random + """ + function_description: FunctionDescription = Register.load_function_description(output_int) + assert function_description.output_class_definition == "int" + assert function_description.output_type_hint is int + + def output_float(input: str) -> float: + """ + Does something random + """ + function_description: FunctionDescription = Register.load_function_description(output_float) + assert function_description.output_class_definition == "float" + assert function_description.output_type_hint is float + + def output_str(input: str) -> str: + """ + Does something random + """ + function_description: FunctionDescription = Register.load_function_description(output_str) + assert function_description.output_class_definition == "str" + assert function_description.output_type_hint is str + + + def output_bool(input: str) -> bool: + """ + Does something random + """ + function_description: FunctionDescription = Register.load_function_description(output_bool) + assert function_description.output_class_definition == "bool" + assert function_description.output_type_hint is bool + + def output_optional_bool(input: str) -> Optional[bool]: + """ + Does something random + """ + function_description: FunctionDescription = Register.load_function_description(output_optional_bool) + assert function_description.output_class_definition == f"Union of following classes {json.dumps({'bool': 'bool', 'NoneType': 'None'})}" + + + def output_union_bool(input: str) -> Union[bool, int]: + """ + Does something random + """ + function_description: FunctionDescription = Register.load_function_description(output_union_bool) + assert function_description.output_class_definition == f"Union of following classes {json.dumps({'bool': 'bool', 'int': 'int'})}" + +def test_output_pydantic_classes(): + class Person(BaseModel): + name: str + age: int + height: float + is_cool: bool + favourite_numbers: List[int] + even_more_favourite_numbers: tuple[int, ...] + favourite_dict: Dict[str, int] + + def output_person(input: str) -> Person: + """ + Does something random + """ + person_output_description = ' class Person(BaseModel):\n name: str\n age: int\n height: float\n is_cool: bool\n favourite_numbers: List[int]\n even_more_favourite_numbers: tuple[int, ...]\n favourite_dict: Dict[str, int]\n' + function_description: FunctionDescription = Register.load_function_description(output_person) + assert function_description.output_class_definition == person_output_description + + def output_optional_person(input: str) -> Optional[Person]: + """ + Does something random + """ + optional_person_description = f"Union of following classes {json.dumps({'Person': person_output_description, 'NoneType': 'None'})}" + function_description: FunctionDescription = Register.load_function_description(output_optional_person) + assert function_description.output_class_definition == optional_person_description + + def output_union_person(input: str) -> Union[Person, int]: + """ + Does something random + """ + union_person_description = f"Union of following classes {json.dumps({'Person': person_output_description, 'int': 'int'})}" + function_description: FunctionDescription = Register.load_function_description(output_union_person) + assert function_description.output_class_definition == union_person_description + + + +if __name__ == '__main__': + test_output_base_classes() + test_output_pydantic_classes() diff --git a/tests/test_token_counter.py b/tests/test_token_counter.py index 6f252fa..04a6bb2 100644 --- a/tests/test_token_counter.py +++ b/tests/test_token_counter.py @@ -36,7 +36,8 @@ def test_token_counter_finetunable(): prompt, distilled_model, suitable_for_distillation, is_distilled_model = lang_model.get_generation_case(args, kwargs, function_description, - {}) + {}, + "") assert suitable_for_distillation assert is_distilled_model assert distilled_model.model_name == "test_ft_1" @@ -54,7 +55,8 @@ def test_token_counter_non_finetunable_1(): prompt, distilled_model, suitable_for_distillation, is_distilled_model = lang_model.get_generation_case(args, kwargs, function_description, - {}) + {}, + "") assert not suitable_for_distillation assert not is_distilled_model assert distilled_model.model_name == "gpt-4" @@ -72,7 +74,8 @@ def test_token_counter_non_finetunable_2(): prompt, distilled_model, suitable_for_distillation, is_distilled_model = lang_model.get_generation_case(args, kwargs, function_description, - {}) + {}, + "") assert not suitable_for_distillation assert not is_distilled_model assert distilled_model.model_name == "gpt-4-32k" @@ -92,7 +95,8 @@ def test_error_raise(): prompt, distilled_model, suitable_for_distillation, is_distilled_model = lang_model.get_generation_case(args, kwargs, function_description, - {}) + {}, + "") except ValueError: error = True assert error diff --git a/tests/test_validator/test_instantiate.py b/tests/test_validator/test_instantiate.py index fe0c8dd..c1fdc1c 100644 --- a/tests/test_validator/test_instantiate.py +++ b/tests/test_validator/test_instantiate.py @@ -15,6 +15,10 @@ class Person(BaseModel): age: int height: float is_cool: bool + favourite_numbers: List[int] + even_more_favourite_numbers: tuple[int, ...] + favourite_dict: Dict[str, int] + def __eq__(self, other): return self.model_dump() == other.model_dump() @@ -27,10 +31,20 @@ def __hash__(self): "name": "John", "age": 20, "height": 1.8, - "is_cool": True + "is_cool": True, + "favourite_numbers": [1, 2, 3], + "even_more_favourite_numbers": (1, 2, 3), + "favourite_dict": {"a": 1, "b": 2}, } person_obj = validator.instantiate(person, Person) assert isinstance(person_obj, Person) + # test lists + list_pydantic = [person, person] + person_obj = validator.instantiate(list_pydantic, List[Person]) + assert isinstance(person_obj, list) + assert isinstance(person_obj[0], Person) + assert isinstance(person_obj[1], Person) + assert len(person_obj) == 2 # Nested data classes or Pydantic models. @dataclass @@ -76,8 +90,10 @@ def test_primitives(): assert validator.instantiate("1.0", str) != 1.0 assert validator.instantiate("true", str) != True assert validator.instantiate({}, dict) == {} + assert validator.instantiate({"asd": 2, "bb": "ad"}, dict) == {"asd": 2, "bb": "ad"} assert validator.instantiate([], list) == [] assert validator.instantiate((), tuple) == () + assert validator.instantiate((1,2), tuple) == (1, 2) assert validator.instantiate(set(), frozenset) == set() assert validator.instantiate((), frozenset) == () assert validator.instantiate((), set) == () @@ -231,4 +247,4 @@ class ExtendedList(List[int]): test_instantiate_dataclass() test_primitives() test_generics() - test_extended_generics() \ No newline at end of file + test_extended_generics(Validator()) \ No newline at end of file diff --git a/tests/test_validator/test_validate_value.py b/tests/test_validator/test_validate_value.py index 4029870..f9bd963 100644 --- a/tests/test_validator/test_validate_value.py +++ b/tests/test_validator/test_validate_value.py @@ -144,6 +144,9 @@ class Person: age: int height: float is_cool: bool + favourite_numbers: List[int] + even_more_favourite_numbers: tuple[int, ...] + favourite_dict: Dict[str, int] def __eq__(self, other): return self.dict() == other.dict() @@ -151,8 +154,15 @@ def __eq__(self, other): def __hash__(self): return hash(str(self.__dict__)) - person = Person('John', 20, 1.8, True) - person = {'name': 'John', 'age': 20, 'height': 1.8, 'is_cool': True} + person = { + "name": "John", + "age": 20, + "height": 1.8, + "is_cool": True, + "favourite_numbers": [1, 2, 3], + "even_more_favourite_numbers": (1, 2, 3), + "favourite_dict": {"a": 1, "b": 2}, + } assert validator.check_type(person, Person) assert not validator.check_type(person, str) @@ -198,6 +208,10 @@ class Person(BaseModel): age: int height: float is_cool: bool + favourite_numbers: List[int] + even_more_favourite_numbers: tuple[int, ...] + favourite_dict: Dict[str, int] + def __eq__(self, other): return self.model_dump() == other.model_dump() @@ -205,8 +219,15 @@ def __eq__(self, other): def __hash__(self): return hash(str(self.model_dump())) - person = Person(name='John', age=20, height=1.8, is_cool=True) - person = {'name': 'John', 'age': 20, 'height': 1.8, 'is_cool': True} + person = { + "name": "John", + "age": 20, + "height": 1.8, + "is_cool": True, + "favourite_numbers": [1, 2, 3], + "even_more_favourite_numbers": (1, 2, 3), + "favourite_dict": {"a": 1, "b": 2}, + } assert validator.check_type(person, Person) assert not validator.check_type(person, str) @@ -246,10 +267,10 @@ def __hash__(self): if __name__ == "__main__": test_validate_pydantic() - test_validate_dataclasses() - test_validate_literal_types() - test_validate_collection_list_types() - test_validate_collection_dict_types() - test_validate_type_annotations() - test_validate_complex_types() - test_validate_base_types() \ No newline at end of file + #test_validate_dataclasses() + #test_validate_literal_types() + #test_validate_collection_list_types() + #test_validate_collection_dict_types() + #test_validate_type_annotations() + #test_validate_complex_types() + #test_validate_base_types() \ No newline at end of file