diff --git a/llama_hub/tools/library.json b/llama_hub/tools/library.json index ae6e2ec54a..80b70bd901 100644 --- a/llama_hub/tools/library.json +++ b/llama_hub/tools/library.json @@ -114,6 +114,10 @@ "id": "tools/vector_db", "author": "jerryjliu" }, + "WaiiToolSpec": { + "id": "tools/waii", + "author": "wangdatan" + }, "WikipediaToolSpec": { "id": "tools/wikipedia", "author": "ajhofmann" diff --git a/llama_hub/tools/notebooks/waii.ipynb b/llama_hub/tools/notebooks/waii.ipynb new file mode 100644 index 0000000000..60e4006755 --- /dev/null +++ b/llama_hub/tools/notebooks/waii.ipynb @@ -0,0 +1,517 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "7f0195c2-4a20-488e-8782-ca5a83488d0d", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_hub.tools.waii import WaiiToolSpec\n", + "\n", + "waii_tool = WaiiToolSpec(\n", + " url=\"https://tweakit.waii.ai/api/\",\n", + " # API Key of Waii (not OpenAI API key)\n", + " api_key=\"3a...\",\n", + " # Which database you want to use, you need add the db connection to Waii first\n", + " database_key=\"snowflake://...\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0a79a9fa-e5ff-4242-99a2-08cc85e158a9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"The table 'TABLES' contains the most columns. The top 5 tables with the number of columns are 'TABLES' with 25 columns, followed by 'SHOW' with 5 columns.\"" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from llama_index import VectorStoreIndex\n", + "\n", + "# Use as Data Loader, load data to index and query it\n", + "documents = waii_tool.load_data('Get all tables with their number of columns')\n", + "index = VectorStoreIndex.from_documents(documents).as_query_engine()\n", + "\n", + "index.query('Which table contains most columns, tell me top 5 tables with number of columns?').response" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b259d9cd-bbb8-4fff-a4ce-80fb0f3a1a10", + "metadata": {}, + "outputs": [], + "source": [ + "# Use as tool, initialize it\n", + "from llama_index.agent import OpenAIAgent\n", + "from llama_index.llms import OpenAI\n", + "\n", + "agent = OpenAIAgent.from_tools(waii_tool.to_tool_list(), llm=OpenAI(model='gpt-4-1106-preview'), verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "094b7878-59d6-4f12-b357-4f0d254953da", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "STARTING TURN 1\n", + "---------------\n", + "\n", + "=== Calling Function ===\n", + "Calling function: get_answer with args: {\"ask\":\"What are the top 3 countries with the highest number of car factories?\"}\n", + "Got output: Japan, Germany, and the USA.\n", + "========================\n", + "\n", + "STARTING TURN 2\n", + "---------------\n", + "\n", + "The top 3 countries with the highest number of car factories are Japan, Germany, and the USA.\n" + ] + } + ], + "source": [ + "# Ask simple questions\n", + "print(agent.chat(\"Give me top 3 countries with the most number of car factory\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "233deb3d-547b-49a2-89a4-28fa1abea9dc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "STARTING TURN 1\n", + "---------------\n", + "\n", + "=== Calling Function ===\n", + "Calling function: get_answer with args: {\"ask\": \"What are the car factories in Japan?\"}\n", + "Got output: Nissan Motors, Honda, Mazda, Subaru, and Toyota are car factories in Japan.\n", + "========================\n", + "\n", + "=== Calling Function ===\n", + "Calling function: get_answer with args: {\"ask\": \"What are the car factories in Germany?\"}\n", + "Got output: Volkswagen, BMW, Daimler Benz, and Opel are car factories in Germany.\n", + "========================\n", + "\n", + "=== Calling Function ===\n", + "Calling function: get_answer with args: {\"ask\": \"What are the car factories in the USA?\"}\n", + "Got output: amc, gm, ford, and chrysler are car factories in the USA.\n", + "========================\n", + "\n", + "STARTING TURN 2\n", + "---------------\n", + "\n", + "Here are some of the car factories in the top 3 countries with the most number of car factories:\n", + "\n", + "- In Japan: Nissan Motors, Honda, Mazda, Subaru, and Toyota.\n", + "- In Germany: Volkswagen, BMW, Daimler Benz, and Opel.\n", + "- In the USA: AMC, GM, Ford, and Chrysler.\n" + ] + } + ], + "source": [ + "print(agent.chat(\"What are the car factories of these countries\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "90c2ba4d-6ac4-4cbb-93b0-e03a8d015042", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "STARTING TURN 1\n", + "---------------\n", + "\n", + "=== Calling Function ===\n", + "Calling function: get_answer with args: {\"ask\":\"What are the top 3 longest running queries and their durations?\"}\n", + "Got output: The top 3 longest running queries and their durations are as follows:\n", + "1. Query ID: 01b08971-0001-e7a0-0022-ba8700c28122, Duration: 365143 milliseconds\n", + "2. Query ID: 01b08aca-0001-e830-0022-ba8700c2a09e, Duration: 190413 milliseconds\n", + "3. Query ID: 01b08ac4-0001-e7ed-0022-ba8700c2951a, Duration: 170837 milliseconds\n", + "========================\n", + "\n", + "STARTING TURN 2\n", + "---------------\n", + "\n", + "The top 3 longest running queries and their respective durations are:\n", + "\n", + "1. Query ID: 01b08971-0001-e7a0-0022-ba8700c28122, with a duration of 365,143 milliseconds.\n", + "2. Query ID: 01b08aca-0001-e830-0022-ba8700c2a09e, with a duration of 190,413 milliseconds.\n", + "3. Query ID: 01b08ac4-0001-e7ed-0022-ba8700c2951a, with a duration of 170,837 milliseconds.\n" + ] + } + ], + "source": [ + "# Do performance analysis\n", + "print(agent.chat(\"Give me top 3 longest running queries, and their duration.\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0706d166-d94f-40fc-adae-15b1379f501b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "STARTING TURN 1\n", + "---------------\n", + "\n", + "=== Calling Function ===\n", + "Calling function: performance_analyze with args: {\"query_uuid\":\"01b08aca-0001-e830-0022-ba8700c2a09e\"}\n", + "Got output: {\n", + " \"summary\": [\n", + " \"The 'aggregate' operator with functions sum(sum(sum(sum_internal(ss.ss_sales_price, count(*))))) and grouping keys i.i_category, i.i_class, i.i_brand, i.i_product_name, d.d_year, d.d_qoy, d.d_moy, s.s_store_id is the most time-consuming, with an overall percentage of 0.31 of execution time and a significant amount of bytes spilled to local storage (4774297600 bytes).\",\n", + " \"The 'windowfunction' operator with function rank() over (partition by i.i_category order by sum(sum(sum(sum_internal(ss.ss_sales_price, count(*))))) desc nulls first) also contributes to the high execution time with an overall percentage of 0.2 and a large amount of bytes spilled to local storage (15166328832 bytes).\",\n", + " \"The 'tablescan' operator on the store_sales table is scanning a large amount of data (3514269696 bytes) without utilizing cache, which indicates a potential area for optimization.\"\n", + " ],\n", + " \"recommendations\": [\n", + " \"Consider pushing down relevant filters to the 'tablescan' operator on the store_sales table to reduce the bytes scanned. This could be done by analyzing the join conditions and where clause to identify any filters that can be applied earlier in the query execution.\",\n", + " \"Optimize the 'aggregate' operator by reviewing the necessity of all grouping keys. If some keys are not essential for the final result, removing them could reduce the complexity of aggregation and the amount of data spilled to local storage.\",\n", + " \"For the 'windowfunction' operator, ensure that the input size is minimized by optimizing preceding operations, such as 'aggregate', to reduce the data volume before ranking. This could involve pushing down filters or optimizing the grouping strategy.\"\n", + " ],\n", + " \"query_text\": \"WITH sales_data AS ( SELECT i.i_category, i.i_class, i.i_brand, i.i_product_name, d.d_year, d.d_qoy, d.d_moy, s.s_store_id, SUM(ss.ss_sales_price) AS total_sales FROM tweakit_perf_db.sample_tpcds.store_sales AS ss INNER JOIN tweakit_perf_db.sample_tpcds.item AS i ON ss.ss_item_sk = i.i_item_sk INNER JOIN tweakit_perf_db.sample_tpcds.date_dim AS d ON ss.ss_sold_date_sk = d.d_date_sk INNER JOIN tweakit_perf_db.sample_tpcds.store AS s ON ss.ss_store_sk = s.s_store_sk WHERE d.d_year = 2000 GROUP BY i.i_category, i.i_class, i.i_brand, i.i_product_name, d.d_year, d.d_qoy, d.d_moy, s.s_store_id ), ranked_data AS ( SELECT *, RANK() OVER (PARTITION BY i_category ORDER BY total_sales DESC) AS rank FROM sales_data ) SELECT * FROM ranked_data WHERE rank <= 100 ORDER BY i_category, i_class, i_brand, i_product_name, d_year, d_qoy, d_moy, s_store_id, total_sales, rank limit 100\",\n", + " \"execution_time_ms\": 190094,\n", + " \"compilation_time_ms\": 341\n", + "}\n", + "========================\n", + "\n", + "STARTING TURN 2\n", + "---------------\n", + "\n", + "The analysis of the 2nd-longest running query reveals the following insights:\n", + "\n", + "- The most time-consuming operator is the 'aggregate' operator, which uses functions like `sum()` and grouping keys such as `i.i_category`, `i.i_class`, `i.i_brand`, `i.i_product_name`, `d.d_year`, `d.d_qoy`, `d.d_moy`, and `s.s_store_id`. It accounts for 0.31 of the execution time and has spilled a significant amount of bytes (4,774,297,600 bytes) to local storage.\n", + "- The 'windowfunction' operator with the `rank()` function also contributes to the high execution time, with an overall percentage of 0.2 and a large amount of bytes spilled to local storage (15,166,328,832 bytes).\n", + "- The 'tablescan' operator on the `store_sales` table is scanning a large amount of data (3,514,269,696 bytes) without utilizing cache, indicating a potential area for optimization.\n", + "\n", + "Based on these findings, the recommendations for optimization are:\n", + "\n", + "1. Push down relevant filters to the 'tablescan' operator on the `store_sales` table to reduce the bytes scanned. This could be achieved by analyzing the join conditions and where clause to identify any filters that can be applied earlier in the query execution.\n", + "2. Optimize the 'aggregate' operator by reviewing the necessity of all grouping keys. Removing non-essential keys could reduce the complexity of aggregation and the amount of data spilled to local storage.\n", + "3. Minimize the input size for the 'windowfunction' operator by optimizing preceding operations, such as 'aggregate', to reduce the data volume before ranking. This could involve pushing down filters or optimizing the grouping strategy.\n", + "\n", + "The query text is as follows:\n", + "```sql\n", + "WITH sales_data AS (\n", + " SELECT i.i_category, i.i_class, i.i_brand, i.i_product_name, d.d_year, d.d_qoy, d.d_moy, s.s_store_id, SUM(ss.ss_sales_price) AS total_sales\n", + " FROM tweakit_perf_db.sample_tpcds.store_sales AS ss\n", + " INNER JOIN tweakit_perf_db.sample_tpcds.item AS i ON ss.ss_item_sk = i.i_item_sk\n", + " INNER JOIN tweakit_perf_db.sample_tpcds.date_dim AS d ON ss.ss_sold_date_sk = d.d_date_sk\n", + " INNER JOIN tweakit_perf_db.sample_tpcds.store AS s ON ss.ss_store_sk = s.s_store_sk\n", + " WHERE d.d_year = 2000\n", + " GROUP BY i.i_category, i.i_class, i.i_brand, i.i_product_name, d.d_year, d.d_qoy, d.d_moy, s.s_store_id\n", + "), ranked_data AS (\n", + " SELECT *, RANK() OVER (PARTITION BY i_category ORDER BY total_sales DESC) AS rank\n", + " FROM sales_data\n", + ")\n", + "SELECT *\n", + "FROM ranked_data\n", + "WHERE rank <= 100\n", + "ORDER BY i_category, i_class, i_brand, i_product_name, d_year, d_qoy, d_moy, s_store_id, total_sales, rank\n", + "LIMIT 100\n", + "```\n", + "\n", + "The execution time of the query was 190,094 milliseconds, and the compilation time was 341 milliseconds.\n" + ] + } + ], + "source": [ + "print(agent.chat(\"analyze the 2nd-longest running query\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "47530eba-24be-42d9-b1a0-fa1af28934f7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "STARTING TURN 1\n", + "---------------\n", + "\n", + "=== Calling Function ===\n", + "Calling function: diff_query with args: {\"previous_query\":\"SELECT\\n employee_id,\\n department,\\n salary,\\n AVG(salary) OVER (PARTITION BY department) AS department_avg_salary,\\n salary - AVG(salary) OVER (PARTITION BY department) AS diff_from_avg\\nFROM\\n employees;\",\"current_query\":\"SELECT\\n employee_id,\\n department,\\n salary,\\n MAX(salary) OVER (PARTITION BY department) AS department_max_salary,\\n salary - AVG(salary) OVER (PARTITION BY department) AS diff_from_avg\\nFROM\\n employees;\\nLIMIT 100;\"}\n", + "Got output: The query retrieves the employee ID, department, salary, maximum salary within the department, and the difference between the employee's salary and the average salary within the department. The result set is limited to 100 rows.\n", + "========================\n", + "\n", + "STARTING TURN 2\n", + "---------------\n", + "\n", + "The difference between the two queries is as follows:\n", + "\n", + "1. In the first query, the fourth column calculates the average salary for each department (`department_avg_salary`) using the `AVG()` window function partitioned by the department.\n", + "\n", + "2. In the second query, the fourth column calculates the maximum salary for each department (`department_max_salary`) using the `MAX()` window function partitioned by the department.\n", + "\n", + "Additionally, the second query has a `LIMIT 100` clause at the end, which restricts the result set to the first 100 rows.\n", + "\n", + "Here's a summary of the differences:\n", + "- The first query calculates the average salary per department, while the second query calculates the maximum salary per department.\n", + "- The second query limits the result set to 100 rows, whereas the first query does not impose any limit on the number of rows returned.\n" + ] + } + ], + "source": [ + "# Diff two queries\n", + "previous_query = \"\"\"\n", + "SELECT\n", + " employee_id,\n", + " department,\n", + " salary,\n", + " AVG(salary) OVER (PARTITION BY department) AS department_avg_salary,\n", + " salary - AVG(salary) OVER (PARTITION BY department) AS diff_from_avg\n", + "FROM\n", + " employees;\n", + "\"\"\"\n", + "current_query = \"\"\"\n", + "SELECT\n", + " employee_id,\n", + " department,\n", + " salary,\n", + " MAX(salary) OVER (PARTITION BY department) AS department_max_salary,\n", + " salary - AVG(salary) OVER (PARTITION BY department) AS diff_from_avg\n", + "FROM\n", + " employees;\n", + "LIMIT 100;\n", + "\"\"\"\n", + "print(agent.chat(f\"tell me difference between {previous_query} and {current_query}\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7a222df4-d00e-4fe8-be8e-9efdb43f1462", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "STARTING TURN 1\n", + "---------------\n", + "\n", + "=== Calling Function ===\n", + "Calling function: describe_dataset with args: {\"ask\":\"Can you summarize the dataset for me?\"}\n", + "Got output: The dataset consists of multiple schemas, each containing tables related to different domains. The schemas include \"FLIGHT\" which contains tables related to airlines, airports, and flights, \"STUDENT_TRANSCRIPTS_TRACKING\" which contains tables related to student information, courses, degree programs, departments, semesters, and transcripts, \"WORLD\" which contains tables related to cities, countries, and languages, \"INFORMATION_SCHEMA\" which provides information about the database objects and metadata in the WAII database, \"EMPLOYEE_HIRE_EVALUATION\" which contains tables related to employee information, evaluations, hiring evaluations, and shop details, and \"PETS\" which contains tables related to pet ownership and student information.\n", + "\n", + "Each schema has a summary describing its purpose and the type of information it contains. Additionally, there are common questions associated with each schema, which provide insights into the types of queries that can be answered using the dataset. The common tables within each schema are also described, providing information about the specific tables and their purpose within the schema.\n", + "\n", + "Overall, the dataset covers a wide range of domains including airlines, airports, flights, student records, courses, degrees, departments, cities, countries, languages, database metadata, employee information, evaluations, hiring evaluations, shop details, pet ownership, and student information. It can be used to analyze and generate reports on various aspects of these domains.\n", + "========================\n", + "\n", + "STARTING TURN 2\n", + "---------------\n", + "\n", + "The dataset is a comprehensive collection of information spread across multiple schemas, each dedicated to a specific domain:\n", + "\n", + "- **FLIGHT**: This schema includes tables related to airlines, airports, and flights, providing data that could be used for analyzing air travel and operations.\n", + "- **STUDENT_TRANSCRIPTS_TRACKING**: This schema contains tables with student information, courses, degree programs, departments, semesters, and transcripts, which can be used for educational administration and tracking academic progress.\n", + "- **WORLD**: This schema features tables related to cities, countries, and languages, offering a dataset for geographical and linguistic analysis.\n", + "- **INFORMATION_SCHEMA**: This schema provides metadata about the database objects within the WAII database, useful for database management and schema exploration.\n", + "- **EMPLOYEE_HIRE_EVALUATION**: This schema includes tables related to employee information, evaluations, hiring evaluations, and shop details, which can be used for human resources management and performance analysis.\n", + "- **PETS**: This schema contains tables related to pet ownership and student information, which could be used for studies on pet ownership demographics and related student activities.\n", + "\n", + "Each schema is accompanied by a summary that outlines its purpose and the nature of the data it holds. There are also common questions associated with each schema to guide users in understanding the types of queries that can be performed. The dataset is versatile and can be utilized to conduct analyses and generate reports across various fields such as air travel, education, geography, database management, human resources, and pet ownership studies.\n" + ] + } + ], + "source": [ + "# Describe dataset\n", + "print(agent.chat(\"Summarize the dataset\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d6bdf837-241a-4637-a07e-a73fafd52a07", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "STARTING TURN 1\n", + "---------------\n", + "\n", + "=== Calling Function ===\n", + "Calling function: transcode with args: {\"instruction\":\"Translate this PySpark query to Snowflake SQL.\",\"source_dialect\":\"pyspark\",\"source_query\":\"from pyspark.sql import SparkSession\\nfrom pyspark.sql.functions import avg, lag, lead, round\\nfrom pyspark.sql.window import Window\\n\\nspark = SparkSession.builder.appName(\\\"yearly_car_analysis\\\").getOrCreate()\\n\\nyearly_avg_hp = cars_data.groupBy(\\\"year\\\").agg(avg(\\\"horsepower\\\").alias(\\\"avg_horsepower\\\"))\\n\\nwindowSpec = Window.orderBy(\\\"year\\\")\\n\\nyearly_comparisons = yearly_avg_hp.select(\\n \\\"year\\\",\\n \\\"avg_horsepower\\\",\\n lag(\\\"avg_horsepower\\\").over(windowSpec).alias(\\\"prev_year_hp\\\"),\\n lead(\\\"avg_horsepower\\\").over(windowSpec).alias(\\\"next_year_hp\\\")\\n)\\n\\nfinal_result = yearly_comparisons.select(\\n \\\"year\\\",\\n \\\"avg_horsepower\\\",\\n round(\\n (yearly_comparisons.avg_horsepower - yearly_comparisons.prev_year_hp) / \\n yearly_comparisons.prev_year_hp * 100, 2\\n ).alias(\\\"percentage_diff_prev_year\\\"),\\n round(\\n (yearly_comparisons.next_year_hp - yearly_comparisons.avg_horsepower) / \\n yearly_comparisons.avg_horsepower * 100, 2\\n ).alias(\\\"percentage_diff_next_year\\\")\\n).orderBy(\\\"year\\\")\\n\\nfinal_result.show()\",\"target_dialect\":\"snowflake\"}\n", + "Got output: WITH yearly_avg_hp AS (\n", + " SELECT\n", + " year,\n", + " AVG(horsepower) AS avg_horsepower\n", + " FROM waii.car.cars_data\n", + " GROUP BY\n", + " year\n", + "),\n", + "\n", + "yearly_comparisons AS (\n", + " SELECT\n", + " year,\n", + " avg_horsepower,\n", + " LAG(avg_horsepower) OVER (ORDER BY year) AS prev_year_hp,\n", + " LEAD(avg_horsepower) OVER (ORDER BY year) AS next_year_hp\n", + " FROM yearly_avg_hp\n", + ")\n", + "\n", + "SELECT\n", + " year,\n", + " avg_horsepower,\n", + " ROUND((\n", + " (\n", + " avg_horsepower - prev_year_hp\n", + " ) / NULLIF(prev_year_hp, 0) * 100\n", + " ), 2) AS percentage_diff_prev_year,\n", + " ROUND((\n", + " (\n", + " next_year_hp - avg_horsepower\n", + " ) / NULLIF(avg_horsepower, 0) * 100\n", + " ), 2) AS percentage_diff_next_year\n", + "FROM yearly_comparisons\n", + "ORDER BY\n", + " year\n", + "\n", + "========================\n", + "\n", + "STARTING TURN 2\n", + "---------------\n", + "\n", + "The translated Snowflake SQL query from the provided PySpark code is as follows:\n", + "\n", + "```sql\n", + "WITH yearly_avg_hp AS (\n", + " SELECT\n", + " year,\n", + " AVG(horsepower) AS avg_horsepower\n", + " FROM waii.car.cars_data\n", + " GROUP BY\n", + " year\n", + "),\n", + "\n", + "yearly_comparisons AS (\n", + " SELECT\n", + " year,\n", + " avg_horsepower,\n", + " LAG(avg_horsepower) OVER (ORDER BY year) AS prev_year_hp,\n", + " LEAD(avg_horsepower) OVER (ORDER BY year) AS next_year_hp\n", + " FROM yearly_avg_hp\n", + ")\n", + "\n", + "SELECT\n", + " year,\n", + " avg_horsepower,\n", + " ROUND((\n", + " (\n", + " avg_horsepower - prev_year_hp\n", + " ) / NULLIF(prev_year_hp, 0) * 100\n", + " ), 2) AS percentage_diff_prev_year,\n", + " ROUND((\n", + " (\n", + " next_year_hp - avg_horsepower\n", + " ) / NULLIF(avg_horsepower, 0) * 100\n", + " ), 2) AS percentage_diff_next_year\n", + "FROM yearly_comparisons\n", + "ORDER BY\n", + " year\n", + "```\n", + "\n", + "This query calculates the average horsepower (`avg_horsepower`) for each year from the `cars_data` table, then uses window functions to get the previous and next year's average horsepower. It finally computes the percentage difference from the previous and next year's average horsepower and orders the result by year.\n" + ] + } + ], + "source": [ + "q = \"\"\"\n", + "from pyspark.sql import SparkSession\n", + "from pyspark.sql.functions import avg, lag, lead, round\n", + "from pyspark.sql.window import Window\n", + "\n", + "spark = SparkSession.builder.appName(\"yearly_car_analysis\").getOrCreate()\n", + "\n", + "yearly_avg_hp = cars_data.groupBy(\"year\").agg(avg(\"horsepower\").alias(\"avg_horsepower\"))\n", + "\n", + "windowSpec = Window.orderBy(\"year\")\n", + "\n", + "yearly_comparisons = yearly_avg_hp.select(\n", + " \"year\",\n", + " \"avg_horsepower\",\n", + " lag(\"avg_horsepower\").over(windowSpec).alias(\"prev_year_hp\"),\n", + " lead(\"avg_horsepower\").over(windowSpec).alias(\"next_year_hp\")\n", + ")\n", + "\n", + "final_result = yearly_comparisons.select(\n", + " \"year\",\n", + " \"avg_horsepower\",\n", + " round(\n", + " (yearly_comparisons.avg_horsepower - yearly_comparisons.prev_year_hp) / \n", + " yearly_comparisons.prev_year_hp * 100, 2\n", + " ).alias(\"percentage_diff_prev_year\"),\n", + " round(\n", + " (yearly_comparisons.next_year_hp - yearly_comparisons.avg_horsepower) / \n", + " yearly_comparisons.avg_horsepower * 100, 2\n", + " ).alias(\"percentage_diff_next_year\")\n", + ").orderBy(\"year\")\n", + "\n", + "final_result.show()\n", + "\"\"\"\n", + "print(agent.chat(f\"translate this pyspark query {q}, to Snowflake\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e94a9d08-1562-47c1-8983-77270bcca21e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "myenv", + "language": "python", + "name": "myenv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama_hub/tools/waii/README.md b/llama_hub/tools/waii/README.md new file mode 100644 index 0000000000..3805478285 --- /dev/null +++ b/llama_hub/tools/waii/README.md @@ -0,0 +1,160 @@ +# Waii Tool + +This tool connects to database connections managed by Waii, which allows generic SQL queries, do performance analyze, describe a SQL query, and more. + +## Usage + +First you need to create a waii.ai account, you request an account from [here](https://waii.ai/). + +Initialize the tool with your account credentials: + +```python +from llama_hub.tools.waii import WaiiToolSpec + +waii_tool = WaiiToolSpec( + url="https://tweakit.waii.ai/api/", + # API Key of Waii (not OpenAI API key) + api_key="...", + # Connection key of WAII connected database, see https://github.com/waii-ai/waii-sdk-py#get-connections + database_key='...' +) +``` + +## Tools + +The tools available are: + +- `get_answer`: Get answer to natural language question (which generate a SQL query, run it, explain the result) +- `describe_query`: Describe a SQL query +- `performance_analyze`: Analyze performance of a SQL query (by query_id) +- `diff_query`: Compare two SQL queries +- `describe_dataset`: Describe dataset, such as table, schema, etc. +- `transcode`: Transcode SQL query to another SQL dialect +- `get_semantic_contexts`: Get semantic contexts of a SQL query +- `generate_query_only`: Generate SQL query only (not run it) +- `run_query`: Run a SQL query + +You can also load the data directly call `load_data` + +## Examples + +### Load data + +```python +documents = waii_tool.load_data('Get all tables with their number of columns') +index = VectorStoreIndex.from_documents(documents).as_query_engine() + +print(index.query('Which table contains most columns?')) +``` + +### Use as a Tool + +#### Initialize the agent: + +```python +from llama_index.agent import OpenAIAgent +from llama_index.llms import OpenAI + +agent = OpenAIAgent.from_tools(waii_tool.to_tool_list(), llm=OpenAI(model='gpt-4-1106-preview')) +``` + +#### Ask simple question + +```python +agent.chat("Give me top 3 countries with the most number of car factory") +agent.chat("What are the car factories of these countries") +``` + +#### Do performance analyze + +```python +agent.chat("Give me top 3 longest running queries, and their duration.") +agent.chat("analyze the 2nd-longest running query") +``` + +#### Diff two queries + +```python +previous_query = """ +SELECT + employee_id, + department, + salary, + AVG(salary) OVER (PARTITION BY department) AS department_avg_salary, + salary - AVG(salary) OVER (PARTITION BY department) AS diff_from_avg +FROM + employees; +""" +current_query = """ +SELECT + employee_id, + department, + salary, + MAX(salary) OVER (PARTITION BY department) AS department_max_salary, + salary - AVG(salary) OVER (PARTITION BY department) AS diff_from_avg +FROM + employees; +LIMIT 100; +""" +agent.chat(f"tell me difference between {previous_query} and {current_query}") +``` + +#### Describe dataset + +```python +agent.chat("Summarize the dataset") +agent.chat("Give me questions which I can ask about this dataset") +``` + +#### Describe a query + +```python +q = """ +SELECT + employee_id, + department, + salary, + AVG(salary) OVER (PARTITION BY department) AS department_avg_salary, + salary - AVG(salary) OVER (PARTITION BY department) AS diff_from_avg +FROM + employees; +""" +agent.chat(f"what this query can do? {q}") +``` + +#### Migrate query to another dialect + +```python +q = """ +from pyspark.sql import SparkSession +from pyspark.sql.functions import avg, col +from pyspark.sql.window import Window + +# Initialize Spark session +spark = SparkSession.builder.appName("example").getOrCreate() + +# Assuming you have a DataFrame called 'employees' +# If not, you need to read your data into a DataFrame first + +# Define window specification +windowSpec = Window.partitionBy("department") + +# Perform the query +result = (employees + .select( + col("employee_id"), + col("department"), + col("salary"), + avg("salary").over(windowSpec).alias("department_avg_salary"), + (col("salary") - avg("salary").over(windowSpec)).alias("diff_from_avg") + )) + +# Show the result +result.show() +""" +agent.chat(f"translate this pyspark query {q}, to Snowflake") +``` + +### Use Waii API directly + +You can also use Waii API directly, see [here](https://github.com/waii-ai/waii-sdk-py) \ No newline at end of file diff --git a/llama_hub/tools/waii/__init__.py b/llama_hub/tools/waii/__init__.py new file mode 100644 index 0000000000..af4fc1ae98 --- /dev/null +++ b/llama_hub/tools/waii/__init__.py @@ -0,0 +1,6 @@ +# __init__.py +from llama_hub.tools.waii.base import ( + WaiiToolSpec, +) + +__all__ = ["WaiiToolSpec"] diff --git a/llama_hub/tools/waii/base.py b/llama_hub/tools/waii/base.py new file mode 100644 index 0000000000..082ad0fb42 --- /dev/null +++ b/llama_hub/tools/waii/base.py @@ -0,0 +1,274 @@ +"""Waii Tool.""" +import json +from typing import List, Optional + +from llama_index.readers.base import BaseReader +from llama_index.readers.schema.base import Document +from llama_index.response_synthesizers import TreeSummarize +from llama_index.tools.tool_spec.base import BaseToolSpec + + +class WaiiToolSpec(BaseToolSpec, BaseReader): + spec_functions = [ + "get_answer", + "describe_query", + "performance_analyze", + "diff_query", + "describe_dataset", + "transcode", + "get_semantic_contexts", + "generate_query_only", + "run_query", + ] + + def __init__( + self, + url: Optional[str] = None, + api_key: Optional[str] = None, + database_key: Optional[str] = None, + verbose: Optional[bool] = False, + ) -> None: + from waii_sdk_py import WAII + + WAII.initialize(url=url, api_key=api_key) + WAII.Database.activate_connection(key=database_key) + self.verbose = verbose + + def load_data(self, ask: str) -> List[Document]: + """Query using natural language and load data from the Database, returning a list of Documents. + + Args: + ask: a natural language question. + + Returns: + List[Document]: A list of Document objects. + """ + from waii_sdk_py import WAII + from waii_sdk_py.query import QueryGenerationRequest, RunQueryRequest + + query = WAII.Query.generate( + QueryGenerationRequest(ask=ask), verbose=self.verbose + ).query + documents = WAII.Query.run( + RunQueryRequest(query=query), verbose=self.verbose + ).rows + return [Document(text=str(doc)) for doc in documents] + + def _get_summarization(self, original_ask: str, documents): + texts = [] + + n_chars = 0 + for i in range(len(documents)): + t = str(documents[i].text) + if len(t) + n_chars > 8192: + texts.append(f"... {len(documents) - i} more results") + break + texts.append(t) + n_chars += len(t) + + summarizer = TreeSummarize(verbose=self.verbose) + response = summarizer.get_response(original_ask, texts) + return response + + def get_answer(self, ask: str): + """ + Generate a SQL query and run it against the database, returning the summarization of the answer + Args: + ask: a natural language question. + + Returns: + str: A string containing the summarization of the answer. + """ + docs = self.load_data(ask) + return self._get_summarization(ask, docs) + + def generate_query_only(self, ask: str): + """ + Generate a SQL query and NOT run it, returning the query. If you need to get answer, you should use get_answer instead. + + Args: + ask: a natural language question. + + Returns: + str: A string containing the query. + """ + from waii_sdk_py import WAII + from waii_sdk_py.query import QueryGenerationRequest + + query = WAII.Query.generate( + QueryGenerationRequest(ask=ask), verbose=self.verbose + ).query + return query + + def run_query(self, sql: str): + """ + This function run a SQL query against the database, returning the summarization of the answer. You use choose the function when you need to run a SQL query + + Args: + sql: a SQL query. + + Returns: + str: A string containing the summarization of the answer. + """ + from waii_sdk_py import WAII + from waii_sdk_py.query import RunQueryRequest + + documents = WAII.Query.run( + RunQueryRequest(query=sql), verbose=self.verbose + ).rows + return self._get_summarization( + "run query", [Document(text=str(doc)) for doc in documents] + ) + + def describe_query(self, question: str, query: str): + """ + Describe a sql query, returning the summarization of the answer + + Args: + question: a natural language question which the people want to ask. + query: a sql query. + + Returns: + str: A string containing the summarization of the answer. + """ + from waii_sdk_py import WAII + from waii_sdk_py.query import DescribeQueryRequest + + result = WAII.Query.describe(DescribeQueryRequest(query=query)) + result = json.dumps(result.dict(), indent=2) + response = self._get_summarization(question, [Document(text=result)]) + return response + + def performance_analyze(self, query_uuid: str): + """ + Analyze the performance of a query, returning the summarization of the answer + + Args: + query_uuid: a query uuid, e.g. xxxxxxxxxxxxx... + + Returns: + str: A string containing the summarization of the answer. + """ + from waii_sdk_py import WAII + from waii_sdk_py.query import QueryPerformanceRequest + + result = WAII.Query.analyze_performance( + QueryPerformanceRequest(query_id=query_uuid) + ) + result = json.dumps(result.dict(), indent=2) + return result + + def diff_query(self, previous_query: str, current_query: str): + """ + Diff two sql queries, returning the summarization of the answer + + Args: + previous_query: previous sql query. + current_query: current sql query. + + Returns: + str: A string containing the summarization of the answer. + """ + from waii_sdk_py import WAII + from waii_sdk_py.query import DiffQueryRequest + + result = WAII.Query.diff( + DiffQueryRequest(query=current_query, previous_query=previous_query) + ) + result = json.dumps(result.dict(), indent=2) + return self._get_summarization("get diff summary", [Document(text=result)]) + + def describe_dataset( + self, + ask: str, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + ): + """ + Describe a dataset (no matter if it is a table or schema), returning the summarization of the answer. + Example questions like: "describe the dataset", "what the schema is about", "example question for the table xxx", etc. + When both schema and table are None, describe the whole database. + + Args: + ask: a natural language question (how you want to describe the dataset). + schema_name: a schema name (shouldn't include the database name or the table name). + table_name: a table name. (shouldn't include the database name or the schema name). + + Returns: + str: A string containing the summarization of the answer. + """ + from waii_sdk_py import WAII + + catalog = WAII.Database.get_catalogs() + + # filter by schema / table + schemas = {} + tables = {} + + for c in catalog.catalogs: + for s in c.schemas: + for t in s.tables: + if ( + schema_name is not None + and schema_name.lower() != t.name.schema_name.lower() + ): + continue + if table_name is not None: + if table_name.lower() != t.name.table_name.lower(): + continue + tables[str(t.name)] = t + schemas[str(s.name)] = s + + # remove tables ref from schemas + for schema in schemas: + schemas[schema].tables = None + + # generate response + response = self._get_summarization( + ask + ", use the provided information to get comprehensive summarization", + [Document(text=str(schemas[schema])) for schema in schemas] + + [Document(text=str(tables[table])) for table in tables], + ) + return response + + def transcode( + self, + instruction: Optional[str] = "", + source_dialect: Optional[str] = None, + source_query: Optional[str] = None, + target_dialect: Optional[str] = None, + ): + """ + Transcode a sql query from one dialect to another, returning generated query + + Args: + instruction: instruction in natural language. + source_dialect: the source dialect of the query. + source_query: the source query. + target_dialect: the target dialect of the query. + + Returns: + str: A string containing the generated query. + """ + + from waii_sdk_py import WAII + from waii_sdk_py.query import TranscodeQueryRequest + + result = WAII.Query.transcode( + TranscodeQueryRequest( + ask=instruction, + source_dialect=source_dialect, + source_query=source_query, + target_dialect=target_dialect, + ) + ) + return result.query + + def get_semantic_contexts(self): + """ + Get all pre-defined semantic contexts + :return: + """ + from waii_sdk_py import WAII + + return WAII.SemanticContext.get_semantic_context().semantic_context diff --git a/llama_hub/tools/waii/requirements.txt b/llama_hub/tools/waii/requirements.txt new file mode 100644 index 0000000000..2fa2f6e0e5 --- /dev/null +++ b/llama_hub/tools/waii/requirements.txt @@ -0,0 +1 @@ +waii-sdk-py \ No newline at end of file