-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ba6eaac
commit 78cb39b
Showing
18 changed files
with
346 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
|
||
setup( | ||
name="TruthTorchLM", # Your package name | ||
version="0.1.8", # Package version | ||
version="0.1.9", # Package version | ||
author="Yavuz Faruk Bakman", | ||
author_email="[email protected]", | ||
description="TruthTorchLM is an open-source library designed to assess truthfulness in language models' outputs. The library integrates state-of-the-art methods, offers comprehensive benchmarking tools across various tasks, and enables seamless integration with popular frameworks like Huggingface and LiteLLM.", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
127 changes: 127 additions & 0 deletions
127
src/TruthTorchLM/long_form_generation/claim_check_methods/naive_application.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from .claim_check_method import ClaimCheckMethod | ||
from TruthTorchLM.truth_methods import TruthMethod | ||
from TruthTorchLM.utils.common_utils import fix_tokenizer_chat | ||
from TruthTorchLM.generation import get_sampling_properties, sample_generations_hf_local, sample_generations_api | ||
from ..templates import ANSWER_GENERATION_INSTRUCTION | ||
|
||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, PreTrainedModel | ||
|
||
import torch | ||
from typing import Union | ||
from copy import deepcopy | ||
|
||
class NaiveApplication(ClaimCheckMethod): | ||
def __init__(self, | ||
generate_answer_instruction:list=ANSWER_GENERATION_INSTRUCTION, | ||
truth_methods:list[TruthMethod]=None, batch_generation:bool = True, | ||
use_question:bool=True): | ||
super().__init__() | ||
|
||
self.generate_answer_instruction = generate_answer_instruction | ||
self.truth_methods = truth_methods | ||
self.batch_generation = batch_generation | ||
self.use_question = use_question | ||
|
||
def _get_truth_value_local(self, truth_methods, model, tokenizer, question, text, answer, model_output, generation_seed, messages, **kwargs): | ||
|
||
number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations = get_sampling_properties(truth_methods) | ||
|
||
sampled_gen_dict = sample_generations_hf_local(model, text, tokenizer, generation_seed, number_of_generations=number_of_generations, | ||
return_text=return_text, return_logits=return_logits, return_logprobs=return_logprobs,return_attentions=return_attentions, return_activations=return_activations, batch_generation=self.batch_generation, **kwargs) | ||
|
||
normalized_truth_values = [] | ||
unnormalized_truth_values = [] | ||
method_spec_outputs = [] | ||
for truth_method in truth_methods: | ||
truth_values = truth_method(model=model, input_text=text, generated_text=answer, question_context=question, all_ids=model_output, tokenizer=tokenizer, generation_seed = generation_seed, sampled_generations_dict=sampled_gen_dict, messages=messages, **kwargs) | ||
normalized_truth_values.append(truth_values['normalized_truth_value']) | ||
unnormalized_truth_values.append(truth_values['truth_value']) | ||
method_spec_outputs.append(truth_values) | ||
|
||
return normalized_truth_values, unnormalized_truth_values, method_spec_outputs | ||
|
||
def check_claim_local(self, model:PreTrainedModel, input_text:str, generated_text:str, question_context:str, | ||
claim:str, text_so_far:str, all_ids:Union[list, torch.Tensor], | ||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, generation_seed=None, | ||
messages:list = [], **kwargs): | ||
|
||
question = question_context if self.use_question else "" | ||
q_messages = deepcopy(self.generate_answer_instruction) | ||
q_messages[1]["content"] = question | ||
tokenizer, q_messages = fix_tokenizer_chat(tokenizer, q_messages) | ||
text = tokenizer.apply_chat_template(q_messages, tokenize = False, add_generation_prompt=True, continue_final_message=False) | ||
q_messages.append({"role": "assistant", "content": claim}) | ||
tokenizer, q_messages = fix_tokenizer_chat(tokenizer, q_messages) | ||
text_messsages = tokenizer.apply_chat_template(q_messages, tokenize = False, add_generation_prompt=True, continue_final_message=False) | ||
model_outputs = tokenizer.encode(text_messsages, return_tensors="pt").to(model.device) | ||
|
||
t_messages = deepcopy(self.generate_answer_instruction) | ||
t_messages[-1] = {"role": "user", "content": question} | ||
normalized_truth_values, unnormalized_truth_values, method_spec_outputs = self._get_truth_value_local(self.truth_methods, model=model, tokenizer=tokenizer, | ||
question=question, text=text, answer=claim, | ||
model_output=model_outputs, generation_seed=generation_seed, | ||
messages=t_messages, **kwargs) | ||
final_method_specific_outputs = [] | ||
for i in range(len(self.truth_methods)): | ||
output_dict = {"Truth method name": self.truth_methods[i].__class__.__name__} | ||
method_spec_outputs[i].pop('generated_text', None) | ||
output_dict["detailed_outputs"] = method_spec_outputs[i] | ||
final_method_specific_outputs.append(output_dict) | ||
|
||
return { "claim": claim, "normalized_truth_values": normalized_truth_values, "truth_values": unnormalized_truth_values, | ||
"question": question, "truth_method_spec_outputs": final_method_specific_outputs} | ||
|
||
def _get_truth_value_api(self, truth_methods, model, q_messages, question, answer, generation_seed, **kwargs): | ||
|
||
#Get sampled generations to be used in truth methods | ||
number_of_generations, return_text, return_logits, return_logprobs, return_attentions, return_activations = get_sampling_properties(truth_methods) | ||
sampled_gen_dict = sample_generations_api(model, q_messages, generation_seed, number_of_generations=number_of_generations, | ||
return_text=return_text, return_logits=return_logits, return_logprobs=return_logprobs,return_attentions=return_attentions, return_activations=return_activations, **kwargs) | ||
|
||
normalized_truth_values = [] | ||
unnormalized_truth_values = [] | ||
method_spec_outputs = [] | ||
for truth_method in truth_methods: | ||
truth_values = truth_method(model=model, messages=q_messages, generated_text=answer, question_context=question, generation_seed=generation_seed, sampled_generations_dict=sampled_gen_dict, **kwargs) | ||
normalized_truth_values.append(truth_values['normalized_truth_value']) | ||
unnormalized_truth_values.append(truth_values['truth_value']) | ||
method_spec_outputs.append(truth_values) | ||
|
||
return normalized_truth_values, unnormalized_truth_values, method_spec_outputs | ||
|
||
def check_claim_api(self, model:str, messages:list, generated_text:str, | ||
question_context:str, claim:str, text_so_far:str, generation_seed=None, **kwargs): | ||
|
||
|
||
requires_logprobs = False | ||
for truth_method in self.truth_methods: | ||
if truth_method.REQUIRES_LOGPROBS: | ||
requires_logprobs = True | ||
print(f"Truth method '{truth_method.__class__.__name__}' requires logprobs.") | ||
|
||
if requires_logprobs: | ||
raise ValueError(f"Truth methods requiring logprobs cannot be used with QuestionGeneration claim check method.") | ||
|
||
q_messages = deepcopy(self.generate_answer_instruction) | ||
question = question_context if self.use_question else "" | ||
#Get truth value for truth method | ||
q_messages[1]["content"] = question | ||
normalized_truth_values, unnormalized_truth_values, method_spec_outputs = self._get_truth_value_api(self.truth_methods, model=model, | ||
q_messages=q_messages, question=question, answer=claim, | ||
generation_seed=generation_seed, **kwargs) | ||
final_method_specific_outputs = [] | ||
for i in range(len(self.truth_methods)): | ||
output_dict = {"Truth method name": self.truth_methods[i].__class__.__name__} | ||
method_spec_outputs[i].pop('generated_text', None) | ||
output_dict["detailed_outputs"] = method_spec_outputs[i] | ||
final_method_specific_outputs.append(output_dict) | ||
|
||
return { "claim": claim, "normalized_truth_values": normalized_truth_values, "truth_values": unnormalized_truth_values, | ||
"question": question_context, "truth_method_spec_outputs": final_method_specific_outputs} | ||
|
||
def __str__(self): | ||
|
||
return f"Claim Check Method by using the orginal question and the claim.\n\ | ||
Use original question (if false, empty string is used as question): {self.use_question}\n\ | ||
Answer generation instruction (used as the template for original question - claim pair):\n {self.generate_answer_instruction}\n\n\ | ||
Truth methods to assign a score the question(s):\n {self.truth_methods}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import torch | ||
import numpy as np | ||
from typing import Union | ||
|
||
from .truth_method import TruthMethod | ||
from ..generation import sample_generations_hf_local, sample_generations_api | ||
from TruthTorchLM.error_handler import handle_logprobs_error | ||
|
||
from sentence_transformers.cross_encoder import CrossEncoder | ||
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast | ||
|
||
|
||
class SAR(TruthMethod): | ||
|
||
REQUIRES_SAMPLED_TEXT = True | ||
REQUIRES_SAMPLED_LOGPROBS = True | ||
|
||
def __init__(self, number_of_generations=5, t=0.001, | ||
model_for_similarity=None, similarity_model_device = 'cuda', batch_generation=True):#normalization | ||
super().__init__() | ||
if model_for_similarity is None: | ||
self.model_for_similarity = CrossEncoder('cross-encoder/stsb-roberta-large', num_labels=1, device=similarity_model_device) | ||
else: | ||
self.model_for_similarity = model_for_similarity | ||
|
||
self.number_of_generations = number_of_generations | ||
self.t = t | ||
self.batch_generation = batch_generation | ||
|
||
def _sentsar(self, generated_texts:list[str], question_context:str, scores:list[float], sampled_generations_dict:dict): | ||
|
||
similarities = {} | ||
for i in range(len(generated_texts)): | ||
similarities[i] = [] | ||
|
||
for i in range(len(generated_texts)): | ||
for j in range(i+1, len(generated_texts)): | ||
gen_i = question_context + generated_texts[i] | ||
gen_j = question_context + generated_texts[j] | ||
similarity_i_j = self.model_for_similarity.predict([gen_i, gen_j]) | ||
similarities[i].append(similarity_i_j) | ||
similarities[j].append(similarity_i_j) | ||
|
||
probs = torch.exp(torch.tensor(scores)) | ||
assert len(probs) == len(similarities) | ||
|
||
sentence_scores = [] | ||
for idx, prob in enumerate(probs): | ||
w_ent = -torch.log( | ||
prob + ((torch.tensor(similarities[idx]) / self.t) * torch.cat([probs[:idx], probs[idx + 1:]])).sum()) | ||
sentence_scores.append(w_ent) | ||
sentence_scores = torch.tensor(sentence_scores) | ||
|
||
entropy = (torch.sum(sentence_scores, dim=0) / torch.tensor(sentence_scores.shape[0])).item() | ||
return {"truth_value": -entropy, 'SAR': entropy, "score_for_each_generation": scores, 'generated_texts': generated_texts, "similarities": similarities} | ||
|
||
def _tokensar_local(self, question_context:str, generated_text:str, tokens:list[int], logprobs:list[float], tokenizer:Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): | ||
importance_vector = [] | ||
for i in range(len(tokens)): | ||
removed_answer_ids = tokens[:i] + tokens[i+1:] | ||
removed_answer = tokenizer.decode(removed_answer_ids, skip_special_tokens=True) | ||
score = self.model_for_similarity.predict([( question_context +" "+removed_answer, question_context + ' ' + generated_text)]) | ||
score = 1 - score[0] | ||
importance_vector.append(score) | ||
|
||
importance_vector = importance_vector / np.sum(importance_vector) | ||
return np.dot(importance_vector, logprobs) | ||
|
||
def forward_hf_local(self, model:PreTrainedModel, input_text:str, generated_text:str, question_context:str, all_ids:Union[list, torch.Tensor], | ||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, generation_seed = None, sampled_generations_dict:dict = None, messages:list = [], **kwargs): | ||
|
||
if sampled_generations_dict is None: | ||
sampled_generations_dict = sample_generations_hf_local(model = model, input_text = input_text, tokenizer = tokenizer, generation_seed=generation_seed, | ||
number_of_generations=self.number_of_generations, return_text = True, return_logprobs=True, batch_generation=self.batch_generation, **kwargs) | ||
|
||
|
||
generated_texts = sampled_generations_dict["generated_texts"][:self.number_of_generations] | ||
generated_tokens = sampled_generations_dict["tokens"][:self.number_of_generations] | ||
logprobs = sampled_generations_dict["logprobs"][:self.number_of_generations] | ||
|
||
scores = [] | ||
for i in range(self.number_of_generations): | ||
score = self._tokensar_local(question_context, generated_texts[i], generated_tokens[i], logprobs[i], tokenizer) | ||
scores.append(score) #scores are in log scale | ||
|
||
return self._sentsar(generated_texts, question_context, scores, sampled_generations_dict) | ||
|
||
def _tokensar_api(self, question_context:str, generated_text:str, tokens:list[str], logprobs:list[float]): | ||
importance_vector = [] | ||
for i in range(len(tokens)): | ||
removed_answer = "".join(tokens[:i]) + "".join(tokens[i+1:]) | ||
score = self.model_for_similarity.predict([( question_context +" "+removed_answer, question_context + ' ' + generated_text)]) | ||
score = 1 - score[0] | ||
importance_vector.append(score) | ||
|
||
importance_vector = importance_vector / np.sum(importance_vector) | ||
return np.dot(importance_vector, logprobs) | ||
|
||
@handle_logprobs_error | ||
def forward_api(self, model:str, messages:list, generated_text:str, question_context:str, generation_seed = None, sampled_generations_dict:dict = None, logprobs:list=None, generated_tokens:list=None, **kwargs): | ||
|
||
if sampled_generations_dict is None: | ||
sampled_generations_dict = sample_generations_api(model = model, messages = messages, generation_seed = generation_seed, | ||
number_of_generations=self.number_of_generations, return_text = True, return_logprobs=True, **kwargs) | ||
|
||
|
||
generated_texts = sampled_generations_dict["generated_texts"][:self.number_of_generations] | ||
generated_tokens = sampled_generations_dict["tokens"][:self.number_of_generations] | ||
logprobs = sampled_generations_dict["logprobs"][:self.number_of_generations] | ||
|
||
scores = [] | ||
for i in range(self.number_of_generations): | ||
score = self._tokensar_api(question_context, generated_texts[i], generated_tokens[i], logprobs[i]) | ||
scores.append(score) #scores are in log scale | ||
|
||
return self._sentsar(generated_texts, question_context, scores, sampled_generations_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.