Skip to content

Commit

Permalink
fix: utilised ContextSource class to have more cohesivness between s…
Browse files Browse the repository at this point in the history
…ource and chunks

Signed-off-by: Kannav02 <[email protected]>
  • Loading branch information
Kannav02 committed Jan 28, 2025
1 parent 75f0163 commit b9e94c1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
69 changes: 42 additions & 27 deletions backend/src/api/routers/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,30 +111,37 @@ async def get_hybrid_response(user_input: UserInput) -> ChatResponse:
user_question = user_input.query
result = hybrid_llm_chain.invoke(user_question)

links = []
context = []
context_sources = []
for i in result["context"]:
if "url" in i.metadata:
links.append(i.metadata["url"])
context_sources.append(ContextSource(
context=i.page_content,
source=i.metadata["url"]
))
elif "source" in i.metadata:
links.append(i.metadata["source"])
context.append(i.page_content)

links = list(set(links))
links = list(set(links))
context_sources.append(ContextSource(
context=i.page_content,
source=i.metadata["source"]
))

if user_input.list_sources and user_input.list_context:
response = {
response = {
"response": result["answer"],
"sources": (links),
"context": (context),
"context_sources": context_sources
}

elif user_input.list_sources:
response = {"response": result["answer"], "sources": (links)}
response = {
"response": result["answer"],
"context_sources": [ContextSource(context="", source=cs.source) for cs in context_sources]
}
elif user_input.list_context:
response = {"response": result["answer"], "context": (context)}
response = {
"response": result["answer"],
"context_sources": [ContextSource(context=cs.context, source="") for cs in context_sources]
}
else:
response = {"response": result["answer"]}
response = {"response": result["answer"],"context_sources": []}

return ChatResponse(**response)

Expand All @@ -160,29 +167,37 @@ async def get_sim_response(user_input: UserInput) -> ChatResponse:
user_question = user_input.query
result = sim_llm_chain.invoke(user_question)

links = []
context = []
context_sources = []
for i in result["context"]:
if "url" in i.metadata:
links.append(i.metadata["url"])
context_sources.append(ContextSource(
context=i.page_content,
source=i.metadata["url"]
))
elif "source" in i.metadata:
links.append(i.metadata["source"])
context.append(i.page_content)

links = list(set(links))
context_sources.append(ContextSource(
context=i.page_content,
source=i.metadata["source"]
))

if user_input.list_sources and user_input.list_context:
response = {
response = {
"response": result["answer"],
"sources": (links),
"context": (context),
"context_sources": context_sources
}

elif user_input.list_sources:
response = {"response": result["answer"], "sources": (links)}
response = {
"response": result["answer"],
"context_sources": [ContextSource(context="", source=cs.source) for cs in context_sources]
}
elif user_input.list_context:
response = {"response": result["answer"], "context": (context)}
response = {
"response": result["answer"],
"context_sources": [ContextSource(context=cs.context, source="") for cs in context_sources]
}
else:
response = {"response": result["answer"]}
response = {"response": result["answer"],"context_sources": []}

return ChatResponse(**response)

Expand Down
18 changes: 10 additions & 8 deletions backend/src/api/routers/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from starlette.responses import StreamingResponse

from ...agents.retriever_graph import RetrieverGraph
from ..models.response_model import ChatResponse, UserInput
from ..models.response_model import ChatResponse, ContextSource, UserInput

logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO").upper())
load_dotenv()
Expand Down Expand Up @@ -119,30 +119,32 @@ async def get_agent_response(user_input: UserInput) -> ChatResponse:
llm_response = output[-1]["generate"]["messages"][0]
tools = output[0]["agent"]["tools"]

urls = []
context = []
context_sources = []
tool_index = 1
for tool in tools:
urls.extend(list(output[tool_index].values())[0]["urls"])
context.append(list(output[tool_index].values())[0]["context"])
tool_index += 1

for url, context in zip(urls, [context]):
context_sources.append(ContextSource(context=context, source=url))
tool_index += 1
else:
llm_response = "LLM response extraction failed"
logging.error("LLM response extraction failed")

if user_input.list_sources and user_input.list_context:
response = {
"response": llm_response,
"sources": (urls),
"context": (context),
"context_sources": context_sources,
"tool": tools,
}
elif user_input.list_sources:
response = {"response": llm_response, "sources": (urls), "tool": tools}
response = {"response": llm_response, "context_sources":[ContextSource(context="", source=cs.source) for cs in context_sources], "tool": tools}
elif user_input.list_context:
response = {"response": llm_response, "context": (context), "tool": tools}
response = {"response": llm_response, "context_sources":[ContextSource(context=cs.context, source="") for cs in context_sources], "tool": tools}
else:
response = {"response": llm_response, "tool": tools}
response = {"response": llm_response,"context_sources":[ContextSource(context="", source="")], "tool": tools}

return ChatResponse(**response)

Expand Down

0 comments on commit b9e94c1

Please sign in to comment.