hyb182 commited on
Commit
75922df
·
1 Parent(s): 2a71867

feat: 增加redis KV存储,增加openai+neo4j+milvus+redis的demo测试,新增lightrag.py: RedisKVStorage,新增requirements.txt:aioredis依赖

Browse files
examples/lightrag_openai_neo4j_milvus_redis_demo.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from lightrag import LightRAG, QueryParam
3
+ from lightrag.llm import ollama_embed, openai_complete_if_cache
4
+ from lightrag.utils import EmbeddingFunc
5
+
6
+ # WorkingDir
7
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
8
+ WORKING_DIR = os.path.join(ROOT_DIR, "myKG")
9
+ if not os.path.exists(WORKING_DIR):
10
+ os.mkdir(WORKING_DIR)
11
+ print(f"WorkingDir: {WORKING_DIR}")
12
+
13
+ # redis
14
+ os.environ["REDIS_URI"] = "redis://localhost:6379"
15
+
16
+ # neo4j
17
+ BATCH_SIZE_NODES = 500
18
+ BATCH_SIZE_EDGES = 100
19
+ os.environ["NEO4J_URI"] = "bolt://117.50.173.35:7687"
20
+ os.environ["NEO4J_USERNAME"] = "neo4j"
21
+ os.environ["NEO4J_PASSWORD"] = "12345678"
22
+
23
+ # milvus
24
+ os.environ["MILVUS_URI"] = "http://117.50.173.35:19530"
25
+ os.environ["MILVUS_USER"] = "root"
26
+ os.environ["MILVUS_PASSWORD"] = "Milvus"
27
+ os.environ["MILVUS_DB_NAME"] = "lightrag"
28
+
29
+
30
+ async def llm_model_func(
31
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
32
+ ) -> str:
33
+ return await openai_complete_if_cache(
34
+ "deepseek-chat",
35
+ prompt,
36
+ system_prompt=system_prompt,
37
+ history_messages=history_messages,
38
+ api_key="sk-91d0b59f25554251aa813ed756d79a6d",
39
+ base_url="https://api.deepseek.com",
40
+ **kwargs,
41
+ )
42
+
43
+
44
+ embedding_func = EmbeddingFunc(
45
+ embedding_dim=768,
46
+ max_token_size=512,
47
+ func=lambda texts: ollama_embed(
48
+ texts, embed_model="shaw/dmeta-embedding-zh", host="http://117.50.173.35:11434"
49
+ ),
50
+ )
51
+
52
+ rag = LightRAG(
53
+ working_dir=WORKING_DIR,
54
+ llm_model_func=llm_model_func,
55
+ llm_model_max_token_size=32768,
56
+ embedding_func=embedding_func,
57
+ chunk_token_size=512,
58
+ chunk_overlap_token_size=256,
59
+ kv_storage="RedisKVStorage",
60
+ graph_storage="Neo4JStorage",
61
+ vector_storage="MilvusVectorDBStorge",
62
+ doc_status_storage="RedisKVStorage",
63
+ )
64
+
65
+ file = "../book.txt"
66
+ with open(file, "r", encoding="utf-8") as f:
67
+ rag.insert(f.read())
68
+
69
+ print(rag.query("谁会3D建模 ?", param=QueryParam(mode="mix")))
lightrag/kg/redis_impl.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm.asyncio import tqdm as tqdm_async
3
+ from dataclasses import dataclass
4
+ import aioredis
5
+ from lightrag.utils import logger
6
+ from lightrag.base import BaseKVStorage
7
+ import json
8
+
9
+
10
+ @dataclass
11
+ class RedisKVStorage(BaseKVStorage):
12
+ def __post_init__(self):
13
+ redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
14
+ self._redis = aioredis.from_url(redis_url, decode_responses=True)
15
+ logger.info(f"Use Redis as KV {self.namespace}")
16
+
17
+ async def all_keys(self) -> list[str]:
18
+ keys = await self._redis.keys(f"{self.namespace}:*")
19
+ return [key.split(":", 1)[-1] for key in keys]
20
+
21
+ async def get_by_id(self, id):
22
+ data = await self._redis.get(f"{self.namespace}:{id}")
23
+ return json.loads(data) if data else None
24
+
25
+ async def get_by_ids(self, ids, fields=None):
26
+ pipe = self._redis.pipeline()
27
+ for id in ids:
28
+ pipe.get(f"{self.namespace}:{id}")
29
+ results = await pipe.execute()
30
+
31
+ if fields:
32
+ # Filter fields if specified
33
+ return [
34
+ {field: value.get(field) for field in fields if field in value}
35
+ if (value := json.loads(result))
36
+ else None
37
+ for result in results
38
+ ]
39
+
40
+ return [json.loads(result) if result else None for result in results]
41
+
42
+ async def filter_keys(self, data: list[str]) -> set[str]:
43
+ pipe = self._redis.pipeline()
44
+ for key in data:
45
+ pipe.exists(f"{self.namespace}:{key}")
46
+ results = await pipe.execute()
47
+
48
+ existing_ids = {data[i] for i, exists in enumerate(results) if exists}
49
+ return set(data) - existing_ids
50
+
51
+ async def upsert(self, data: dict[str, dict]):
52
+ pipe = self._redis.pipeline()
53
+ for k, v in tqdm_async(data.items(), desc="Upserting"):
54
+ pipe.set(f"{self.namespace}:{k}", json.dumps(v))
55
+ await pipe.execute()
56
+
57
+ for k in data:
58
+ data[k]["_id"] = k
59
+ return data
60
+
61
+ async def drop(self):
62
+ keys = await self._redis.keys(f"{self.namespace}:*")
63
+ if keys:
64
+ await self._redis.delete(*keys)
lightrag/lightrag.py CHANGED
@@ -52,6 +52,7 @@ STORAGES = {
52
  "OracleVectorDBStorage": ".kg.oracle_impl",
53
  "MilvusVectorDBStorge": ".kg.milvus_impl",
54
  "MongoKVStorage": ".kg.mongo_impl",
 
55
  "ChromaVectorDBStorage": ".kg.chroma_impl",
56
  "TiDBKVStorage": ".kg.tidb_impl",
57
  "TiDBVectorDBStorage": ".kg.tidb_impl",
 
52
  "OracleVectorDBStorage": ".kg.oracle_impl",
53
  "MilvusVectorDBStorge": ".kg.milvus_impl",
54
  "MongoKVStorage": ".kg.mongo_impl",
55
+ "RedisKVStorage": ".kg.redis_impl",
56
  "ChromaVectorDBStorage": ".kg.chroma_impl",
57
  "TiDBKVStorage": ".kg.tidb_impl",
58
  "TiDBVectorDBStorage": ".kg.tidb_impl",
requirements.txt CHANGED
@@ -2,6 +2,7 @@ accelerate
2
  aioboto3
3
  aiofiles
4
  aiohttp
 
5
  asyncpg
6
 
7
  # database packages
 
2
  aioboto3
3
  aiofiles
4
  aiohttp
5
+ aioredis
6
  asyncpg
7
 
8
  # database packages