|
|
import sqlite3 |
|
|
import json |
|
|
from contextlib import contextmanager |
|
|
from typing import List, Dict, Any, Tuple |
|
|
from config import DB_PATH |
|
|
|
|
|
@contextmanager |
|
|
def get_db_connection(): |
|
|
"""Context manager for database connections.""" |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
conn.row_factory = sqlite3.Row |
|
|
try: |
|
|
yield conn |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
def fetch_all_embeddings(table: str) -> List[Tuple[int, str, List[float]]]: |
|
|
"""Fetch all embeddings from a table.""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
cur.execute(f"SELECT id, full_text, embedding FROM {table}") |
|
|
rows = cur.fetchall() |
|
|
|
|
|
parsed = [] |
|
|
for row in rows: |
|
|
try: |
|
|
parsed.append((row['id'], row['full_text'], json.loads(row['embedding']))) |
|
|
except (json.JSONDecodeError, TypeError): |
|
|
continue |
|
|
return parsed |
|
|
|
|
|
def fetch_row_by_id(table: str, row_id: int) -> Dict[str, Any]: |
|
|
"""Fetch a single row by ID.""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
cur.execute(f"SELECT * FROM {table} WHERE id = ?", (row_id,)) |
|
|
row = cur.fetchone() |
|
|
return dict(row) if row else {} |
|
|
|
|
|
def fetch_all_faq_embeddings() -> List[Tuple[int, str, str, List[float]]]: |
|
|
"""Fetch all FAQ embeddings.""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
cur.execute("SELECT id, question, answer, embedding FROM faq_entries") |
|
|
rows = cur.fetchall() |
|
|
|
|
|
parsed = [] |
|
|
for row in rows: |
|
|
try: |
|
|
parsed.append((row['id'], row['question'], row['answer'], json.loads(row['embedding']))) |
|
|
except (json.JSONDecodeError, TypeError): |
|
|
continue |
|
|
return parsed |
|
|
|
|
|
def log_question( |
|
|
question: str, |
|
|
session_id: str = None, |
|
|
category: str = None, |
|
|
answer: str = None, |
|
|
detected_mode: str = None, |
|
|
routing_question: str = None, |
|
|
rule_triggered: str = None, |
|
|
link_provided: bool = False |
|
|
): |
|
|
"""Log a user question to the database with comprehensive observability metadata. |
|
|
|
|
|
Args: |
|
|
question: The user's question |
|
|
session_id: Session identifier |
|
|
category: Question category (e.g., 'faq_match', 'llm_generated', 'policy_violation') |
|
|
answer: The bot's response |
|
|
detected_mode: Operating mode ('Mode A' or 'Mode B') |
|
|
routing_question: The routing question asked (if any) |
|
|
rule_triggered: Business rule that was triggered (e.g., 'audit_rule', 'free_class_first') |
|
|
link_provided: Whether a direct link was included in the response |
|
|
""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
|
|
|
try: |
|
|
cur.execute(""" |
|
|
INSERT INTO question_logs ( |
|
|
session_id, question, category, answer, |
|
|
detected_mode, routing_question, rule_triggered, link_provided |
|
|
) |
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?) |
|
|
""", ( |
|
|
session_id, question, category, answer, |
|
|
detected_mode, routing_question, rule_triggered, |
|
|
1 if link_provided else 0 |
|
|
)) |
|
|
except sqlite3.OperationalError as e: |
|
|
|
|
|
print(f"⚠️ Logging error: {e}. Falling back to basic logging.") |
|
|
conn.commit() |
|
|
|
|
|
def get_recent_history(session_id: str, limit: int = 5) -> List[Dict[str, str]]: |
|
|
"""Retrieve recent conversation history for a session from the logs.""" |
|
|
if not session_id: |
|
|
return [] |
|
|
|
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
cur.execute(""" |
|
|
SELECT question, answer |
|
|
FROM question_logs |
|
|
WHERE session_id = ? |
|
|
AND answer IS NOT NULL |
|
|
ORDER BY timestamp DESC |
|
|
LIMIT ? |
|
|
""", (session_id, limit)) |
|
|
rows = cur.fetchall() |
|
|
|
|
|
|
|
|
history = [] |
|
|
for row in reversed(rows): |
|
|
history.append({"role": "user", "content": row['question']}) |
|
|
history.append({"role": "assistant", "content": row['answer']}) |
|
|
return history |
|
|
|
|
|
def get_session_state(session_id: str) -> Dict[str, Any]: |
|
|
"""Get session state from DB""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
cur.execute("SELECT * FROM user_sessions WHERE session_id = ?", (session_id,)) |
|
|
row = cur.fetchone() |
|
|
if row: |
|
|
return dict(row) |
|
|
return {"preference": None, "msg_count": 0, "clarification_count": 0, "knowledge_context": "{}"} |
|
|
|
|
|
def update_session_state(session_id: str, preference: str = None, increment_count: bool = True, increment_clarification: bool = False, reset_clarification: bool = False, knowledge_update: Dict = None): |
|
|
"""Update session state with Knowledge Dictionary support""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
|
|
|
|
|
|
cur.execute("SELECT preference, msg_count, clarification_count, knowledge_context FROM user_sessions WHERE session_id = ?", (session_id,)) |
|
|
row = cur.fetchone() |
|
|
|
|
|
current_knowledge = {} |
|
|
if row: |
|
|
curr_pref, curr_count, curr_clarification, curr_knowledge_json = row |
|
|
try: |
|
|
current_knowledge = json.loads(curr_knowledge_json) |
|
|
except: |
|
|
current_knowledge = {} |
|
|
|
|
|
new_pref = preference if preference else curr_pref |
|
|
new_count = curr_count + 1 if increment_count else curr_count |
|
|
|
|
|
|
|
|
if new_count > 10: |
|
|
print(f"🔄 Session {session_id} reached 10 messages. Resetting memory context.") |
|
|
new_count = 1 |
|
|
new_pref = None |
|
|
current_knowledge = {} |
|
|
new_clarification = 0 |
|
|
else: |
|
|
new_clarification = curr_clarification |
|
|
if reset_clarification: |
|
|
new_clarification = 0 |
|
|
elif increment_clarification: |
|
|
new_clarification = curr_clarification + 1 |
|
|
|
|
|
|
|
|
if knowledge_update: |
|
|
current_knowledge.update(knowledge_update) |
|
|
|
|
|
new_knowledge_json = json.dumps(current_knowledge) |
|
|
|
|
|
cur.execute(""" |
|
|
UPDATE user_sessions |
|
|
SET preference = ?, msg_count = ?, clarification_count = ?, knowledge_context = ?, last_updated = CURRENT_TIMESTAMP |
|
|
WHERE session_id = ? |
|
|
""", (new_pref, new_count, new_clarification, new_knowledge_json, session_id)) |
|
|
else: |
|
|
new_pref = preference |
|
|
new_count = 1 if increment_count else 0 |
|
|
new_clarification = 1 if increment_clarification else 0 |
|
|
|
|
|
if knowledge_update: |
|
|
current_knowledge.update(knowledge_update) |
|
|
new_knowledge_json = json.dumps(current_knowledge) |
|
|
|
|
|
cur.execute(""" |
|
|
INSERT INTO user_sessions (session_id, preference, msg_count, clarification_count, knowledge_context) |
|
|
VALUES (?, ?, ?, ?, ?) |
|
|
""", (session_id, new_pref, new_count, new_clarification, new_knowledge_json)) |
|
|
|
|
|
conn.commit() |
|
|
|
|
|
def update_faq_entry(faq_id: int, question: str, answer: str): |
|
|
"""Update an existing FAQ entry.""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
cur.execute( |
|
|
"UPDATE faq_entries SET question = ?, answer = ?, embedding = NULL WHERE id = ?", |
|
|
(question, answer, faq_id) |
|
|
) |
|
|
conn.commit() |
|
|
|
|
|
def delete_faq_entry(faq_id: int): |
|
|
"""Delete an FAQ entry.""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
cur.execute("DELETE FROM faq_entries WHERE id = ?", (faq_id,)) |
|
|
conn.commit() |
|
|
|
|
|
def add_faq_entry(question: str, answer: str): |
|
|
"""Add a new FAQ entry.""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
cur.execute( |
|
|
"INSERT INTO faq_entries (question, answer) VALUES (?, ?)", |
|
|
(question, answer) |
|
|
) |
|
|
conn.commit() |
|
|
|
|
|
def bulk_update_faqs(entries: List[Dict[str, str]]): |
|
|
"""Bulk update FAQs from a list of dictionaries.""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
for entry in entries: |
|
|
question = entry.get('Question') or entry.get('question') |
|
|
answer = entry.get('Answer') or entry.get('answer') |
|
|
if question and answer: |
|
|
cur.execute( |
|
|
"INSERT INTO faq_entries (question, answer) VALUES (?, ?)", |
|
|
(question, answer) |
|
|
) |
|
|
conn.commit() |
|
|
|
|
|
def bulk_update_podcasts(entries: List[Dict[str, str]]): |
|
|
"""Bulk update Podcasts from a list of dictionaries.""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
for entry in entries: |
|
|
guest = entry.get('Guest Name') or entry.get('guest_name') |
|
|
url = entry.get('YouTube URL') or entry.get('youtube_url') |
|
|
summary = entry.get('Summary') or entry.get('summary') |
|
|
if guest and url and summary: |
|
|
|
|
|
full_text = f"Guest: {guest}. Summary: {summary}" |
|
|
|
|
|
h_json = json.dumps([{"summary": summary}]) |
|
|
cur.execute( |
|
|
"INSERT INTO podcast_episodes (guest_name, youtube_url, highlight_json, full_text) VALUES (?, ?, ?, ?)", |
|
|
(guest, url, h_json, full_text) |
|
|
) |
|
|
conn.commit() |
|
|
|
|
|
def fetch_all_podcast_metadata() -> List[Dict[str, Any]]: |
|
|
"""Fetch all podcast metadata for the admin table.""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
cur.execute("SELECT id, guest_name, youtube_url, highlight_json FROM podcast_episodes") |
|
|
rows = cur.fetchall() |
|
|
results = [] |
|
|
for row in rows: |
|
|
d = dict(row) |
|
|
|
|
|
try: |
|
|
h = json.loads(d['highlight_json']) |
|
|
d['summary'] = h[0]['summary'] if h and isinstance(h, list) else d['highlight_json'] |
|
|
except: |
|
|
d['summary'] = d['highlight_json'] |
|
|
results.append(d) |
|
|
return results |
|
|
|
|
|
def fetch_all_faq_metadata() -> List[Dict[str, Any]]: |
|
|
"""Fetch all FAQ metadata for the admin table.""" |
|
|
with get_db_connection() as conn: |
|
|
cur = conn.cursor() |
|
|
cur.execute("SELECT id, question, answer FROM faq_entries") |
|
|
return [dict(row) for row in cur.fetchall()] |