Skip to content

Commit

Permalink
Add support for more LLM models (docker#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasonjo authored Nov 5, 2024
1 parent 6680d03 commit caec526
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
20 changes: 18 additions & 2 deletions chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
from utils import BaseLogger, extract_title_and_question
from langchain_google_genai import GoogleGenerativeAIEmbeddings

AWS_MODELS = (
"ai21.jamba-instruct-v1:0",
"amazon.titan",
"anthropic.claude",
"cohere.command",
"meta.llama",
"mistral.mi",
)

def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
if embedding_model_name == "ollama":
Expand Down Expand Up @@ -55,9 +63,9 @@ def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config=


def load_llm(llm_name: str, logger=BaseLogger(), config={}):
if llm_name == "gpt-4":
if llm_name in ["gpt-4", "gpt-4o", "gpt-4-turbo"]:
logger.info("LLM: Using GPT-4")
return ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
return ChatOpenAI(temperature=0, model_name=llm_name, streaming=True)
elif llm_name == "gpt-3.5":
logger.info("LLM: Using GPT-3.5")
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
Expand All @@ -68,6 +76,14 @@ def load_llm(llm_name: str, logger=BaseLogger(), config={}):
model_kwargs={"temperature": 0.0, "max_tokens_to_sample": 1024},
streaming=True,
)
elif llm_name.startswith(AWS_MODELS):
logger.info(f"LLM: {llm_name}")
return ChatBedrock(
model_id=llm_name,
model_kwargs={"temperature": 0.0, "max_tokens_to_sample": 1024},
streaming=True,
)

elif len(llm_name):
logger.info(f"LLM: Using Ollama: {llm_name}")
return ChatOllama(
Expand Down
2 changes: 1 addition & 1 deletion env.example
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#*****************************************************************
# LLM and Embedding Model
#*****************************************************************
LLM=llama2 #or any Ollama model tag, gpt-4, gpt-3.5, or claudev2
LLM=llama2 #or any Ollama model tag, gpt-4 (o or turbo), gpt-3.5, or any bedrock model
EMBEDDING_MODEL=sentence_transformer #or google-genai-embedding-001 openai, ollama, or aws

#*****************************************************************
Expand Down
10 changes: 9 additions & 1 deletion pull_model.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ COPY <<EOF pull_model.clj
(let [llm (get (System/getenv) "LLM")
url (get (System/getenv) "OLLAMA_BASE_URL")]
(println (format "pulling ollama model %s using %s" llm url))
(if (and llm url (not (#{"gpt-4" "gpt-3.5" "claudev2"} llm)))
(if (and llm
url
(not (#{"gpt-4" "gpt-3.5" "claudev2" "gpt-4o" "gpt-4-turbo"} llm))
(not (some #(.startsWith llm %) ["ai21.jamba-instruct-v1:0"
"amazon.titan"
"anthropic.claude"
"cohere.command"
"meta.llama"
"mistral.mi"])))

;; ----------------------------------------------------------------------
;; just call `ollama pull` here - create OLLAMA_HOST from OLLAMA_BASE_URL
Expand Down

0 comments on commit caec526

Please sign in to comment.