-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor query endpoint for Streaming Responses. (#53)
* streaming response from the model at /query endpoint Signed-off-by: SarveshAtawane <[email protected]> * formatting error fix Signed-off-by: SarveshAtawane <[email protected]> --------- Signed-off-by: SarveshAtawane <[email protected]>
- Loading branch information
1 parent
f0a9fde
commit 4d814c6
Showing
3 changed files
with
66 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,93 +1,75 @@ | ||
import torch | ||
from transformers import pipeline | ||
from langchain_huggingface import HuggingFacePipeline | ||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
from langchain_core.runnables.history import RunnableWithMessageHistory | ||
from langchain_community.vectorstores import Chroma | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | ||
from threading import Thread | ||
from langchain.vectorstores import Chroma | ||
from langchain.embeddings import HuggingFaceEmbeddings | ||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from model import load_quantized_model | ||
from tokenizer import initialize_tokenizer | ||
from embeddings import embedding_function | ||
from session_history import get_session_history | ||
from utils import load_yaml_file | ||
from session_history import get_session_history | ||
|
||
|
||
def get_conversation(): | ||
def initialize_models(): | ||
config_data = load_yaml_file("config.yaml") | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(config_data["model_name"]) | ||
model = load_quantized_model(config_data["model_name"]) | ||
|
||
tokenizer = initialize_tokenizer(config_data["model_name"]) | ||
|
||
embeddings = embedding_function() | ||
|
||
if tokenizer.pad_token_id is None: | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
embeddings = HuggingFaceEmbeddings(model_name=config_data["embedding_model_name"]) | ||
vectordb = Chroma( | ||
embedding_function=embeddings, | ||
persist_directory=config_data["persist_directory"], | ||
) | ||
return model, tokenizer, vectordb | ||
|
||
# Retrieve and generate using the relevant snippets of the blog. | ||
retriever = vectordb.as_retriever() | ||
|
||
# build huggingface pipeline for using zephyr-7b-beta | ||
llm_pipeline = pipeline( | ||
"text-generation", | ||
model=model, | ||
tokenizer=tokenizer, | ||
use_cache=True, | ||
device_map="auto", | ||
max_length=4096, # 4096 | ||
do_sample=True, | ||
top_k=5, | ||
num_return_sequences=1, | ||
eos_token_id=tokenizer.eos_token_id, | ||
pad_token_id=tokenizer.eos_token_id, | ||
) | ||
|
||
# specify the llm | ||
llm = HuggingFacePipeline(pipeline=llm_pipeline) | ||
def retrieve_relevant_context(query, vectordb, top_k=3): | ||
results = vectordb.similarity_search(query, k=top_k) | ||
return "\n".join([doc.page_content for doc in results]) | ||
|
||
|
||
def contextualize_question(query, conversation_history): | ||
|
||
# Contextualize question | ||
contextualize_q_system_prompt = """Given a chat history and the latest user question \ | ||
which might reference context in the chat history, formulate a standalone question \ | ||
which can be understood without the chat history. Do NOT answer the question, \ | ||
just reformulate it if needed and otherwise return it as is.""" | ||
contextualize_q_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", contextualize_q_system_prompt), | ||
MessagesPlaceholder("chat_history"), | ||
("human", "{input}"), | ||
] | ||
) | ||
history_aware_retriever = create_history_aware_retriever( | ||
llm, retriever, contextualize_q_prompt | ||
contextualized_query = contextualize_q_prompt.format( | ||
chat_history=conversation_history, input=query | ||
) | ||
return contextualized_query | ||
|
||
|
||
def generate_response(session_id, model, tokenizer, query, vectordb): | ||
conversation_history = get_session_history(session_id) | ||
contextualized_query = contextualize_question(query, conversation_history.messages) | ||
|
||
context = retrieve_relevant_context(contextualized_query, vectordb) | ||
|
||
# Answer question | ||
qa_system_prompt = """You are an assistant for question-answering tasks. \ | ||
Use the following pieces of retrieved context to answer the question. \ | ||
If you don't know the answer, just say that you don't know. \ | ||
Use five sentences maximum and keep the answer concise.\ | ||
Use five sentences maximum and keep the answer concise.""" | ||
|
||
{context}""" | ||
qa_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", qa_system_prompt), | ||
MessagesPlaceholder("chat_history"), | ||
("human", "{input}"), | ||
] | ||
full_prompt = f"{qa_system_prompt}\n\nContext: {context}\n\nQuestion: {contextualized_query}\n\nAnswer:" | ||
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True).to(model.device) | ||
streamer = TextIteratorStreamer( | ||
tokenizer, skip_prompt=True, skip_special_tokens=True | ||
) | ||
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | ||
|
||
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | ||
|
||
conversational_rag_chain = RunnableWithMessageHistory( | ||
rag_chain, | ||
get_session_history, | ||
input_messages_key="input", | ||
history_messages_key="chat_history", | ||
output_messages_key="answer", | ||
generation_kwargs = dict( | ||
inputs, streamer=streamer, max_new_tokens=1000, do_sample=True, temperature=0.7 | ||
) | ||
|
||
return conversational_rag_chain | ||
thread = Thread(target=model.generate, kwargs=generation_kwargs) | ||
thread.start() | ||
|
||
response = "" | ||
for token in streamer: | ||
yield token | ||
response += token | ||
|
||
conversation_history.add_user_message(query) | ||
conversation_history.add_ai_message(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,5 @@ chromadb | |
bitsandbytes | ||
accelerate | ||
uvicorn | ||
black | ||
black | ||
sse-starlette |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,32 @@ | ||
from fastapi import APIRouter, HTTPException | ||
from fastapi import APIRouter, HTTPException, Request | ||
from sse_starlette.sse import EventSourceResponse | ||
from pydantic import BaseModel | ||
from conversation import get_conversation | ||
from conversation import initialize_models, generate_response | ||
|
||
conversational_rag_chain = get_conversation() | ||
model, tokenizer, vectordb = initialize_models() | ||
|
||
|
||
# define the Query class that contains the question | ||
class Query(BaseModel): | ||
text: str | ||
|
||
|
||
# Initialization of the router : | ||
router = APIRouter() | ||
|
||
|
||
# reply to POST requests: '{"text": "How to install Hyperledger fabric?"}' | ||
@router.post("/query") | ||
def answer(q: Query): | ||
async def stream_answer(q: Query, request: Request): | ||
question = q.text | ||
ai_msg_1 = conversational_rag_chain.invoke( | ||
{"input": question}, | ||
config={"configurable": {"session_id": "1"}}, | ||
)["answer"] | ||
|
||
return {"msg": ai_msg_1} | ||
session_id = "1" | ||
|
||
async def event_generator(): | ||
try: | ||
for token in generate_response( | ||
session_id, model, tokenizer, question, vectordb | ||
): | ||
if await request.is_disconnected(): | ||
break | ||
yield {"data": token} | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
return EventSourceResponse(event_generator()) |