Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import os
|
|
|
|
| 3 |
from langchain_core.messages import AIMessage, HumanMessage
|
| 4 |
from langchain_core.prompts import ChatPromptTemplate
|
| 5 |
from langchain_core.runnables import RunnablePassthrough
|
|
@@ -7,41 +8,46 @@ from langchain_community.utilities import SQLDatabase
|
|
| 7 |
from langchain_core.output_parsers import StrOutputParser
|
| 8 |
from langchain_openai import ChatOpenAI
|
| 9 |
from langchain_groq import ChatGroq
|
| 10 |
-
import toml
|
| 11 |
|
| 12 |
# Function to update config.toml file
|
| 13 |
-
def update_secrets_file(data):
|
| 14 |
secrets_file_path = ".streamlit/config.toml"
|
| 15 |
secrets_data = {}
|
| 16 |
-
|
| 17 |
-
# Load existing data from
|
| 18 |
if os.path.exists(secrets_file_path):
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
secrets_data.update(data)
|
| 24 |
|
| 25 |
-
# Write updated data back to
|
| 26 |
-
with open(secrets_file_path, "
|
| 27 |
toml.dump(secrets_data, file)
|
| 28 |
|
| 29 |
|
| 30 |
# Initialize database connections
|
| 31 |
-
def init_databases():
|
| 32 |
secrets_file_path = ".streamlit/config.toml"
|
| 33 |
secrets_data = {}
|
| 34 |
if os.path.exists(secrets_file_path):
|
| 35 |
with open(secrets_file_path, "r") as file:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
db_connections = {}
|
| 41 |
-
for database in
|
| 42 |
database = database.strip()
|
| 43 |
if database:
|
| 44 |
-
db_uri = f"mysql+mysqlconnector://{
|
| 45 |
db_connections[database] = SQLDatabase.from_uri(db_uri)
|
| 46 |
return db_connections
|
| 47 |
|
|
@@ -113,7 +119,27 @@ def get_sql_chain(dbs, llm):
|
|
| 113 |
|
| 114 |
Question: How many shirts are available in stock grouped by colours from each size and finally show me all brands?
|
| 115 |
SQL Query: SELECT brand, color, size, SUM(stock_quantity) AS total_stock FROM t_shirts GROUP BY brand, color, size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
Your turn:
|
| 118 |
|
| 119 |
Question: {question}
|
|
@@ -189,6 +215,7 @@ with st.sidebar:
|
|
| 189 |
st.write("This is a simple chat application using MySQL. Connect to the database and start chatting.")
|
| 190 |
|
| 191 |
if "db" not in st.session_state:
|
|
|
|
| 192 |
st.session_state.Host = st.text_input("Host")
|
| 193 |
st.session_state.Port = st.text_input("Port")
|
| 194 |
st.session_state.User = st.text_input("User")
|
|
@@ -204,8 +231,8 @@ with st.sidebar:
|
|
| 204 |
if st.button("Connect"):
|
| 205 |
with st.spinner("Connecting to databases..."):
|
| 206 |
|
| 207 |
-
# Update
|
| 208 |
-
update_secrets_file({
|
| 209 |
"Host": st.session_state.Host,
|
| 210 |
"Port": st.session_state.Port,
|
| 211 |
"User": st.session_state.User,
|
|
@@ -213,7 +240,7 @@ with st.sidebar:
|
|
| 213 |
"Databases": st.session_state.Databases
|
| 214 |
})
|
| 215 |
|
| 216 |
-
dbs = init_databases()
|
| 217 |
st.session_state.dbs = dbs
|
| 218 |
|
| 219 |
if len(dbs) > 1:
|
|
@@ -253,4 +280,4 @@ if user_query is not None and user_query.strip() != "":
|
|
| 253 |
response = get_response(user_query, st.session_state.dbs, st.session_state.chat_history, st.session_state.llm)
|
| 254 |
st.markdown(response)
|
| 255 |
|
| 256 |
-
st.session_state.chat_history.append(AIMessage(content=response))
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import os
|
| 3 |
+
import toml
|
| 4 |
from langchain_core.messages import AIMessage, HumanMessage
|
| 5 |
from langchain_core.prompts import ChatPromptTemplate
|
| 6 |
from langchain_core.runnables import RunnablePassthrough
|
|
|
|
| 8 |
from langchain_core.output_parsers import StrOutputParser
|
| 9 |
from langchain_openai import ChatOpenAI
|
| 10 |
from langchain_groq import ChatGroq
|
|
|
|
| 11 |
|
| 12 |
# Function to update config.toml file
|
| 13 |
+
def update_secrets_file(user, data):
|
| 14 |
secrets_file_path = ".streamlit/config.toml"
|
| 15 |
secrets_data = {}
|
| 16 |
+
|
| 17 |
+
# Load existing data from config.toml
|
| 18 |
if os.path.exists(secrets_file_path):
|
| 19 |
+
try:
|
| 20 |
+
with open(secrets_file_path, "r") as file:
|
| 21 |
+
secrets_data = toml.load(file)
|
| 22 |
+
except toml.TomlDecodeError:
|
| 23 |
+
secrets_data = {}
|
| 24 |
+
|
| 25 |
+
# Update user-specific secrets data
|
| 26 |
+
if user not in secrets_data:
|
| 27 |
+
secrets_data[user] = {}
|
| 28 |
|
| 29 |
+
secrets_data[user].update(data)
|
|
|
|
| 30 |
|
| 31 |
+
# Write updated data back to config.toml
|
| 32 |
+
with open(secrets_file_path, "w") as file:
|
| 33 |
toml.dump(secrets_data, file)
|
| 34 |
|
| 35 |
|
| 36 |
# Initialize database connections
|
| 37 |
+
def init_databases(user):
|
| 38 |
secrets_file_path = ".streamlit/config.toml"
|
| 39 |
secrets_data = {}
|
| 40 |
if os.path.exists(secrets_file_path):
|
| 41 |
with open(secrets_file_path, "r") as file:
|
| 42 |
+
secrets_data = toml.load(file)
|
| 43 |
+
|
| 44 |
+
user_data = secrets_data.get(user, {})
|
| 45 |
|
| 46 |
db_connections = {}
|
| 47 |
+
for database in user_data.get("Databases", "").split(','):
|
| 48 |
database = database.strip()
|
| 49 |
if database:
|
| 50 |
+
db_uri = f"mysql+mysqlconnector://{user_data['User']}:{user_data['Password']}@{user_data['Host']}:{user_data['Port']}/{database}"
|
| 51 |
db_connections[database] = SQLDatabase.from_uri(db_uri)
|
| 52 |
return db_connections
|
| 53 |
|
|
|
|
| 119 |
|
| 120 |
Question: How many shirts are available in stock grouped by colours from each size and finally show me all brands?
|
| 121 |
SQL Query: SELECT brand, color, size, SUM(stock_quantity) AS total_stock FROM t_shirts GROUP BY brand, color, size
|
| 122 |
+
|
| 123 |
+
Question: select all the movies with minimum and maximum release_year. Note that there can be more than one movies in min and max year hence output rows can be more than 2?
|
| 124 |
+
SQL Query: select * from movies where release_year in (
|
| 125 |
+
(select min(release_year) from movies),
|
| 126 |
+
(select max(release_year) from movies));
|
| 127 |
|
| 128 |
+
Question: Generate a yearly report for Croma India where there are two columns 1. Fiscal Year and 2. Total Gross Sales amount In that year from Croma
|
| 129 |
+
SQL Query: select
|
| 130 |
+
get_fiscal_year(date) as fiscal_year,
|
| 131 |
+
sum(round(sold_quantity*g.gross_price,2)) as yearly_sales
|
| 132 |
+
from fact_sales_monthly s
|
| 133 |
+
join fact_gross_price g
|
| 134 |
+
on
|
| 135 |
+
g.fiscal_year=get_fiscal_year(s.date) and
|
| 136 |
+
g.product_code=s.product_code
|
| 137 |
+
where
|
| 138 |
+
customer_code=90002002
|
| 139 |
+
group by get_fiscal_year(date)
|
| 140 |
+
order by fiscal_year;
|
| 141 |
+
|
| 142 |
+
|
| 143 |
Your turn:
|
| 144 |
|
| 145 |
Question: {question}
|
|
|
|
| 215 |
st.write("This is a simple chat application using MySQL. Connect to the database and start chatting.")
|
| 216 |
|
| 217 |
if "db" not in st.session_state:
|
| 218 |
+
st.session_state.user_id = st.text_input("User ID")
|
| 219 |
st.session_state.Host = st.text_input("Host")
|
| 220 |
st.session_state.Port = st.text_input("Port")
|
| 221 |
st.session_state.User = st.text_input("User")
|
|
|
|
| 231 |
if st.button("Connect"):
|
| 232 |
with st.spinner("Connecting to databases..."):
|
| 233 |
|
| 234 |
+
# Update config.toml with user-specific connection details
|
| 235 |
+
update_secrets_file(st.session_state.user_id, {
|
| 236 |
"Host": st.session_state.Host,
|
| 237 |
"Port": st.session_state.Port,
|
| 238 |
"User": st.session_state.User,
|
|
|
|
| 240 |
"Databases": st.session_state.Databases
|
| 241 |
})
|
| 242 |
|
| 243 |
+
dbs = init_databases(st.session_state.user_id)
|
| 244 |
st.session_state.dbs = dbs
|
| 245 |
|
| 246 |
if len(dbs) > 1:
|
|
|
|
| 280 |
response = get_response(user_query, st.session_state.dbs, st.session_state.chat_history, st.session_state.llm)
|
| 281 |
st.markdown(response)
|
| 282 |
|
| 283 |
+
st.session_state.chat_history.append(AIMessage(content=response))
|