|
import os |
|
import logging |
|
import chromadb |
|
from llama_index.core.tools import FunctionTool |
|
from llama_index.embeddings.nebius import NebiusEmbedding |
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
from llama_index.core import VectorStoreIndex, Settings |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
current_file_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
CHROMA_DB_PATH = os.path.join(current_file_dir, '..', '..', '..', 'chroma_db_schema') |
|
logging.info(f"ChromaDB Schema Path set to: {CHROMA_DB_PATH}") |
|
|
|
|
|
embed_model_name = "BAAI/bge-en-icl" |
|
embed_api_base = "https://api.studio.nebius.com/v1/" |
|
|
|
embeddings = None |
|
try: |
|
embeddings = NebiusEmbedding( |
|
api_key=os.environ.get("NEBIUS_API_KEY"), |
|
model_name=embed_model_name, |
|
api_base=embed_api_base |
|
) |
|
Settings.embed_model = embeddings |
|
|
|
_ = embeddings.get_text_embedding("test validation string") |
|
logging.info("NebiusEmbedding initialized successfully for schema retriever.") |
|
except Exception as e: |
|
logging.error(f"Error initializing NebiusEmbedding in schema_retriever_tool: {e}") |
|
embeddings = None |
|
|
|
|
|
if embeddings: |
|
Settings.embed_model = embeddings |
|
|
|
|
|
def retrieve_schema_context(natural_language_query: str) -> str: |
|
if embeddings is None: |
|
return "Error: Embedding model not initialized for schema retrieval. Cannot perform RAG. Please check your Nebius API key and model configuration." |
|
|
|
try: |
|
db = chromadb.PersistentClient(path=CHROMA_DB_PATH) |
|
chroma_collection = db.get_or_create_collection(name="schema_kb") |
|
logging.info(f"ChromaDB collection 'schema_kb' opened successfully for retrieval.") |
|
|
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
|
|
index = VectorStoreIndex.from_vector_store( |
|
vector_store, |
|
embed_model=embeddings |
|
) |
|
query_engine = index.as_retriever(similarity_top_k=2) |
|
retrieved_nodes = query_engine.retrieve(natural_language_query) |
|
|
|
schema_snippets = [node.get_content() for node in retrieved_nodes] |
|
if not schema_snippets: |
|
return "No relevant schema context found for your query. Please rephrase or simplify." |
|
|
|
return "Retrieved Database Schema Context (relevant to query):\n" + "\n---\n".join(schema_snippets) |
|
|
|
except Exception as e: |
|
logging.exception("Error in retrieve_schema_context:") |
|
return f"Error retrieving schema from RAG: {str(e)}. Ensure ChromaDB is built at {CHROMA_DB_PATH} and embedding model is compatible." |
|
|
|
|
|
|
|
def get_schema_retriever_tool() -> FunctionTool: |
|
return FunctionTool.from_defaults( |
|
fn=retrieve_schema_context, |
|
name="retrieve_schema_context", |
|
description=( |
|
"Retrieves relevant database schema information (tables, columns, relationships, descriptions) " |
|
"from the sales database knowledge base using semantic search (RAG). Always call this first " |
|
"if you need to understand the schema for SQL generation." |
|
) |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
print("--- Testing Schema Retriever Tool Implementation (with RAG) ---") |
|
if not os.environ.get("NEBIUS_API_KEY"): |
|
print("Warning: NEBIUS_API_KEY not set. Schema retrieval might fail.") |
|
|
|
schema_tool = get_schema_retriever_tool() |
|
|
|
queries = [ |
|
"What are the columns in the sales and products tables, and how are they related?", |
|
"Show me customer names and their regions.", |
|
"What is the purpose of the regions table?", |
|
"Tell me about the sales table schema and its purpose." |
|
] |
|
|
|
for query in queries: |
|
print(f"\nCalling tool with: '{query}'") |
|
result = schema_tool.call(natural_language_query=query) |
|
print(f"Result:\n{result}") |