Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed DETA db support, due to shutdown of the app #12

Merged
merged 8 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ h11==0.14.0
httpcore==0.17.3
httpx==0.24.1
httpx-oauth==0.13.0
huggingface-hub==0.24.2
huggingface-hub==0.23.4
idna==3.7
importlib-metadata==6.11.0
iniconfig==2.0.0
Expand All @@ -35,9 +35,9 @@ jsonpatch==1.33
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
langchain==0.2.11
langchain-community==0.2.10
langchain-core==0.2.24
langchain==0.2.14
langchain-core==0.2.32
langchain-community>=0.0.37
langchain-huggingface==0.0.3
langchain-text-splitters==0.2.2
langsmith==0.1.93
Expand Down Expand Up @@ -93,7 +93,7 @@ six==1.16.0
smmap==5.0.1
sniffio==1.3.1
SQLAlchemy==2.0.31
streamlit==1.28.0
streamlit==1.36.0
streamlit-oauth==0.1.5
sympy==1.13.1
tenacity==8.5.0
Expand Down
190 changes: 75 additions & 115 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
import streamlit as st
from deta import Deta
import sys
import os
import json
from backend import (
configure_page_styles,
create_oauth2_component,
display_github_badge,
handle_google_login_if_needed,
hide_main_menu_and_footer,
)
from frontend import (
Expand All @@ -19,129 +17,91 @@
handle_new_chat,
)
from model import create_huggingface_hub


sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from src.auth import *
from src.constant import *

def format_chat_history(messages):
"""Format the chat history as a structured JSON string."""
history = []
for msg in messages[1:]:
content = msg['content']
if '```sql' in content:
content = content.replace('```sql\n', '').replace('\n```', '').strip()

history.append({
"role": msg['role'],
"query" if msg['role'] == 'user' else "response": content
})

formatted_history = json.dumps(history, indent=2)
print("Formatted history:", formatted_history)
return formatted_history

def extract_sql_code(response):
"""Extract clean SQL code from the response."""
sql_code_start = response.find("```sql")
if sql_code_start != -1:
sql_code_end = response.find("```", sql_code_start + 5)
if sql_code_end != -1:
sql_code = response[sql_code_start + 6:sql_code_end].strip()
return f"```sql\n{sql_code}\n```"
return response

def main():
"""Main function to configure and run the Querypls application."""
configure_page_styles("static/css/styles.css")
deta = Deta(DETA_PROJECT_KEY)

if "model" not in st.session_state:
llm = create_huggingface_hub()
st.session_state["model"] = llm
db = deta.Base("users")
oauth2 = create_oauth2_component()

if "code" not in st.session_state or not st.session_state.code:
st.session_state.code = False

if "code" not in st.session_state:
st.session_state.code = False


if "messages" not in st.session_state:
create_message()

hide_main_menu_and_footer()
if st.session_state.code == False:
col1, col2, col3 = st.columns(3)
with col1:
pass
with col2:
with st.container():

display_github_badge()
display_logo_and_heading()

st.markdown("`Made with 🤍`")
if "token" not in st.session_state:
result = oauth2.authorize_button(
"Connect with Google",
REDIRECT_URI,
SCOPE,
icon="data:image/svg+xml;charset=utf-8,%3Csvg \
xmlns='http://www.w3.org/2000/svg' \
xmlns:xlink='http://www.w3.org/1999/xlink' \
viewBox='0 0 48 48'%3E%3Cdefs%3E%3Cpath id='a' \
d='M44.5 20H24v8.5h11.8C34.7 33.9 30.1 37 24 37c-7.2 \
0-13-5.8-13-13s5.8-13 13-13c3.1 0 5.9 1.1 8.1 \
2.9l6.4-6.4C34.6 4.1 29.6 2 24 2 11.8 2 2 11.8 2 \
24s9.8 22 22 22c11 0 21-8 21-22 \
0-1.3-.2-2.7-.5-4z'/%3E%3C/defs%3E%3CclipPath \
id='b'%3E%3Cuse xlink:href='%23a' \
overflow='visible'/%3E%3C/clipPath%3E%3Cpath \
clip-path='url(%23b)' fill='%23FBBC05' \
d='M0 37V11l17 13z'/%3E%3Cpath clip-path='url(%23b)' \
fill='%23EA4335' d='M0 11l17 13 7-6.1L48 \
14V0H0z'/%3E%3Cpath clip-path='url(%23b)' \
fill='%2334A853' d='M0 37l30-23 7.9 1L48 \
0v48H0z'/%3E%3Cpath clip-path='url(%23b)' \
fill='%234285F4' d='M48 48L17 24l-4-3 \
35-10z'/%3E%3C/svg%3E",
use_container_width=True,
)
handle_google_login_if_needed(result)
if st.session_state.code:
st.rerun()
with col3:
pass
else:
with st.sidebar:
display_github_badge()
display_logo_and_heading()
st.markdown("`Made with 🤍`")
if st.session_state.code:
handle_new_chat(db)
if st.session_state.code:
display_previous_chats(db)

if "messages" not in st.session_state:
create_message()
display_welcome_message()

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"], unsafe_allow_html=True)

if prompt := st.chat_input(disabled=(st.session_state.code is False)):
st.session_state.messages.append(
{"role": "user", "content": prompt}
)
with st.chat_message("user"):
st.write(prompt)

prompt_template = PromptTemplate(
template=TEMPLATE, input_variables=["question"]
)

