Commit
·
a343d94
0
Parent(s):
Fine-tuned model + RAG + completed NL-to-SQL Agent
Browse files- .gitignore +1 -0
- app.py +0 -0
- fine-tuning/finetune_script.py +70 -0
- fine-tuning/finetuned_model_deployment.py +70 -0
- fine-tuning/nl_sql_finetune_dataset.jsonl +66 -0
- knowledge_base/business_glossary/kpi_definitions.md +31 -0
- knowledge_base/schema/data_dictionary.md +60 -0
- knowledge_base/schema/sales_schema.sql +32 -0
- requirements.txt +8 -0
- src/agents/agent_models/__pycache__/models.cpython-313.pyc +0 -0
- src/agents/agent_models/models.py +14 -0
- src/agents/agent_tools/__pycache__/schema_retriever_tool.cpython-313.pyc +0 -0
- src/agents/agent_tools/__pycache__/sql_executor_tool.cpython-313.pyc +0 -0
- src/agents/agent_tools/schema_retriever_tool.py +103 -0
- src/agents/agent_tools/sql_executor_tool.py +63 -0
- src/agents/nl_sql_agent.py +123 -0
- src/agents/orchestrator_agent.py +20 -0
- src/rag_index.py +180 -0
- src/setup_database.py +242 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.env
|
app.py
ADDED
File without changes
|
fine-tuning/finetune_script.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from openai import OpenAI
|
3 |
+
import time
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
|
8 |
+
client = OpenAI(
|
9 |
+
base_url="https://api.studio.nebius.com/v1/",
|
10 |
+
api_key=os.environ.get("NEBIUS_API_KEY"),
|
11 |
+
)
|
12 |
+
|
13 |
+
# Upload a training dataset
|
14 |
+
training_dataset = client.files.create(
|
15 |
+
file=open("fine-tuning/nl_sql_finetune_dataset.jsonl", "rb"),
|
16 |
+
purpose="fine-tune"
|
17 |
+
)
|
18 |
+
|
19 |
+
# Fine-tuning job parameters
|
20 |
+
job_request = {
|
21 |
+
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
22 |
+
"suffix": "nl-to-sql-finetuned",
|
23 |
+
"training_file": training_dataset.id,
|
24 |
+
"hyperparameters": {
|
25 |
+
"batch_size": 8,
|
26 |
+
"learning_rate": 2e-5,
|
27 |
+
"n_epochs": 10,
|
28 |
+
"warmup_ratio": 0.1,
|
29 |
+
"weight_decay": 0.01,
|
30 |
+
"lora": True,
|
31 |
+
"lora_r": 8,
|
32 |
+
"lora_alpha": 16,
|
33 |
+
"lora_dropout": 0.05,
|
34 |
+
"packing": True,
|
35 |
+
"max_grad_norm": 1,
|
36 |
+
},
|
37 |
+
}
|
38 |
+
|
39 |
+
# Create and run the fine-tuning job
|
40 |
+
job = client.fine_tuning.jobs.create(**job_request)
|
41 |
+
|
42 |
+
# Make sure that the job has been finished or cancelled
|
43 |
+
active_statuses = ["validating_files", "queued", "running"]
|
44 |
+
while job.status in active_statuses:
|
45 |
+
time.sleep(15)
|
46 |
+
job = client.fine_tuning.jobs.retrieve(job.id)
|
47 |
+
print("current status is", job.status)
|
48 |
+
|
49 |
+
print("Job ID:", job.id)
|
50 |
+
|
51 |
+
if job.status == "succeeded":
|
52 |
+
# Check the job events
|
53 |
+
events = client.fine_tuning.jobs.list_events(job.id)
|
54 |
+
print(events)
|
55 |
+
|
56 |
+
for checkpoint in client.fine_tuning.jobs.checkpoints.list(job.id).data:
|
57 |
+
print("Checkpoint ID:", checkpoint.id)
|
58 |
+
|
59 |
+
# Create a directory for every checkpoint
|
60 |
+
os.makedirs(f"fine-tuning/model_checkpoints/{checkpoint.id}", exist_ok=True)
|
61 |
+
|
62 |
+
for model_file_id in checkpoint.result_files:
|
63 |
+
# Get the name of a model file
|
64 |
+
filename = client.files.retrieve(model_file_id).filename
|
65 |
+
|
66 |
+
# Retrieve the contents of the file
|
67 |
+
file_content = client.files.content(model_file_id)
|
68 |
+
|
69 |
+
# Save the contents into a local file
|
70 |
+
file_content.write_to_file(filename)
|
fine-tuning/finetuned_model_deployment.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
from openai import OpenAI
|
4 |
+
import time
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
api_key=os.environ.get('NEBIUS_API_KEY')
|
9 |
+
api_url="https://api.studio.nebius.com"
|
10 |
+
|
11 |
+
client = OpenAI(
|
12 |
+
base_url=api_url+"/v1",
|
13 |
+
api_key=api_key,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
def create_lora_from_job(name, ft_job, ft_checkpoint, base_model):
|
18 |
+
fine_tuning_result = f"{ft_job}:{ft_checkpoint}"
|
19 |
+
lora_creation_request = {
|
20 |
+
"source": fine_tuning_result,
|
21 |
+
"base_model": base_model,
|
22 |
+
"name": name,
|
23 |
+
"description": "Deploying fine-tuned LoRA model"
|
24 |
+
}
|
25 |
+
|
26 |
+
response = requests.post(
|
27 |
+
f"{api_url}/v0/models",
|
28 |
+
json=lora_creation_request,
|
29 |
+
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
30 |
+
)
|
31 |
+
|
32 |
+
print("Model creation response code:", response.status_code)
|
33 |
+
print("Model creation response body:", response.text)
|
34 |
+
|
35 |
+
response.raise_for_status()
|
36 |
+
return response.json()
|
37 |
+
|
38 |
+
# Get the custom model information, check its status and wait for the validation to complete
|
39 |
+
def wait_for_validation(name, delay=5):
|
40 |
+
while True:
|
41 |
+
time.sleep(delay)
|
42 |
+
print(f"Validating LORA: {name}")
|
43 |
+
lora_info=requests.get(
|
44 |
+
f"{api_url}/v0/models/{name}",
|
45 |
+
headers={"Content-Type": "application/json","Authorization": f"Bearer {api_key}"}
|
46 |
+
).json()
|
47 |
+
if lora_info.get("status") in {"active", "error"}:
|
48 |
+
return lora_info
|
49 |
+
|
50 |
+
# Create a multi-message request
|
51 |
+
def get_completion(model):
|
52 |
+
completion = client.chat.completions.create(
|
53 |
+
model=model,
|
54 |
+
messages=[{"role":"user","content":"Hello"}],
|
55 |
+
)
|
56 |
+
return completion.choices[0].message.content
|
57 |
+
|
58 |
+
# Deploy a LoRA adapter model by using IDs of a fine-tuning job and its checkpoint
|
59 |
+
lora_name=create_lora_from_job("nl-to-sql-adapter", "ftjob-b8246b132b234b78a73e4b508817868d", "ftckpt_db41d377-7e1c-43f8-838a-73801db71444", "meta-llama/Meta-Llama-3.1-8B-Instruct").get("name")
|
60 |
+
|
61 |
+
# Check the custom model status
|
62 |
+
lora_info = wait_for_validation(lora_name)
|
63 |
+
|
64 |
+
# If the custom model validation is successful, create a multi-message request with this model
|
65 |
+
if lora_info.get("status") == "active":
|
66 |
+
print(get_completion(lora_name))
|
67 |
+
# If there is an error, display the reason
|
68 |
+
elif lora_info.get("status") == "error":
|
69 |
+
print(f"An error occurred: {lora_info["status_reason"]}")
|
70 |
+
|
fine-tuning/nl_sql_finetune_dataset.jsonl
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"How many regions are there in the database?"},{"role":"assistant","content":"SELECT COUNT(*) FROM regions;"}]}
|
2 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List all region names."},{"role":"assistant","content":"SELECT region_name FROM regions;"}]}
|
3 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve region_ids and names for regions starting with 'West'."},{"role":"assistant","content":"SELECT region_id, region_name FROM regions WHERE region_name LIKE 'West%';"}]}
|
4 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show me region names in alphabetical order."},{"role":"assistant","content":"SELECT region_name FROM regions ORDER BY region_name ASC;"}]}
|
5 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Count regions whose names contain 'East'."},{"role":"assistant","content":"SELECT COUNT(*) FROM regions WHERE region_name LIKE '%East%';"}]}
|
6 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Get the maximum region_id."},{"role":"assistant","content":"SELECT MAX(region_id) FROM regions;"}]}
|
7 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Select all regions with region_id > 3."},{"role":"assistant","content":"SELECT * FROM regions WHERE region_id > 3;"}]}
|
8 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve the second highest region_id."},{"role":"assistant","content":"SELECT region_id FROM regions ORDER BY region_id DESC LIMIT 1 OFFSET 1;"}]}
|
9 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List the first three regions by ID."},{"role":"assistant","content":"SELECT * FROM regions ORDER BY region_id ASC LIMIT 3;"}]}
|
10 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List region names in reverse alphabetical order."},{"role":"assistant","content":"SELECT region_name FROM regions ORDER BY region_name DESC;"}]}
|
11 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Which regions have even IDs?"},{"role":"assistant","content":"SELECT * FROM regions WHERE region_id % 2 = 0;"}]}
|
12 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"How many products are available?"},{"role":"assistant","content":"SELECT COUNT(*) FROM products;"}]}
|
13 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List all product names and categories."},{"role":"assistant","content":"SELECT product_name, category FROM products;"}]}
|
14 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find products priced above 100."},{"role":"assistant","content":"SELECT * FROM products WHERE price > 100;"}]}
|
15 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Average product price per category."},{"role":"assistant","content":"SELECT category, AVG(price) FROM products GROUP BY category;"}]}
|
16 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve the most expensive product."},{"role":"assistant","content":"SELECT * FROM products ORDER BY price DESC LIMIT 1;"}]}
|
17 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Count how many 'Electronics' products exist."},{"role":"assistant","content":"SELECT COUNT(*) FROM products WHERE category = 'Electronics';"}]}
|
18 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show products priced between 50 and 150."},{"role":"assistant","content":"SELECT * FROM products WHERE price BETWEEN 50 AND 150;"}]}
|
19 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List categories and product counts."},{"role":"assistant","content":"SELECT category, COUNT(*) FROM products GROUP BY category;"}]}
|
20 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Get the total sum of all product prices."},{"role":"assistant","content":"SELECT SUM(price) FROM products;"}]}
|
21 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve the cheapest product name."},{"role":"assistant","content":"SELECT product_name FROM products ORDER BY price ASC LIMIT 1;"}]}
|
22 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Which products end with 'Pro'?"},{"role":"assistant","content":"SELECT product_name FROM products WHERE product_name LIKE '%Pro';"}]}
|
23 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"How many customers are registered?"},{"role":"assistant","content":"SELECT COUNT(*) FROM customers;"}]}
|
24 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List all customer names and emails."},{"role":"assistant","content":"SELECT customer_name, email FROM customers;"}]}
|
25 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find customers in region 2."},{"role":"assistant","content":"SELECT * FROM customers WHERE region_id = 2;"}]}
|
26 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show customers with an email address."},{"role":"assistant","content":"SELECT customer_name FROM customers WHERE email IS NOT NULL;"}]}
|
27 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Count customers per region."},{"role":"assistant","content":"SELECT region_id, COUNT(*) FROM customers GROUP BY region_id;"}]}
|
28 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List customers whose name starts with 'A'."},{"role":"assistant","content":"SELECT * FROM customers WHERE customer_name LIKE 'A%';"}]}
|
29 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve emails sorted alphabetically."},{"role":"assistant","content":"SELECT email FROM customers WHERE email IS NOT NULL ORDER BY email ASC;"}]}
|
30 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Get the most recent customer by ID."},{"role":"assistant","content":"SELECT * FROM customers ORDER BY customer_id DESC LIMIT 1;"}]}
|
31 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Which customers have no region assigned?"},{"role":"assistant","content":"SELECT * FROM customers WHERE region_id IS NULL;"}]}
|
32 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"How many sales have occurred?"},{"role":"assistant","content":"SELECT COUNT(*) FROM sales;"}]}
|
33 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List sale_id, sale_date, and amount for all sales."},{"role":"assistant","content":"SELECT sale_id, sale_date, amount FROM sales;"}]}
|
34 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find sales between '2025-01-01' and '2025-06-01'."},{"role":"assistant","content":"SELECT * FROM sales WHERE sale_date BETWEEN '2025-01-01' AND '2025-06-01';"}]}
|
35 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Sum total sales amount by region."},{"role":"assistant","content":"SELECT region_id, SUM(amount) FROM sales GROUP BY region_id;"}]}
|
36 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Average quantity sold per product."},{"role":"assistant","content":"SELECT product_id, AVG(quantity) FROM sales GROUP BY product_id;"}]}
|
37 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve the highest single sale amount."},{"role":"assistant","content":"SELECT MAX(amount) FROM sales;"}]}
|
38 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Count sales for customer 5."},{"role":"assistant","content":"SELECT COUNT(*) FROM sales WHERE customer_id = 5;"}]}
|
39 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show sales of product 3 ordered by date descending."},{"role":"assistant","content":"SELECT * FROM sales WHERE product_id = 3 ORDER BY sale_date DESC;"}]}
|
40 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Get the total amount for all sales."},{"role":"assistant","content":"SELECT SUM(amount) FROM sales;"}]}
|
41 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"What was the quantity sold on the earliest sale?"},{"role":"assistant","content":"SELECT quantity FROM sales ORDER BY sale_date ASC LIMIT 1;"}]}
|
42 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"For each customer, show total sales amount."},{"role":"assistant","content":"SELECT customer_id, SUM(amount) FROM sales GROUP BY customer_id;"}]}
|
43 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List top 5 customers by total amount spent."},{"role":"assistant","content":"SELECT customer_id, SUM(amount) as total_spent FROM sales GROUP BY customer_id ORDER BY total_spent DESC LIMIT 5;"}]}
|
44 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Join customers and sales to get names and sale amounts."},{"role":"assistant","content":"SELECT c.customer_name, s.amount FROM customers c JOIN sales s ON c.customer_id = s.customer_id;"}]}
|
45 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find the most popular product by sales count."},{"role":"assistant","content":"SELECT product_id, COUNT(*) as cnt FROM sales GROUP BY product_id ORDER BY cnt DESC LIMIT 1;"}]}
|
46 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show average sale amount per region for May 2025."},{"role":"assistant","content":"SELECT region_id, AVG(amount) FROM sales WHERE sale_date LIKE '2025-05-%' GROUP BY region_id;"}]}
|
47 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List the sale_date and customer_name for every sale."},{"role":"assistant","content":"SELECT s.sale_date, c.customer_name FROM sales s JOIN customers c ON s.customer_id = c.customer_id;"}]}
|
48 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show customer_name, product_name, quantity, and amount for each sale."},{"role":"assistant","content":"SELECT c.customer_name, p.product_name, s.quantity, s.amount FROM sales s JOIN customers c ON s.customer_id = c.customer_id JOIN products p ON s.product_id = p.product_id;"}]}
|
49 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Calculate total sales amount per product category."},{"role":"assistant","content":"SELECT p.category, SUM(s.amount) AS total_revenue FROM sales s JOIN products p ON s.product_id = p.product_id GROUP BY p.category;"}]}
|
50 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Get the top 3 product categories by total revenue."},{"role":"assistant","content":"SELECT p.category, SUM(s.amount) AS revenue FROM sales s JOIN products p ON s.product_id = p.product_id GROUP BY p.category ORDER BY revenue DESC LIMIT 3;"}]}
|
51 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List customer_name and region_name for all sales over 100."},{"role":"assistant","content":"SELECT c.customer_name, r.region_name FROM sales s JOIN customers c ON s.customer_id = c.customer_id JOIN regions r ON s.region_id = r.region_id WHERE s.amount > 100;"}]}
|
52 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find all products that have never been sold."},{"role":"assistant","content":"SELECT p.product_name FROM products p LEFT JOIN sales s ON p.product_id = s.product_id WHERE s.sale_id IS NULL;"}]}
|
53 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List regions that had no sales in June 2025."},{"role":"assistant","content":"SELECT r.region_name FROM regions r LEFT JOIN sales s ON r.region_id = s.region_id AND s.sale_date LIKE '2025-06-%' WHERE s.sale_id IS NULL;"}]}
|
54 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show a running total of amount over sale_date."},{"role":"assistant","content":"SELECT sale_date, amount, SUM(amount) OVER (ORDER BY sale_date) AS running_total FROM sales;"}]}
|
55 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Rank customers by total spent."},{"role":"assistant","content":"SELECT customer_id, SUM(amount) AS total_spent, RANK() OVER (ORDER BY SUM(amount) DESC) AS rank FROM sales GROUP BY customer_id;"}]}
|
56 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Calculate the average sale amount per calendar month."},{"role":"assistant","content":"SELECT SUBSTR(sale_date,1,7) AS month, AVG(amount) FROM sales GROUP BY month;"}]}
|
57 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Compute each region’s percentage of total sales amount."},{"role":"assistant","content":"SELECT region_id, SUM(amount)*100.0/(SELECT SUM(amount) FROM sales) AS pct_of_total FROM sales GROUP BY region_id;"}]}
|
58 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"For each product, show its price and the average sale amount."},{"role":"assistant","content":"SELECT p.product_name, p.price, (SELECT AVG(amount) FROM sales WHERE product_id = p.product_id) AS avg_sale FROM products p;"}]}
|
59 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List customers who spent more than the average customer."},{"role":"assistant","content":"SELECT DISTINCT c.customer_name FROM customers c WHERE (SELECT SUM(amount) FROM sales WHERE customer_id = c.customer_id) > (SELECT AVG(total) FROM (SELECT SUM(amount) AS total FROM sales GROUP BY customer_id));"}]}
|
60 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Use a CTE to show monthly total sales."},{"role":"assistant","content":"WITH monthly_sales AS (SELECT SUBSTR(sale_date,1,7) AS month, SUM(amount) AS total FROM sales GROUP BY month) SELECT * FROM monthly_sales;"}]}
|
61 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Use a CTE to find the top 5 products by revenue."},{"role":"assistant","content":"WITH prod_rev AS (SELECT product_id, SUM(amount) AS revenue FROM sales GROUP BY product_id) SELECT p.product_name, pr.revenue FROM prod_rev pr JOIN products p ON pr.product_id = p.product_id ORDER BY pr.revenue DESC LIMIT 5;"}]}
|
62 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find each customer’s first sale date."},{"role":"assistant","content":"SELECT customer_id, MIN(sale_date) AS first_sale FROM sales GROUP BY customer_id;"}]}
|
63 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Identify repeat customers (more than one sale)."},{"role":"assistant","content":"SELECT customer_id FROM sales GROUP BY customer_id HAVING COUNT(*) > 1;"}]}
|
64 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find customers who purchased both product 1 and product 2."},{"role":"assistant","content":"SELECT DISTINCT s1.customer_id FROM sales s1 JOIN sales s2 ON s1.customer_id = s2.customer_id WHERE s1.product_id = 1 AND s2.product_id = 2;"}]}
|
65 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List products sold in every region."},{"role":"assistant","content":"SELECT product_id FROM sales GROUP BY product_id HAVING COUNT(DISTINCT region_id) = (SELECT COUNT(*) FROM regions);"}]}
|
66 |
+
{"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Determine which region has the highest average sale amount."},{"role":"assistant","content":"SELECT region_id FROM sales GROUP BY region_id ORDER BY AVG(amount) DESC LIMIT 1;"}]}
|
knowledge_base/business_glossary/kpi_definitions.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Business Glossary and KPI Definitions
|
2 |
+
|
3 |
+
This document provides definitions for key business terms and performance indicators used in our sales analysis.
|
4 |
+
|
5 |
+
---
|
6 |
+
|
7 |
+
## Key Performance Indicators (KPIs)
|
8 |
+
|
9 |
+
### Average Sales
|
10 |
+
**Definition:** The average monetary value of a single sales transaction. It is calculated by dividing the total sales amount by the total number of sales transactions over a given period.
|
11 |
+
**Calculation:** `SUM(amount) / COUNT(sale_id)` from the `sales` table.
|
12 |
+
|
13 |
+
### Sales Volume
|
14 |
+
**Definition:** The total number of products sold over a specified period. This indicates the quantity of goods moved.
|
15 |
+
**Calculation:** `SUM(quantity)` from the `sales` table.
|
16 |
+
|
17 |
+
### Revenue by Region
|
18 |
+
**Definition:** The total monetary value of sales attributed to a specific geographical region over a given period.
|
19 |
+
**Calculation:** `SUM(amount)` grouped by `region_name` from `sales` and `regions` tables.
|
20 |
+
|
21 |
+
---
|
22 |
+
|
23 |
+
## Business Terms
|
24 |
+
|
25 |
+
### Customer Segment
|
26 |
+
**Definition:** A group of customers sharing similar characteristics, often based on demographics, purchasing behavior, or psychographics. (Note: We don't have explicit segments in the current DB, but this is an example of what you'd define).
|
27 |
+
|
28 |
+
### Last Quarter
|
29 |
+
**Definition:** Refers to the most recently completed three-month period. For example, if the current month is June, the last quarter would typically be January-March. If it's July, the last quarter would be April-June.
|
30 |
+
|
31 |
+
---
|
knowledge_base/schema/data_dictionary.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data Dictionary for Sales Database
|
2 |
+
|
3 |
+
This document describes the tables and columns within the sales database.
|
4 |
+
|
5 |
+
---
|
6 |
+
|
7 |
+
## Table: `regions`
|
8 |
+
**Purpose:** Stores information about geographical sales regions.
|
9 |
+
|
10 |
+
| Column | Type | Description |
|
11 |
+
| :---------- | :------ | :-------------------------------------------- |
|
12 |
+
| `region_id` | INTEGER | Unique identifier for each sales region. |
|
13 |
+
| `region_name`| TEXT | Name of the sales region (e.g., 'North', 'South'). |
|
14 |
+
|
15 |
+
---
|
16 |
+
|
17 |
+
## Table: `products`
|
18 |
+
**Purpose:** Contains details about products sold.
|
19 |
+
|
20 |
+
| Column | Type | Description |
|
21 |
+
| :------------ | :------ | :---------------------------------------------- |
|
22 |
+
| `product_id` | INTEGER | Unique identifier for each product. |
|
23 |
+
| `product_name`| TEXT | Name of the product (e.g., 'Laptop', 'Pen Set').|
|
24 |
+
| `category` | TEXT | Product category (e.g., 'Electronics', 'Furniture').|
|
25 |
+
| `price` | REAL | Standard price of the product. |
|
26 |
+
|
27 |
+
---
|
28 |
+
|
29 |
+
## Table: `customers`
|
30 |
+
**Purpose:** Stores customer information.
|
31 |
+
|
32 |
+
| Column | Type | Description |
|
33 |
+
| :------------ | :------ | :---------------------------------------------- |
|
34 |
+
| `customer_id` | INTEGER | Unique identifier for each customer. |
|
35 |
+
| `customer_name`| TEXT | Full name of the customer. |
|
36 |
+
| `email` | TEXT | Customer's email address, unique per customer. |
|
37 |
+
| `region_id` | INTEGER | ID of the region the customer belongs to. Links to `regions.region_id`. |
|
38 |
+
|
39 |
+
---
|
40 |
+
|
41 |
+
## Table: `sales`
|
42 |
+
**Purpose:** Records individual sales transactions.
|
43 |
+
|
44 |
+
| Column | Type | Description |
|
45 |
+
| :------------ | :------ | :---------------------------------------------- |
|
46 |
+
| `sale_id` | INTEGER | Unique identifier for each sale transaction. |
|
47 |
+
| `product_id` | INTEGER | ID of the product sold. Links to `products.product_id`. |
|
48 |
+
| `customer_id` | INTEGER | ID of the customer involved in the sale. Links to `customers.customer_id`. |
|
49 |
+
| `region_id` | INTEGER | ID of the region where the sale occurred. This can also be derived from the customer's region. Links to `regions.region_id`. |
|
50 |
+
| `sale_date` | TEXT | Date of the sale in 'YYYY-MM-DD' format. |
|
51 |
+
| `quantity` | INTEGER | Number of units sold in this transaction. |
|
52 |
+
| `amount` | REAL | Total monetary value of the sale (quantity * price, potentially with variations). |
|
53 |
+
|
54 |
+
---
|
55 |
+
|
56 |
+
## Relationships:
|
57 |
+
* `customers.region_id` links to `regions.region_id`
|
58 |
+
* `sales.product_id` links to `products.product_id`
|
59 |
+
* `sales.customer_id` links to `customers.customer_id`
|
60 |
+
* `sales.region_id` links to `regions.region_id`
|
knowledge_base/schema/sales_schema.sql
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CREATE TABLE regions (
|
2 |
+
region_id INTEGER PRIMARY KEY,
|
3 |
+
region_name TEXT NOT NULL UNIQUE
|
4 |
+
);
|
5 |
+
|
6 |
+
CREATE TABLE products (
|
7 |
+
product_id INTEGER PRIMARY KEY,
|
8 |
+
product_name TEXT NOT NULL,
|
9 |
+
category TEXT NOT NULL,
|
10 |
+
price REAL NOT NULL
|
11 |
+
);
|
12 |
+
|
13 |
+
CREATE TABLE customers (
|
14 |
+
customer_id INTEGER PRIMARY KEY,
|
15 |
+
customer_name TEXT NOT NULL,
|
16 |
+
email TEXT UNIQUE,
|
17 |
+
region_id INTEGER,
|
18 |
+
FOREIGN KEY (region_id) REFERENCES regions(region_id)
|
19 |
+
);
|
20 |
+
|
21 |
+
CREATE TABLE sales (
|
22 |
+
sale_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
23 |
+
product_id INTEGER NOT NULL,
|
24 |
+
customer_id INTEGER NOT NULL,
|
25 |
+
region_id INTEGER NOT NULL,
|
26 |
+
sale_date TEXT NOT NULL, -- Stored as YYYY-MM-DD
|
27 |
+
quantity INTEGER NOT NULL,
|
28 |
+
amount REAL NOT NULL,
|
29 |
+
FOREIGN KEY (product_id) REFERENCES products(product_id),
|
30 |
+
FOREIGN KEY (customer_id) REFERENCES customers(customer_id),
|
31 |
+
FOREIGN KEY (region_id) REFERENCES regions(region_id)
|
32 |
+
);
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python-dotenv
|
2 |
+
llama-index
|
3 |
+
llama-index-embeddings-nebius
|
4 |
+
llama-index-vector-stores-chroma
|
5 |
+
sqlalchemy
|
6 |
+
llama-index-llms-nebius
|
7 |
+
openai
|
8 |
+
tqdm
|
src/agents/agent_models/__pycache__/models.cpython-313.pyc
ADDED
Binary file (973 Bytes). View file
|
|
src/agents/agent_models/models.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from llama_index.llms.nebius import NebiusLLM
|
3 |
+
from openai import OpenAI
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
base_agent_model_id = "Qwen/Qwen3-235B-A22B"
|
8 |
+
finetuned_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct-LoRa:nl-to-sql-finetuned-jbkN"
|
9 |
+
|
10 |
+
def get_base_agent_model():
|
11 |
+
return NebiusLLM(api_key=os.environ["NEBIUS_API_KEY"], model=base_agent_model_id)
|
12 |
+
|
13 |
+
def get_finetuned_model():
|
14 |
+
return NebiusLLM(api_key=os.environ["NEBIUS_API_KEY"], model=finetuned_model_id)
|
src/agents/agent_tools/__pycache__/schema_retriever_tool.cpython-313.pyc
ADDED
Binary file (5.08 kB). View file
|
|
src/agents/agent_tools/__pycache__/sql_executor_tool.cpython-313.pyc
ADDED
Binary file (3.52 kB). View file
|
|
src/agents/agent_tools/schema_retriever_tool.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/agents/tools/schema_retriever_tool.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
import chromadb
|
6 |
+
from llama_index.core.tools import FunctionTool
|
7 |
+
from llama_index.embeddings.nebius import NebiusEmbedding
|
8 |
+
from llama_index.vector_stores.chroma import ChromaVectorStore
|
9 |
+
from llama_index.core import VectorStoreIndex, Settings
|
10 |
+
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
|
13 |
+
# Define the path to your ChromaDB for schema.
|
14 |
+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
15 |
+
# Navigate up to the project root for the ChromaDB directory
|
16 |
+
# This path is relative to the *tool file*, so three levels up to the root.
|
17 |
+
CHROMA_DB_PATH = os.path.join(current_file_dir, '..', '..', '..', 'chroma_db_schema')
|
18 |
+
logging.info(f"ChromaDB Schema Path set to: {CHROMA_DB_PATH}")
|
19 |
+
|
20 |
+
# Initialize NebiusEmbedding
|
21 |
+
embed_model_name = "BAAI/bge-en-icl"
|
22 |
+
embed_api_base = "https://api.studio.nebius.com/v1/"
|
23 |
+
|
24 |
+
embeddings = None
|
25 |
+
try:
|
26 |
+
embeddings = NebiusEmbedding(
|
27 |
+
api_key=os.environ.get("NEBIUS_API_KEY"),
|
28 |
+
model_name=embed_model_name,
|
29 |
+
api_base=embed_api_base
|
30 |
+
)
|
31 |
+
Settings.embed_model = embeddings
|
32 |
+
# Test the embedding model
|
33 |
+
_ = embeddings.get_text_embedding("test validation string")
|
34 |
+
logging.info("NebiusEmbedding initialized successfully for schema retriever.")
|
35 |
+
except Exception as e:
|
36 |
+
logging.error(f"Error initializing NebiusEmbedding in schema_retriever_tool: {e}")
|
37 |
+
embeddings = None
|
38 |
+
|
39 |
+
# Set the global embedding model for LlamaIndex if not already set (good practice)
|
40 |
+
if embeddings:
|
41 |
+
Settings.embed_model = embeddings
|
42 |
+
|
43 |
+
# Main retrieval function
|
44 |
+
def retrieve_schema_context(natural_language_query: str) -> str:
|
45 |
+
if embeddings is None:
|
46 |
+
return "Error: Embedding model not initialized for schema retrieval. Cannot perform RAG. Please check your Nebius API key and model configuration."
|
47 |
+
|
48 |
+
try:
|
49 |
+
db = chromadb.PersistentClient(path=CHROMA_DB_PATH)
|
50 |
+
chroma_collection = db.get_or_create_collection(name="schema_kb")
|
51 |
+
logging.info(f"ChromaDB collection 'schema_kb' opened successfully for retrieval.")
|
52 |
+
|
53 |
+
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
54 |
+
|
55 |
+
index = VectorStoreIndex.from_vector_store(
|
56 |
+
vector_store,
|
57 |
+
embed_model=embeddings # Explicitly pass the initialized embeddings
|
58 |
+
)
|
59 |
+
query_engine = index.as_retriever(similarity_top_k=2)
|
60 |
+
retrieved_nodes = query_engine.retrieve(natural_language_query)
|
61 |
+
|
62 |
+
schema_snippets = [node.get_content() for node in retrieved_nodes]
|
63 |
+
if not schema_snippets:
|
64 |
+
return "No relevant schema context found for your query. Please rephrase or simplify."
|
65 |
+
|
66 |
+
return "Retrieved Database Schema Context (relevant to query):\n" + "\n---\n".join(schema_snippets)
|
67 |
+
|
68 |
+
except Exception as e:
|
69 |
+
logging.exception("Error in retrieve_schema_context:")
|
70 |
+
return f"Error retrieving schema from RAG: {str(e)}. Ensure ChromaDB is built at {CHROMA_DB_PATH} and embedding model is compatible."
|
71 |
+
|
72 |
+
|
73 |
+
# Exportable tool
|
74 |
+
def get_schema_retriever_tool() -> FunctionTool:
|
75 |
+
return FunctionTool.from_defaults(
|
76 |
+
fn=retrieve_schema_context,
|
77 |
+
name="retrieve_schema_context",
|
78 |
+
description=(
|
79 |
+
"Retrieves relevant database schema information (tables, columns, relationships, descriptions) "
|
80 |
+
"from the sales database knowledge base using semantic search (RAG). Always call this first "
|
81 |
+
"if you need to understand the schema for SQL generation."
|
82 |
+
)
|
83 |
+
)
|
84 |
+
|
85 |
+
# Test entry point
|
86 |
+
if __name__ == "__main__":
|
87 |
+
print("--- Testing Schema Retriever Tool Implementation (with RAG) ---")
|
88 |
+
if not os.environ.get("NEBIUS_API_KEY"):
|
89 |
+
print("Warning: NEBIUS_API_KEY not set. Schema retrieval might fail.")
|
90 |
+
|
91 |
+
schema_tool = get_schema_retriever_tool()
|
92 |
+
|
93 |
+
queries = [
|
94 |
+
"What are the columns in the sales and products tables, and how are they related?",
|
95 |
+
"Show me customer names and their regions.",
|
96 |
+
"What is the purpose of the regions table?",
|
97 |
+
"Tell me about the sales table schema and its purpose."
|
98 |
+
]
|
99 |
+
|
100 |
+
for query in queries:
|
101 |
+
print(f"\nCalling tool with: '{query}'")
|
102 |
+
result = schema_tool.call(natural_language_query=query)
|
103 |
+
print(f"Result:\n{result}")
|
src/agents/agent_tools/sql_executor_tool.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
import pandas as pd
|
3 |
+
import io
|
4 |
+
import os
|
5 |
+
from llama_index.core.tools import FunctionTool
|
6 |
+
|
7 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
8 |
+
DATABASE_PATH = os.path.join(current_dir, '..', '..', '..', 'data', 'sales_database.db')
|
9 |
+
|
10 |
+
def execute_sql_query(sql_query: str) -> str:
|
11 |
+
"""
|
12 |
+
Executes a SQL SELECT query against the sales database and returns the results as a formatted string (CSV representation).
|
13 |
+
|
14 |
+
This tool should only be used to retrieve data using SELECT statements.
|
15 |
+
It does not support INSERT, UPDATE, DELETE, or DDL commands for security reasons.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
sql_query (str): The complete and correct SQL SELECT query to execute.
|
19 |
+
Do not include semicolons at the end of the query.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
str: A CSV string representation of the query results, or an error message.
|
23 |
+
"""
|
24 |
+
# Safety check: Only allow SELECT queries
|
25 |
+
if not sql_query.strip().upper().startswith("SELECT"):
|
26 |
+
return "Error: Only SELECT queries are allowed for security reasons."
|
27 |
+
|
28 |
+
conn = None
|
29 |
+
try:
|
30 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
31 |
+
df = pd.read_sql_query(sql_query, conn)
|
32 |
+
|
33 |
+
if df.empty:
|
34 |
+
return "Query executed successfully, but no results were found."
|
35 |
+
|
36 |
+
# Convert DataFrame to a CSV string for LLM consumption
|
37 |
+
output = io.StringIO()
|
38 |
+
df.to_csv(output, index=False)
|
39 |
+
return output.getvalue()
|
40 |
+
|
41 |
+
except pd.io.sql.DatabaseError as e:
|
42 |
+
return f"Database Query Error: {e}"
|
43 |
+
except Exception as e:
|
44 |
+
return f"An unexpected error occurred during SQL execution: {e}"
|
45 |
+
finally:
|
46 |
+
if conn:
|
47 |
+
conn.close()
|
48 |
+
|
49 |
+
def get_sql_executor_tool() -> FunctionTool:
|
50 |
+
"""
|
51 |
+
Returns a LlamaIndex FunctionTool for executing SQL SELECT queries.
|
52 |
+
"""
|
53 |
+
return FunctionTool.from_defaults(
|
54 |
+
fn=execute_sql_query,
|
55 |
+
name="execute_sql_query",
|
56 |
+
description=(
|
57 |
+
"Executes a SQL SELECT query against the sales database and returns the results. "
|
58 |
+
"Use this tool to get actual data to answer the user's question. "
|
59 |
+
"Always generate the full, correct SQL query to answer the question before calling this tool. "
|
60 |
+
"Only SELECT queries are allowed. "
|
61 |
+
"Input should be the complete and correct SQL SELECT query string."
|
62 |
+
)
|
63 |
+
)
|
src/agents/nl_sql_agent.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import sys
|
4 |
+
from llama_index.llms.openai import OpenAI
|
5 |
+
from llama_index.core.agent import ReActAgent
|
6 |
+
from agent_tools.sql_executor_tool import get_sql_executor_tool
|
7 |
+
from agent_tools.schema_retriever_tool import get_schema_retriever_tool
|
8 |
+
from agent_models.models import get_finetuned_model
|
9 |
+
|
10 |
+
# Configure logging for better visibility into agent's thought process
|
11 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
12 |
+
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
13 |
+
|
14 |
+
class NLSQLAgent:
|
15 |
+
def __init__(self, model_name: str, api_key: str, base_url: str):
|
16 |
+
"""
|
17 |
+
Initializes the NL-to-SQL Agent, which translates natural language to SQL, executes it, and provides answers.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
model_name (str): The name/ID of fine-tuned model deployed on Nebius AI.
|
21 |
+
api_key (str): Nebius AI API key.
|
22 |
+
base_url (str): The base URL for the Nebius AI API endpoint
|
23 |
+
"""
|
24 |
+
self.llm = get_finetuned_model()
|
25 |
+
self.system_prompt = (
|
26 |
+
"You are an expert SQL assistant for a sales database. Your task is to accurately translate "
|
27 |
+
"natural language questions into SQL queries, execute them using the provided tools, and then "
|
28 |
+
"provide concise, natural language answers based on the query results. "
|
29 |
+
"You MUST strictly follow the Thought-Action-Action Input-Observation format for ALL steps. "
|
30 |
+
"Pay extreme attention to the 'Action Input' format, which MUST be a valid JSON dictionary "
|
31 |
+
"matching the tool's parameters. **ALWAYS provide ALL required parameters.**\n"
|
32 |
+
"Follow these rules strictly:\n"
|
33 |
+
"1. **Analyze the query:** Understand the user's intent and identify if schema context is needed.\n"
|
34 |
+
"2. **Schema Retrieval (if needed):** If the user's question requires knowledge about the "
|
35 |
+
" database structure (tables, columns, relationships) that you might not implicitly know, "
|
36 |
+
" **FIRST, ALWAYS call `retrieve_schema_context`** using the user's original query. "
|
37 |
+
" The Action Input for `retrieve_schema_context` MUST be a JSON object like: "
|
38 |
+
" `{\"natural_language_query\": \"user's exact question here\"}`. "
|
39 |
+
" Process the retrieved schema before generating SQL.\n"
|
40 |
+
"3. **SQL Generation (Crucial):** Generate a syntactically correct and semantically accurate SQL SELECT query "
|
41 |
+
" that directly answers the user's question based on the available schema information (either "
|
42 |
+
" from `retrieve_schema_context` or your fine-tuned knowledge). Only use tables and columns "
|
43 |
+
" that exist in the schema. Do not include semicolons at the end of the query.\n"
|
44 |
+
"4. **SQL Execution (MANDATORY):** **IMMEDIATELY AFTER generating a SQL query, you MUST call the `execute_sql_query` tool to run it.** "
|
45 |
+
" The Action Input for `execute_sql_query` MUST be a JSON object like: "
|
46 |
+
" `{\"sql_query\": \"your generated SQL query here\"}`.\n"
|
47 |
+
"5. **Answer Formulation (Only after SQL Execution and Observation):** Analyze the results from `execute_sql_query` "
|
48 |
+
" and formulate a clear, concise, and helpful natural language answer to the user's original question. "
|
49 |
+
" If the query returns no results, state that clearly.\n"
|
50 |
+
"6. **Error Handling:** If a question cannot be answered from the database, explain why politely. "
|
51 |
+
" Do not make up data or generate SQL for non-existent tables/columns.\n"
|
52 |
+
"7. **Iterative Refinement:** If a tool call fails or provides unexpected output, consider revising your approach "
|
53 |
+
" or informing the user about the issue. Continue the thought-process until a final answer is derived.\n"
|
54 |
+
"8. The database schema includes tables such as `regions`, `products`, `customers`, and `sales`. "
|
55 |
+
" You should understand their structures and relationships from your training and schema context. "
|
56 |
+
" For example: `sales` table likely has columns like `sale_id`, `customer_id`, `product_id`, `date`, `amount`. "
|
57 |
+
" `customers` table might have `customer_id`, `customer_name`, `region_id`. `products` table might have `product_id`, `product_name`, `price`. "
|
58 |
+
" `regions` table might have `region_id`, `region_name`."
|
59 |
+
"Remember, your final output should be an `Answer:` tag with the natural language response, not raw SQL or tool calls unless explicitly asked to show your work."
|
60 |
+
)
|
61 |
+
|
62 |
+
self.tools = [get_schema_retriever_tool(), get_sql_executor_tool()]
|
63 |
+
|
64 |
+
self.agent = ReActAgent.from_tools(
|
65 |
+
llm=self.llm,
|
66 |
+
tools=self.tools,
|
67 |
+
context=self.system_prompt,
|
68 |
+
verbose=True,
|
69 |
+
)
|
70 |
+
|
71 |
+
async def process_query(self, user_query: str) -> str:
|
72 |
+
"""
|
73 |
+
Processes a user's natural language query using the NL-to-SQL agent.
|
74 |
+
This method executes the agent's Thought-Action-Observation loop.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
user_query (str): The natural language question from the user.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
str: The final natural language answer based on SQL execution, or an error/explanation.
|
81 |
+
"""
|
82 |
+
try:
|
83 |
+
response_object = await self.agent.achat(user_query)
|
84 |
+
return str(response_object)
|
85 |
+
except Exception as e:
|
86 |
+
logging.error(f"Error in NLSQLAgent.process_query: {e}")
|
87 |
+
return f"I encountered an error while processing your request: {e}. Please try again or rephrase."
|
88 |
+
|
89 |
+
# Example Usage (for testing the NLSQLAgent directly)
|
90 |
+
if __name__ == "__main__":
|
91 |
+
NEBIUS_API_KEY = os.environ.get("NEBIUS_API_KEY")
|
92 |
+
NEBIUS_BASE_URL = "https://api.studio.nebius.com/v1/openai" # Example: Verify with Nebius AI docs
|
93 |
+
YOUR_FINETUNED_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct-LoRa:nl-to-sql-finetuned-jbkN" # Replace this
|
94 |
+
|
95 |
+
if not NEBIUS_API_KEY:
|
96 |
+
print("Error: NEBIUS_API_KEY environment variable not set.")
|
97 |
+
print("Please set it before running the agent.")
|
98 |
+
elif YOUR_FINETUNED_MODEL_ID == "your-finetuned-model-id":
|
99 |
+
print("Error: Please replace 'your-finetuned-model-id' with your actual model ID.")
|
100 |
+
else:
|
101 |
+
nl_sql_agent = NLSQLAgent(
|
102 |
+
model_name=YOUR_FINETUNED_MODEL_ID,
|
103 |
+
api_key=NEBIUS_API_KEY,
|
104 |
+
base_url=NEBIUS_BASE_URL
|
105 |
+
)
|
106 |
+
|
107 |
+
print("\nNL-to-SQL Agent initialized. Ask a question about your sales database.")
|
108 |
+
|
109 |
+
async def main_loop():
|
110 |
+
while True:
|
111 |
+
user_input = input("\nYou: ")
|
112 |
+
if user_input.lower() in ["exit", "quit"]:
|
113 |
+
print("Exiting agent.")
|
114 |
+
break
|
115 |
+
|
116 |
+
response = await nl_sql_agent.process_query(user_input)
|
117 |
+
print(f"Agent: {response}")
|
118 |
+
|
119 |
+
# For multi-turn conversations, ReActAgent handles history automatically.
|
120 |
+
# If you need to explicitly reset, you might re-instantiate or call agent.reset()
|
121 |
+
|
122 |
+
import asyncio
|
123 |
+
asyncio.run(main_loop())
|
src/agents/orchestrator_agent.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import logging
|
3 |
+
from llama_index.core.agent import ReActAgent
|
4 |
+
from agents.models import get_base_agent_model
|
5 |
+
|
6 |
+
#Logging
|
7 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
8 |
+
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
|
9 |
+
|
10 |
+
class OrchestratorAgent:
|
11 |
+
def __init__(self):
|
12 |
+
self.tools = []
|
13 |
+
self.llm = get_base_agent_model()
|
14 |
+
self.agent = ReActAgent.from_tools(llm=self.llm, tools=self.tools, context=self.system_prompt, verbose=True)
|
15 |
+
|
16 |
+
async def __call__(self, question: str) -> str:
|
17 |
+
response_object = self.agent.chat(question)
|
18 |
+
full_response_text = str(response_object)
|
19 |
+
|
20 |
+
return full_response_text
|
src/rag_index.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
load_dotenv()
|
3 |
+
import os
|
4 |
+
import chromadb
|
5 |
+
import re
|
6 |
+
import logging
|
7 |
+
from sqlalchemy import create_engine
|
8 |
+
from uuid import uuid4
|
9 |
+
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, Settings
|
10 |
+
from llama_index.vector_stores.chroma import ChromaVectorStore
|
11 |
+
from llama_index.embeddings.nebius import NebiusEmbedding
|
12 |
+
from llama_index.core import SQLDatabase
|
13 |
+
from llama_index.core.schema import TextNode
|
14 |
+
|
15 |
+
# Configure logging
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
|
18 |
+
# --- Configuration ---
|
19 |
+
DATABASE_PATH = os.path.join('data', 'sales_database.db')
|
20 |
+
|
21 |
+
# CORRECTED PATHS based on your clarification:
|
22 |
+
# schema is under knowledge_base
|
23 |
+
SCHEMA_DIR = os.path.join('knowledge_base', 'schema')
|
24 |
+
DATA_DICTIONARY_PATH = os.path.join(SCHEMA_DIR, 'data_dictionary.md')
|
25 |
+
SALES_SCHEMA_SQL_PATH = os.path.join(SCHEMA_DIR, 'sales_schema.sql')
|
26 |
+
|
27 |
+
# kpi_definitions.md is under business_glossary
|
28 |
+
KPI_DEFINITIONS_PATH = os.path.join('knowledge_base', 'business_glossary', 'kpi_definitions.md')
|
29 |
+
|
30 |
+
CHROMA_DB_SCHEMA_PATH = os.path.join('.', 'chroma_db_schema')
|
31 |
+
CHROMA_DB_KPI_PATH = os.path.join('.', 'chroma_db_kpi')
|
32 |
+
|
33 |
+
# Ensure NEBIUS_API_KEY is set
|
34 |
+
if "NEBIUS_API_KEY" not in os.environ:
|
35 |
+
logging.critical("NEBIUS_API_KEY environment variable not set. Please set it before running this script.")
|
36 |
+
raise ValueError("NEBIUS_API_KEY environment variable not set. Please set it before running this script.")
|
37 |
+
|
38 |
+
# --- Initialize Nebius AI Embedding Model (Global for consistency) ---
|
39 |
+
print("\n--- Initializing Nebius AI Embedding Model ---")
|
40 |
+
try:
|
41 |
+
embed_model = NebiusEmbedding(
|
42 |
+
api_key=os.environ["NEBIUS_API_KEY"],
|
43 |
+
model_name="BAAI/bge-en-icl", # Consistent with schema_retriever_tool for now
|
44 |
+
api_base="https://api.studio.nebius.com/v1/" # Verify this base URL is correct for the model
|
45 |
+
)
|
46 |
+
test_string = "This is a test string to generate an embedding for diagnostic purposes."
|
47 |
+
test_embedding = embed_model.get_text_embedding(test_string)
|
48 |
+
|
49 |
+
if test_embedding is None or not isinstance(test_embedding, list) or not all(isinstance(x, (int, float)) for x in test_embedding):
|
50 |
+
logging.critical(f"FATAL ERROR: NebiusEmbedding returned invalid output for test string. Type: {type(test_embedding)}, Value: {test_embedding[:10] if isinstance(test_embedding, list) else test_embedding}")
|
51 |
+
raise ValueError("NebiusEmbedding test failed: returned invalid embedding. Cannot proceed with RAG index creation.")
|
52 |
+
|
53 |
+
logging.info(f"Nebius AI Embedding Model initialized and tested successfully. Embedding length: {len(test_embedding)}")
|
54 |
+
|
55 |
+
except Exception as e:
|
56 |
+
logging.critical(f"UNRECOVERABLE ERROR during Nebius AI Embedding Model initialization or testing: {e}")
|
57 |
+
print(f"\n!!!! UNRECOVERABLE ERROR: {e} !!!!")
|
58 |
+
print("Please check your NEBIUS_API_KEY, model_name, and api_base configuration carefully.")
|
59 |
+
raise
|
60 |
+
|
61 |
+
print("--- Nebius AI Embedding Model setup complete. Proceeding to Knowledge Base setup. ---")
|
62 |
+
|
63 |
+
# --- Helper Function to parse Data Dictionary ---
|
64 |
+
def parse_data_dictionary_md(md_path: str) -> dict:
|
65 |
+
table_descriptions = {}
|
66 |
+
if not os.path.exists(md_path):
|
67 |
+
logging.error(f"Data dictionary file not found: {md_path}")
|
68 |
+
return {}
|
69 |
+
|
70 |
+
with open(md_path, 'r', encoding='utf-8') as f:
|
71 |
+
content = f.read()
|
72 |
+
|
73 |
+
table_headers = re.finditer(r'## Table: `(\w+)`\n\*\*Purpose:\*\* (.+?)(?=\n## Table: `|\Z)', content, re.DOTALL)
|
74 |
+
|
75 |
+
for match in table_headers:
|
76 |
+
table_name = match.group(1).strip()
|
77 |
+
purpose = match.group(2).strip()
|
78 |
+
table_descriptions[table_name] = purpose
|
79 |
+
logging.info(f"Parsed table: {table_name}, Purpose length: {len(purpose)}")
|
80 |
+
|
81 |
+
return table_descriptions
|
82 |
+
|
83 |
+
# --- Setup for Schema Retriever Agent's Knowledge Base ---
|
84 |
+
print("\n--- Setting up Schema Retriever Agent's Knowledge Base (chroma_db_schema) ---")
|
85 |
+
try:
|
86 |
+
data_dict_descriptions = parse_data_dictionary_md(DATA_DICTIONARY_PATH)
|
87 |
+
logging.info(f"Loaded descriptions for {len(data_dict_descriptions)} tables from data_dictionary.md")
|
88 |
+
|
89 |
+
engine = create_engine(f"sqlite:///{DATABASE_PATH}")
|
90 |
+
sql_database = SQLDatabase(engine)
|
91 |
+
logging.info(f"Connected to database: {DATABASE_PATH}")
|
92 |
+
|
93 |
+
all_table_names = sql_database.get_usable_table_names()
|
94 |
+
|
95 |
+
schema_nodes = []
|
96 |
+
if not all_table_names:
|
97 |
+
logging.warning("WARNING: No tables found in the database. Schema KB will be empty.")
|
98 |
+
|
99 |
+
for table_name in all_table_names:
|
100 |
+
ddl = sql_database.get_single_table_info(table_name)
|
101 |
+
human_description = data_dict_descriptions.get(table_name, "No specific description available for this table.")
|
102 |
+
|
103 |
+
combined_context = f"Table Name: {table_name}\n" \
|
104 |
+
f"Description: {human_description}\n" \
|
105 |
+
f"Table Schema (DDL):\n{ddl if ddl else ''}"
|
106 |
+
|
107 |
+
node_embedding = embed_model.get_text_embedding(combined_context)
|
108 |
+
if node_embedding is None:
|
109 |
+
raise ValueError(f"Failed to generate embedding for schema table: {table_name}")
|
110 |
+
|
111 |
+
# Adding a simple metadata dictionary
|
112 |
+
schema_nodes.append(TextNode(text=combined_context, embedding=node_embedding, id_=table_name, metadata={"table_name": table_name, "source": "data_dictionary"}))
|
113 |
+
|
114 |
+
chroma_client_schema = chromadb.PersistentClient(path=CHROMA_DB_SCHEMA_PATH)
|
115 |
+
chroma_collection_schema = chroma_client_schema.get_or_create_collection(name="schema_kb")
|
116 |
+
|
117 |
+
if chroma_collection_schema.count() > 0:
|
118 |
+
logging.info(f"Clearing {chroma_collection_schema.count()} existing items from schema_kb before re-indexing.")
|
119 |
+
chroma_collection_schema.delete(ids=[id_ for id_ in chroma_collection_schema.get()['ids']])
|
120 |
+
|
121 |
+
|
122 |
+
logging.info(f"Adding {len(schema_nodes)} schema nodes to ChromaDB directly...")
|
123 |
+
if schema_nodes:
|
124 |
+
chroma_collection_schema.add(
|
125 |
+
documents=[node.text for node in schema_nodes],
|
126 |
+
embeddings=[node.embedding for node in schema_nodes],
|
127 |
+
# Pass non-empty metadata dict for each node
|
128 |
+
metadatas=[node.metadata for node in schema_nodes],
|
129 |
+
ids=[node.id_ for node in schema_nodes]
|
130 |
+
)
|
131 |
+
logging.info(f"Schema knowledge base indexed and persisted to {CHROMA_DB_SCHEMA_PATH}")
|
132 |
+
except Exception as e:
|
133 |
+
logging.exception("Error setting up Schema KB:")
|
134 |
+
print(f"Error setting up Schema KB: {e}")
|
135 |
+
print("Please ensure your database is created and accessible and data_dictionary.md is correctly formatted.")
|
136 |
+
print("If the error persists, there might be a deeper compatibility issue. See above for more detailed embedding checks.")
|
137 |
+
raise
|
138 |
+
|
139 |
+
|
140 |
+
# --- Setup for KPI Answering Agent's Knowledge Base (kpi_definitions.md) ---
|
141 |
+
print("\n--- Setting up KPI Answering Agent's Knowledge Base (chroma_db_kpi) ---")
|
142 |
+
try:
|
143 |
+
kpi_docs = SimpleDirectoryReader(input_files=[KPI_DEFINITIONS_PATH]).load_data()
|
144 |
+
logging.info(f"Loaded {len(kpi_docs)} documents for KPI Agent.")
|
145 |
+
|
146 |
+
kpi_nodes = []
|
147 |
+
for doc in kpi_docs:
|
148 |
+
node_embedding = embed_model.get_text_embedding(doc.get_content())
|
149 |
+
if node_embedding is None:
|
150 |
+
raise ValueError(f"Failed to generate embedding for KPI document: {doc.id_}")
|
151 |
+
# Adding a simple metadata dictionary
|
152 |
+
kpi_nodes.append(TextNode(text=doc.get_content(), embedding=node_embedding, id_=str(uuid4()), metadata={"source_file": os.path.basename(KPI_DEFINITIONS_PATH), "doc_type": "kpi_definition"}))
|
153 |
+
|
154 |
+
chroma_client_kpi = chromadb.PersistentClient(path=CHROMA_DB_KPI_PATH)
|
155 |
+
chroma_collection_kpi = chroma_client_kpi.get_or_create_collection(name="kpi_kb")
|
156 |
+
|
157 |
+
if chroma_collection_kpi.count() > 0:
|
158 |
+
logging.info(f"Clearing {chroma_collection_kpi.count()} existing items from kpi_kb before re-indexing.")
|
159 |
+
chroma_collection_kpi.delete(ids=[id_ for id_ in chroma_collection_kpi.get()['ids']])
|
160 |
+
|
161 |
+
logging.info(f"Adding {len(kpi_nodes)} KPI nodes to ChromaDB directly...")
|
162 |
+
if kpi_nodes:
|
163 |
+
chroma_collection_kpi.add(
|
164 |
+
documents=[node.text for node in kpi_nodes],
|
165 |
+
embeddings=[node.embedding for node in kpi_nodes],
|
166 |
+
# Pass non-empty metadata dict for each node
|
167 |
+
metadatas=[node.metadata for node in kpi_nodes],
|
168 |
+
ids=[node.id_ for node in kpi_nodes]
|
169 |
+
)
|
170 |
+
logging.info(f"KPI knowledge base indexed and persisted to {CHROMA_DB_KPI_PATH}")
|
171 |
+
except Exception as e:
|
172 |
+
logging.exception("Error setting up KPI KB:")
|
173 |
+
print(f"Error setting up KPI KB: {e}")
|
174 |
+
print("Please ensure your 'kpi_definitions.md' file exists and contains valid text content.")
|
175 |
+
print("If the error persists, there might be a deeper compatibility issue. See above for more detailed embedding checks.")
|
176 |
+
raise
|
177 |
+
|
178 |
+
|
179 |
+
print("\nKnowledge base setup complete for RAG agents!")
|
180 |
+
print(f"Indices are saved in '{CHROMA_DB_SCHEMA_PATH}' and '{CHROMA_DB_KPI_PATH}' directories.")
|
src/setup_database.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
import os
|
3 |
+
import pandas as pd
|
4 |
+
from datetime import datetime, timedelta
|
5 |
+
import random
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
DATABASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data')
|
9 |
+
DATABASE_PATH = os.path.join(DATABASE_DIR, 'sales_database.db')
|
10 |
+
|
11 |
+
def create_database():
|
12 |
+
"""Creates the SQLite database file if it doesn't exist."""
|
13 |
+
os.makedirs(DATABASE_DIR, exist_ok=True)
|
14 |
+
conn = None
|
15 |
+
try:
|
16 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
17 |
+
print(f"Database created successfully at {DATABASE_PATH}")
|
18 |
+
except sqlite3.Error as e:
|
19 |
+
print(f"Error creating database: {e}")
|
20 |
+
finally:
|
21 |
+
if conn:
|
22 |
+
conn.close()
|
23 |
+
|
24 |
+
def create_tables(conn):
|
25 |
+
"""Creates necessary tables in the database."""
|
26 |
+
cursor = conn.cursor()
|
27 |
+
|
28 |
+
# Regions Table
|
29 |
+
cursor.execute('''
|
30 |
+
CREATE TABLE IF NOT EXISTS regions (
|
31 |
+
region_id INTEGER PRIMARY KEY,
|
32 |
+
region_name TEXT NOT NULL UNIQUE
|
33 |
+
);
|
34 |
+
''')
|
35 |
+
|
36 |
+
# Products Table
|
37 |
+
cursor.execute('''
|
38 |
+
CREATE TABLE IF NOT EXISTS products (
|
39 |
+
product_id INTEGER PRIMARY KEY,
|
40 |
+
product_name TEXT NOT NULL,
|
41 |
+
category TEXT NOT NULL,
|
42 |
+
price REAL NOT NULL
|
43 |
+
);
|
44 |
+
''')
|
45 |
+
|
46 |
+
# Customers Table
|
47 |
+
cursor.execute('''
|
48 |
+
CREATE TABLE IF NOT EXISTS customers (
|
49 |
+
customer_id INTEGER PRIMARY KEY,
|
50 |
+
customer_name TEXT NOT NULL,
|
51 |
+
email TEXT UNIQUE,
|
52 |
+
region_id INTEGER,
|
53 |
+
FOREIGN KEY (region_id) REFERENCES regions(region_id)
|
54 |
+
);
|
55 |
+
''')
|
56 |
+
|
57 |
+
# Sales Table
|
58 |
+
cursor.execute('''
|
59 |
+
CREATE TABLE IF NOT EXISTS sales (
|
60 |
+
sale_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
61 |
+
product_id INTEGER NOT NULL,
|
62 |
+
customer_id INTEGER NOT NULL,
|
63 |
+
region_id INTEGER NOT NULL,
|
64 |
+
sale_date TEXT NOT NULL, -- Stored as YYYY-MM-DD
|
65 |
+
quantity INTEGER NOT NULL,
|
66 |
+
amount REAL NOT NULL,
|
67 |
+
FOREIGN KEY (product_id) REFERENCES products(product_id),
|
68 |
+
FOREIGN KEY (customer_id) REFERENCES customers(customer_id),
|
69 |
+
FOREIGN KEY (region_id) REFERENCES regions(region_id)
|
70 |
+
);
|
71 |
+
''')
|
72 |
+
conn.commit()
|
73 |
+
print("Tables created successfully.")
|
74 |
+
|
75 |
+
def insert_dummy_data(conn):
|
76 |
+
"""Inserts sample data into the tables."""
|
77 |
+
cursor = conn.cursor()
|
78 |
+
|
79 |
+
# Clear existing data before inserting new, to ensure fresh start
|
80 |
+
print("Clearing existing data...")
|
81 |
+
cursor.execute("DELETE FROM sales")
|
82 |
+
cursor.execute("DELETE FROM customers")
|
83 |
+
cursor.execute("DELETE FROM products")
|
84 |
+
cursor.execute("DELETE FROM regions")
|
85 |
+
conn.commit()
|
86 |
+
print("Existing data cleared.")
|
87 |
+
|
88 |
+
|
89 |
+
# Insert Regions
|
90 |
+
regions_data = [
|
91 |
+
(1, 'North'),
|
92 |
+
(2, 'South'),
|
93 |
+
(3, 'East'),
|
94 |
+
(4, 'West'),
|
95 |
+
(5, 'Central') # Added a new region
|
96 |
+
]
|
97 |
+
cursor.executemany("INSERT OR IGNORE INTO regions (region_id, region_name) VALUES (?, ?)", regions_data)
|
98 |
+
print(f"Inserted {len(regions_data)} regions.")
|
99 |
+
|
100 |
+
# Insert Products (More variety, including sub-categories implicitly via name)
|
101 |
+
products_data = [
|
102 |
+
(101, 'Laptop Basic', 'Electronics', 800.00),
|
103 |
+
(102, 'Laptop Pro', 'Electronics', 1500.00),
|
104 |
+
(103, 'Smartphone X', 'Electronics', 700.00),
|
105 |
+
(104, 'Smartwatch', 'Electronics', 250.00),
|
106 |
+
(105, 'Headphones Noise-Cancelling', 'Electronics', 180.00),
|
107 |
+
(201, 'Office Desk Standard', 'Furniture', 200.00),
|
108 |
+
(202, 'Ergonomic Chair', 'Furniture', 350.00),
|
109 |
+
(203, 'Bookshelf Small', 'Furniture', 80.00),
|
110 |
+
(301, 'Notebook A4', 'Stationery', 12.00),
|
111 |
+
(302, 'Premium Pen Set', 'Stationery', 30.00),
|
112 |
+
(303, 'Art Supplies Kit', 'Stationery', 45.00),
|
113 |
+
(401, 'Running Shoes', 'Apparel', 110.00), # New Category
|
114 |
+
(402, 'T-Shirt Cotton', 'Apparel', 25.00),
|
115 |
+
]
|
116 |
+
cursor.executemany("INSERT OR IGNORE INTO products (product_id, product_name, category, price) VALUES (?, ?, ?, ?)", products_data)
|
117 |
+
print(f"Inserted {len(products_data)} products.")
|
118 |
+
|
119 |
+
# Insert Customers (More customers)
|
120 |
+
customer_names = [
|
121 |
+
'Alice Smith', 'Bob Johnson', 'Charlie Brown', 'Diana Prince', 'Eve Adams',
|
122 |
+
'Frank White', 'Grace Lee', 'Henry Green', 'Ivy Chen', 'Jack Taylor',
|
123 |
+
'Karen Black', 'Liam Davis', 'Mia Wilson', 'Noah Martinez', 'Olivia Garcia',
|
124 |
+
'Peter Rodriguez', 'Quinn Miller', 'Rachel Jones', 'Sam Hernandez', 'Tina Clark'
|
125 |
+
] # 20 customers
|
126 |
+
customers_data = []
|
127 |
+
for i, name in enumerate(customer_names):
|
128 |
+
# Assign customers to regions, trying to balance or bias
|
129 |
+
region_id = random.choice([1, 1, 2, 3, 4, 5]) # North region (ID 1) slightly more customers
|
130 |
+
customers_data.append((i + 1, name, f"{name.replace(' ', '').lower()}@example.com", region_id))
|
131 |
+
cursor.executemany("INSERT OR IGNORE INTO customers (customer_id, customer_name, email, region_id) VALUES (?, ?, ?, ?)", customers_data)
|
132 |
+
print(f"Inserted {len(customers_data)} customers.")
|
133 |
+
|
134 |
+
|
135 |
+
# --- Sales Data Generation Configuration (MOVED TO HERE) ---
|
136 |
+
NUM_SALES_RECORDS = 15000
|
137 |
+
|
138 |
+
end_date = datetime.now()
|
139 |
+
|
140 |
+
# Ensure a good chunk of sales are recent
|
141 |
+
recent_sales_period_days = 60 # Sales within last 2 months
|
142 |
+
long_term_sales_period_days = 730 # Sales within last 2 years
|
143 |
+
|
144 |
+
# Define how many sales should be recent vs. long-term
|
145 |
+
recent_sales_proportion = 0.3 # 30% of sales in the last 2 months
|
146 |
+
num_recent_sales = int(NUM_SALES_RECORDS * recent_sales_proportion)
|
147 |
+
num_long_term_sales = NUM_SALES_RECORDS - num_recent_sales
|
148 |
+
|
149 |
+
products_df = pd.DataFrame(products_data, columns=['product_id', 'product_name', 'category', 'price'])
|
150 |
+
customers_df = pd.DataFrame(customers_data, columns=['customer_id', 'customer_name', 'email', 'region_id'])
|
151 |
+
regions_df = pd.DataFrame(regions_data, columns=['region_id', 'region_name'])
|
152 |
+
|
153 |
+
sales_records = []
|
154 |
+
|
155 |
+
# Define some regional biases (e.g., North sells more)
|
156 |
+
region_weights = {1: 0.30, 2: 0.25, 3: 0.20, 4: 0.15, 5: 0.10}
|
157 |
+
weighted_regions = [r for r, weight in region_weights.items() for _ in range(int(weight * 100))]
|
158 |
+
|
159 |
+
# Define some category biases (e.g., Electronics sells more)
|
160 |
+
category_weights = {'Electronics': 0.40, 'Furniture': 0.30, 'Stationery': 0.20, 'Apparel': 0.10}
|
161 |
+
all_weighted_products = []
|
162 |
+
for cat, weight in category_weights.items():
|
163 |
+
cat_products = products_df[products_df['category'] == cat]['product_id'].tolist()
|
164 |
+
all_weighted_products.extend(cat_products * int(weight * 100 / (len(cat_products) if cat_products else 1)))
|
165 |
+
|
166 |
+
# Simulate some seasonality (e.g., more sales towards end of year)
|
167 |
+
def get_seasonal_factor(date):
|
168 |
+
month = date.month
|
169 |
+
if month in [11, 12]: return random.uniform(1.1, 1.3)
|
170 |
+
elif month in [1, 2]: return random.uniform(0.8, 1.0)
|
171 |
+
else: return random.uniform(0.9, 1.1)
|
172 |
+
|
173 |
+
# --- End Sales Data Generation Configuration ---
|
174 |
+
|
175 |
+
|
176 |
+
# Generate recent sales
|
177 |
+
for _ in tqdm(range(num_recent_sales), desc="Generating Recent Sales Data"):
|
178 |
+
current_date = end_date - timedelta(days=random.randint(0, recent_sales_period_days - 1))
|
179 |
+
|
180 |
+
chosen_region_id = random.choice(weighted_regions)
|
181 |
+
region_customers = customers_df[customers_df['region_id'] == chosen_region_id]
|
182 |
+
customer = region_customers.sample(1).iloc[0] if not region_customers.empty else customers_df.sample(1).iloc[0]
|
183 |
+
|
184 |
+
product = products_df[products_df['product_id'] == random.choice(all_weighted_products)].iloc[0]
|
185 |
+
|
186 |
+
quantity = random.randint(1, 5)
|
187 |
+
seasonal_factor = get_seasonal_factor(current_date)
|
188 |
+
amount = round(quantity * product['price'] * (1 + random.uniform(-0.1, 0.1)) * seasonal_factor, 2)
|
189 |
+
|
190 |
+
sales_records.append((
|
191 |
+
int(product['product_id']), # Explicitly cast to int
|
192 |
+
int(customer['customer_id']), # Explicitly cast to int
|
193 |
+
int(customer['region_id']), # Explicitly cast to int
|
194 |
+
current_date.strftime('%Y-%m-%d'),
|
195 |
+
quantity,
|
196 |
+
amount
|
197 |
+
))
|
198 |
+
|
199 |
+
# Generate long-term sales (apply the same fix here)
|
200 |
+
for _ in tqdm(range(num_long_term_sales), desc="Generating Long-Term Sales Data"):
|
201 |
+
current_date = end_date - timedelta(days=random.randint(recent_sales_period_days, long_term_sales_period_days - 1))
|
202 |
+
|
203 |
+
chosen_region_id = random.choice(weighted_regions)
|
204 |
+
region_customers = customers_df[customers_df['region_id'] == chosen_region_id]
|
205 |
+
customer = region_customers.sample(1).iloc[0] if not region_customers.empty else customers_df.sample(1).iloc[0]
|
206 |
+
|
207 |
+
product = products_df[products_df['product_id'] == random.choice(all_weighted_products)].iloc[0]
|
208 |
+
|
209 |
+
quantity = random.randint(1, 5)
|
210 |
+
seasonal_factor = get_seasonal_factor(current_date)
|
211 |
+
amount = round(quantity * product['price'] * (1 + random.uniform(-0.1, 0.1)) * seasonal_factor, 2)
|
212 |
+
|
213 |
+
sales_records.append((
|
214 |
+
int(product['product_id']), # Explicitly cast to int
|
215 |
+
int(customer['customer_id']), # Explicitly cast to int
|
216 |
+
int(customer['region_id']), # Explicitly cast to int
|
217 |
+
current_date.strftime('%Y-%m-%d'),
|
218 |
+
quantity,
|
219 |
+
amount
|
220 |
+
))
|
221 |
+
|
222 |
+
cursor.executemany("INSERT INTO sales (product_id, customer_id, region_id, sale_date, quantity, amount) VALUES (?, ?, ?, ?, ?, ?)", sales_records)
|
223 |
+
conn.commit()
|
224 |
+
print(f"Inserted {len(sales_records)} dummy sales records.")
|
225 |
+
print("Dummy data inserted successfully.")
|
226 |
+
|
227 |
+
def main():
|
228 |
+
create_database()
|
229 |
+
conn = None
|
230 |
+
try:
|
231 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
232 |
+
create_tables(conn)
|
233 |
+
insert_dummy_data(conn)
|
234 |
+
except sqlite3.Error as e:
|
235 |
+
print(f"An error occurred: {e}")
|
236 |
+
finally:
|
237 |
+
if conn:
|
238 |
+
conn.close()
|
239 |
+
print("Database connection closed.")
|
240 |
+
|
241 |
+
if __name__ == "__main__":
|
242 |
+
main()
|