Skip to content

Commit

Permalink
fix: made it compatible to ContextSources
Browse files Browse the repository at this point in the history
Signed-off-by: Kannav02 <[email protected]>
  • Loading branch information
Kannav02 committed Jan 30, 2025
1 parent b9e94c1 commit 28f3154
Showing 1 changed file with 37 additions and 70 deletions.
107 changes: 37 additions & 70 deletions frontend/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Callable, Any



def measure_response_time(func: Callable[..., Any]) -> Callable[..., tuple[Any, float]]:
def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, float]:
start_time = time.time()
Expand Down Expand Up @@ -43,6 +44,25 @@ def translate_chat_history_to_api(chat_history, max_pairs=4):
i -= 1
return api_format

def display_sources_context(context_sources:list[dict[str,str]]):

with st.expander("Sources and Context"):
try:
if context_sources:
for idx, cs in enumerate(context_sources, 1):
st.markdown(f"**Source {idx}**:")
if cs.get("source"):
st.markdown(f"[{cs['source']}]({cs['source']})")
if cs.get("context"):
st.markdown(f"**Related Context:**\n> {cs['context']}")
st.markdown("---") # Separator between source-context pairs
else:
st.markdown("No Sources or Context Available.")
except (ValueError, SyntaxError) as e:
st.markdown(f"Failed to parse sources: {e}")




@measure_response_time
def response_generator(user_input: str) -> tuple[str, str] | tuple[None, None]:
Expand Down Expand Up @@ -71,12 +91,12 @@ def response_generator(user_input: str) -> tuple[str, str] | tuple[None, None]:
if not isinstance(data, dict):
st.error("Invalid response format")
return None, None
sources = data.get("sources", "")
context_sources = data.get("context_sources", [])
st.session_state.metadata[user_input] = {
"sources": sources,
"context": data.get("context", ""),
"context_sources": context_sources,

}
return data.get("response", ""), sources
return data.get("response", ""), context_sources
except requests.exceptions.RequestException as e:
st.error(f"Request failed: {e}")
return None, None
Expand All @@ -93,7 +113,7 @@ def main() -> None:
st.title("OR Assistant")

base_url = os.getenv("CHAT_ENDPOINT", "http://localhost:8000")
selected_endpoint = "/graphs/agent-retriever"
selected_endpoint = "/chains/mock"

if "selected_endpoint" not in st.session_state:
st.session_state.selected_endpoint = selected_endpoint
Expand All @@ -110,8 +130,6 @@ def main() -> None:
st.session_state.chat_history = []
if "metadata" not in st.session_state:
st.session_state.metadata = {}
if "sources" not in st.session_state:
st.session_state.sources = {}

if not st.session_state.chat_history:
st.session_state.chat_history.append(
Expand All @@ -129,33 +147,8 @@ def main() -> None:
user_message = st.session_state.chat_history[idx - 1]
if user_message["role"] == "user":
user_input = user_message["content"]
sources = st.session_state.sources.get(user_input)
with st.expander("Sources:"):
try:
if sources:
if isinstance(sources, str):
cleaned_sources = sources.replace("{", "[").replace(
"}", "]"
)
parsed_sources = ast.literal_eval(cleaned_sources)
else:
parsed_sources = sources
if (
isinstance(parsed_sources, (list, set))
and parsed_sources
):
sources_list = "\n".join(
f"- [{link}]({link})"
for link in parsed_sources
if link.strip()
)
st.markdown(sources_list)
else:
st.markdown("No Sources Attached.")
else:
st.markdown("No Sources Attached.")
except (ValueError, SyntaxError) as e:
st.markdown(f"Failed to parse sources: {e}")
context_sources = st.session_state.metadata.get(user_input,{}).get("context_sources",[])
display_sources_context(context_sources)

user_input = st.chat_input("Enter your queries ...")

Expand All @@ -173,7 +166,7 @@ def main() -> None:
and isinstance(response_tuple, tuple)
and len(response_tuple) == 2
):
response, sources = response_tuple
response, context_sources = response_tuple
if response is not None:
response_buffer = response

Expand All @@ -199,35 +192,8 @@ def main() -> None:
"role": "ai",
}
)
display_sources_context(context_sources)

st.session_state.sources[user_input] = sources

with st.expander("Sources:"):
try:
if sources:
if isinstance(sources, str):
cleaned_sources = sources.replace("{", "[").replace(
"}", "]"
)
parsed_sources = ast.literal_eval(cleaned_sources)
else:
parsed_sources = sources
if (
isinstance(parsed_sources, (list, set))
and parsed_sources
):
sources_list = "\n".join(
f"- [{link}]({link})"
for link in parsed_sources
if link.strip()
)
st.markdown(sources_list)
else:
st.markdown("No Sources Attached.")
else:
st.markdown("No Sources Attached.")
except (ValueError, SyntaxError) as e:
st.markdown(f"Failed to parse sources: {e}")
else:
st.error("Invalid response from the API")

Expand Down Expand Up @@ -266,19 +232,20 @@ def update_state() -> None:
gen_ans = st.session_state.chat_history[-1][
"content"
] # Last AI response
sources = st.session_state.metadata.get(selected_question, {}).get(
"sources", ["N/A"]
)
context = st.session_state.metadata.get(selected_question, {}).get(
"context", ["N/A"]
)
metadata = st.session_state.metadata.get(selected_question, {})
context_sources = metadata.get("context_sources", [])

# Extract sources and contexts separately for feedback
sources = [cs["source"] for cs in context_sources if cs.get("source")]
contexts = [cs["context"] for cs in context_sources if cs.get("context")]

reaction = "upvote" if thumbs_up else "downvote"

submit_feedback_to_google_sheet(
question=selected_question,
answer=gen_ans,
sources=sources if isinstance(sources, list) else [sources],
context=context if isinstance(context, list) else [context],
context=contexts if isinstance(contexts, list) else [contexts],
issue="", # Leave issue blank
version=os.getenv("RAG_VERSION", get_git_commit_hash()),
reaction=reaction, # Pass the reaction
Expand Down

0 comments on commit 28f3154

Please sign in to comment.