DakshChaudhary commited on
Commit
a343d94
·
0 Parent(s):

Fine-tuned model + RAG + completed NL-to-SQL Agent

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
app.py ADDED
File without changes
fine-tuning/finetune_script.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import OpenAI
3
+ import time
4
+ from dotenv import load_dotenv
5
+ load_dotenv()
6
+
7
+
8
+ client = OpenAI(
9
+ base_url="https://api.studio.nebius.com/v1/",
10
+ api_key=os.environ.get("NEBIUS_API_KEY"),
11
+ )
12
+
13
+ # Upload a training dataset
14
+ training_dataset = client.files.create(
15
+ file=open("fine-tuning/nl_sql_finetune_dataset.jsonl", "rb"),
16
+ purpose="fine-tune"
17
+ )
18
+
19
+ # Fine-tuning job parameters
20
+ job_request = {
21
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
22
+ "suffix": "nl-to-sql-finetuned",
23
+ "training_file": training_dataset.id,
24
+ "hyperparameters": {
25
+ "batch_size": 8,
26
+ "learning_rate": 2e-5,
27
+ "n_epochs": 10,
28
+ "warmup_ratio": 0.1,
29
+ "weight_decay": 0.01,
30
+ "lora": True,
31
+ "lora_r": 8,
32
+ "lora_alpha": 16,
33
+ "lora_dropout": 0.05,
34
+ "packing": True,
35
+ "max_grad_norm": 1,
36
+ },
37
+ }
38
+
39
+ # Create and run the fine-tuning job
40
+ job = client.fine_tuning.jobs.create(**job_request)
41
+
42
+ # Make sure that the job has been finished or cancelled
43
+ active_statuses = ["validating_files", "queued", "running"]
44
+ while job.status in active_statuses:
45
+ time.sleep(15)
46
+ job = client.fine_tuning.jobs.retrieve(job.id)
47
+ print("current status is", job.status)
48
+
49
+ print("Job ID:", job.id)
50
+
51
+ if job.status == "succeeded":
52
+ # Check the job events
53
+ events = client.fine_tuning.jobs.list_events(job.id)
54
+ print(events)
55
+
56
+ for checkpoint in client.fine_tuning.jobs.checkpoints.list(job.id).data:
57
+ print("Checkpoint ID:", checkpoint.id)
58
+
59
+ # Create a directory for every checkpoint
60
+ os.makedirs(f"fine-tuning/model_checkpoints/{checkpoint.id}", exist_ok=True)
61
+
62
+ for model_file_id in checkpoint.result_files:
63
+ # Get the name of a model file
64
+ filename = client.files.retrieve(model_file_id).filename
65
+
66
+ # Retrieve the contents of the file
67
+ file_content = client.files.content(model_file_id)
68
+
69
+ # Save the contents into a local file
70
+ file_content.write_to_file(filename)
fine-tuning/finetuned_model_deployment.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ from openai import OpenAI
4
+ import time
5
+ from dotenv import load_dotenv
6
+ load_dotenv()
7
+
8
+ api_key=os.environ.get('NEBIUS_API_KEY')
9
+ api_url="https://api.studio.nebius.com"
10
+
11
+ client = OpenAI(
12
+ base_url=api_url+"/v1",
13
+ api_key=api_key,
14
+ )
15
+
16
+
17
+ def create_lora_from_job(name, ft_job, ft_checkpoint, base_model):
18
+ fine_tuning_result = f"{ft_job}:{ft_checkpoint}"
19
+ lora_creation_request = {
20
+ "source": fine_tuning_result,
21
+ "base_model": base_model,
22
+ "name": name,
23
+ "description": "Deploying fine-tuned LoRA model"
24
+ }
25
+
26
+ response = requests.post(
27
+ f"{api_url}/v0/models",
28
+ json=lora_creation_request,
29
+ headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
30
+ )
31
+
32
+ print("Model creation response code:", response.status_code)
33
+ print("Model creation response body:", response.text)
34
+
35
+ response.raise_for_status()
36
+ return response.json()
37
+
38
+ # Get the custom model information, check its status and wait for the validation to complete
39
+ def wait_for_validation(name, delay=5):
40
+ while True:
41
+ time.sleep(delay)
42
+ print(f"Validating LORA: {name}")
43
+ lora_info=requests.get(
44
+ f"{api_url}/v0/models/{name}",
45
+ headers={"Content-Type": "application/json","Authorization": f"Bearer {api_key}"}
46
+ ).json()
47
+ if lora_info.get("status") in {"active", "error"}:
48
+ return lora_info
49
+
50
+ # Create a multi-message request
51
+ def get_completion(model):
52
+ completion = client.chat.completions.create(
53
+ model=model,
54
+ messages=[{"role":"user","content":"Hello"}],
55
+ )
56
+ return completion.choices[0].message.content
57
+
58
+ # Deploy a LoRA adapter model by using IDs of a fine-tuning job and its checkpoint
59
+ lora_name=create_lora_from_job("nl-to-sql-adapter", "ftjob-b8246b132b234b78a73e4b508817868d", "ftckpt_db41d377-7e1c-43f8-838a-73801db71444", "meta-llama/Meta-Llama-3.1-8B-Instruct").get("name")
60
+
61
+ # Check the custom model status
62
+ lora_info = wait_for_validation(lora_name)
63
+
64
+ # If the custom model validation is successful, create a multi-message request with this model
65
+ if lora_info.get("status") == "active":
66
+ print(get_completion(lora_name))
67
+ # If there is an error, display the reason
68
+ elif lora_info.get("status") == "error":
69
+ print(f"An error occurred: {lora_info["status_reason"]}")
70
+
fine-tuning/nl_sql_finetune_dataset.jsonl ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"How many regions are there in the database?"},{"role":"assistant","content":"SELECT COUNT(*) FROM regions;"}]}
2
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List all region names."},{"role":"assistant","content":"SELECT region_name FROM regions;"}]}
3
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve region_ids and names for regions starting with 'West'."},{"role":"assistant","content":"SELECT region_id, region_name FROM regions WHERE region_name LIKE 'West%';"}]}
4
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show me region names in alphabetical order."},{"role":"assistant","content":"SELECT region_name FROM regions ORDER BY region_name ASC;"}]}
5
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Count regions whose names contain 'East'."},{"role":"assistant","content":"SELECT COUNT(*) FROM regions WHERE region_name LIKE '%East%';"}]}
6
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Get the maximum region_id."},{"role":"assistant","content":"SELECT MAX(region_id) FROM regions;"}]}
7
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Select all regions with region_id > 3."},{"role":"assistant","content":"SELECT * FROM regions WHERE region_id > 3;"}]}
8
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve the second highest region_id."},{"role":"assistant","content":"SELECT region_id FROM regions ORDER BY region_id DESC LIMIT 1 OFFSET 1;"}]}
9
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List the first three regions by ID."},{"role":"assistant","content":"SELECT * FROM regions ORDER BY region_id ASC LIMIT 3;"}]}
10
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List region names in reverse alphabetical order."},{"role":"assistant","content":"SELECT region_name FROM regions ORDER BY region_name DESC;"}]}
11
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Which regions have even IDs?"},{"role":"assistant","content":"SELECT * FROM regions WHERE region_id % 2 = 0;"}]}
12
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"How many products are available?"},{"role":"assistant","content":"SELECT COUNT(*) FROM products;"}]}
13
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List all product names and categories."},{"role":"assistant","content":"SELECT product_name, category FROM products;"}]}
14
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find products priced above 100."},{"role":"assistant","content":"SELECT * FROM products WHERE price > 100;"}]}
15
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Average product price per category."},{"role":"assistant","content":"SELECT category, AVG(price) FROM products GROUP BY category;"}]}
16
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve the most expensive product."},{"role":"assistant","content":"SELECT * FROM products ORDER BY price DESC LIMIT 1;"}]}
17
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Count how many 'Electronics' products exist."},{"role":"assistant","content":"SELECT COUNT(*) FROM products WHERE category = 'Electronics';"}]}
18
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show products priced between 50 and 150."},{"role":"assistant","content":"SELECT * FROM products WHERE price BETWEEN 50 AND 150;"}]}
19
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List categories and product counts."},{"role":"assistant","content":"SELECT category, COUNT(*) FROM products GROUP BY category;"}]}
20
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Get the total sum of all product prices."},{"role":"assistant","content":"SELECT SUM(price) FROM products;"}]}
21
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve the cheapest product name."},{"role":"assistant","content":"SELECT product_name FROM products ORDER BY price ASC LIMIT 1;"}]}
22
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Which products end with 'Pro'?"},{"role":"assistant","content":"SELECT product_name FROM products WHERE product_name LIKE '%Pro';"}]}
23
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"How many customers are registered?"},{"role":"assistant","content":"SELECT COUNT(*) FROM customers;"}]}
24
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List all customer names and emails."},{"role":"assistant","content":"SELECT customer_name, email FROM customers;"}]}
25
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find customers in region 2."},{"role":"assistant","content":"SELECT * FROM customers WHERE region_id = 2;"}]}
26
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show customers with an email address."},{"role":"assistant","content":"SELECT customer_name FROM customers WHERE email IS NOT NULL;"}]}
27
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Count customers per region."},{"role":"assistant","content":"SELECT region_id, COUNT(*) FROM customers GROUP BY region_id;"}]}
28
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List customers whose name starts with 'A'."},{"role":"assistant","content":"SELECT * FROM customers WHERE customer_name LIKE 'A%';"}]}
29
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve emails sorted alphabetically."},{"role":"assistant","content":"SELECT email FROM customers WHERE email IS NOT NULL ORDER BY email ASC;"}]}
30
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Get the most recent customer by ID."},{"role":"assistant","content":"SELECT * FROM customers ORDER BY customer_id DESC LIMIT 1;"}]}
31
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Which customers have no region assigned?"},{"role":"assistant","content":"SELECT * FROM customers WHERE region_id IS NULL;"}]}
32
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"How many sales have occurred?"},{"role":"assistant","content":"SELECT COUNT(*) FROM sales;"}]}
33
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List sale_id, sale_date, and amount for all sales."},{"role":"assistant","content":"SELECT sale_id, sale_date, amount FROM sales;"}]}
34
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find sales between '2025-01-01' and '2025-06-01'."},{"role":"assistant","content":"SELECT * FROM sales WHERE sale_date BETWEEN '2025-01-01' AND '2025-06-01';"}]}
35
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Sum total sales amount by region."},{"role":"assistant","content":"SELECT region_id, SUM(amount) FROM sales GROUP BY region_id;"}]}
36
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Average quantity sold per product."},{"role":"assistant","content":"SELECT product_id, AVG(quantity) FROM sales GROUP BY product_id;"}]}
37
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Retrieve the highest single sale amount."},{"role":"assistant","content":"SELECT MAX(amount) FROM sales;"}]}
38
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Count sales for customer 5."},{"role":"assistant","content":"SELECT COUNT(*) FROM sales WHERE customer_id = 5;"}]}
39
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show sales of product 3 ordered by date descending."},{"role":"assistant","content":"SELECT * FROM sales WHERE product_id = 3 ORDER BY sale_date DESC;"}]}
40
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Get the total amount for all sales."},{"role":"assistant","content":"SELECT SUM(amount) FROM sales;"}]}
41
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"What was the quantity sold on the earliest sale?"},{"role":"assistant","content":"SELECT quantity FROM sales ORDER BY sale_date ASC LIMIT 1;"}]}
42
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"For each customer, show total sales amount."},{"role":"assistant","content":"SELECT customer_id, SUM(amount) FROM sales GROUP BY customer_id;"}]}
43
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List top 5 customers by total amount spent."},{"role":"assistant","content":"SELECT customer_id, SUM(amount) as total_spent FROM sales GROUP BY customer_id ORDER BY total_spent DESC LIMIT 5;"}]}
44
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Join customers and sales to get names and sale amounts."},{"role":"assistant","content":"SELECT c.customer_name, s.amount FROM customers c JOIN sales s ON c.customer_id = s.customer_id;"}]}
45
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find the most popular product by sales count."},{"role":"assistant","content":"SELECT product_id, COUNT(*) as cnt FROM sales GROUP BY product_id ORDER BY cnt DESC LIMIT 1;"}]}
46
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show average sale amount per region for May 2025."},{"role":"assistant","content":"SELECT region_id, AVG(amount) FROM sales WHERE sale_date LIKE '2025-05-%' GROUP BY region_id;"}]}
47
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List the sale_date and customer_name for every sale."},{"role":"assistant","content":"SELECT s.sale_date, c.customer_name FROM sales s JOIN customers c ON s.customer_id = c.customer_id;"}]}
48
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show customer_name, product_name, quantity, and amount for each sale."},{"role":"assistant","content":"SELECT c.customer_name, p.product_name, s.quantity, s.amount FROM sales s JOIN customers c ON s.customer_id = c.customer_id JOIN products p ON s.product_id = p.product_id;"}]}
49
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Calculate total sales amount per product category."},{"role":"assistant","content":"SELECT p.category, SUM(s.amount) AS total_revenue FROM sales s JOIN products p ON s.product_id = p.product_id GROUP BY p.category;"}]}
50
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Get the top 3 product categories by total revenue."},{"role":"assistant","content":"SELECT p.category, SUM(s.amount) AS revenue FROM sales s JOIN products p ON s.product_id = p.product_id GROUP BY p.category ORDER BY revenue DESC LIMIT 3;"}]}
51
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List customer_name and region_name for all sales over 100."},{"role":"assistant","content":"SELECT c.customer_name, r.region_name FROM sales s JOIN customers c ON s.customer_id = c.customer_id JOIN regions r ON s.region_id = r.region_id WHERE s.amount > 100;"}]}
52
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find all products that have never been sold."},{"role":"assistant","content":"SELECT p.product_name FROM products p LEFT JOIN sales s ON p.product_id = s.product_id WHERE s.sale_id IS NULL;"}]}
53
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List regions that had no sales in June 2025."},{"role":"assistant","content":"SELECT r.region_name FROM regions r LEFT JOIN sales s ON r.region_id = s.region_id AND s.sale_date LIKE '2025-06-%' WHERE s.sale_id IS NULL;"}]}
54
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Show a running total of amount over sale_date."},{"role":"assistant","content":"SELECT sale_date, amount, SUM(amount) OVER (ORDER BY sale_date) AS running_total FROM sales;"}]}
55
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Rank customers by total spent."},{"role":"assistant","content":"SELECT customer_id, SUM(amount) AS total_spent, RANK() OVER (ORDER BY SUM(amount) DESC) AS rank FROM sales GROUP BY customer_id;"}]}
56
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Calculate the average sale amount per calendar month."},{"role":"assistant","content":"SELECT SUBSTR(sale_date,1,7) AS month, AVG(amount) FROM sales GROUP BY month;"}]}
57
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Compute each region’s percentage of total sales amount."},{"role":"assistant","content":"SELECT region_id, SUM(amount)*100.0/(SELECT SUM(amount) FROM sales) AS pct_of_total FROM sales GROUP BY region_id;"}]}
58
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"For each product, show its price and the average sale amount."},{"role":"assistant","content":"SELECT p.product_name, p.price, (SELECT AVG(amount) FROM sales WHERE product_id = p.product_id) AS avg_sale FROM products p;"}]}
59
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List customers who spent more than the average customer."},{"role":"assistant","content":"SELECT DISTINCT c.customer_name FROM customers c WHERE (SELECT SUM(amount) FROM sales WHERE customer_id = c.customer_id) > (SELECT AVG(total) FROM (SELECT SUM(amount) AS total FROM sales GROUP BY customer_id));"}]}
60
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Use a CTE to show monthly total sales."},{"role":"assistant","content":"WITH monthly_sales AS (SELECT SUBSTR(sale_date,1,7) AS month, SUM(amount) AS total FROM sales GROUP BY month) SELECT * FROM monthly_sales;"}]}
61
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Use a CTE to find the top 5 products by revenue."},{"role":"assistant","content":"WITH prod_rev AS (SELECT product_id, SUM(amount) AS revenue FROM sales GROUP BY product_id) SELECT p.product_name, pr.revenue FROM prod_rev pr JOIN products p ON pr.product_id = p.product_id ORDER BY pr.revenue DESC LIMIT 5;"}]}
62
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find each customer’s first sale date."},{"role":"assistant","content":"SELECT customer_id, MIN(sale_date) AS first_sale FROM sales GROUP BY customer_id;"}]}
63
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Identify repeat customers (more than one sale)."},{"role":"assistant","content":"SELECT customer_id FROM sales GROUP BY customer_id HAVING COUNT(*) > 1;"}]}
64
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Find customers who purchased both product 1 and product 2."},{"role":"assistant","content":"SELECT DISTINCT s1.customer_id FROM sales s1 JOIN sales s2 ON s1.customer_id = s2.customer_id WHERE s1.product_id = 1 AND s2.product_id = 2;"}]}
65
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"List products sold in every region."},{"role":"assistant","content":"SELECT product_id FROM sales GROUP BY product_id HAVING COUNT(DISTINCT region_id) = (SELECT COUNT(*) FROM regions);"}]}
66
+ {"messages":[{"role":"system","content":"You are a helpful assistant that translates natural language queries into executable SQL queries. Only respond with the SQL query."},{"role":"user","content":"Determine which region has the highest average sale amount."},{"role":"assistant","content":"SELECT region_id FROM sales GROUP BY region_id ORDER BY AVG(amount) DESC LIMIT 1;"}]}
knowledge_base/business_glossary/kpi_definitions.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Business Glossary and KPI Definitions
2
+
3
+ This document provides definitions for key business terms and performance indicators used in our sales analysis.
4
+
5
+ ---
6
+
7
+ ## Key Performance Indicators (KPIs)
8
+
9
+ ### Average Sales
10
+ **Definition:** The average monetary value of a single sales transaction. It is calculated by dividing the total sales amount by the total number of sales transactions over a given period.
11
+ **Calculation:** `SUM(amount) / COUNT(sale_id)` from the `sales` table.
12
+
13
+ ### Sales Volume
14
+ **Definition:** The total number of products sold over a specified period. This indicates the quantity of goods moved.
15
+ **Calculation:** `SUM(quantity)` from the `sales` table.
16
+
17
+ ### Revenue by Region
18
+ **Definition:** The total monetary value of sales attributed to a specific geographical region over a given period.
19
+ **Calculation:** `SUM(amount)` grouped by `region_name` from `sales` and `regions` tables.
20
+
21
+ ---
22
+
23
+ ## Business Terms
24
+
25
+ ### Customer Segment
26
+ **Definition:** A group of customers sharing similar characteristics, often based on demographics, purchasing behavior, or psychographics. (Note: We don't have explicit segments in the current DB, but this is an example of what you'd define).
27
+
28
+ ### Last Quarter
29
+ **Definition:** Refers to the most recently completed three-month period. For example, if the current month is June, the last quarter would typically be January-March. If it's July, the last quarter would be April-June.
30
+
31
+ ---
knowledge_base/schema/data_dictionary.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Dictionary for Sales Database
2
+
3
+ This document describes the tables and columns within the sales database.
4
+
5
+ ---
6
+
7
+ ## Table: `regions`
8
+ **Purpose:** Stores information about geographical sales regions.
9
+
10
+ | Column | Type | Description |
11
+ | :---------- | :------ | :-------------------------------------------- |
12
+ | `region_id` | INTEGER | Unique identifier for each sales region. |
13
+ | `region_name`| TEXT | Name of the sales region (e.g., 'North', 'South'). |
14
+
15
+ ---
16
+
17
+ ## Table: `products`
18
+ **Purpose:** Contains details about products sold.
19
+
20
+ | Column | Type | Description |
21
+ | :------------ | :------ | :---------------------------------------------- |
22
+ | `product_id` | INTEGER | Unique identifier for each product. |
23
+ | `product_name`| TEXT | Name of the product (e.g., 'Laptop', 'Pen Set').|
24
+ | `category` | TEXT | Product category (e.g., 'Electronics', 'Furniture').|
25
+ | `price` | REAL | Standard price of the product. |
26
+
27
+ ---
28
+
29
+ ## Table: `customers`
30
+ **Purpose:** Stores customer information.
31
+
32
+ | Column | Type | Description |
33
+ | :------------ | :------ | :---------------------------------------------- |
34
+ | `customer_id` | INTEGER | Unique identifier for each customer. |
35
+ | `customer_name`| TEXT | Full name of the customer. |
36
+ | `email` | TEXT | Customer's email address, unique per customer. |
37
+ | `region_id` | INTEGER | ID of the region the customer belongs to. Links to `regions.region_id`. |
38
+
39
+ ---
40
+
41
+ ## Table: `sales`
42
+ **Purpose:** Records individual sales transactions.
43
+
44
+ | Column | Type | Description |
45
+ | :------------ | :------ | :---------------------------------------------- |
46
+ | `sale_id` | INTEGER | Unique identifier for each sale transaction. |
47
+ | `product_id` | INTEGER | ID of the product sold. Links to `products.product_id`. |
48
+ | `customer_id` | INTEGER | ID of the customer involved in the sale. Links to `customers.customer_id`. |
49
+ | `region_id` | INTEGER | ID of the region where the sale occurred. This can also be derived from the customer's region. Links to `regions.region_id`. |
50
+ | `sale_date` | TEXT | Date of the sale in 'YYYY-MM-DD' format. |
51
+ | `quantity` | INTEGER | Number of units sold in this transaction. |
52
+ | `amount` | REAL | Total monetary value of the sale (quantity * price, potentially with variations). |
53
+
54
+ ---
55
+
56
+ ## Relationships:
57
+ * `customers.region_id` links to `regions.region_id`
58
+ * `sales.product_id` links to `products.product_id`
59
+ * `sales.customer_id` links to `customers.customer_id`
60
+ * `sales.region_id` links to `regions.region_id`
knowledge_base/schema/sales_schema.sql ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CREATE TABLE regions (
2
+ region_id INTEGER PRIMARY KEY,
3
+ region_name TEXT NOT NULL UNIQUE
4
+ );
5
+
6
+ CREATE TABLE products (
7
+ product_id INTEGER PRIMARY KEY,
8
+ product_name TEXT NOT NULL,
9
+ category TEXT NOT NULL,
10
+ price REAL NOT NULL
11
+ );
12
+
13
+ CREATE TABLE customers (
14
+ customer_id INTEGER PRIMARY KEY,
15
+ customer_name TEXT NOT NULL,
16
+ email TEXT UNIQUE,
17
+ region_id INTEGER,
18
+ FOREIGN KEY (region_id) REFERENCES regions(region_id)
19
+ );
20
+
21
+ CREATE TABLE sales (
22
+ sale_id INTEGER PRIMARY KEY AUTOINCREMENT,
23
+ product_id INTEGER NOT NULL,
24
+ customer_id INTEGER NOT NULL,
25
+ region_id INTEGER NOT NULL,
26
+ sale_date TEXT NOT NULL, -- Stored as YYYY-MM-DD
27
+ quantity INTEGER NOT NULL,
28
+ amount REAL NOT NULL,
29
+ FOREIGN KEY (product_id) REFERENCES products(product_id),
30
+ FOREIGN KEY (customer_id) REFERENCES customers(customer_id),
31
+ FOREIGN KEY (region_id) REFERENCES regions(region_id)
32
+ );
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ llama-index
3
+ llama-index-embeddings-nebius
4
+ llama-index-vector-stores-chroma
5
+ sqlalchemy
6
+ llama-index-llms-nebius
7
+ openai
8
+ tqdm
src/agents/agent_models/__pycache__/models.cpython-313.pyc ADDED
Binary file (973 Bytes). View file
 
