Spaces:
Sleeping
Sleeping
import re | |
import sqlite3 | |
from dotenv import load_dotenv | |
import os | |
import streamlit as st | |
import google.generativeai as genai | |
load_dotenv() | |
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) | |
def get_gemini_response(question, prompt): | |
model = genai.GenerativeModel("gemini-1.5-pro-latest") | |
response = model.generate_content([prompt[0], question]) | |
return response.text | |
def read_sql_query(response, db): | |
conn = sqlite3.connect(db) | |
cur = conn.cursor() | |
# Extract SQL query from the response | |
sql_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL) | |
if sql_match: | |
sql = sql_match.group(1).strip() | |
else: | |
# If no SQL code block is found, try to find a SQL-like statement | |
sql_match = re.search(r'SELECT.*?FROM.*?;', response, re.DOTALL | re.IGNORECASE) | |
if sql_match: | |
sql = sql_match.group(0).strip() | |
else: | |
st.header("Error: No valid SQL query found in the response.\nPlease write more clear retrieval query.") | |
exit() | |
try: | |
cur.execute(sql) | |
rows = cur.fetchall() | |
conn.close() | |
return rows | |
except sqlite3.Error as e: | |
conn.close() | |
return f"SQLite error: {e}" | |
prompt = [ | |
""" | |
You are an expert in converting English questions to SQL code. | |
The SQL database has the name STUDENT and has the following columns - NAME, CLASS, | |
SECTION \n\nFor example, \nExample 1: How many entries of records are present?, | |
the SQL command will be something like this SELECT COUNT(*) FROM STUDENT ; | |
\nExample 2: What are the names of students who study in AI?, the SQL command will be | |
something like this SELECT NAME FROM STUDENT WHERE CLASS = "AI"; | |
also the sql code should not have ''' in beginning or end or a 'sql' word in output | |
""" | |
] | |
st.set_page_config(page_title="SQL Retrieval App") | |
st.info("This is a demo the default db is a student database containing a table of student name, class and section") | |
st.header("Gemini App - Retrieve SQL Data") | |
question = st.text_input("Input a retrieval question about the database (e.g. get all student names and class) ", key="input") | |
submit = st.button("Ask SQL") | |
if submit: | |
response = get_gemini_response(question, prompt) | |
print(response) | |
response = read_sql_query(response, "student.db") | |
st.subheader("Response: ") | |
for row in response: | |
print(row) | |
st.header(row) | |