TextToSQL / app.py
thechaiexperiment's picture
Update app.py
3b6af07 verified
import gradio as gr
import openai
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from typing import Optional, Tuple
import re
# OpenRouter API Key (Replace with yours)
OPENROUTER_API_KEY = "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"
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
# Hugging Face Space path
DB_PATH = "ecommerce.db"
# Ensure dataset exists
if not os.path.exists(DB_PATH):
os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db") # Replace with actual dataset link
# Initialize OpenAI client
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
# Function: Fetch database schema
def fetch_schema(db_path: str) -> str:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema = ""
for table in tables:
table_name = table[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
schema += f"Table: {table_name}\n"
for column in columns:
schema += f" Column: {column[1]}, Type: {column[2]}\n"
conn.close()
return schema
# Function: Extract SQL query from LLM response
def extract_sql_query(response: str) -> str:
# Use regex to find content between ```sql and ```
match = re.search(r"```sql(.*?)```", response, re.DOTALL)
if match:
return match.group(1).strip() # Extract and return the SQL query
return response # Fallback: return the entire response if no SQL block is found
# Function: Convert text to SQL
def text_to_sql(query: str, schema: str) -> str:
prompt = (
"You are an SQL expert. Given the following database schema:\n\n"
f"{schema}\n\n"
"Convert the following query into SQL:\n\n"
f"Query: {query}\n"
"SQL:"
)
try:
response = openai_client.chat.completions.create(
model=OPENROUTER_MODEL,
messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
)
sql_response = response.choices[0].message.content.strip()
return extract_sql_query(sql_response) # Extract SQL query from the response
except Exception as e:
return f"Error: {e}"
def preprocess_sql_for_sqlite(sql_query: str) -> str:
"""
Replace non-SQLite functions with SQLite-compatible equivalents.
"""
sql_query = re.sub(r"\bMONTH\s*\(\s*([\w.]+)\s*\)", r"strftime('%m', \1)", sql_query)
sql_query = re.sub(r"\bYEAR\s*\(\s*([\w.]+)\s*\)", r"strftime('%Y', \1)", sql_query)
return sql_query
def execute_sql(sql_query: str) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
try:
conn = sqlite3.connect(DB_PATH)
sql_query = preprocess_sql_for_sqlite(sql_query) # Convert to SQLite-compatible SQL
df = pd.read_sql_query(sql_query, conn)
conn.close()
return df, None
except Exception as e:
return None, f"SQL Execution Error: {e}"
# Function: Generate Dynamic Visualization
def visualize_data(df: pd.DataFrame) -> Optional[str]:
if df.empty or df.shape[1] < 2:
return None
plt.figure(figsize=(6, 4))
sns.set_theme(style="darkgrid")
# Detect numeric columns
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) < 1:
return None
# Choose visualization type dynamically
if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric
sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
plt.title(f"Distribution of {numeric_cols[0]}")
elif len(numeric_cols) == 2: # Two numeric columns, assume X-Y plot
sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
elif df.shape[0] < 10: # If rows are few, prefer pie chart
plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
plt.title(f"Proportion of {numeric_cols[0]}")
else: # Default: Bar chart for categories + values
sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
plt.xticks(rotation=45)
plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")
plt.tight_layout()
plt.savefig("chart.png")
return "chart.png"
# Gradio UI
def gradio_ui(query: str) -> Tuple[str, str, Optional[str]]:
schema = fetch_schema(DB_PATH)
sql_query = text_to_sql(query, schema)
df, error = execute_sql(sql_query)
if error:
return sql_query, error, None
visualization = visualize_data(df) if df is not None else None
return sql_query, df.to_string(index=False), visualization
# Launch Gradio App
with gr.Blocks() as demo:
gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
submit_btn = gr.Button("Convert & Execute")
sql_output = gr.Textbox(label="Generated SQL Query")
table_output = gr.Textbox(label="Query Results")
chart_output = gr.Image(label="Data Visualization")
submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])
demo.launch()