Parthiban97 commited on
Commit
ff95d63
·
verified ·
1 Parent(s): 5155bdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -20
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import os
 
3
  from langchain_core.messages import AIMessage, HumanMessage
4
  from langchain_core.prompts import ChatPromptTemplate
5
  from langchain_core.runnables import RunnablePassthrough
@@ -7,41 +8,46 @@ from langchain_community.utilities import SQLDatabase
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_openai import ChatOpenAI
9
  from langchain_groq import ChatGroq
10
- import toml
11
 
12
  # Function to update config.toml file
13
- def update_secrets_file(data):
14
  secrets_file_path = ".streamlit/config.toml"
15
  secrets_data = {}
16
-
17
- # Load existing data from secrets.toml
18
  if os.path.exists(secrets_file_path):
19
- with open(secrets_file_path, "r") as file:
20
- secrets_data = toml.load(file)
 
 
 
 
 
 
 
21
 
22
- # Update secrets data with new data
23
- secrets_data.update(data)
24
 
25
- # Write updated data back to secrets.toml
26
- with open(secrets_file_path, "a") as file:
27
  toml.dump(secrets_data, file)
28
 
29
 
30
  # Initialize database connections
31
- def init_databases():
32
  secrets_file_path = ".streamlit/config.toml"
33
  secrets_data = {}
34
  if os.path.exists(secrets_file_path):
35
  with open(secrets_file_path, "r") as file:
36
- content = file.read().strip()
37
- if content:
38
- secrets_data = toml.loads(content)
39
 
40
  db_connections = {}
41
- for database in secrets_data.get("Databases", "").split(','):
42
  database = database.strip()
43
  if database:
44
- db_uri = f"mysql+mysqlconnector://{secrets_data['User']}:{secrets_data['Password']}@{secrets_data['Host']}:{secrets_data['Port']}/{database}"
45
  db_connections[database] = SQLDatabase.from_uri(db_uri)
46
  return db_connections
47
 
@@ -113,7 +119,27 @@ def get_sql_chain(dbs, llm):
113
 
114
  Question: How many shirts are available in stock grouped by colours from each size and finally show me all brands?
115
  SQL Query: SELECT brand, color, size, SUM(stock_quantity) AS total_stock FROM t_shirts GROUP BY brand, color, size
 
 
 
 
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  Your turn:
118
 
119
  Question: {question}
@@ -189,6 +215,7 @@ with st.sidebar:
189
  st.write("This is a simple chat application using MySQL. Connect to the database and start chatting.")
190
 
191
  if "db" not in st.session_state:
 
192
  st.session_state.Host = st.text_input("Host")
193
  st.session_state.Port = st.text_input("Port")
194
  st.session_state.User = st.text_input("User")
@@ -204,8 +231,8 @@ with st.sidebar:
204
  if st.button("Connect"):
205
  with st.spinner("Connecting to databases..."):
206
 
207
- # Update secrets.toml with connection details
208
- update_secrets_file({
209
  "Host": st.session_state.Host,
210
  "Port": st.session_state.Port,
211
  "User": st.session_state.User,
@@ -213,7 +240,7 @@ with st.sidebar:
213
  "Databases": st.session_state.Databases
214
  })
215
 
216
- dbs = init_databases()
217
  st.session_state.dbs = dbs
218
 
219
  if len(dbs) > 1:
@@ -253,4 +280,4 @@ if user_query is not None and user_query.strip() != "":
253
  response = get_response(user_query, st.session_state.dbs, st.session_state.chat_history, st.session_state.llm)
254
  st.markdown(response)
255
 
256
- st.session_state.chat_history.append(AIMessage(content=response))
 
1
  import streamlit as st
2
  import os
3
+ import toml
4
  from langchain_core.messages import AIMessage, HumanMessage
5
  from langchain_core.prompts import ChatPromptTemplate
6
  from langchain_core.runnables import RunnablePassthrough
 
8
  from langchain_core.output_parsers import StrOutputParser
9
  from langchain_openai import ChatOpenAI
10
  from langchain_groq import ChatGroq
 
