Skip to content

Commit

Permalink
🎉 Add SQLZilla class
Browse files Browse the repository at this point in the history
  • Loading branch information
henryhamon committed Jul 29, 2024
1 parent a0b165b commit e555050
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
30 changes: 19 additions & 11 deletions python/sqlzilla/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
from dotenv import load_dotenv
import os
from sqlzilla import SQLZilla

# Load environment variables from .env file
load_dotenv()
Expand All @@ -27,13 +28,15 @@
if 'query_result' not in st.session_state:
st.session_state.query_result = None

def db_connection():
def db_connection_str():
user = st.session_state.user
pwd = st.session_state.pwd
host = st.session_state.hostname
prt = st.session_state.port
ns = st.session_state.namespace
iris_conn_str = f"iris://{user}:{pwd}@{host}:{prt}/{ns}"
return f"iris://{user}:{pwd}@{host}:{prt}/{ns}"

def db_connection(iris_conn_str):
engine = create_engine(iris_conn_str)
return engine.connect().connection

Expand All @@ -48,13 +51,15 @@ def run_query():
except Exception as e:
st.error(f"Error running query: {str(e)}")

def assistant_interaction(sqlzilla, prompt):
response = sqlzilla.prompt(prompt)
st.session_state.chat_history.append({"role": "user", "content": prompt})
st.session_state.chat_history.append({"role": "assistant", "content": response})

# Function to simulate assistant interaction
def assistant_interaction(prompt):
# This is a placeholder. In a real scenario, you'd call an AI service here.
response = f"Assistant: I've analyzed your prompt '{prompt}'. Here's a suggested SQL query:\n\nSELECT * FROM Table WHERE condition = 'value';"
st.session_state.chat_history.append(("User", prompt))
st.session_state.chat_history.append(("Assistant", response))
# Check if the response contains SQL code and update the editor
if "SELECT" in response.upper():
st.session_state.query_result = response

return response

left_co, cent_co, last_co = st.columns(3)
Expand Down Expand Up @@ -89,7 +94,10 @@ def assistant_interaction(prompt):
# Initial prompts for namespace and database schema
database_schema = st.text_input('Enter Database Schema')

if st.session_state.namespace and database_schema:
if st.session_state.namespace and database_schema and st.session_state.openai_api_key:
sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key)
context = sqlzilla.schema_context_management(database_schema)

# Layout for the page
col1, col2 = st.columns(2)

Expand Down Expand Up @@ -122,11 +130,11 @@ def assistant_interaction(prompt):
# Add user message to chat history
st.session_state.chat_history.append({"role": "user", "content": prompt})

response = assistant_interaction(prompt)
response = assistant_interaction(sqlzilla, prompt)
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant response to chat history
st.session_state.chat_history.append({"role": "assistant", "content": response})
else:
st.warning('Please select a namespace and enter a database schema to proceed.')
st.warning('Please select a database schema to proceed.')
17 changes: 9 additions & 8 deletions python/sqlzilla/sqlzilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from langchain_iris import IRISVector

class SQLZilla:
def __init__(self, engine, cnx):
self.engine = engine
self.cnx = cnx
def __init__(self, connection_string, openai_api_key):
self.openai_api_key = openai_api_key
self.iris_conn_str = connection_string
self.engine = create_engine(connection_string)
self.cnx = self.engine.connect().connection
self.context = {}
self.context["top_k"] = 3
self.examples = [
Expand Down Expand Up @@ -83,7 +85,6 @@ def __init__(self, engine, cnx):
},
]


def get_table_definitions_array(self, schema, table=None):
cursor = self.cnx.cursor()

Expand Down Expand Up @@ -197,16 +198,16 @@ def schema_context_management(self, schema):
new_tables_docs, tables_docs_ids = self.filter_not_in_collection(
"sql_tables",
self.tables_docs,
self.get_ids_from_string_array([x.page_content for x in tables_docs])
self.get_ids_from_string_array([x.page_content for x in self.tables_docs])
)
self.tables_docs_ids = tables_docs_ids


def prompt(self, input):
db = IRISVector.from_documents(
embedding = OpenAIEmbeddings(),
embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key),
documents = self.tables_docs,
connection_string=iris_conn_str,
connection_string= self.iris_conn_str,
collection_name="sql_tables",
ids=self.tables_docs_ids
)
Expand Down Expand Up @@ -252,7 +253,7 @@ def prompt(self, input):
"input": input
})

model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=self.openai_api_key)
output_parser = StrOutputParser()
chain_model = prompt | model | output_parser
response = chain_model.invoke({
Expand Down

0 comments on commit e555050

Please sign in to comment.