Skip to content

Commit

Permalink
Validator, logging and modelling improvements (#127)
Browse files Browse the repository at this point in the history
Fixed a bug regarding lists of pydantic objects in validator
Added logging
Added openai turbo models
Fixed a LLama70b bug
  • Loading branch information
MartBakler authored Jan 29, 2024
1 parent 2ca8d47 commit 21672ec
Show file tree
Hide file tree
Showing 16 changed files with 294 additions and 85 deletions.
3 changes: 3 additions & 0 deletions src/tanuki/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,13 @@ def wrapper(*args, **kwargs) -> Union[Embedding, Any]:
# Configure the function modeler using incoming parameters
function_modeler.environment_id = environment_id
if ignore_finetuning:
logging.info(f"The flag for ignoring finetuning has been set True for {test_func.__name__}. No model distillation will be performed.")
function_modeler.execute_finetune_blacklist.append(func_hash)
if ignore_finetune_fetching:
logging.info(f"The flag for ignoring searching for finetuned models has been set True for {test_func.__name__}. No already finetuned models will be looked for.")
function_modeler.check_finetune_blacklist.append(func_hash)
if ignore_data_storage:
logging.info(f"The flag for ignoring data storage has been set True for {test_func.__name__}. No data will be read or saved and model distillation will not be performed.")
function_modeler.store_data_blacklist.append(func_hash)
task_type = function_description.type
if len(teacher_models) > 0:
Expand Down
2 changes: 1 addition & 1 deletion src/tanuki/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# default models
DEFAULT_TEACHER_MODEL_NAMES = ["gpt-4", "gpt-4-32k", ]
DEFAULT_DISTILLED_MODEL_NAME = "gpt-3.5-finetune"
DEFAULT_DISTILLED_MODEL_NAME = "gpt-3.5-turbo-1106"
DEFAULT_EMBEDDING_MODEL_NAME = "ada-002"

# provider names
Expand Down
36 changes: 23 additions & 13 deletions src/tanuki/function_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import json
from typing import List, Tuple, Dict, Union

import openai
import logging

from tanuki.constants import EXAMPLE_ELEMENT_LIMIT, PATCHES, SYMBOLIC_ALIGNMENTS, POSITIVE_EMBEDDABLE_ALIGNMENTS, \
NEGATIVE_EMBEDDABLE_ALIGNMENTS, OPENAI_PROVIDER
from tanuki.models.function_type import FunctionType
from tanuki.language_models.llm_configs import DEFAULT_GENERATIVE_MODELS, DEFAULT_EMBEDDING_MODELS
from tanuki.language_models.llm_configs import DEFAULT_TEACHER_MODELS, DEFAULT_EMBEDDING_MODELS
from tanuki.language_models.llm_configs.abc_base_config import BaseModelConfig
from tanuki.language_models.llm_finetune_api_abc import LLM_Finetune_API
from tanuki.models.finetune_job import FinetuneJob
Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(self, data_worker: DatasetWorker,
self.store_data_blacklist = []
self.api_provider = api_provider
self.teacher_models_override = {}
self.startup_logging_checker = {}

def _get_dataset_info(self, dataset_type, func_hash, type="length"):
"""
Expand All @@ -65,7 +66,7 @@ def _configure_teacher_models(self,
if task_type == FunctionType.EMBEDDABLE:
preconfigured_models = DEFAULT_EMBEDDING_MODELS
elif task_type == FunctionType.SYMBOLIC:
preconfigured_models = DEFAULT_GENERATIVE_MODELS
preconfigured_models = DEFAULT_TEACHER_MODELS
for model in teacher_models:
if isinstance(model, str):
if model not in preconfigured_models:
Expand Down Expand Up @@ -318,7 +319,7 @@ def load_function_config(self, func_hash, function_description):

def _check_for_finetunes(self, function_description: FunctionDescription, finetune_provider : str) -> Tuple[bool, Dict]:
# hash the function_hash into 16 characters (to embed it into the name of OpenAI finetunes, for later retrieval)

logging.info(f"Checking for finetunes for {function_description.name} using {finetune_provider}")
finetune_hash = function_description.__hash__(purpose="finetune") + encode_int(self.environment_id)
# List 10 fine-tuning jobs
finetunes: List[FinetuneJob] = self.api_provider[finetune_provider].list_finetuned(limit=1000)
Expand All @@ -333,10 +334,12 @@ def _check_for_finetunes(self, function_description: FunctionDescription, finetu
config = self._construct_config_from_finetune(finetune_hash, finetune)
# save the config
self.data_worker.update_function_config(function_description.__hash__(), config)
logging.info(f"Found finetuned model for {function_description.name} [{config.distilled_model.model_name}]")
return True, config
except:
logging.info(f"Found finetuned model for {function_description.name} [{finetune.fine_tuned_model.model_name}] but could not load it")
return False, {}

logging.info(f"No finetuned model found for {function_description.name}")
return False, {}

def _construct_config_from_finetune(self, finetune_hash: str, finetune: FinetuneJob):
Expand Down Expand Up @@ -426,16 +429,16 @@ def check_for_finetuning(self, function_description, func_hash):
# check if already finetuning
if "job_id" in self.function_configs[func_hash].current_training_run:
# check for job status
self._check_finetuning_status(func_hash)
self._check_finetuning_status(func_hash, function_description)
else:
# check for finetuning condition
if self._check_finetuning_condition(func_hash):
if self._check_finetuning_condition(func_hash, function_description):
self._execute_finetuning(function_description, func_hash)
except Exception as e:
print(e)
print("Error checking for finetuning")

def _check_finetuning_condition(self, func_hash):
def _check_finetuning_condition(self, func_hash, function_description):
"""
Check if the finetuning condition is met
Currently finetuning condition is dependent on the number of symbolic datapoints since last finetuning
Expand All @@ -453,6 +456,11 @@ def _check_finetuning_condition(self, func_hash):
# if havent read in the patch dataset size, read it in
patch_dataset_size = self._get_dataset_info(PATCHES, func_hash, type="length")
self.dataset_sizes[PATCHES][func_hash] = patch_dataset_size
if func_hash not in self.startup_logging_checker:
logging.info(f"Function {function_description.name} [{align_dataset_size} aligns | {patch_dataset_size} runs] will be finetuned from"\
f" {self.function_configs[func_hash].teacher_models[0].model_name} using {self.function_configs[func_hash].distilled_model.provider} in "\
f"{training_threshold-(patch_dataset_size + align_dataset_size)} runs")
self.startup_logging_checker[func_hash] = True

return (patch_dataset_size + align_dataset_size) > training_threshold

Expand Down Expand Up @@ -529,8 +537,10 @@ def _execute_finetuning(self, function_description, func_hash):
# Use the stream as a file
try:
finetune_provider = self.function_configs[func_hash].distilled_model.provider
logging.info(f"Starting finetuning for {function_description.name} using {finetune_provider}")
finetuning_response: FinetuneJob = self.api_provider[finetune_provider].finetune(file=temp_file, suffix=finetune_hash)
except Exception as e:
logging.info(f"Could not start finetuning for {function_description.name} using {finetune_provider}. Error: {e}")
return

self.function_configs[func_hash].current_training_run = {"job_id": finetuning_response.id,
Expand All @@ -544,7 +554,7 @@ def _execute_finetuning(self, function_description, func_hash):
print(e)
print("Could not update config file to register a finetuning run")

def _check_finetuning_status(self, func_hash):
def _check_finetuning_status(self, func_hash, function_description):
"""
Check the status of the current finetuning job
If the job is finished, update the config file to reflect the new model
Expand All @@ -560,18 +570,18 @@ def _check_finetuning_status(self, func_hash):
self.function_configs[func_hash].current_training_run["last_checked"] = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S")
if response.status == "succeeded" or response.status == "failed":
self._update_finetune_config(response, func_hash)
self._update_finetune_config(response, func_hash, function_description)
else:
self._update_config_file(func_hash)

def _update_finetune_config(self, response: FinetuneJob, func_hash):
def _update_finetune_config(self, response: FinetuneJob, func_hash, function_description):
"""
Update the config file to reflect the new model and switch the current model to the finetuned model
"""
self.function_configs[func_hash].update_with_finetuned_response(response)
logging.info(f"Finetuning for {function_description.name} using {self.function_configs[func_hash].distilled_model.provider} finished with status: {response.status}")
try:
self._update_config_file(func_hash)
except Exception as e:
print(e)
print("Could not update config file after a successful finetuning run")
logging.info(f"Could not update the function configuration file with the finetuned model for {function_description.name}. Error: {e}")
pass
14 changes: 7 additions & 7 deletions src/tanuki/language_models/embedding_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class EmbeddingModelManager(object):
def __init__(self, function_modeler, api_provider: APIManager):
self.function_modeler = function_modeler
self.api_provider = api_provider
self.current_generators = {}
self.initialized_functions = {}

def get_embedding_case(self, args, function_description: FunctionDescription, kwargs, examples=None):
# example_input = f"Examples:{examples}\n" if examples else ""
Expand All @@ -25,12 +25,12 @@ def get_embedding_case(self, args, function_description: FunctionDescription, kw


# loggings
if function_hash not in self.current_generators:
logging.info(f"Generating function embeddings with {model.model_name}")
self.current_generators[function_hash] = model.model_name
elif self.current_generators[function_hash] != model.model_name:
logging.info(f"Switching embeddings generation from {self.current_generators[function_hash]} to {model.model_name}")
self.current_generators[function_hash] = model.model_name
if function_hash not in self.initialized_functions:
logging.info(f"Generating function embeddings for {function_description.name} with {model.model_name}")
self.initialized_functions[function_hash] = model.model_name
elif self.initialized_functions[function_hash] != model.model_name:
logging.info(f"Switching embeddings generation for {function_description.name} from {self.initialized_functions[function_hash]} to {model.model_name}")
self.initialized_functions[function_hash] = model.model_name

return content, model

Expand Down
41 changes: 28 additions & 13 deletions src/tanuki/language_models/language_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self,
self.api_provider = api_provider
self.function_modeler = function_modeler
self.default_generation_length = generation_token_limit
self.current_generators = {}
self.initialized_functions = {}
self.token_counts = {}

def __call__(self,
Expand Down Expand Up @@ -83,17 +83,24 @@ def generate(self, args, kwargs, function_description, llm_parameters={}):
The main generation function, given the args, kwargs, function description and model type, generate a response and check if the datapoint can be saved to the finetune dataset
"""

func_hash = function_description.__hash__()
prompt, model, save_to_finetune, is_distilled_model = self.get_generation_case(args, kwargs,
function_description,
llm_parameters)
func_hash = function_description.__hash__()
llm_parameters,
func_hash)
# loggings
if func_hash not in self.current_generators:
logging.info(f"Generating function outputs with {model.model_name}")
self.current_generators[func_hash] = model.model_name
elif self.current_generators[func_hash] != model.model_name:
logging.info(f"Switching output generation from {self.current_generators[func_hash]} to {model.model_name}")
self.current_generators[func_hash] = model.model_name
current_function_setup = self.initialized_functions.get(func_hash, None) # getting the current function setup - model and align statements
if current_function_setup:
generator_model = current_function_setup["model"]
if is_distilled_model:
logging.info(f"Generating function outputs for {function_description.name} with a finetuned model: {model.model_name}.")
self.initialized_functions[func_hash]["model"] = model.model_name
elif generator_model == "":
logging.info(f"Found {len(current_function_setup['examples'])} align statements for {function_description.name}. Generating function outputs with {model.model_name}.")
self.initialized_functions[func_hash]["model"] = model.model_name
elif generator_model != model.model_name:
logging.info(f"Switching output generation from {generator_model} to {model.model_name} for function {function_description.name}.")
self.initialized_functions[func_hash]["model"] = model.model_name

choice = self._synthesise_answer(prompt, model, llm_parameters)
output = LanguageModelOutput(choice, save_to_finetune, is_distilled_model)
Expand All @@ -114,7 +121,7 @@ def _synthesise_answer(self, prompt, model, llm_parameters):
return self.api_provider[model.provider].generate(model, system_message, prompt, **llm_parameters)


def get_generation_case(self, args, kwargs, function_description, llm_parameters):
def get_generation_case(self, args, kwargs, function_description, llm_parameters, func_hash):
"""
Get the generation case with the correct prompt and model
First get the current model, then if distilled model, do zero-shot prompt and return False as suitable_for_finetune
Expand All @@ -126,6 +133,9 @@ def get_generation_case(self, args, kwargs, function_description, llm_parameters
is_distilled_model = distilled_model.model_name != ""
suitable_for_distillation, input_prompt_token_count = self.suitable_for_finetuning_token_check(args, kwargs, f,
distilled_model)
if func_hash not in self.initialized_functions:
# initialise the initialized_functions dict
self.initialized_functions[func_hash] = {"model": "", "examples": []}
# no examples needed, using a finetuned model. Dont save to finetune dataset
if is_distilled_model and suitable_for_distillation:
prompt = self.construct_prompt(f, args, kwargs, [], distilled_model)
Expand All @@ -136,13 +146,18 @@ def get_generation_case(self, args, kwargs, function_description, llm_parameters
examples = [f"Inputs:\nArgs: {align['args']}\nKwargs: {align['kwargs']}\nOutput: {align['output']}" for align in
aligns]

# update the examples in the initialized_functions dict
self.initialized_functions[func_hash]["examples"] = examples

examples_token_count = sum([approximate_token_count(example) for example in examples])
generation_tokens = llm_parameters.get("max_new_tokens", self.default_generation_length)
model = self.choose_model_from_tokens(teacher_models,
examples_token_count + input_prompt_token_count + generation_tokens,
len(examples))
if model:
prompt = self.construct_prompt(f, args, kwargs, examples, model)
examples_with_parsing_tokens = [f"Inputs:\nArgs: {align['args']}\nKwargs: {align['kwargs']}\nOutput:{model.parsing_helper_tokens['start_token']}{align['output']}{model.parsing_helper_tokens['end_token']}" for align in
aligns]
prompt = self.construct_prompt(f, args, kwargs, examples_with_parsing_tokens, model)
return prompt, model, suitable_for_distillation, False
else:
raise ValueError(
Expand Down Expand Up @@ -179,14 +194,14 @@ def construct_prompt(self, f, args, kwargs, examples, model):
"""
if examples:
final_examples = "\n".join(
[f"{model.parsing_helper_tokens['start_token']}{align}{model.parsing_helper_tokens['end_token']}" for align in
[f"{align}" for align in
examples])
example_input = f"Examples:{final_examples}\n"
else:
example_input = ""

instruction_prompt = model.instructions
content = f"{instruction_prompt}\nFunction: {f}\n{example_input}---\n{model.parsing_helper_tokens['start_token']}Inputs:\nArgs: {args}\nKwargs: {kwargs}\nOutput:"
content = f"{instruction_prompt}\nFunction: {f}\n{example_input}---\nInputs:\nArgs: {args}\nKwargs: {kwargs}\nOutput:{model.parsing_helper_tokens['start_token']}"
return content

def repair_generate(self, args, kwargs, f, failed_outputs_list, aligns, models, llm_parameters):
Expand Down
6 changes: 5 additions & 1 deletion src/tanuki/language_models/llama_bedrock_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,9 @@ def generate(self, model: BaseModelConfig, system_message: str, prompt: str, **k
if model.parsing_helper_tokens["end_token"]:
# remove the end token from the choice
choice = choice.split(model.parsing_helper_tokens["end_token"])[0]
# check if starting token is in choice
if model.parsing_helper_tokens["start_token"] in choice:
# remove the starting token from the choice
choice = choice.split(model.parsing_helper_tokens["start_token"])[-1]

return choice
return choice.strip()
Loading

0 comments on commit 21672ec

Please sign in to comment.