Skip to content

Commit

Permalink
adding support for gpt4all, tgi, along with openai
Browse files Browse the repository at this point in the history
  • Loading branch information
doncamilom committed Feb 29, 2024
1 parent f6336c5 commit 067f65f
Showing 1 changed file with 59 additions and 18 deletions.
77 changes: 59 additions & 18 deletions chemcrow/agents/chemcrow.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,109 @@
import os
from dotenv import load_dotenv
from typing import Optional, Dict
from typing import Optional, Dict, Literal
import langchain
import nest_asyncio
from langchain import PromptTemplate, chains
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from pydantic import ValidationError
from rmrkl import ChatZeroShotAgent, RetryAgentExecutor

from langchain.llms import GPT4All

from .prompts import FORMAT_INSTRUCTIONS, QUESTION_PROMPT, REPHRASE_TEMPLATE, SUFFIX
from .tools import make_tools


def _make_llm(model, temp, verbose, api_key, max_tokens=1000, n_ctx=2048):
if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
def _make_llm(
model_type: Literal["openai", "tgi", "gpt4all"],
model_server_url: Optional[str],
verbose,
api_key,
**kwargs
):
if model_type == "openai":
load_dotenv()
try:
llm = langchain.chat_models.ChatOpenAI(
temperature=temp,
model_name=model,
temperature=kwargs['temp'],
model_name=kwargs['model'],
request_timeout=1000,
streaming=True if verbose else False,
callbacks=[StreamingStdOutCallbackHandler()] if verbose else [None],
openai_api_key = api_key
)
except:
raise ValueError("Invalid OpenAI API key")
elif os.path.exists(model):
ext = os.path.splitext(model)[-1].lower()
if ext == ".gguf":
# If GPT4All style weights
llm = GPT4All(model=model, max_tokens=max_tokens, verbose=False)

elif model_type == "tgi":
from langchain.llms import HuggingFaceTextGenInference
llm = HuggingFaceTextGenInference(
inference_server_url=model_server_url,
max_new_tokens=kwargs['max_tokens'],
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=kwargs['temp'],
repetition_penalty=1.03,
)

elif model_type == "gpt4all":
from langchain.llms import GPT4All
model = kwargs['model']
if isinstance(model, str):
if os.path.exists(model) and model.endswith(".gguf"):
llm = GPT4All(
model=model,
max_tokens=kwargs['max_tokens'],
temp=kwargs['temp'],
verbose=False
)
else:
raise ValueError(f"Couldn't load model. Only models with .gguf format are suported currently.")
else:
raise ValueError(f"Found file: {model}, however only models with .gguf format are suported currently.")
else:
raise ValueError(f"Invalid model name: {model}")
raise ValueError(f"Invalid model name: {model}")
return llm



class ChemCrow:
def __init__(
self,
model_type = 'openai',
model_server_url: Optional[str] = None,
tools=None,
model="gpt-4-0613",
tools_model="gpt-3.5-turbo-0613",
temp=0.1,
max_tokens: int = 4096,
max_iterations=40,
verbose=True,
streaming: bool = True,
openai_api_key: str = '',
api_keys: Dict[str, str] = {},
max_tokens: int = 1000, # Not required for using OpenAI's API
n_ctx: int = 2048
):
"""Initialize ChemCrow agent."""

self.llm = _make_llm(model, temp, verbose, openai_api_key, max_tokens, n_ctx)
self.llm = _make_llm(
model_type,
model_server_url,
verbose,
openai_api_key,
model=model,
max_tokens=max_tokens,
temp=temp
)

if tools is None:
api_keys["OPENAI_API_KEY"] = openai_api_key
tools_llm = _make_llm(tools_model, temp, verbose, openai_api_key, max_tokens, n_ctx)
tools_llm = _make_llm(
model_type,
model_server_url,
verbose,
openai_api_key,
model=model,
max_tokens=max_tokens,
temp=temp
)
tools = make_tools(
tools_llm,
api_keys = api_keys,
Expand Down

0 comments on commit 067f65f

Please sign in to comment.