aifeifei798's picture
Upload 8 files
3d40769 verified
raw
history blame
7.39 kB
# database/setup.py
import os
import sqlite3
import json
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
from llama_index.embeddings.google import GooglePairedEmbeddings
# 导入你的工具注册表
from tools.tool_registry import get_all_tools
# --- 配置持久化路径 ---
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
SQLITE_DB_PATH = os.path.join(DATA_DIR, "tools.metadata.db")
MILVUS_DATA_PATH = os.path.join(
DATA_DIR, "milvus_data.db"
) # Milvus Lite 将数据存在一个文件中
# --- Milvus Lite 配置 ---
MILVUS_COLLECTION_NAME = "tool_embeddings"
EMBEDDING_DIM = 768 # Google's text-embedding-004 model dimension
# --- 全局变量,避免重复初始化 ---
_db_initialized = False
_milvus_initialized = False
def initialize_system():
"""
系统的主初始化函数。
它会创建目录、设置数据库和向量库,并加载工具。
这个函数是幂等的,即多次调用不会产生副作用。
"""
global _db_initialized, _milvus_initialized
print("--- 开始系统初始化 ---")
# 1. 创建数据目录
os.makedirs(DATA_DIR, exist_ok=True)
# 2. 初始化SQLite数据库
if not _db_initialized:
_init_sqlite_db()
_db_initialized = True
# 3. 初始化Milvus Lite向量数据库
if not _milvus_initialized:
milvus_client = _init_milvus_lite()
_milvus_initialized = True
else:
milvus_client = MilvusClient(uri=MILVUS_DATA_PATH)
# 4. 获取所有工具定义
all_tools_definitions = get_all_tools()
# 5. 将工具元数据同步到SQLite
_sync_tools_to_sqlite(all_tools_definitions)
# 6. 将工具描述的嵌入同步到Milvus Lite
_sync_tool_embeddings_to_milvus(milvus_client)
# 7. 从LlamaIndex创建工具推荐器 (在这里创建并返回)
from core.tool_recommender import LlamaIndexToolRecommender
tool_recommender = LlamaIndexToolRecommender(
milvus_client=milvus_client, sqlite_db_path=SQLITE_DB_PATH
)
print("--- 系统初始化完成 ---")
return all_tools_definitions, tool_recommender
def _init_sqlite_db():
"""初始化SQLite数据库并创建表。"""
print(f"SQLite DB 路径: {SQLITE_DB_PATH}")
with sqlite3.connect(SQLITE_DB_PATH) as conn:
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS tools (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
description TEXT NOT NULL,
parameters TEXT NOT NULL -- 存储JSON字符串
)
"""
)
conn.commit()
print("SQLite DB 表已确认存在。")
def _init_milvus_lite():
"""初始化Milvus Lite并创建集合和索引。"""
print(f"Milvus Lite 数据路径: {MILVUS_DATA_PATH}")
client = MilvusClient(uri=MILVUS_DATA_PATH)
if not client.has_collection(collection_name=MILVUS_COLLECTION_NAME):
print(f"Milvus集合 '{MILVUS_COLLECTION_NAME}' 不存在,正在创建...")
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
FieldSchema(
name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM
),
]
schema = CollectionSchema(fields, description="Tool embedding collection")
client.create_collection(collection_name=MILVUS_COLLECTION_NAME, schema=schema)
index_params = client.prepare_index_params()
index_params.add_index(
field_name="embedding",
index_type="AUTOINDEX", # 让Milvus自动选择最佳索引
metric_type="L2",
)
client.create_index(
collection_name=MILVUS_COLLECTION_NAME, index_params=index_params
)
print("Milvus集合和索引创建完成。")
else:
print(f"Milvus集合 '{MILVUS_COLLECTION_NAME}' 已存在。")
# 确保集合已加载到内存中以供搜索
client.load_collection(collection_name=MILVUS_COLLECTION_NAME)
return client
def _sync_tools_to_sqlite(tools_definitions):
"""将工具定义同步到SQLite数据库。"""
print("正在同步工具元数据到SQLite...")
with sqlite3.connect(SQLITE_DB_PATH) as conn:
cursor = conn.cursor()
for tool in tools_definitions:
cursor.execute("SELECT id FROM tools WHERE name = ?", (tool.name,))
if cursor.fetchone() is None:
# 工具不存在,插入新工具
cursor.execute(
"INSERT INTO tools (name, description, parameters) VALUES (?, ?, ?)",
(tool.name, tool.description, json.dumps(tool.args)),
)
print(f" - 新增工具到SQLite: {tool.name}")
conn.commit()
print("SQLite同步完成。")
def _sync_tool_embeddings_to_milvus(milvus_client):
"""计算并同步工具描述的嵌入到Milvus Lite。"""
print("正在同步工具嵌入到Milvus...")
# 1. 从SQLite获取所有工具
with sqlite3.connect(SQLITE_DB_PATH) as conn:
cursor = conn.cursor()
cursor.execute("SELECT id, description FROM tools")
all_tools_in_db = cursor.fetchall()
# 2. 获取Milvus中已存在的工具ID
try:
existing_milvus_ids_raw = milvus_client.query(
collection_name=MILVUS_COLLECTION_NAME,
filter="id > 0",
output_fields=["id"],
)
existing_milvus_ids = {item["id"] for item in existing_milvus_ids_raw}
except Exception:
existing_milvus_ids = set()
# 3. 找出需要计算嵌入的新工具
new_tools_to_embed = [
(tool_id, description)
for tool_id, description in all_tools_in_db
if tool_id not in existing_milvus_ids
]
if not new_tools_to_embed:
print("所有工具嵌入已是最新,无需同步。")
return
print(f"发现 {len(new_tools_to_embed)} 个新工具需要生成嵌入...")
# 4. 初始化嵌入模型
try:
# 确保你的API Key已在环境中设置
embed_model = GooglePairedEmbeddings(
model_name="models/text-embedding-004",
task_type="retrieval_document", # 用于存储的文档嵌入
)
except Exception as e:
print(f"错误:无法初始化Google嵌入模型。请检查API Key。 - {e}")
return
# 5. 生成嵌入并准备插入
tool_ids_to_insert = [tool[0] for tool in new_tools_to_embed]
descriptions_to_embed = [tool[1] for tool in new_tools_to_embed]
embeddings = embed_model.get_text_embedding_batch(
descriptions_to_embed, show_progress=True
)
data_to_insert = [
{"id": tool_id, "embedding": embedding}
for tool_id, embedding in zip(tool_ids_to_insert, embeddings)
]
# 6. 插入到Milvus
milvus_client.insert(collection_name=MILVUS_COLLECTION_NAME, data=data_to_insert)
milvus_client.flush([MILVUS_COLLECTION_NAME]) # 确保数据写入
print(f"成功将 {len(data_to_insert)} 个新嵌入插入到Milvus。")