Merge pull request #393 from partoneplay/main
Browse files- examples/lightrag_ollama_neo4j_milvus_demo.py +51 -0
- lightrag/kg/milvus_impl.py +88 -0
- lightrag/lightrag.py +3 -0
- requirements.txt +1 -0
examples/lightrag_ollama_neo4j_milvus_demo.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from lightrag import LightRAG, QueryParam
|
3 |
+
from lightrag.llm import ollama_model_complete, ollama_embed
|
4 |
+
from lightrag.utils import EmbeddingFunc
|
5 |
+
|
6 |
+
# WorkingDir
|
7 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
8 |
+
WORKING_DIR = os.path.join(ROOT_DIR, "myKG")
|
9 |
+
if not os.path.exists(WORKING_DIR):
|
10 |
+
os.mkdir(WORKING_DIR)
|
11 |
+
print(f"WorkingDir: {WORKING_DIR}")
|
12 |
+
|
13 |
+
# neo4j
|
14 |
+
BATCH_SIZE_NODES = 500
|
15 |
+
BATCH_SIZE_EDGES = 100
|
16 |
+
os.environ["NEO4J_URI"] = "bolt://localhost:7687"
|
17 |
+
os.environ["NEO4J_USERNAME"] = "neo4j"
|
18 |
+
os.environ["NEO4J_PASSWORD"] = "neo4j"
|
19 |
+
|
20 |
+
# milvus
|
21 |
+
os.environ["MILVUS_URI"] = "http://localhost:19530"
|
22 |
+
os.environ["MILVUS_USER"] = "root"
|
23 |
+
os.environ["MILVUS_PASSWORD"] = "root"
|
24 |
+
os.environ["MILVUS_DB_NAME"] = "lightrag"
|
25 |
+
|
26 |
+
|
27 |
+
rag = LightRAG(
|
28 |
+
working_dir=WORKING_DIR,
|
29 |
+
llm_model_func=ollama_model_complete,
|
30 |
+
llm_model_name="qwen2.5:14b",
|
31 |
+
llm_model_max_async=4,
|
32 |
+
llm_model_max_token_size=32768,
|
33 |
+
llm_model_kwargs={"host": "http://127.0.0.1:11434", "options": {"num_ctx": 32768}},
|
34 |
+
embedding_func=EmbeddingFunc(
|
35 |
+
embedding_dim=1024,
|
36 |
+
max_token_size=8192,
|
37 |
+
func=lambda texts: ollama_embed(
|
38 |
+
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
|
39 |
+
),
|
40 |
+
),
|
41 |
+
graph_storage="Neo4JStorage",
|
42 |
+
vector_storage="MilvusVectorDBStorge",
|
43 |
+
)
|
44 |
+
|
45 |
+
file = "./book.txt"
|
46 |
+
with open(file, "r") as f:
|
47 |
+
rag.insert(f.read())
|
48 |
+
|
49 |
+
print(
|
50 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
51 |
+
)
|
lightrag/kg/milvus_impl.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
4 |
+
from dataclasses import dataclass
|
5 |
+
import numpy as np
|
6 |
+
from lightrag.utils import logger
|
7 |
+
from ..base import BaseVectorStorage
|
8 |
+
|
9 |
+
from pymilvus import MilvusClient
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class MilvusVectorDBStorge(BaseVectorStorage):
|
14 |
+
@staticmethod
|
15 |
+
def create_collection_if_not_exist(
|
16 |
+
client: MilvusClient, collection_name: str, **kwargs
|
17 |
+
):
|
18 |
+
if client.has_collection(collection_name):
|
19 |
+
return
|
20 |
+
client.create_collection(
|
21 |
+
collection_name, max_length=64, id_type="string", **kwargs
|
22 |
+
)
|
23 |
+
|
24 |
+
def __post_init__(self):
|
25 |
+
self._client = MilvusClient(
|
26 |
+
uri=os.environ.get(
|
27 |
+
"MILVUS_URI",
|
28 |
+
os.path.join(self.global_config["working_dir"], "milvus_lite.db"),
|
29 |
+
),
|
30 |
+
user=os.environ.get("MILVUS_USER", ""),
|
31 |
+
password=os.environ.get("MILVUS_PASSWORD", ""),
|
32 |
+
token=os.environ.get("MILVUS_TOKEN", ""),
|
33 |
+
db_name=os.environ.get("MILVUS_DB_NAME", ""),
|
34 |
+
)
|
35 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
36 |
+
MilvusVectorDBStorge.create_collection_if_not_exist(
|
37 |
+
self._client,
|
38 |
+
self.namespace,
|
39 |
+
dimension=self.embedding_func.embedding_dim,
|
40 |
+
)
|
41 |
+
|
42 |
+
async def upsert(self, data: dict[str, dict]):
|
43 |
+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
44 |
+
if not len(data):
|
45 |
+
logger.warning("You insert an empty data to vector DB")
|
46 |
+
return []
|
47 |
+
list_data = [
|
48 |
+
{
|
49 |
+
"id": k,
|
50 |
+
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
51 |
+
}
|
52 |
+
for k, v in data.items()
|
53 |
+
]
|
54 |
+
contents = [v["content"] for v in data.values()]
|
55 |
+
batches = [
|
56 |
+
contents[i : i + self._max_batch_size]
|
57 |
+
for i in range(0, len(contents), self._max_batch_size)
|
58 |
+
]
|
59 |
+
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
60 |
+
embeddings_list = []
|
61 |
+
for f in tqdm_async(
|
62 |
+
asyncio.as_completed(embedding_tasks),
|
63 |
+
total=len(embedding_tasks),
|
64 |
+
desc="Generating embeddings",
|
65 |
+
unit="batch",
|
66 |
+
):
|
67 |
+
embeddings = await f
|
68 |
+
embeddings_list.append(embeddings)
|
69 |
+
embeddings = np.concatenate(embeddings_list)
|
70 |
+
for i, d in enumerate(list_data):
|
71 |
+
d["vector"] = embeddings[i]
|
72 |
+
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
73 |
+
return results
|
74 |
+
|
75 |
+
async def query(self, query, top_k=5):
|
76 |
+
embedding = await self.embedding_func([query])
|
77 |
+
results = self._client.search(
|
78 |
+
collection_name=self.namespace,
|
79 |
+
data=embedding,
|
80 |
+
limit=top_k,
|
81 |
+
output_fields=list(self.meta_fields),
|
82 |
+
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
|
83 |
+
)
|
84 |
+
print(results)
|
85 |
+
return [
|
86 |
+
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
|
87 |
+
for dp in results[0]
|
88 |
+
]
|
lightrag/lightrag.py
CHANGED
@@ -44,6 +44,8 @@ from .kg.neo4j_impl import Neo4JStorage
|
|
44 |
|
45 |
from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
|
46 |
|
|
|
|
|
47 |
# future KG integrations
|
48 |
|
49 |
# from .kg.ArangoDB_impl import (
|
@@ -228,6 +230,7 @@ class LightRAG:
|
|
228 |
# vector storage
|
229 |
"NanoVectorDBStorage": NanoVectorDBStorage,
|
230 |
"OracleVectorDBStorage": OracleVectorDBStorage,
|
|
|
231 |
# graph storage
|
232 |
"NetworkXStorage": NetworkXStorage,
|
233 |
"Neo4JStorage": Neo4JStorage,
|
|
|
44 |
|
45 |
from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
|
46 |
|
47 |
+
from .kg.milvus_impl import MilvusVectorDBStorge
|
48 |
+
|
49 |
# future KG integrations
|
50 |
|
51 |
# from .kg.ArangoDB_impl import (
|
|
|
230 |
# vector storage
|
231 |
"NanoVectorDBStorage": NanoVectorDBStorage,
|
232 |
"OracleVectorDBStorage": OracleVectorDBStorage,
|
233 |
+
"MilvusVectorDBStorge": MilvusVectorDBStorge,
|
234 |
# graph storage
|
235 |
"NetworkXStorage": NetworkXStorage,
|
236 |
"Neo4JStorage": Neo4JStorage,
|
requirements.txt
CHANGED
@@ -11,6 +11,7 @@ networkx
|
|
11 |
ollama
|
12 |
openai
|
13 |
oracledb
|
|
|
14 |
pyvis
|
15 |
tenacity
|
16 |
# lmdeploy[all]
|
|
|
11 |
ollama
|
12 |
openai
|
13 |
oracledb
|
14 |
+
pymilvus
|
15 |
pyvis
|
16 |
tenacity
|
17 |
# lmdeploy[all]
|