11
 
12
  # Function to update config.toml file
13
+ def update_secrets_file(user, data):
14
  secrets_file_path = ".streamlit/config.toml"
15
  secrets_data = {}
16
+
17
+ # Load existing data from config.toml
18
  if os.path.exists(secrets_file_path):
19
+ try:
20
+ with open(secrets_file_path, "r") as file:
21
+ secrets_data = toml.load(file)
22
+ except toml.TomlDecodeError:
23
+ secrets_data = {}
24
+
25
+ # Update user-specific secrets data
26
+ if user not in secrets_data:
27
+ secrets_data[user] = {}
28
 
29
+ secrets_data[user].update(data)
 
30
 
31
+ # Write updated data back to config.toml
32
+ with open(secrets_file_path, "w") as file:
33
  toml.dump(secrets_data, file)
34
 
35
 
36
  # Initialize database connections
37
+ def init_databases(user):
38
  secrets_file_path = ".streamlit/config.toml"
39
  secrets_data = {}
40
  if os.path.exists(secrets_file_path):
41
  with open(secrets_file_path, "r") as file:
42
+ secrets_data = toml.load(file)
43
+
44
+ user_data = secrets_data.get(user, {})
45
 
46
  db_connections = {}
47
+ for database in user_data.get("Databases", "").split(','):
48
  database = database.strip()
49
  if database:
50
+ db_uri = f"mysql+mysqlconnector://{user_data['User']}:{user_data['Password']}@{user_data['Host']}:{user_data['Port']}/{database}"
51
  db_connections[database] = SQLDatabase.from_uri(db_uri)
52
  return db_connections
53
 
 
119
 
120
  Question: How many shirts are available in stock grouped by colours from each size and finally show me all brands?
121
  SQL Query: SELECT brand, color, size, SUM(stock_quantity) AS total_stock FROM t_shirts GROUP BY brand, color, size
122
+
123
+ Question: select all the movies with minimum and maximum release_year. Note that there can be more than one movies in min and max year hence output rows can be more than 2?
124
+ SQL Query: select * from movies where release_year in (
125
+ (select min(release_year) from movies),
126
+ (select max(release_year) from movies));
127
 
128
+ Question: Generate a yearly report for Croma India where there are two columns 1. Fiscal Year and 2. Total Gross Sales amount In that year from Croma
129
+ SQL Query: select
130
+ get_fiscal_year(date) as fiscal_year,
131
+ sum(round(sold_quantity*g.gross_price,2)) as yearly_sales
132
+ from fact_sales_monthly s
133
+ join fact_gross_price g
134
+ on
135
+ g.fiscal_year=get_fiscal_year(s.date) and
136
+ g.product_code=s.product_code
137
+ where
138
+ customer_code=90002002
139
+ group by get_fiscal_year(date)
140
+ order by fiscal_year;
141
+
142
+
143
  Your turn:
144
 
145
  Question: {question}
 
215
  st.write("This is a simple chat application using MySQL. Connect to the database and start chatting.")
216
 
217
  if "db" not in st.session_state:
218
+ st.session_state.user_id = st.text_input("User ID")
219
  st.session_state.Host = st.text_input("Host")
220
  st.session_state.Port = st.text_input("Port")
221
  st.session_state.User = st.text_input("User")
 
231
  if st.button("Connect"):
232
  with st.spinner("Connecting to databases..."):
233
 
234
+ # Update config.toml with user-specific connection details
235
+ update_secrets_file(st.session_state.user_id, {
236
  "Host": st.session_state.Host,
237
  "Port": st.session_state.Port,
238
  "User": st.session_state.User,
 
240
  "Databases": st.session_state.Databases
241
  })
242
 
243
+ dbs = init_databases(st.session_state.user_id)
244
  st.session_state.dbs = dbs
245
 
246
  if len(dbs) > 1:
 
280
  response = get_response(user_query, st.session_state.dbs, st.session_state.chat_history, st.session_state.llm)
281
  st.markdown(response)
282
 
283
+ st.session_state.chat_history.append(AIMessage(content=response))