zrguo commited on
Commit
7b6db35
·
2 Parent(s): adbc3bf e03ffeb

Merge pull request #393 from partoneplay/main

Browse files
examples/lightrag_ollama_neo4j_milvus_demo.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from lightrag import LightRAG, QueryParam
3
+ from lightrag.llm import ollama_model_complete, ollama_embed
4
+ from lightrag.utils import EmbeddingFunc
5
+
6
+ # WorkingDir
7
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
8
+ WORKING_DIR = os.path.join(ROOT_DIR, "myKG")
9
+ if not os.path.exists(WORKING_DIR):
10
+ os.mkdir(WORKING_DIR)
11
+ print(f"WorkingDir: {WORKING_DIR}")
12
+
13
+ # neo4j
14
+ BATCH_SIZE_NODES = 500
15
+ BATCH_SIZE_EDGES = 100
16
+ os.environ["NEO4J_URI"] = "bolt://localhost:7687"
17
+ os.environ["NEO4J_USERNAME"] = "neo4j"
18
+ os.environ["NEO4J_PASSWORD"] = "neo4j"
19
+
20
+ # milvus
21
+ os.environ["MILVUS_URI"] = "http://localhost:19530"
22
+ os.environ["MILVUS_USER"] = "root"
23
+ os.environ["MILVUS_PASSWORD"] = "root"
24
+ os.environ["MILVUS_DB_NAME"] = "lightrag"
25
+
26
+
27
+ rag = LightRAG(
28
+ working_dir=WORKING_DIR,
29
+ llm_model_func=ollama_model_complete,
30
+ llm_model_name="qwen2.5:14b",
31
+ llm_model_max_async=4,
32
+ llm_model_max_token_size=32768,
33
+ llm_model_kwargs={"host": "http://127.0.0.1:11434", "options": {"num_ctx": 32768}},
34
+ embedding_func=EmbeddingFunc(
35
+ embedding_dim=1024,
36
+ max_token_size=8192,
37
+ func=lambda texts: ollama_embed(
38
+ texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
39
+ ),
40
+ ),
41
+ graph_storage="Neo4JStorage",
42
+ vector_storage="MilvusVectorDBStorge",
43
+ )
44
+
45
+ file = "./book.txt"
46
+ with open(file, "r") as f:
47
+ rag.insert(f.read())
48
+
49
+ print(
50
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
51
+ )
lightrag/kg/milvus_impl.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ from tqdm.asyncio import tqdm as tqdm_async
4
+ from dataclasses import dataclass
5
+ import numpy as np
6
+ from lightrag.utils import logger
7
+ from ..base import BaseVectorStorage
8
+
9
+ from pymilvus import MilvusClient
10
+
11
+
12
+ @dataclass
13
+ class MilvusVectorDBStorge(BaseVectorStorage):
14
+ @staticmethod
15
+ def create_collection_if_not_exist(
16
+ client: MilvusClient, collection_name: str, **kwargs
17
+ ):
18
+ if client.has_collection(collection_name):
19
+ return
20
+ client.create_collection(
21
+ collection_name, max_length=64, id_type="string", **kwargs
22
+ )
23
+
24
+ def __post_init__(self):
25
+ self._client = MilvusClient(
26
+ uri=os.environ.get(
27
+ "MILVUS_URI",
28
+ os.path.join(self.global_config["working_dir"], "milvus_lite.db"),
29
+ ),
30
+ user=os.environ.get("MILVUS_USER", ""),
31
+ password=os.environ.get("MILVUS_PASSWORD", ""),
32
+ token=os.environ.get("MILVUS_TOKEN", ""),
33
+ db_name=os.environ.get("MILVUS_DB_NAME", ""),
34
+ )
35
+ self._max_batch_size = self.global_config["embedding_batch_num"]
36
+ MilvusVectorDBStorge.create_collection_if_not_exist(
37
+ self._client,
38
+ self.namespace,
39
+ dimension=self.embedding_func.embedding_dim,
40
+ )
41
+
42
+ async def upsert(self, data: dict[str, dict]):
43
+ logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
44
+ if not len(data):
45
+ logger.warning("You insert an empty data to vector DB")
46
+ return []
47
+ list_data = [
48
+ {
49
+ "id": k,
50
+ **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
51
+ }
52
+ for k, v in data.items()
53
+ ]
54
+ contents = [v["content"] for v in data.values()]
55
+ batches = [
56
+ contents[i : i + self._max_batch_size]
57
+ for i in range(0, len(contents), self._max_batch_size)
58
+ ]
59
+ embedding_tasks = [self.embedding_func(batch) for batch in batches]
60
+ embeddings_list = []
61
+ for f in tqdm_async(
62
+ asyncio.as_completed(embedding_tasks),
63
+ total=len(embedding_tasks),
64
+ desc="Generating embeddings",
65
+ unit="batch",
66
+ ):
67
+ embeddings = await f
68
+ embeddings_list.append(embeddings)
69
+ embeddings = np.concatenate(embeddings_list)
70
+ for i, d in enumerate(list_data):
71
+ d["vector"] = embeddings[i]
72
+ results = self._client.upsert(collection_name=self.namespace, data=list_data)
73
+ return results
74
+
75
+ async def query(self, query, top_k=5):
76
+ embedding = await self.embedding_func([query])
77
+ results = self._client.search(
78
+ collection_name=self.namespace,
79
+ data=embedding,
80
+ limit=top_k,
81
+ output_fields=list(self.meta_fields),
82
+ search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
83
+ )
84
+ print(results)
85
+ return [
86
+ {**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
87
+ for dp in results[0]
88
+ ]
lightrag/lightrag.py CHANGED
@@ -44,6 +44,8 @@ from .kg.neo4j_impl import Neo4JStorage
44
 
45
  from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
46
 
 
 
47
  # future KG integrations
48
 
49
  # from .kg.ArangoDB_impl import (
@@ -228,6 +230,7 @@ class LightRAG:
228
  # vector storage
229
  "NanoVectorDBStorage": NanoVectorDBStorage,
230
  "OracleVectorDBStorage": OracleVectorDBStorage,
 
231
  # graph storage
232
  "NetworkXStorage": NetworkXStorage,
233
  "Neo4JStorage": Neo4JStorage,
 
44
 
45
  from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
46
 
47
+ from .kg.milvus_impl import MilvusVectorDBStorge
48
+
49
  # future KG integrations
50
 
51
  # from .kg.ArangoDB_impl import (
 
230
  # vector storage
231
  "NanoVectorDBStorage": NanoVectorDBStorage,
232
  "OracleVectorDBStorage": OracleVectorDBStorage,
233
+ "MilvusVectorDBStorge": MilvusVectorDBStorge,
234
  # graph storage
235
  "NetworkXStorage": NetworkXStorage,
236
  "Neo4JStorage": Neo4JStorage,
requirements.txt CHANGED
@@ -11,6 +11,7 @@ networkx
11
  ollama
12
  openai
13
  oracledb
 
14
  pyvis
15
  tenacity
16
  # lmdeploy[all]
 
11
  ollama
12
  openai
13
  oracledb
14
+ pymilvus
15
  pyvis
16
  tenacity
17
  # lmdeploy[all]