if "model" in st.session_state:
llm_chain = (
prompt_template
| st.session_state.model
| StrOutputParser()
)
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Generating..."):
response = llm_chain.invoke(prompt)
import re

code_block_match = re.search(
r"```sql(.*?)```", response, re.DOTALL
)
if code_block_match:
code_block = code_block_match.group(1)
st.markdown(
f"```sql\n{code_block}\n```",
unsafe_allow_html=True,
)
message = {
"role": "assistant",
"content": f"```sql\n{code_block}\n```",
}
st.session_state.messages.append(message)


with st.sidebar:
display_github_badge()
display_logo_and_heading()
st.markdown("`Made with 🤍`")
handle_new_chat()

display_welcome_message()
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)

conversation_history = format_chat_history(st.session_state.messages)
prompt_template = PromptTemplate(
template=TEMPLATE,
input_variables=["input", "conversation_history"]
)

if "model" in st.session_state:
llm_chain = prompt_template | st.session_state.model | StrOutputParser()

with st.chat_message("assistant"):
with st.spinner("Generating..."):
response = llm_chain.invoke({
"input": prompt,
"conversation_history": conversation_history
})

# Clean and format the response
formatted_response = extract_sql_code(response)
st.markdown(formatted_response)

# Add to chat history
st.session_state.messages.append({
"role": "assistant",
"content": formatted_response
})

if __name__ == "__main__":
main()
main()
74 changes: 38 additions & 36 deletions src/frontend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import streamlit as st
from src.database import database, get_previous_chats


def display_logo_and_heading():
"""Displays the Querypls logo."""
Expand All @@ -14,77 +12,81 @@ def display_welcome_message():
st.markdown(f"#### Welcome to \n ## 🗃️💬Querypls - Prompt to SQL")


def handle_new_chat(db, max_chat_histories=5):
def handle_new_chat(max_chat_histories=5):
"""Handles the initiation of a new chat session.

Displays the remaining chat history count and provides a button to start a new chat.

Args:
db: Deta Base instance.
max_chat_histories (int, optional): Maximum number of chat histories to retain.

Returns:
None
"""
remaining_chats = max_chat_histories - len(
get_previous_chats(db, st.session_state.user_email)
)
remaining_chats = max_chat_histories - len(st.session_state.get("previous_chats", []))
st.markdown(
f" #### Remaining Chat Histories: \
`{remaining_chats}/{max_chat_histories}`"
f" #### Remaining Chat Histories: `{remaining_chats}/{max_chat_histories}`"
)
st.markdown(
"You can create up to 5 chat histories. Each history \
can contain unlimited messages."
"You can create up to 5 chat histories. Each history can contain unlimited messages."
)

if st.button("➕ New chat"):
database(db, previous_key=st.session_state.key)
save_chat_history() # Save current chat before creating a new one
create_message()


def display_previous_chats(db):
"""Displays previous chat records.
def display_previous_chats():
"""Displays previous chat records stored in session state.

Retrieves and displays a list of previous chat records for the user.
Allows the user to select a chat to view.

Args:
db: Deta Base instance.

Returns:
None
"""
previous_chats = get_previous_chats(db, st.session_state.user_email)
reversed_chats = reversed(previous_chats)
if "previous_chats" in st.session_state:
reversed_chats = reversed(st.session_state["previous_chats"])

for chat in reversed_chats:
if st.button(chat["title"], key=chat["key"]):
update_session_state(db, chat)
for chat in reversed_chats:
if st.button(chat["title"], key=chat["key"]):
update_session_state(chat)


def create_message():
"""Creates a default assistant message and initializes a session key."""

st.session_state["messages"] = [
{"role": "assistant", "content": "How may I help you?"}
]
st.session_state["key"] = "key"
return


def update_session_state(db, chat):
def update_session_state(chat):
"""Updates the session state with selected chat information.

Args:
db: Deta Base instance.
chat (dict): Selected chat information.

Returns:
None
"""
previous_chat = st.session_state["messages"]
previous_key = st.session_state["key"]
st.session_state["messages"] = chat["chat"]
st.session_state["key"] = chat["key"]
database(db, previous_key, previous_chat)


def save_chat_history():
"""Saves the current chat to session state if it contains messages."""
if "messages" in st.session_state and len(st.session_state["messages"]) > 1:
# Initialize previous chats list if it doesn't exist
if "previous_chats" not in st.session_state:
st.session_state["previous_chats"] = []

# Create a chat summary to store in session
title = st.session_state["messages"][1]["content"]
chat_summary = {
"title": title[:25] + "....." if len(title) > 25 else title,
"chat": st.session_state["messages"],
"key": f"chat_{len(st.session_state['previous_chats']) + 1}"
}

st.session_state["previous_chats"].append(chat_summary)

# Limit chat histories to a maximum number
if len(st.session_state["previous_chats"]) > 5:
st.session_state["previous_chats"].pop(0) # Remove oldest chat
st.warning(
f"The oldest chat history has been removed as you reached the limit of 5 chat histories."
)
2 changes: 1 addition & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ def create_huggingface_hub():
return HuggingFaceHub(
huggingfacehub_api_token=HUGGINGFACE_API_TOKEN,
repo_id=REPO_ID,
model_kwargs={"temperature": 0.2, "max_new_tokens": 180},
model_kwargs={"temperature": 0.7, "max_new_tokens": 180},
)
Loading