Skip to content

Commit

Permalink
Refactor query endpoint for Streaming Responses. (#53)
Browse files Browse the repository at this point in the history
* streaming response from the model at /query endpoint

Signed-off-by: SarveshAtawane <[email protected]>

* formatting error fix

Signed-off-by: SarveshAtawane <[email protected]>

---------

Signed-off-by: SarveshAtawane <[email protected]>
  • Loading branch information
SarveshAtawane authored Oct 10, 2024
1 parent f0a9fde commit 4d814c6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 77 deletions.
108 changes: 45 additions & 63 deletions src/core/conversation.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,75 @@
import torch
from transformers import pipeline
from langchain_huggingface import HuggingFacePipeline
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.vectorstores import Chroma
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from model import load_quantized_model
from tokenizer import initialize_tokenizer
from embeddings import embedding_function
from session_history import get_session_history
from utils import load_yaml_file
from session_history import get_session_history


def get_conversation():
def initialize_models():
config_data = load_yaml_file("config.yaml")

tokenizer = AutoTokenizer.from_pretrained(config_data["model_name"])
model = load_quantized_model(config_data["model_name"])

tokenizer = initialize_tokenizer(config_data["model_name"])

embeddings = embedding_function()

if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
embeddings = HuggingFaceEmbeddings(model_name=config_data["embedding_model_name"])
vectordb = Chroma(
embedding_function=embeddings,
persist_directory=config_data["persist_directory"],
)
return model, tokenizer, vectordb

# Retrieve and generate using the relevant snippets of the blog.
retriever = vectordb.as_retriever()

# build huggingface pipeline for using zephyr-7b-beta
llm_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
use_cache=True,
device_map="auto",
max_length=4096, # 4096
do_sample=True,
top_k=5,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)

# specify the llm
llm = HuggingFacePipeline(pipeline=llm_pipeline)
def retrieve_relevant_context(query, vectordb, top_k=3):
results = vectordb.similarity_search(query, k=top_k)
return "\n".join([doc.page_content for doc in results])


def contextualize_question(query, conversation_history):

# Contextualize question
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
contextualized_query = contextualize_q_prompt.format(
chat_history=conversation_history, input=query
)
return contextualized_query


def generate_response(session_id, model, tokenizer, query, vectordb):
conversation_history = get_session_history(session_id)
contextualized_query = contextualize_question(query, conversation_history.messages)

context = retrieve_relevant_context(contextualized_query, vectordb)

# Answer question
qa_system_prompt = """You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
Use five sentences maximum and keep the answer concise.\
Use five sentences maximum and keep the answer concise."""

{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
full_prompt = f"{qa_system_prompt}\n\nContext: {context}\n\nQuestion: {contextualized_query}\n\nAnswer:"
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True).to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
generation_kwargs = dict(
inputs, streamer=streamer, max_new_tokens=1000, do_sample=True, temperature=0.7
)

return conversational_rag_chain
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

response = ""
for token in streamer:
yield token
response += token

conversation_history.add_user_message(query)
conversation_history.add_ai_message(response)
3 changes: 2 additions & 1 deletion src/core/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ chromadb
bitsandbytes
accelerate
uvicorn
black
black
sse-starlette
32 changes: 19 additions & 13 deletions src/core/routes/main.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Request
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from conversation import get_conversation
from conversation import initialize_models, generate_response

conversational_rag_chain = get_conversation()
model, tokenizer, vectordb = initialize_models()


# define the Query class that contains the question
class Query(BaseModel):
text: str


# Initialization of the router :
router = APIRouter()


# reply to POST requests: '{"text": "How to install Hyperledger fabric?"}'
@router.post("/query")
def answer(q: Query):
async def stream_answer(q: Query, request: Request):
question = q.text
ai_msg_1 = conversational_rag_chain.invoke(
{"input": question},
config={"configurable": {"session_id": "1"}},
)["answer"]

return {"msg": ai_msg_1}
session_id = "1"

async def event_generator():
try:
for token in generate_response(
session_id, model, tokenizer, question, vectordb
):
if await request.is_disconnected():
break
yield {"data": token}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

return EventSourceResponse(event_generator())

0 comments on commit 4d814c6

Please sign in to comment.