Skip to content

Commit

Permalink
Merge pull request #32 from rmusser01/main
Browse files Browse the repository at this point in the history
Ollama API fix
  • Loading branch information
rmusser01 authored Oct 13, 2024
2 parents 2597e66 + ebfd4c8 commit e2f8139
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
19 changes: 16 additions & 3 deletions App_Function_Libraries/RAG/Embeddings_Create.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
# Global cache for embedding models
embedding_models = {}

# Commit hashes
commit_hashes = {
"jinaai/jina-embeddings-v3": "4be32c2f5d65b95e4bcce473545b7883ec8d2edd",
"Alibaba-NLP/gte-large-en-v1.5": "104333d6af6f97649377c2afbde10a7704870c7b",
"dunzhang/setll_en_400M_v5": "2aa5579fcae1c579de199a3866b6e514bbbf5d10"
}

class HuggingFaceEmbedder:
def __init__(self, model_name, cache_dir, timeout_seconds=30):
self.model_name = model_name
Expand All @@ -53,17 +60,20 @@ def __init__(self, model_name, cache_dir, timeout_seconds=30):
self.unload_timer = None

def load_model(self):
# https://huggingface.co/docs/transformers/custom_models
if self.model is None:
# Pass cache_dir to from_pretrained to specify download directory
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
cache_dir=self.cache_dir # Specify cache directory
cache_dir=self.cache_dir, # Specify cache directory
revision=commit_hashes.get(self.model_name, None) # Pass commit hash
)
self.model = AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
cache_dir=self.cache_dir # Specify cache directory
cache_dir=self.cache_dir, # Specify cache directory
revision=commit_hashes.get(self.model_name, None) # Pass commit hash
)
self.model.to(self.device)
self.last_used_time = time.time()
Expand All @@ -88,6 +98,7 @@ def reset_timer(self):

def create_embeddings(self, texts):
self.load_model()
# https://huggingface.co/docs/transformers/custom_models
inputs = self.tokenizer(
texts,
return_tensors="pt",
Expand Down Expand Up @@ -117,10 +128,12 @@ class ONNXEmbedder:
def __init__(self, model_name, onnx_model_dir, timeout_seconds=30):
self.model_name = model_name
self.model_path = os.path.join(onnx_model_dir, f"{model_name}.onnx")
# https://huggingface.co/docs/transformers/custom_models
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir=onnx_model_dir # Ensure tokenizer uses the same directory
cache_dir=onnx_model_dir, # Ensure tokenizer uses the same directory
revision=commit_hashes.get(model_name, None) # Pass commit hash
)
self.session = None
self.timeout_seconds = timeout_seconds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def summarize(
return summarize_with_local_llm(input_data, custom_prompt_arg, temp, system_message)
elif api_name.lower() == "huggingface":
return summarize_with_huggingface(api_key, input_data, custom_prompt_arg, temp, )#system_message)
elif api_name.lower() == "custom-openai":
return summarize_with_custom_openai(api_key, input_data, custom_prompt_arg, temp, system_message)
elif api_name.lower() == "ollama":
return summarize_with_ollama(input_data, custom_prompt_arg, api_key, temp, system_message)
else:
return f"Error: Invalid API Name {api_name}"

Expand Down

0 comments on commit e2f8139

Please sign in to comment.