src/agents/agent_models/models.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from llama_index.llms.nebius import NebiusLLM
3
+ from openai import OpenAI
4
+ from dotenv import load_dotenv
5
+ load_dotenv()
6
+
7
+ base_agent_model_id = "Qwen/Qwen3-235B-A22B"
8
+ finetuned_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct-LoRa:nl-to-sql-finetuned-jbkN"
9
+
10
+ def get_base_agent_model():
11
+ return NebiusLLM(api_key=os.environ["NEBIUS_API_KEY"], model=base_agent_model_id)
12
+
13
+ def get_finetuned_model():
14
+ return NebiusLLM(api_key=os.environ["NEBIUS_API_KEY"], model=finetuned_model_id)
src/agents/agent_tools/__pycache__/schema_retriever_tool.cpython-313.pyc ADDED
Binary file (5.08 kB). View file
 
src/agents/agent_tools/__pycache__/sql_executor_tool.cpython-313.pyc ADDED
Binary file (3.52 kB). View file
 
src/agents/agent_tools/schema_retriever_tool.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/agents/tools/schema_retriever_tool.py
2
+
3
+ import os
4
+ import logging
5
+ import chromadb
6
+ from llama_index.core.tools import FunctionTool
7
+ from llama_index.embeddings.nebius import NebiusEmbedding
8
+ from llama_index.vector_stores.chroma import ChromaVectorStore
9
+ from llama_index.core import VectorStoreIndex, Settings
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ # Define the path to your ChromaDB for schema.
14
+ current_file_dir = os.path.dirname(os.path.abspath(__file__))
15
+ # Navigate up to the project root for the ChromaDB directory
16
+ # This path is relative to the *tool file*, so three levels up to the root.
17
+ CHROMA_DB_PATH = os.path.join(current_file_dir, '..', '..', '..', 'chroma_db_schema')
18
+ logging.info(f"ChromaDB Schema Path set to: {CHROMA_DB_PATH}")
19
+
20
+ # Initialize NebiusEmbedding
21
+ embed_model_name = "BAAI/bge-en-icl"
22
+ embed_api_base = "https://api.studio.nebius.com/v1/"
23
+
24
+ embeddings = None
25
+ try:
26
+ embeddings = NebiusEmbedding(
27
+ api_key=os.environ.get("NEBIUS_API_KEY"),
28
+ model_name=embed_model_name,
29
+ api_base=embed_api_base
30
+ )
31
+ Settings.embed_model = embeddings
32
+ # Test the embedding model
33
+ _ = embeddings.get_text_embedding("test validation string")
34
+ logging.info("NebiusEmbedding initialized successfully for schema retriever.")
35
+ except Exception as e:
36
+ logging.error(f"Error initializing NebiusEmbedding in schema_retriever_tool: {e}")
37
+ embeddings = None
38
+
39
+ # Set the global embedding model for LlamaIndex if not already set (good practice)
40
+ if embeddings:
41
+ Settings.embed_model = embeddings
42
+
43
+ # Main retrieval function
44
+ def retrieve_schema_context(natural_language_query: str) -> str:
45
+ if embeddings is None:
46
+ return "Error: Embedding model not initialized for schema retrieval. Cannot perform RAG. Please check your Nebius API key and model configuration."
47
+
48
+ try:
49
+ db = chromadb.PersistentClient(path=CHROMA_DB_PATH)
50
+ chroma_collection = db.get_or_create_collection(name="schema_kb")
51
+ logging.info(f"ChromaDB collection 'schema_kb' opened successfully for retrieval.")
52
+
53
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
54
+
55
+ index = VectorStoreIndex.from_vector_store(
56
+ vector_store,
57
+ embed_model=embeddings # Explicitly pass the initialized embeddings
58
+ )
59
+ query_engine = index.as_retriever(similarity_top_k=2)
60
+ retrieved_nodes = query_engine.retrieve(natural_language_query)
61
+
62
+ schema_snippets = [node.get_content() for node in retrieved_nodes]
63
+ if not schema_snippets:
64
+ return "No relevant schema context found for your query. Please rephrase or simplify."
65
+
66
+ return "Retrieved Database Schema Context (relevant to query):\n" + "\n---\n".join(schema_snippets)
67
+
68
+ except Exception as e:
69
+ logging.exception("Error in retrieve_schema_context:")
70
+ return f"Error retrieving schema from RAG: {str(e)}. Ensure ChromaDB is built at {CHROMA_DB_PATH} and embedding model is compatible."
71
+
72
+
73
+ # Exportable tool
74
+ def get_schema_retriever_tool() -> FunctionTool:
75
+ return FunctionTool.from_defaults(
76
+ fn=retrieve_schema_context,
77
+ name="retrieve_schema_context",
78
+ description=(
79
+ "Retrieves relevant database schema information (tables, columns, relationships, descriptions) "
80
+ "from the sales database knowledge base using semantic search (RAG). Always call this first "
81
+ "if you need to understand the schema for SQL generation."
82
+ )
83
+ )
84
+
85
+ # Test entry point
86
+ if __name__ == "__main__":
87
+ print("--- Testing Schema Retriever Tool Implementation (with RAG) ---")
88
+ if not os.environ.get("NEBIUS_API_KEY"):
89
+ print("Warning: NEBIUS_API_KEY not set. Schema retrieval might fail.")
90
+
91
+ schema_tool = get_schema_retriever_tool()
92
+
93
+ queries = [
94
+ "What are the columns in the sales and products tables, and how are they related?",
95
+ "Show me customer names and their regions.",
96
+ "What is the purpose of the regions table?",
97
+ "Tell me about the sales table schema and its purpose."
98
+ ]
99
+
100
+ for query in queries:
101
+ print(f"\nCalling tool with: '{query}'")
102
+ result = schema_tool.call(natural_language_query=query)
103
+ print(f"Result:\n{result}")
src/agents/agent_tools/sql_executor_tool.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import pandas as pd
3
+ import io
4
+ import os
5
+ from llama_index.core.tools import FunctionTool
6
+
7
+ current_dir = os.path.dirname(os.path.abspath(__file__))
8
+ DATABASE_PATH = os.path.join(current_dir, '..', '..', '..', 'data', 'sales_database.db')
9
+
10
+ def execute_sql_query(sql_query: str) -> str:
11
+ """
12
+ Executes a SQL SELECT query against the sales database and returns the results as a formatted string (CSV representation).
13
+
14
+ This tool should only be used to retrieve data using SELECT statements.
15
+ It does not support INSERT, UPDATE, DELETE, or DDL commands for security reasons.
16
+
17
+ Args:
18
+ sql_query (str): The complete and correct SQL SELECT query to execute.
19
+ Do not include semicolons at the end of the query.
20
+
21
+ Returns:
22
+ str: A CSV string representation of the query results, or an error message.
23
+ """
24
+ # Safety check: Only allow SELECT queries
25
+ if not sql_query.strip().upper().startswith("SELECT"):
26
+ return "Error: Only SELECT queries are allowed for security reasons."
27
+
28
+ conn = None
29
+ try:
30
+ conn = sqlite3.connect(DATABASE_PATH)
31
+ df = pd.read_sql_query(sql_query, conn)
32
+
33
+ if df.empty:
34
+ return "Query executed successfully, but no results were found."
35
+
36
+ # Convert DataFrame to a CSV string for LLM consumption
37
+ output = io.StringIO()
38
+ df.to_csv(output, index=False)
39
+ return output.getvalue()
40
+
41
+ except pd.io.sql.DatabaseError as e:
42
+ return f"Database Query Error: {e}"
43
+ except Exception as e:
44
+ return f"An unexpected error occurred during SQL execution: {e}"
45
+ finally:
46
+ if conn:
47
+ conn.close()
48
+
49
+ def get_sql_executor_tool() -> FunctionTool:
50
+ """
51
+ Returns a LlamaIndex FunctionTool for executing SQL SELECT queries.
52
+ """
53
+ return FunctionTool.from_defaults(
54
+ fn=execute_sql_query,
55
+ name="execute_sql_query",
56
+ description=(
57
+ "Executes a SQL SELECT query against the sales database and returns the results. "
58
+ "Use this tool to get actual data to answer the user's question. "
59
+ "Always generate the full, correct SQL query to answer the question before calling this tool. "
60
+ "Only SELECT queries are allowed. "
61
+ "Input should be the complete and correct SQL SELECT query string."
62
+ )
63
+ )
src/agents/nl_sql_agent.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import sys
4
+ from llama_index.llms.openai import OpenAI
5
+ from llama_index.core.agent import ReActAgent
6
+ from agent_tools.sql_executor_tool import get_sql_executor_tool
7
+ from agent_tools.schema_retriever_tool import get_schema_retriever_tool
8
+ from agent_models.models import get_finetuned_model
9
+
10
+ # Configure logging for better visibility into agent's thought process
11
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
12
+ logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
13
+
14
+ class NLSQLAgent:
15
+ def __init__(self, model_name: str, api_key: str, base_url: str):
16
+ """
17
+ Initializes the NL-to-SQL Agent, which translates natural language to SQL, executes it, and provides answers.
18
+
19
+ Args:
20
+ model_name (str): The name/ID of fine-tuned model deployed on Nebius AI.
21
+ api_key (str): Nebius AI API key.
22
+ base_url (str): The base URL for the Nebius AI API endpoint
23
+ """
24
+ self.llm = get_finetuned_model()
25
+ self.system_prompt = (
26
+ "You are an expert SQL assistant for a sales database. Your task is to accurately translate "
27
+ "natural language questions into SQL queries, execute them using the provided tools, and then "
28
+ "provide concise, natural language answers based on the query results. "
29
+ "You MUST strictly follow the Thought-Action-Action Input-Observation format for ALL steps. "
30
+ "Pay extreme attention to the 'Action Input' format, which MUST be a valid JSON dictionary "
31
+ "matching the tool's parameters. **ALWAYS provide ALL required parameters.**\n"
32
+ "Follow these rules strictly:\n"
33
+ "1. **Analyze the query:** Understand the user's intent and identify if schema context is needed.\n"
34
+ "2. **Schema Retrieval (if needed):** If the user's question requires knowledge about the "
35
+ " database structure (tables, columns, relationships) that you might not implicitly know, "
36
+ " **FIRST, ALWAYS call `retrieve_schema_context`** using the user's original query. "
37
+ " The Action Input for `retrieve_schema_context` MUST be a JSON object like: "
38
+ " `{\"natural_language_query\": \"user's exact question here\"}`. "
39
+ " Process the retrieved schema before generating SQL.\n"
40
+ "3. **SQL Generation (Crucial):** Generate a syntactically correct and semantically accurate SQL SELECT query "
41
+ " that directly answers the user's question based on the available schema information (either "
42
+ " from `retrieve_schema_context` or your fine-tuned knowledge). Only use tables and columns "
43
+ " that exist in the schema. Do not include semicolons at the end of the query.\n"
44
+ "4. **SQL Execution (MANDATORY):** **IMMEDIATELY AFTER generating a SQL query, you MUST call the `execute_sql_query` tool to run it.** "
45
+ " The Action Input for `execute_sql_query` MUST be a JSON object like: "
46
+ " `{\"sql_query\": \"your generated SQL query here\"}`.\n"
47
+ "5. **Answer Formulation (Only after SQL Execution and Observation):** Analyze the results from `execute_sql_query` "
48
+ " and formulate a clear, concise, and helpful natural language answer to the user's original question. "
49
+ " If the query returns no results, state that clearly.\n"
50
+ "6. **Error Handling:** If a question cannot be answered from the database, explain why politely. "
51
+ " Do not make up data or generate SQL for non-existent tables/columns.\n"
52
+ "7. **Iterative Refinement:** If a tool call fails or provides unexpected output, consider revising your approach "
53
+ " or informing the user about the issue. Continue the thought-process until a final answer is derived.\n"
54
+ "8. The database schema includes tables such as `regions`, `products`, `customers`, and `sales`. "
55
+ " You should understand their structures and relationships from your training and schema context. "
56
+ " For example: `sales` table likely has columns like `sale_id`, `customer_id`, `product_id`, `date`, `amount`. "
57
+ " `customers` table might have `customer_id`, `customer_name`, `region_id`. `products` table might have `product_id`, `product_name`, `price`. "
58
+ " `regions` table might have `region_id`, `region_name`."
59
+ "Remember, your final output should be an `Answer:` tag with the natural language response, not raw SQL or tool calls unless explicitly asked to show your work."
60
+ )
61
+
62
+ self.tools = [get_schema_retriever_tool(), get_sql_executor_tool()]
63
+
64
+ self.agent = ReActAgent.from_tools(
65
+ llm=self.llm,
66
+ tools=self.tools,
67
+ context=self.system_prompt,
68
+ verbose=True,
69
+ )
70
+
71
+ async def process_query(self, user_query: str) -> str:
72
+ """
73
+ Processes a user's natural language query using the NL-to-SQL agent.
74
+ This method executes the agent's Thought-Action-Observation loop.
75
+
76
+ Args:
77
+ user_query (str): The natural language question from the user.
78
+
79
+ Returns:
80
+ str: The final natural language answer based on SQL execution, or an error/explanation.
81
+ """
82
+ try:
83
+ response_object = await self.agent.achat(user_query)
84
+ return str(response_object)
85
+ except Exception as e:
86
+ logging.error(f"Error in NLSQLAgent.process_query: {e}")
87
+ return f"I encountered an error while processing your request: {e}. Please try again or rephrase."
88
+
89
+ # Example Usage (for testing the NLSQLAgent directly)
90
+ if __name__ == "__main__":
91
+ NEBIUS_API_KEY = os.environ.get("NEBIUS_API_KEY")
92
+ NEBIUS_BASE_URL = "https://api.studio.nebius.com/v1/openai" # Example: Verify with Nebius AI docs
93
+ YOUR_FINETUNED_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct-LoRa:nl-to-sql-finetuned-jbkN" # Replace this
94
+
95
+ if not NEBIUS_API_KEY:
96
+ print("Error: NEBIUS_API_KEY environment variable not set.")
97
+ print("Please set it before running the agent.")
98
+ elif YOUR_FINETUNED_MODEL_ID == "your-finetuned-model-id":
99
+ print("Error: Please replace 'your-finetuned-model-id' with your actual model ID.")
100
+ else:
101
+ nl_sql_agent = NLSQLAgent(
102
+ model_name=YOUR_FINETUNED_MODEL_ID,
103
+ api_key=NEBIUS_API_KEY,
104
+ base_url=NEBIUS_BASE_URL
105
+ )
106
+
107
+ print("\nNL-to-SQL Agent initialized. Ask a question about your sales database.")
108
+
109
+ async def main_loop():
110
+ while True:
111
+ user_input = input("\nYou: ")
112
+ if user_input.lower() in ["exit", "quit"]:
113
+ print("Exiting agent.")
114
+ break
115
+
116
+ response = await nl_sql_agent.process_query(user_input)
117
+ print(f"Agent: {response}")
118
+
119
+ # For multi-turn conversations, ReActAgent handles history automatically.
120
+ # If you need to explicitly reset, you might re-instantiate or call agent.reset()
121
+
122
+ import asyncio
123
+ asyncio.run(main_loop())
src/agents/orchestrator_agent.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ from llama_index.core.agent import ReActAgent
4
+ from agents.models import get_base_agent_model
5
+
6
+ #Logging
7
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
8
+ logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
9
+
10
+ class OrchestratorAgent:
11
+ def __init__(self):
12
+ self.tools = []
13
+ self.llm = get_base_agent_model()
14
+ self.agent = ReActAgent.from_tools(llm=self.llm, tools=self.tools, context=self.system_prompt, verbose=True)
15
+
16
+ async def __call__(self, question: str) -> str:
17
+ response_object = self.agent.chat(question)
18
+ full_response_text = str(response_object)
19
+
20
+ return full_response_text
src/rag_index.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ load_dotenv()
3
+ import os
4
+ import chromadb
5
+ import re
6
+ import logging
7
+ from sqlalchemy import create_engine
8
+ from uuid import uuid4
9
+ from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, Settings
10
+ from llama_index.vector_stores.chroma import ChromaVectorStore
11
+ from llama_index.embeddings.nebius import NebiusEmbedding
12
+ from llama_index.core import SQLDatabase
13
+ from llama_index.core.schema import TextNode
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+
18
+ # --- Configuration ---
19
+ DATABASE_PATH = os.path.join('data', 'sales_database.db')
20
+
21
+ # CORRECTED PATHS based on your clarification:
22
+ # schema is under knowledge_base
23
+ SCHEMA_DIR = os.path.join('knowledge_base', 'schema')
24
+ DATA_DICTIONARY_PATH = os.path.join(SCHEMA_DIR, 'data_dictionary.md')
25
+ SALES_SCHEMA_SQL_PATH = os.path.join(SCHEMA_DIR, 'sales_schema.sql')
26
+
27
+ # kpi_definitions.md is under business_glossary
28
+ KPI_DEFINITIONS_PATH = os.path.join('knowledge_base', 'business_glossary', 'kpi_definitions.md')
29
+
30
+ CHROMA_DB_SCHEMA_PATH = os.path.join('.', 'chroma_db_schema')
31
+ CHROMA_DB_KPI_PATH = os.path.join('.', 'chroma_db_kpi')
32
+
33
+ # Ensure NEBIUS_API_KEY is set
34
+ if "NEBIUS_API_KEY" not in os.environ:
35
+ logging.critical("NEBIUS_API_KEY environment variable not set. Please set it before running this script.")
36
+ raise ValueError("NEBIUS_API_KEY environment variable not set. Please set it before running this script.")
37
+
38
+ # --- Initialize Nebius AI Embedding Model (Global for consistency) ---
39
+ print("\n--- Initializing Nebius AI Embedding Model ---")
40
+ try:
41
+ embed_model = NebiusEmbedding(
42
+ api_key=os.environ["NEBIUS_API_KEY"],
43
+ model_name="BAAI/bge-en-icl", # Consistent with schema_retriever_tool for now
44
+ api_base="https://api.studio.nebius.com/v1/" # Verify this base URL is correct for the model
45
+ )
46
+ test_string = "This is a test string to generate an embedding for diagnostic purposes."
47
+ test_embedding = embed_model.get_text_embedding(test_string)
48
+
49
+ if test_embedding is None or not isinstance(test_embedding, list) or not all(isinstance(x, (int, float)) for x in test_embedding):
50
+ logging.critical(f"FATAL ERROR: NebiusEmbedding returned invalid output for test string. Type: {type(test_embedding)}, Value: {test_embedding[:10] if isinstance(test_embedding, list) else test_embedding}")
51
+ raise ValueError("NebiusEmbedding test failed: returned invalid embedding. Cannot proceed with RAG index creation.")
52
+
53
+ logging.info(f"Nebius AI Embedding Model initialized and tested successfully. Embedding length: {len(test_embedding)}")
54
+
55
+ except Exception as e:
56
+ logging.critical(f"UNRECOVERABLE ERROR during Nebius AI Embedding Model initialization or testing: {e}")
57
+ print(f"\n!!!! UNRECOVERABLE ERROR: {e} !!!!")
58
+ print("Please check your NEBIUS_API_KEY, model_name, and api_base configuration carefully.")
59
+ raise
60
+
61
+ print("--- Nebius AI Embedding Model setup complete. Proceeding to Knowledge Base setup. ---")
62
+
63
+ # --- Helper Function to parse Data Dictionary ---
64
+ def parse_data_dictionary_md(md_path: str) -> dict:
65
+ table_descriptions = {}
66
+ if not os.path.exists(md_path):
67
+ logging.error(f"Data dictionary file not found: {md_path}")
68
+ return {}
69
+
70
+ with open(md_path, 'r', encoding='utf-8') as f:
71
+ content = f.read()
72
+
73
+ table_headers = re.finditer(r'## Table: `(\w+)`\n\*\*Purpose:\*\* (.+?)(?=\n## Table: `|\Z)', content, re.DOTALL)
74
+
75
+ for match in table_headers:
76
+ table_name = match.group(1).strip()
77
+ purpose = match.group(2).strip()
78
+ table_descriptions[table_name] = purpose
79
+ logging.info(f"Parsed table: {table_name}, Purpose length: {len(purpose)}")
80
+
81
+ return table_descriptions
82
+
83
+ # --- Setup for Schema Retriever Agent's Knowledge Base ---
84
+ print("\n--- Setting up Schema Retriever Agent's Knowledge Base (chroma_db_schema) ---")
85
+ try:
86
+ data_dict_descriptions = parse_data_dictionary_md(DATA_DICTIONARY_PATH)
87
+ logging.info(f"Loaded descriptions for {len(data_dict_descriptions)} tables from data_dictionary.md")
88
+
89
+ engine = create_engine(f"sqlite:///{DATABASE_PATH}")
90
+ sql_database = SQLDatabase(engine)
91
+ logging.info(f"Connected to database: {DATABASE_PATH}")
92
+
93
+ all_table_names = sql_database.get_usable_table_names()
94
+
95
+ schema_nodes = []
96
+ if not all_table_names:
97
+ logging.warning("WARNING: No tables found in the database. Schema KB will be empty.")
98
+
99
+ for table_name in all_table_names:
100
+ ddl = sql_database.get_single_table_info(table_name)
101
+ human_description = data_dict_descriptions.get(table_name, "No specific description available for this table.")
102
+
103
+ combined_context = f"Table Name: {table_name}\n" \
104
+ f"Description: {human_description}\n" \
105
+ f"Table Schema (DDL):\n{ddl if ddl else ''}"
106
+
107
+ node_embedding = embed_model.get_text_embedding(combined_context)
108
+ if node_embedding is None:
109
+ raise ValueError(f"Failed to generate embedding for schema table: {table_name}")
110
+
111
+ # Adding a simple metadata dictionary
112
+ schema_nodes.append(TextNode(text=combined_context, embedding=node_embedding, id_=table_name, metadata={"table_name": table_name, "source": "data_dictionary"}))
113
+
114
+ chroma_client_schema = chromadb.PersistentClient(path=CHROMA_DB_SCHEMA_PATH)
115
+ chroma_collection_schema = chroma_client_schema.get_or_create_collection(name="schema_kb")
116
+
117
+ if chroma_collection_schema.count() > 0:
118
+ logging.info(f"Clearing {chroma_collection_schema.count()} existing items from schema_kb before re-indexing.")
119
+ chroma_collection_schema.delete(ids=[id_ for id_ in chroma_collection_schema.get()['ids']])
120
+
121
+
122
+ logging.info(f"Adding {len(schema_nodes)} schema nodes to ChromaDB directly...")
123
+ if schema_nodes:
124
+ chroma_collection_schema.add(
125
+ documents=[node.text for node in schema_nodes],
126
+ embeddings=[node.embedding for node in schema_nodes],
127
+ # Pass non-empty metadata dict for each node
128
+ metadatas=[node.metadata for node in schema_nodes],
129
+ ids=[node.id_ for node in schema_nodes]
130
+ )
131
+ logging.info(f"Schema knowledge base indexed and persisted to {CHROMA_DB_SCHEMA_PATH}")
132
+ except Exception as e:
133
+ logging.exception("Error setting up Schema KB:")
134
+ print(f"Error setting up Schema KB: {e}")
135
+ print("Please ensure your database is created and accessible and data_dictionary.md is correctly formatted.")
136
+ print("If the error persists, there might be a deeper compatibility issue. See above for more detailed embedding checks.")
137
+ raise
138
+
139
+
140
+ # --- Setup for KPI Answering Agent's Knowledge Base (kpi_definitions.md) ---
141
+ print("\n--- Setting up KPI Answering Agent's Knowledge Base (chroma_db_kpi) ---")
142
+ try:
143
+ kpi_docs = SimpleDirectoryReader(input_files=[KPI_DEFINITIONS_PATH]).load_data()
144
+ logging.info(f"Loaded {len(kpi_docs)} documents for KPI Agent.")
145
+
146
+ kpi_nodes = []
147
+ for doc in kpi_docs:
148
+ node_embedding = embed_model.get_text_embedding(doc.get_content())
149
+ if node_embedding is None:
150
+ raise ValueError(f"Failed to generate embedding for KPI document: {doc.id_}")
151
+ # Adding a simple metadata dictionary
152
+ kpi_nodes.append(TextNode(text=doc.get_content(), embedding=node_embedding, id_=str(uuid4()), metadata={"source_file": os.path.basename(KPI_DEFINITIONS_PATH), "doc_type": "kpi_definition"}))
153
+
154
+ chroma_client_kpi = chromadb.PersistentClient(path=CHROMA_DB_KPI_PATH)
155
+ chroma_collection_kpi = chroma_client_kpi.get_or_create_collection(name="kpi_kb")
156
+
157
+ if chroma_collection_kpi.count() > 0:
158
+ logging.info(f"Clearing {chroma_collection_kpi.count()} existing items from kpi_kb before re-indexing.")
159
+ chroma_collection_kpi.delete(ids=[id_ for id_ in chroma_collection_kpi.get()['ids']])
160
+
161
+ logging.info(f"Adding {len(kpi_nodes)} KPI nodes to ChromaDB directly...")
162
+ if kpi_nodes:
163
+ chroma_collection_kpi.add(
164
+ documents=[node.text for node in kpi_nodes],
165
+ embeddings=[node.embedding for node in kpi_nodes],
166
+ # Pass non-empty metadata dict for each node
167
+ metadatas=[node.metadata for node in kpi_nodes],
168
+ ids=[node.id_ for node in kpi_nodes]
169
+ )
170
+ logging.info(f"KPI knowledge base indexed and persisted to {CHROMA_DB_KPI_PATH}")
171
+ except Exception as e:
172
+ logging.exception("Error setting up KPI KB:")
173
+ print(f"Error setting up KPI KB: {e}")
174
+ print("Please ensure your 'kpi_definitions.md' file exists and contains valid text content.")
175
+ print("If the error persists, there might be a deeper compatibility issue. See above for more detailed embedding checks.")
176
+ raise
177
+
178
+
179
+ print("\nKnowledge base setup complete for RAG agents!")
180
+ print(f"Indices are saved in '{CHROMA_DB_SCHEMA_PATH}' and '{CHROMA_DB_KPI_PATH}' directories.")
src/setup_database.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import os
3
+ import pandas as pd
4
+ from datetime import datetime, timedelta
5
+ import random
6
+ from tqdm import tqdm
7
+
8
+ DATABASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data')
9
+ DATABASE_PATH = os.path.join(DATABASE_DIR, 'sales_database.db')
10
+
11
+ def create_database():
12
+ """Creates the SQLite database file if it doesn't exist."""
13
+ os.makedirs(DATABASE_DIR, exist_ok=True)
14
+ conn = None
15
+ try:
16
+ conn = sqlite3.connect(DATABASE_PATH)
17
+ print(f"Database created successfully at {DATABASE_PATH}")
18
+ except sqlite3.Error as e:
19
+ print(f"Error creating database: {e}")
20
+ finally:
21
+ if conn:
22
+ conn.close()
23
+
24
+ def create_tables(conn):
25
+ """Creates necessary tables in the database."""
26
+ cursor = conn.cursor()
27
+
28
+ # Regions Table
29
+ cursor.execute('''
30
+ CREATE TABLE IF NOT EXISTS regions (
31
+ region_id INTEGER PRIMARY KEY,
32
+ region_name TEXT NOT NULL UNIQUE
33
+ );
34
+ ''')
35
+
36
+ # Products Table
37
+ cursor.execute('''
38
+ CREATE TABLE IF NOT EXISTS products (
39
+ product_id INTEGER PRIMARY KEY,
40
+ product_name TEXT NOT NULL,
41
+ category TEXT NOT NULL,
42
+ price REAL NOT NULL
43
+ );
44
+ ''')
45
+
46
+ # Customers Table
47
+ cursor.execute('''
48
+ CREATE TABLE IF NOT EXISTS customers (
49
+ customer_id INTEGER PRIMARY KEY,
50
+ customer_name TEXT NOT NULL,
51
+ email TEXT UNIQUE,
52
+ region_id INTEGER,
53
+ FOREIGN KEY (region_id) REFERENCES regions(region_id)
54
+ );
55
+ ''')
56
+
57
+ # Sales Table
58
+ cursor.execute('''
59
+ CREATE TABLE IF NOT EXISTS sales (
60
+ sale_id INTEGER PRIMARY KEY AUTOINCREMENT,
61
+ product_id INTEGER NOT NULL,
62
+ customer_id INTEGER NOT NULL,
63
+ region_id INTEGER NOT NULL,
64
+ sale_date TEXT NOT NULL, -- Stored as YYYY-MM-DD
65
+ quantity INTEGER NOT NULL,
66
+ amount REAL NOT NULL,
67
+ FOREIGN KEY (product_id) REFERENCES products(product_id),
68
+ FOREIGN KEY (customer_id) REFERENCES customers(customer_id),
69
+ FOREIGN KEY (region_id) REFERENCES regions(region_id)
70
+ );
71
+ ''')
72
+ conn.commit()
73
+ print("Tables created successfully.")
74
+
75
+ def insert_dummy_data(conn):
76
+ """Inserts sample data into the tables."""
77
+ cursor = conn.cursor()
78
+
79
+ # Clear existing data before inserting new, to ensure fresh start
80
+ print("Clearing existing data...")
81
+ cursor.execute("DELETE FROM sales")
82
+ cursor.execute("DELETE FROM customers")
83
+ cursor.execute("DELETE FROM products")
84
+ cursor.execute("DELETE FROM regions")
85
+ conn.commit()
86
+ print("Existing data cleared.")
87
+
88
+
89
+ # Insert Regions
90
+ regions_data = [
91
+ (1, 'North'),
92
+ (2, 'South'),
93
+ (3, 'East'),
94
+ (4, 'West'),
95
+ (5, 'Central') # Added a new region
96
+ ]
97
+ cursor.executemany("INSERT OR IGNORE INTO regions (region_id, region_name) VALUES (?, ?)", regions_data)
98
+ print(f"Inserted {len(regions_data)} regions.")
99
+
100
+ # Insert Products (More variety, including sub-categories implicitly via name)
101
+ products_data = [
102
+ (101, 'Laptop Basic', 'Electronics', 800.00),
103
+ (102, 'Laptop Pro', 'Electronics', 1500.00),
104
+ (103, 'Smartphone X', 'Electronics', 700.00),
105
+ (104, 'Smartwatch', 'Electronics', 250.00),
106
+ (105, 'Headphones Noise-Cancelling', 'Electronics', 180.00),
107
+ (201, 'Office Desk Standard', 'Furniture', 200.00),
108
+ (202, 'Ergonomic Chair', 'Furniture', 350.00),
109
+ (203, 'Bookshelf Small', 'Furniture', 80.00),
110
+ (301, 'Notebook A4', 'Stationery', 12.00),
111
+ (302, 'Premium Pen Set', 'Stationery', 30.00),
112
+ (303, 'Art Supplies Kit', 'Stationery', 45.00),
113
+ (401, 'Running Shoes', 'Apparel', 110.00), # New Category
114
+ (402, 'T-Shirt Cotton', 'Apparel', 25.00),
115
+ ]
116
+ cursor.executemany("INSERT OR IGNORE INTO products (product_id, product_name, category, price) VALUES (?, ?, ?, ?)", products_data)
117
+ print(f"Inserted {len(products_data)} products.")
118
+
119
+ # Insert Customers (More customers)
120
+ customer_names = [
121
+ 'Alice Smith', 'Bob Johnson', 'Charlie Brown', 'Diana Prince', 'Eve Adams',
122
+ 'Frank White', 'Grace Lee', 'Henry Green', 'Ivy Chen', 'Jack Taylor',
123
+ 'Karen Black', 'Liam Davis', 'Mia Wilson', 'Noah Martinez', 'Olivia Garcia',
124
+ 'Peter Rodriguez', 'Quinn Miller', 'Rachel Jones', 'Sam Hernandez', 'Tina Clark'
125
+ ] # 20 customers
126
+ customers_data = []
127
+ for i, name in enumerate(customer_names):
128
+ # Assign customers to regions, trying to balance or bias
129
+ region_id = random.choice([1, 1, 2, 3, 4, 5]) # North region (ID 1) slightly more customers
130
+ customers_data.append((i + 1, name, f"{name.replace(' ', '').lower()}@example.com", region_id))
131
+ cursor.executemany("INSERT OR IGNORE INTO customers (customer_id, customer_name, email, region_id) VALUES (?, ?, ?, ?)", customers_data)
132
+ print(f"Inserted {len(customers_data)} customers.")
133
+
134
+
135
+ # --- Sales Data Generation Configuration (MOVED TO HERE) ---
136
+ NUM_SALES_RECORDS = 15000
137
+
138
+ end_date = datetime.now()
139
+
140
+ # Ensure a good chunk of sales are recent
141
+ recent_sales_period_days = 60 # Sales within last 2 months
142
+ long_term_sales_period_days = 730 # Sales within last 2 years
143
+
144
+ # Define how many sales should be recent vs. long-term
145
+ recent_sales_proportion = 0.3 # 30% of sales in the last 2 months
146
+ num_recent_sales = int(NUM_SALES_RECORDS * recent_sales_proportion)
147
+ num_long_term_sales = NUM_SALES_RECORDS - num_recent_sales
148
+
149
+ products_df = pd.DataFrame(products_data, columns=['product_id', 'product_name', 'category', 'price'])
150
+ customers_df = pd.DataFrame(customers_data, columns=['customer_id', 'customer_name', 'email', 'region_id'])
151
+ regions_df = pd.DataFrame(regions_data, columns=['region_id', 'region_name'])
152
+
153
+ sales_records = []
154
+
155
+ # Define some regional biases (e.g., North sells more)
156
+ region_weights = {1: 0.30, 2: 0.25, 3: 0.20, 4: 0.15, 5: 0.10}
157
+ weighted_regions = [r for r, weight in region_weights.items() for _ in range(int(weight * 100))]
158
+
159
+ # Define some category biases (e.g., Electronics sells more)
160
+ category_weights = {'Electronics': 0.40, 'Furniture': 0.30, 'Stationery': 0.20, 'Apparel': 0.10}
161
+ all_weighted_products = []
162
+ for cat, weight in category_weights.items():
163
+ cat_products = products_df[products_df['category'] == cat]['product_id'].tolist()
164
+ all_weighted_products.extend(cat_products * int(weight * 100 / (len(cat_products) if cat_products else 1)))
165
+
166
+ # Simulate some seasonality (e.g., more sales towards end of year)
167
+ def get_seasonal_factor(date):
168
+ month = date.month
169
+ if month in [11, 12]: return random.uniform(1.1, 1.3)
170
+ elif month in [1, 2]: return random.uniform(0.8, 1.0)
171
+ else: return random.uniform(0.9, 1.1)
172
+
173
+ # --- End Sales Data Generation Configuration ---
174
+
175
+
176
+ # Generate recent sales
177
+ for _ in tqdm(range(num_recent_sales), desc="Generating Recent Sales Data"):
178
+ current_date = end_date - timedelta(days=random.randint(0, recent_sales_period_days - 1))
179
+
180
+ chosen_region_id = random.choice(weighted_regions)
181
+ region_customers = customers_df[customers_df['region_id'] == chosen_region_id]
182
+ customer = region_customers.sample(1).iloc[0] if not region_customers.empty else customers_df.sample(1).iloc[0]
183
+
184
+ product = products_df[products_df['product_id'] == random.choice(all_weighted_products)].iloc[0]
185
+
186
+ quantity = random.randint(1, 5)
187
+ seasonal_factor = get_seasonal_factor(current_date)
188
+ amount = round(quantity * product['price'] * (1 + random.uniform(-0.1, 0.1)) * seasonal_factor, 2)
189
+
190
+ sales_records.append((
191
+ int(product['product_id']), # Explicitly cast to int
192
+ int(customer['customer_id']), # Explicitly cast to int
193
+ int(customer['region_id']), # Explicitly cast to int
194
+ current_date.strftime('%Y-%m-%d'),
195
+ quantity,
196
+ amount
197
+ ))
198
+
199
+ # Generate long-term sales (apply the same fix here)
200
+ for _ in tqdm(range(num_long_term_sales), desc="Generating Long-Term Sales Data"):
201
+ current_date = end_date - timedelta(days=random.randint(recent_sales_period_days, long_term_sales_period_days - 1))
202
+
203
+ chosen_region_id = random.choice(weighted_regions)
204
+ region_customers = customers_df[customers_df['region_id'] == chosen_region_id]
205
+ customer = region_customers.sample(1).iloc[0] if not region_customers.empty else customers_df.sample(1).iloc[0]
206
+
207
+ product = products_df[products_df['product_id'] == random.choice(all_weighted_products)].iloc[0]
208
+
209
+ quantity = random.randint(1, 5)
210
+ seasonal_factor = get_seasonal_factor(current_date)
211
+ amount = round(quantity * product['price'] * (1 + random.uniform(-0.1, 0.1)) * seasonal_factor, 2)
212
+
213
+ sales_records.append((
214
+ int(product['product_id']), # Explicitly cast to int
215
+ int(customer['customer_id']), # Explicitly cast to int
216
+ int(customer['region_id']), # Explicitly cast to int
217
+ current_date.strftime('%Y-%m-%d'),
218
+ quantity,
219
+ amount
220
+ ))
221
+
222
+ cursor.executemany("INSERT INTO sales (product_id, customer_id, region_id, sale_date, quantity, amount) VALUES (?, ?, ?, ?, ?, ?)", sales_records)
223
+ conn.commit()
224
+ print(f"Inserted {len(sales_records)} dummy sales records.")
225
+ print("Dummy data inserted successfully.")
226
+
227
+ def main():
228
+ create_database()
229
+ conn = None
230
+ try:
231
+ conn = sqlite3.connect(DATABASE_PATH)
232
+ create_tables(conn)
233
+ insert_dummy_data(conn)
234
+ except sqlite3.Error as e:
235
+ print(f"An error occurred: {e}")
236
+ finally:
237
+ if conn:
238
+ conn.close()
239
+ print("Database connection closed.")
240
+
241
+ if __name__ == "__main__":
242
+ main()