howdy3's picture
Create app.py
e9bcde7 verified
import re
import sqlite3
from dotenv import load_dotenv
import os
import streamlit as st
import google.generativeai as genai
load_dotenv()
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
def get_gemini_response(question, prompt):
model = genai.GenerativeModel("gemini-1.5-pro-latest")
response = model.generate_content([prompt[0], question])
return response.text
def read_sql_query(response, db):
conn = sqlite3.connect(db)
cur = conn.cursor()
# Extract SQL query from the response
sql_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL)
if sql_match:
sql = sql_match.group(1).strip()
else:
# If no SQL code block is found, try to find a SQL-like statement
sql_match = re.search(r'SELECT.*?FROM.*?;', response, re.DOTALL | re.IGNORECASE)
if sql_match:
sql = sql_match.group(0).strip()
else:
st.header("Error: No valid SQL query found in the response.\nPlease write more clear retrieval query.")
exit()
try:
cur.execute(sql)
rows = cur.fetchall()
conn.close()
return rows
except sqlite3.Error as e:
conn.close()
return f"SQLite error: {e}"
prompt = [
"""
You are an expert in converting English questions to SQL code.
The SQL database has the name STUDENT and has the following columns - NAME, CLASS,
SECTION \n\nFor example, \nExample 1: How many entries of records are present?,
the SQL command will be something like this SELECT COUNT(*) FROM STUDENT ;
\nExample 2: What are the names of students who study in AI?, the SQL command will be
something like this SELECT NAME FROM STUDENT WHERE CLASS = "AI";
also the sql code should not have ''' in beginning or end or a 'sql' word in output
"""
]
st.set_page_config(page_title="SQL Retrieval App")
st.info("This is a demo the default db is a student database containing a table of student name, class and section")
st.header("Gemini App - Retrieve SQL Data")
question = st.text_input("Input a retrieval question about the database (e.g. get all student names and class) ", key="input")
submit = st.button("Ask SQL")
if submit:
response = get_gemini_response(question, prompt)
print(response)
response = read_sql_query(response, "student.db")
st.subheader("Response: ")
for row in response:
print(row)
st.header(row)