zrguo commited on
Commit
b2a4240
·
2 Parent(s): 2cf0f7e 5dcb28f

Merge pull request #257 from jin38324/main

Browse files

Add Oracle database as all type of storage (KV/vector/graph)

.gitignore CHANGED
@@ -10,4 +10,5 @@ local_neo4jWorkDir/
10
  neo4jWorkDir/
11
  ignore_this.txt
12
  .venv/
 
13
  .ruff_cache/
 
10
  neo4jWorkDir/
11
  ignore_this.txt
12
  .venv/
13
+ *.ignore.*
14
  .ruff_cache/
README.md CHANGED
@@ -22,6 +22,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
22
  </div>
23
 
24
  ## 🎉 News
 
25
  - [x] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete-entity).
26
  - [x] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
27
  - [x] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
 
22
  </div>
23
 
24
  ## 🎉 News
25
+ - [x] [2024.11.12]🎯📢You can [use Oracle Database 23ai for all storage types (kv/vector/graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py) now.
26
  - [x] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete-entity).
27
  - [x] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
28
  - [x] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
examples/lightrag_api_oracle_demo..py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, File, UploadFile
2
+ from contextlib import asynccontextmanager
3
+ from pydantic import BaseModel
4
+ from typing import Optional
5
+
6
+ import sys
7
+ import os
8
+ from pathlib import Path
9
+
10
+ import asyncio
11
+ import nest_asyncio
12
+ from lightrag import LightRAG, QueryParam
13
+ from lightrag.llm import openai_complete_if_cache, openai_embedding
14
+ from lightrag.utils import EmbeddingFunc
15
+ import numpy as np
16
+
17
+ from lightrag.kg.oracle_impl import OracleDB
18
+
19
+
20
+ print(os.getcwd())
21
+
22
+ script_directory = Path(__file__).resolve().parent.parent
23
+ sys.path.append(os.path.abspath(script_directory))
24
+
25
+
26
+ # Apply nest_asyncio to solve event loop issues
27
+ nest_asyncio.apply()
28
+
29
+ DEFAULT_RAG_DIR = "index_default"
30
+
31
+
32
+ # We use OpenAI compatible API to call LLM on Oracle Cloud
33
+ # More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
34
+ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
35
+ APIKEY = "ocigenerativeai"
36
+
37
+ # Configure working directory
38
+ WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
39
+ print(f"WORKING_DIR: {WORKING_DIR}")
40
+ LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus")
41
+ print(f"LLM_MODEL: {LLM_MODEL}")
42
+ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
43
+ print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
44
+ EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
45
+ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
46
+
47
+
48
+ if not os.path.exists(WORKING_DIR):
49
+ os.mkdir(WORKING_DIR)
50
+
51
+
52
+ async def llm_model_func(
53
+ prompt, system_prompt=None, history_messages=[], **kwargs
54
+ ) -> str:
55
+ return await openai_complete_if_cache(
56
+ LLM_MODEL,
57
+ prompt,
58
+ system_prompt=system_prompt,
59
+ history_messages=history_messages,
60
+ api_key=APIKEY,
61
+ base_url=BASE_URL,
62
+ **kwargs,
63
+ )
64
+
65
+
66
+ async def embedding_func(texts: list[str]) -> np.ndarray:
67
+ return await openai_embedding(
68
+ texts,
69
+ model=EMBEDDING_MODEL,
70
+ api_key=APIKEY,
71
+ base_url=BASE_URL,
72
+ )
73
+
74
+
75
+ async def get_embedding_dim():
76
+ test_text = ["This is a test sentence."]
77
+ embedding = await embedding_func(test_text)
78
+ embedding_dim = embedding.shape[1]
79
+ return embedding_dim
80
+
81
+
82
+ async def init():
83
+ # Detect embedding dimension
84
+ embedding_dimension = await get_embedding_dim()
85
+ print(f"Detected embedding dimension: {embedding_dimension}")
86
+ # Create Oracle DB connection
87
+ # The `config` parameter is the connection configuration of Oracle DB
88
+ # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
89
+ # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
90
+ # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
91
+
92
+ oracle_db = OracleDB(
93
+ config={
94
+ "user": "",
95
+ "password": "",
96
+ "dsn": "",
97
+ "config_dir": "",
98
+ "wallet_location": "",
99
+ "wallet_password": "",
100
+ "workspace": "",
101
+ } # specify which docs you want to store and query
102
+ )
103
+
104
+ # Check if Oracle DB tables exist, if not, tables will be created
105
+ await oracle_db.check_tables()
106
+ # Initialize LightRAG
107
+ # We use Oracle DB as the KV/vector/graph storage
108
+ rag = LightRAG(
109
+ enable_llm_cache=False,
110
+ working_dir=WORKING_DIR,
111
+ chunk_token_size=512,
112
+ llm_model_func=llm_model_func,
113
+ embedding_func=EmbeddingFunc(
114
+ embedding_dim=embedding_dimension,
115
+ max_token_size=512,
116
+ func=embedding_func,
117
+ ),
118
+ graph_storage="OracleGraphStorage",
119
+ kv_storage="OracleKVStorage",
120
+ vector_storage="OracleVectorDBStorage",
121
+ )
122
+
123
+ # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
124
+ rag.graph_storage_cls.db = oracle_db
125
+ rag.key_string_value_json_storage_cls.db = oracle_db
126
+ rag.vector_db_storage_cls.db = oracle_db
127
+
128
+ return rag
129
+
130
+
131
+ # Data models
132
+
133
+
134
+ class QueryRequest(BaseModel):
135
+ query: str
136
+ mode: str = "hybrid"
137
+ only_need_context: bool = False
138
+
139
+
140
+ class InsertRequest(BaseModel):
141
+ text: str
142
+
143
+
144
+ class Response(BaseModel):
145
+ status: str
146
+ data: Optional[str] = None
147
+ message: Optional[str] = None
148
+
149
+
150
+ # API routes
151
+
152
+ rag = None # 定义为全局对象
153
+
154
+
155
+ @asynccontextmanager
156
+ async def lifespan(app: FastAPI):
157
+ global rag
158
+ rag = await init() # 在应用启动时初始化 `rag`
159
+ print("done!")
160
+ yield
161
+
162
+
163
+ app = FastAPI(
164
+ title="LightRAG API", description="API for RAG operations", lifespan=lifespan
165
+ )
166
+
167
+
168
+ @app.post("/query", response_model=Response)
169
+ async def query_endpoint(request: QueryRequest):
170
+ try:
171
+ # loop = asyncio.get_event_loop()
172
+ result = await rag.aquery(
173
+ request.query,
174
+ param=QueryParam(
175
+ mode=request.mode, only_need_context=request.only_need_context
176
+ ),
177
+ )
178
+ return Response(status="success", data=result)
179
+ except Exception as e:
180
+ raise HTTPException(status_code=500, detail=str(e))
181
+
182
+
183
+ @app.post("/insert", response_model=Response)
184
+ async def insert_endpoint(request: InsertRequest):
185
+ try:
186
+ loop = asyncio.get_event_loop()
187
+ await loop.run_in_executor(None, lambda: rag.insert(request.text))
188
+ return Response(status="success", message="Text inserted successfully")
189
+ except Exception as e:
190
+ raise HTTPException(status_code=500, detail=str(e))
191
+
192
+
193
+ @app.post("/insert_file", response_model=Response)
194
+ async def insert_file(file: UploadFile = File(...)):
195
+ try:
196
+ file_content = await file.read()
197
+ # Read file content
198
+ try:
199
+ content = file_content.decode("utf-8")
200
+ except UnicodeDecodeError:
201
+ # If UTF-8 decoding fails, try other encodings
202
+ content = file_content.decode("gbk")
203
+ # Insert file content
204
+ loop = asyncio.get_event_loop()
205
+ await loop.run_in_executor(None, lambda: rag.insert(content))
206
+
207
+ return Response(
208
+ status="success",
209
+ message=f"File content from {file.filename} inserted successfully",
210
+ )
211
+ except Exception as e:
212
+ raise HTTPException(status_code=500, detail=str(e))
213
+
214
+
215
+ @app.get("/health")
216
+ async def health_check():
217
+ return {"status": "healthy"}
218
+
219
+
220
+ if __name__ == "__main__":
221
+ import uvicorn
222
+
223
+ uvicorn.run(app, host="0.0.0.0", port=8020)
224
+
225
+ # Usage example
226
+ # To run the server, use the following command in your terminal:
227
+ # python lightrag_api_openai_compatible_demo.py
228
+
229
+ # Example requests:
230
+ # 1. Query:
231
+ # curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
232
+
233
+ # 2. Insert text:
234
+ # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
235
+
236
+ # 3. Insert file:
237
+ # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
238
+
239
+ # 4. Health check:
240
+ # curl -X GET "http://127.0.0.1:8020/health"
examples/lightrag_oracle_demo.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from pathlib import Path
4
+ import asyncio
5
+ from lightrag import LightRAG, QueryParam
6
+ from lightrag.llm import openai_complete_if_cache, openai_embedding
7
+ from lightrag.utils import EmbeddingFunc
8
+ import numpy as np
9
+ from lightrag.kg.oracle_impl import OracleDB
10
+
11
+ print(os.getcwd())
12
+ script_directory = Path(__file__).resolve().parent.parent
13
+ sys.path.append(os.path.abspath(script_directory))
14
+
15
+ WORKING_DIR = "./dickens"
16
+
17
+ # We use OpenAI compatible API to call LLM on Oracle Cloud
18
+ # More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
19
+ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
20
+ APIKEY = "ocigenerativeai"
21
+ CHATMODEL = "cohere.command-r-plus"
22
+ EMBEDMODEL = "cohere.embed-multilingual-v3.0"
23
+
24
+
25
+ if not os.path.exists(WORKING_DIR):
26
+ os.mkdir(WORKING_DIR)
27
+
28
+
29
+ async def llm_model_func(
30
+ prompt, system_prompt=None, history_messages=[], **kwargs
31
+ ) -> str:
32
+ return await openai_complete_if_cache(
33
+ CHATMODEL,
34
+ prompt,
35
+ system_prompt=system_prompt,
36
+ history_messages=history_messages,
37
+ api_key=APIKEY,
38
+ base_url=BASE_URL,
39
+ **kwargs,
40
+ )
41
+
42
+
43
+ async def embedding_func(texts: list[str]) -> np.ndarray:
44
+ return await openai_embedding(
45
+ texts,
46
+ model=EMBEDMODEL,
47
+ api_key=APIKEY,
48
+ base_url=BASE_URL,
49
+ )
50
+
51
+
52
+ async def get_embedding_dim():
53
+ test_text = ["This is a test sentence."]
54
+ embedding = await embedding_func(test_text)
55
+ embedding_dim = embedding.shape[1]
56
+ return embedding_dim
57
+
58
+
59
+ async def main():
60
+ try:
61
+ # Detect embedding dimension
62
+ embedding_dimension = await get_embedding_dim()
63
+ print(f"Detected embedding dimension: {embedding_dimension}")
64
+
65
+ # Create Oracle DB connection
66
+ # The `config` parameter is the connection configuration of Oracle DB
67
+ # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
68
+ # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
69
+ # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
70
+ oracle_db = OracleDB(
71
+ config={
72
+ "user": "username",
73
+ "password": "xxxxxxxxx",
74
+ "dsn": "xxxxxxx_medium",
75
+ "config_dir": "dir/path/to/oracle/config",
76
+ "wallet_location": "dir/path/to/oracle/wallet",
77
+ "wallet_password": "xxxxxxxxx",
78
+ "workspace": "company", # specify which docs you want to store and query
79
+ }
80
+ )
81
+
82
+ # Check if Oracle DB tables exist, if not, tables will be created
83
+ await oracle_db.check_tables()
84
+
85
+ # Initialize LightRAG
86
+ # We use Oracle DB as the KV/vector/graph storage
87
+ rag = LightRAG(
88
+ enable_llm_cache=False,
89
+ working_dir=WORKING_DIR,
90
+ chunk_token_size=512,
91
+ llm_model_func=llm_model_func,
92
+ embedding_func=EmbeddingFunc(
93
+ embedding_dim=embedding_dimension,
94
+ max_token_size=512,
95
+ func=embedding_func,
96
+ ),
97
+ graph_storage="OracleGraphStorage",
98
+ kv_storage="OracleKVStorage",
99
+ vector_storage="OracleVectorDBStorage",
100
+ )
101
+
102
+ # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
103
+ rag.graph_storage_cls.db = oracle_db
104
+ rag.key_string_value_json_storage_cls.db = oracle_db
105
+ rag.vector_db_storage_cls.db = oracle_db
106
+
107
+ # Extract and Insert into LightRAG storage
108
+ with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
109
+ await rag.ainsert(f.read())
110
+
111
+ # Perform search in different modes
112
+ modes = ["naive", "local", "global", "hybrid"]
113
+ for mode in modes:
114
+ print("=" * 20, mode, "=" * 20)
115
+ print(
116
+ await rag.aquery(
117
+ "What are the top themes in this story?",
118
+ param=QueryParam(mode=mode),
119
+ )
120
+ )
121
+ print("-" * 100, "\n")
122
+
123
+ except Exception as e:
124
+ print(f"An error occurred: {e}")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ asyncio.run(main())
lightrag/base.py CHANGED
@@ -59,6 +59,8 @@ class BaseVectorStorage(StorageNameSpace):
59
 
60
  @dataclass
61
  class BaseKVStorage(Generic[T], StorageNameSpace):
 
 
62
  async def all_keys(self) -> list[str]:
63
  raise NotImplementedError
64
 
@@ -83,6 +85,8 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
83
 
84
  @dataclass
85
  class BaseGraphStorage(StorageNameSpace):
 
 
86
  async def has_node(self, node_id: str) -> bool:
87
  raise NotImplementedError
88
 
 
59
 
60
  @dataclass
61
  class BaseKVStorage(Generic[T], StorageNameSpace):
62
+ embedding_func: EmbeddingFunc
63
+
64
  async def all_keys(self) -> list[str]:
65
  raise NotImplementedError
66
 
 
85
 
86
  @dataclass
87
  class BaseGraphStorage(StorageNameSpace):
88
+ embedding_func: EmbeddingFunc = None
89
+
90
  async def has_node(self, node_id: str) -> bool:
91
  raise NotImplementedError
92
 
lightrag/kg/oracle_impl.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ # import html
4
+ # import os
5
+ from dataclasses import dataclass
6
+ from typing import Union
7
+ import numpy as np
8
+ import array
9
+
10
+ from ..utils import logger
11
+ from ..base import (
12
+ BaseGraphStorage,
13
+ BaseKVStorage,
14
+ BaseVectorStorage,
15
+ )
16
+
17
+ import oracledb
18
+
19
+
20
+ class OracleDB:
21
+ def __init__(self, config, **kwargs):
22
+ self.host = config.get("host", None)
23
+ self.port = config.get("port", None)
24
+ self.user = config.get("user", None)
25
+ self.password = config.get("password", None)
26
+ self.dsn = config.get("dsn", None)
27
+ self.config_dir = config.get("config_dir", None)
28
+ self.wallet_location = config.get("wallet_location", None)
29
+ self.wallet_password = config.get("wallet_password", None)
30
+ self.workspace = config.get("workspace", None)
31
+ self.max = 12
32
+ self.increment = 1
33
+ logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
34
+ if self.user is None or self.password is None:
35
+ raise ValueError("Missing database user or password in addon_params")
36
+
37
+ try:
38
+ oracledb.defaults.fetch_lobs = False
39
+
40
+ self.pool = oracledb.create_pool_async(
41
+ user=self.user,
42
+ password=self.password,
43
+ dsn=self.dsn,
44
+ config_dir=self.config_dir,
45
+ wallet_location=self.wallet_location,
46
+ wallet_password=self.wallet_password,
47
+ min=1,
48
+ max=self.max,
49
+ increment=self.increment,
50
+ )
51
+ logger.info(f"Connected to Oracle database at {self.dsn}")
52
+ except Exception as e:
53
+ logger.error(f"Failed to connect to Oracle database at {self.dsn}")
54
+ logger.error(f"Oracle database error: {e}")
55
+ raise
56
+
57
+ def numpy_converter_in(self, value):
58
+ """Convert numpy array to array.array"""
59
+ if value.dtype == np.float64:
60
+ dtype = "d"
61
+ elif value.dtype == np.float32:
62
+ dtype = "f"
63
+ else:
64
+ dtype = "b"
65
+ return array.array(dtype, value)
66
+
67
+ def input_type_handler(self, cursor, value, arraysize):
68
+ """Set the type handler for the input data"""
69
+ if isinstance(value, np.ndarray):
70
+ return cursor.var(
71
+ oracledb.DB_TYPE_VECTOR,
72
+ arraysize=arraysize,
73
+ inconverter=self.numpy_converter_in,
74
+ )
75
+
76
+ def numpy_converter_out(self, value):
77
+ """Convert array.array to numpy array"""
78
+ if value.typecode == "b":
79
+ dtype = np.int8
80
+ elif value.typecode == "f":
81
+ dtype = np.float32
82
+ else:
83
+ dtype = np.float64
84
+ return np.array(value, copy=False, dtype=dtype)
85
+
86
+ def output_type_handler(self, cursor, metadata):
87
+ """Set the type handler for the output data"""
88
+ if metadata.type_code is oracledb.DB_TYPE_VECTOR:
89
+ return cursor.var(
90
+ metadata.type_code,
91
+ arraysize=cursor.arraysize,
92
+ outconverter=self.numpy_converter_out,
93
+ )
94
+
95
+ async def check_tables(self):
96
+ for k, v in TABLES.items():
97
+ try:
98
+ if k.lower() == "lightrag_graph":
99
+ await self.query(
100
+ "SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
101
+ )
102
+ else:
103
+ await self.query("SELECT 1 FROM {k}".format(k=k))
104
+ except Exception as e:
105
+ logger.error(f"Failed to check table {k} in Oracle database")
106
+ logger.error(f"Oracle database error: {e}")
107
+ try:
108
+ # print(v["ddl"])
109
+ await self.execute(v["ddl"])
110
+ logger.info(f"Created table {k} in Oracle database")
111
+ except Exception as e:
112
+ logger.error(f"Failed to create table {k} in Oracle database")
113
+ logger.error(f"Oracle database error: {e}")
114
+
115
+ logger.info("Finished check all tables in Oracle database")
116
+
117
+ async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
118
+ async with self.pool.acquire() as connection:
119
+ connection.inputtypehandler = self.input_type_handler
120
+ connection.outputtypehandler = self.output_type_handler
121
+ with connection.cursor() as cursor:
122
+ try:
123
+ await cursor.execute(sql)
124
+ except Exception as e:
125
+ logger.error(f"Oracle database error: {e}")
126
+ print(sql)
127
+ raise
128
+ columns = [column[0].lower() for column in cursor.description]
129
+ if multirows:
130
+ rows = await cursor.fetchall()
131
+ if rows:
132
+ data = [dict(zip(columns, row)) for row in rows]
133
+ else:
134
+ data = []
135
+ else:
136
+ row = await cursor.fetchone()
137
+ if row:
138
+ data = dict(zip(columns, row))
139
+ else:
140
+ data = None
141
+ return data
142
+
143
+ async def execute(self, sql: str, data: list = None):
144
+ # logger.info("go into OracleDB execute method")
145
+ try:
146
+ async with self.pool.acquire() as connection:
147
+ connection.inputtypehandler = self.input_type_handler
148
+ connection.outputtypehandler = self.output_type_handler
149
+ with connection.cursor() as cursor:
150
+ if data is None:
151
+ await cursor.execute(sql)
152
+ else:
153
+ # print(data)
154
+ # print(sql)
155
+ await cursor.execute(sql, data)
156
+ await connection.commit()
157
+ except Exception as e:
158
+ logger.error(f"Oracle database error: {e}")
159
+ print(sql)
160
+ print(data)
161
+ raise
162
+
163
+
164
+ @dataclass
165
+ class OracleKVStorage(BaseKVStorage):
166
+ # should pass db object to self.db
167
+ def __post_init__(self):
168
+ self._data = {}
169
+ self._max_batch_size = self.global_config["embedding_batch_num"]
170
+
171
+ ################ QUERY METHODS ################
172
+
173
+ async def get_by_id(self, id: str) -> Union[dict, None]:
174
+ """根据 id 获取 doc_full 数据."""
175
+ SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
176
+ workspace=self.db.workspace, id=id
177
+ )
178
+ # print("get_by_id:"+SQL)
179
+ res = await self.db.query(SQL)
180
+ if res:
181
+ data = res # {"data":res}
182
+ # print (data)
183
+ return data
184
+ else:
185
+ return None
186
+
187
+ # Query by id
188
+ async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
189
+ """根据 id 获取 doc_chunks 数据"""
190
+ SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
191
+ workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
192
+ )
193
+ # print("get_by_ids:"+SQL)
194
+ res = await self.db.query(SQL, multirows=True)
195
+ if res:
196
+ data = res # [{"data":i} for i in res]
197
+ # print(data)
198
+ return data
199
+ else:
200
+ return None
201
+
202
+ async def filter_keys(self, keys: list[str]) -> set[str]:
203
+ """过滤掉重复内容"""
204
+ SQL = SQL_TEMPLATES["filter_keys"].format(
205
+ table_name=N_T[self.namespace],
206
+ workspace=self.db.workspace,
207
+ ids=",".join([f"'{k}'" for k in keys]),
208
+ )
209
+ res = await self.db.query(SQL, multirows=True)
210
+ data = None
211
+ if res:
212
+ exist_keys = [key["id"] for key in res]
213
+ data = set([s for s in keys if s not in exist_keys])
214
+ else:
215
+ exist_keys = []
216
+ data = set([s for s in keys if s not in exist_keys])
217
+ return data
218
+
219
+ ################ INSERT METHODS ################
220
+ async def upsert(self, data: dict[str, dict]):
221
+ left_data = {k: v for k, v in data.items() if k not in self._data}
222
+ self._data.update(left_data)
223
+ # print(self._data)
224
+ # values = []
225
+ if self.namespace == "text_chunks":
226
+ list_data = [
227
+ {
228
+ "__id__": k,
229
+ **{k1: v1 for k1, v1 in v.items()},
230
+ }
231
+ for k, v in data.items()
232
+ ]
233
+ contents = [v["content"] for v in data.values()]
234
+ batches = [
235
+ contents[i : i + self._max_batch_size]
236
+ for i in range(0, len(contents), self._max_batch_size)
237
+ ]
238
+ embeddings_list = await asyncio.gather(
239
+ *[self.embedding_func(batch) for batch in batches]
240
+ )
241
+ embeddings = np.concatenate(embeddings_list)
242
+ for i, d in enumerate(list_data):
243
+ d["__vector__"] = embeddings[i]
244
+ # print(list_data)
245
+ for item in list_data:
246
+ merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
247
+
248
+ values = [
249
+ item["__id__"],
250
+ item["content"],
251
+ self.db.workspace,
252
+ item["tokens"],
253
+ item["chunk_order_index"],
254
+ item["full_doc_id"],
255
+ item["__vector__"],
256
+ ]
257
+ # print(merge_sql)
258
+ await self.db.execute(merge_sql, values)
259
+
260
+ if self.namespace == "full_docs":
261
+ for k, v in self._data.items():
262
+ # values.clear()
263
+ merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
264
+ check_id=k,
265
+ )
266
+ values = [k, self._data[k]["content"], self.db.workspace]
267
+ # print(merge_sql)
268
+ await self.db.execute(merge_sql, values)
269
+ return left_data
270
+
271
+ async def index_done_callback(self):
272
+ if self.namespace in ["full_docs", "text_chunks"]:
273
+ logger.info("full doc and chunk data had been saved into oracle db!")
274
+
275
+
276
+ @dataclass
277
+ class OracleVectorDBStorage(BaseVectorStorage):
278
+ cosine_better_than_threshold: float = 0.2
279
+
280
+ def __post_init__(self):
281
+ pass
282
+
283
+ async def upsert(self, data: dict[str, dict]):
284
+ """向向量数据库中插入数据"""
285
+ pass
286
+
287
+ async def index_done_callback(self):
288
+ pass
289
+
290
+ #################### query method ###############
291
+ async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
292
+ """从向量数据库中查询数据"""
293
+ embeddings = await self.embedding_func([query])
294
+ embedding = embeddings[0]
295
+ # 转换精度
296
+ dtype = str(embedding.dtype).upper()
297
+ dimension = embedding.shape[0]
298
+ embedding_string = ", ".join(map(str, embedding.tolist()))
299
+
300
+ SQL = SQL_TEMPLATES[self.namespace].format(
301
+ embedding_string=embedding_string,
302
+ dimension=dimension,
303
+ dtype=dtype,
304
+ workspace=self.db.workspace,
305
+ top_k=top_k,
306
+ better_than_threshold=self.cosine_better_than_threshold,
307
+ )
308
+ # print(SQL)
309
+ results = await self.db.query(SQL, multirows=True)
310
+ # print("vector search result:",results)
311
+ return results
312
+
313
+
314
+ @dataclass
315
+ class OracleGraphStorage(BaseGraphStorage):
316
+ """基于Oracle的图存储模块"""
317
+
318
+ def __post_init__(self):
319
+ """从graphml文件加载图"""
320
+ self._max_batch_size = self.global_config["embedding_batch_num"]
321
+
322
+ #################### insert method ################
323
+
324
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
325
+ """插入或更新节点"""
326
+ # print("go into upsert node method")
327
+ entity_name = node_id
328
+ entity_type = node_data["entity_type"]
329
+ description = node_data["description"]
330
+ source_id = node_data["source_id"]
331
+ content = entity_name + description
332
+ contents = [content]
333
+ batches = [
334
+ contents[i : i + self._max_batch_size]
335
+ for i in range(0, len(contents), self._max_batch_size)
336
+ ]
337
+ embeddings_list = await asyncio.gather(
338
+ *[self.embedding_func(batch) for batch in batches]
339
+ )
340
+ embeddings = np.concatenate(embeddings_list)
341
+ content_vector = embeddings[0]
342
+ merge_sql = SQL_TEMPLATES["merge_node"].format(
343
+ workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
344
+ )
345
+ # print(merge_sql)
346
+ await self.db.execute(
347
+ merge_sql,
348
+ [
349
+ self.db.workspace,
350
+ entity_name,
351
+ entity_type,
352
+ description,
353
+ source_id,
354
+ content,
355
+ content_vector,
356
+ ],
357
+ )
358
+ # self._graph.add_node(node_id, **node_data)
359
+
360
+ async def upsert_edge(
361
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
362
+ ):
363
+ """插入或更新边"""
364
+ # print("go into upsert edge method")
365
+ source_name = source_node_id
366
+ target_name = target_node_id
367
+ weight = edge_data["weight"]
368
+ keywords = edge_data["keywords"]
369
+ description = edge_data["description"]
370
+ source_chunk_id = edge_data["source_id"]
371
+ content = keywords + source_name + target_name + description
372
+ contents = [content]
373
+ batches = [
374
+ contents[i : i + self._max_batch_size]
375
+ for i in range(0, len(contents), self._max_batch_size)
376
+ ]
377
+ embeddings_list = await asyncio.gather(
378
+ *[self.embedding_func(batch) for batch in batches]
379
+ )
380
+ embeddings = np.concatenate(embeddings_list)
381
+ content_vector = embeddings[0]
382
+ merge_sql = SQL_TEMPLATES["merge_edge"].format(
383
+ workspace=self.db.workspace,
384
+ source_name=source_name,
385
+ target_name=target_name,
386
+ source_chunk_id=source_chunk_id,
387
+ )
388
+ # print(merge_sql)
389
+ await self.db.execute(
390
+ merge_sql,
391
+ [
392
+ self.db.workspace,
393
+ source_name,
394
+ target_name,
395
+ weight,
396
+ keywords,
397
+ description,
398
+ source_chunk_id,
399
+ content,
400
+ content_vector,
401
+ ],
402
+ )
403
+ # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
404
+
405
+ async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
406
+ """为节点生成向量"""
407
+ if algorithm not in self._node_embed_algorithms:
408
+ raise ValueError(f"Node embedding algorithm {algorithm} not supported")
409
+ return await self._node_embed_algorithms[algorithm]()
410
+
411
+ async def _node2vec_embed(self):
412
+ """为节点生成向量"""
413
+ from graspologic import embed
414
+
415
+ embeddings, nodes = embed.node2vec_embed(
416
+ self._graph,
417
+ **self.config["node2vec_params"],
418
+ )
419
+
420
+ nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
421
+ return embeddings, nodes_ids
422
+
423
+ async def index_done_callback(self):
424
+ """写入graphhml图文件"""
425
+ logger.info(
426
+ "Node and edge data had been saved into oracle db already, so nothing to do here!"
427
+ )
428
+
429
+ #################### query method #################
430
+ async def has_node(self, node_id: str) -> bool:
431
+ """根据节点id检查节点是否存在"""
432
+ SQL = SQL_TEMPLATES["has_node"].format(
433
+ workspace=self.db.workspace, node_id=node_id
434
+ )
435
+ # print(SQL)
436
+ # print(self.db.workspace, node_id)
437
+ res = await self.db.query(SQL)
438
+ if res:
439
+ # print("Node exist!",res)
440
+ return True
441
+ else:
442
+ # print("Node not exist!")
443
+ return False
444
+
445
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
446
+ """根据源和目标节点id检查边是否存在"""
447
+ SQL = SQL_TEMPLATES["has_edge"].format(
448
+ workspace=self.db.workspace,
449
+ source_node_id=source_node_id,
450
+ target_node_id=target_node_id,
451
+ )
452
+ # print(SQL)
453
+ res = await self.db.query(SQL)
454
+ if res:
455
+ # print("Edge exist!",res)
456
+ return True
457
+ else:
458
+ # print("Edge not exist!")
459
+ return False
460
+
461
+ async def node_degree(self, node_id: str) -> int:
462
+ """根据节点id获取节点的度"""
463
+ SQL = SQL_TEMPLATES["node_degree"].format(
464
+ workspace=self.db.workspace, node_id=node_id
465
+ )
466
+ # print(SQL)
467
+ res = await self.db.query(SQL)
468
+ if res:
469
+ # print("Node degree",res["degree"])
470
+ return res["degree"]
471
+ else:
472
+ # print("Edge not exist!")
473
+ return 0
474
+
475
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
476
+ """根据源和目标节点id获取边的度"""
477
+ degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
478
+ # print("Edge degree",degree)
479
+ return degree
480
+
481
+ async def get_node(self, node_id: str) -> Union[dict, None]:
482
+ """根据节点id获取节点数据"""
483
+ SQL = SQL_TEMPLATES["get_node"].format(
484
+ workspace=self.db.workspace, node_id=node_id
485
+ )
486
+ # print(self.db.workspace, node_id)
487
+ # print(SQL)
488
+ res = await self.db.query(SQL)
489
+ if res:
490
+ # print("Get node!",self.db.workspace, node_id,res)
491
+ return res
492
+ else:
493
+ # print("Can't get node!",self.db.workspace, node_id)
494
+ return None
495
+
496
+ async def get_edge(
497
+ self, source_node_id: str, target_node_id: str
498
+ ) -> Union[dict, None]:
499
+ """根据源和目标节点id获取边"""
500
+ SQL = SQL_TEMPLATES["get_edge"].format(
501
+ workspace=self.db.workspace,
502
+ source_node_id=source_node_id,
503
+ target_node_id=target_node_id,
504
+ )
505
+ res = await self.db.query(SQL)
506
+ if res:
507
+ # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
508
+ return res
509
+ else:
510
+ # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
511
+ return None
512
+
513
+ async def get_node_edges(self, source_node_id: str):
514
+ """根据节点id获取节点的所有边"""
515
+ if await self.has_node(source_node_id):
516
+ SQL = SQL_TEMPLATES["get_node_edges"].format(
517
+ workspace=self.db.workspace, source_node_id=source_node_id
518
+ )
519
+ res = await self.db.query(sql=SQL, multirows=True)
520
+ if res:
521
+ data = [(i["source_name"], i["target_name"]) for i in res]
522
+ # print("Get node edge!",self.db.workspace, source_node_id,data)
523
+ return data
524
+ else:
525
+ # print("Node Edge not exist!",self.db.workspace, source_node_id)
526
+ return []
527
+
528
+
529
+ N_T = {
530
+ "full_docs": "LIGHTRAG_DOC_FULL",
531
+ "text_chunks": "LIGHTRAG_DOC_CHUNKS",
532
+ "chunks": "LIGHTRAG_DOC_CHUNKS",
533
+ "entities": "LIGHTRAG_GRAPH_NODES",
534
+ "relationships": "LIGHTRAG_GRAPH_EDGES",
535
+ }
536
+
537
+ TABLES = {
538
+ "LIGHTRAG_DOC_FULL": {
539
+ "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
540
+ id varchar(256)PRIMARY KEY,
541
+ workspace varchar(1024),
542
+ doc_name varchar(1024),
543
+ content CLOB,
544
+ meta JSON,
545
+ createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
546
+ updatetime TIMESTAMP DEFAULT NULL
547
+ )"""
548
+ },
549
+ "LIGHTRAG_DOC_CHUNKS": {
550
+ "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
551
+ id varchar(256) PRIMARY KEY,
552
+ workspace varchar(1024),
553
+ full_doc_id varchar(256),
554
+ chunk_order_index NUMBER,
555
+ tokens NUMBER,
556
+ content CLOB,
557
+ content_vector VECTOR,
558
+ createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
559
+ updatetime TIMESTAMP DEFAULT NULL
560
+ )"""
561
+ },
562
+ "LIGHTRAG_GRAPH_NODES": {
563
+ "ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES (
564
+ id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
565
+ workspace varchar(1024),
566
+ name varchar(2048),
567
+ entity_type varchar(1024),
568
+ description CLOB,
569
+ source_chunk_id varchar(256),
570
+ content CLOB,
571
+ content_vector VECTOR,
572
+ createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
573
+ updatetime TIMESTAMP DEFAULT NULL
574
+ )"""
575
+ },
576
+ "LIGHTRAG_GRAPH_EDGES": {
577
+ "ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES (
578
+ id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
579
+ workspace varchar(1024),
580
+ source_name varchar(2048),
581
+ target_name varchar(2048),
582
+ weight NUMBER,
583
+ keywords CLOB,
584
+ description CLOB,
585
+ source_chunk_id varchar(256),
586
+ content CLOB,
587
+ content_vector VECTOR,
588
+ createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
589
+ updatetime TIMESTAMP DEFAULT NULL
590
+ )"""
591
+ },
592
+ "LIGHTRAG_LLM_CACHE": {
593
+ "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
594
+ id varchar(256) PRIMARY KEY,
595
+ send clob,
596
+ return clob,
597
+ model varchar(1024),
598
+ createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
599
+ updatetime TIMESTAMP DEFAULT NULL
600
+ )"""
601
+ },
602
+ "LIGHTRAG_GRAPH": {
603
+ "ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
604
+ VERTEX TABLES (
605
+ lightrag_graph_nodes KEY (id)
606
+ LABEL entity
607
+ PROPERTIES (id,workspace,name) -- ,entity_type,description,source_chunk_id)
608
+ )
609
+ EDGE TABLES (
610
+ lightrag_graph_edges KEY (id)
611
+ SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
612
+ DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
613
+ LABEL has_relation
614
+ PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
615
+ ) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""
616
+ },
617
+ }
618
+
619
+
620
+ SQL_TEMPLATES = {
621
+ # SQL for KVStorage
622
+ "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
623
+ "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
624
+ "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
625
+ "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
626
+ "filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
627
+ "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
628
+ USING DUAL
629
+ ON (a.id = '{check_id}')
630
+ WHEN NOT MATCHED THEN
631
+ INSERT(id,content,workspace) values(:1,:2,:3)
632
+ """,
633
+ "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
634
+ USING DUAL
635
+ ON (a.id = '{check_id}')
636
+ WHEN NOT MATCHED THEN
637
+ INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
638
+ values (:1,:2,:3,:4,:5,:6,:7) """,
639
+ # SQL for VectorStorage
640
+ "entities": """SELECT name as entity_name FROM
641
+ (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
642
+ FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
643
+ WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
644
+ "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
645
+ (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
646
+ FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
647
+ WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
648
+ "chunks": """SELECT id FROM
649
+ (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
650
+ FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
651
+ WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
652
+ # SQL for GraphStorage
653
+ "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
654
+ MATCH (a)
655
+ WHERE a.workspace='{workspace}' AND a.name='{node_id}'
656
+ COLUMNS (a.name))""",
657
+ "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
658
+ MATCH (a) -[e]-> (b)
659
+ WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
660
+ AND a.name='{source_node_id}' AND b.name='{target_node_id}'
661
+ COLUMNS (e.source_name,e.target_name) )""",
662
+ "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
663
+ MATCH (a)-[e]->(b)
664
+ WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
665
+ AND a.name='{node_id}' or b.name = '{node_id}'
666
+ COLUMNS (a.name))""",
667
+ "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
668
+ FROM GRAPH_TABLE (lightrag_graph
669
+ MATCH (a)
670
+ WHERE a.workspace='{workspace}' AND a.name='{node_id}'
671
+ COLUMNS (a.name)
672
+ ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
673
+ WHERE t2.workspace='{workspace}'""",
674
+ "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
675
+ NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
676
+ FROM GRAPH_TABLE (lightrag_graph
677
+ MATCH (a)-[e]->(b)
678
+ WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
679
+ AND a.name='{source_node_id}' and b.name = '{target_node_id}'
680
+ COLUMNS (e.id,a.name as source_id)
681
+ ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
682
+ "get_node_edges": """SELECT source_name,target_name
683
+ FROM GRAPH_TABLE (lightrag_graph
684
+ MATCH (a)-[e]->(b)
685
+ WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
686
+ AND a.name='{source_node_id}'
687
+ COLUMNS (a.name as source_name,b.name as target_name))""",
688
+ "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
689
+ USING DUAL
690
+ ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
691
+ WHEN NOT MATCHED THEN
692
+ INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
693
+ values (:1,:2,:3,:4,:5,:6,:7) """,
694
+ "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
695
+ USING DUAL
696
+ ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
697
+ WHEN NOT MATCHED THEN
698
+ INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
699
+ values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
700
+ }
lightrag/lightrag.py CHANGED
@@ -18,20 +18,6 @@ from .operate import (
18
  naive_query,
19
  )
20
 
21
- from .storage import (
22
- JsonKVStorage,
23
- NanoVectorDBStorage,
24
- NetworkXStorage,
25
- )
26
-
27
- from .kg.neo4j_impl import Neo4JStorage
28
- # future KG integrations
29
-
30
- # from .kg.ArangoDB_impl import (
31
- # GraphStorage as ArangoDBStorage
32
- # )
33
-
34
-
35
  from .utils import (
36
  EmbeddingFunc,
37
  compute_mdhash_id,
@@ -48,6 +34,22 @@ from .base import (
48
  QueryParam,
49
  )
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
53
  try:
@@ -67,7 +69,9 @@ class LightRAG:
67
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
68
  )
69
 
70
- kg: str = field(default="NetworkXStorage")
 
 
71
 
72
  current_log_level = logger.level
73
  log_level: str = field(default=current_log_level)
@@ -107,9 +111,8 @@ class LightRAG:
107
  llm_model_kwargs: dict = field(default_factory=dict)
108
 
109
  # storage
110
- key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
111
- vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
112
  vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
 
113
  enable_llm_cache: bool = True
114
 
115
  # extension
@@ -127,37 +130,57 @@ class LightRAG:
127
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
128
 
129
  # @TODO: should move all storage setup here to leverage initial start params attached to self.
 
 
 
 
 
 
 
130
  self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
131
- self.kg
132
  ]
133
 
134
  if not os.path.exists(self.working_dir):
135
  logger.info(f"Creating working directory {self.working_dir}")
136
  os.makedirs(self.working_dir)
137
 
138
- self.full_docs = self.key_string_value_json_storage_cls(
139
- namespace="full_docs", global_config=asdict(self)
140
- )
141
-
142
- self.text_chunks = self.key_string_value_json_storage_cls(
143
- namespace="text_chunks", global_config=asdict(self)
144
- )
145
-
146
  self.llm_response_cache = (
147
  self.key_string_value_json_storage_cls(
148
- namespace="llm_response_cache", global_config=asdict(self)
 
 
149
  )
150
  if self.enable_llm_cache
151
  else None
152
  )
153
- self.chunk_entity_relation_graph = self.graph_storage_cls(
154
- namespace="chunk_entity_relation", global_config=asdict(self)
155
- )
156
 
157
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
158
  self.embedding_func
159
  )
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  self.entities_vdb = self.vector_db_storage_cls(
162
  namespace="entities",
163
  global_config=asdict(self),
@@ -186,8 +209,17 @@ class LightRAG:
186
 
187
  def _get_storage_class(self) -> Type[BaseGraphStorage]:
188
  return {
189
- "Neo4JStorage": Neo4JStorage,
 
 
 
 
 
 
190
  "NetworkXStorage": NetworkXStorage,
 
 
 
191
  }
192
 
193
  def insert(self, string_or_strings):
 
18
  naive_query,
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  from .utils import (
22
  EmbeddingFunc,
23
  compute_mdhash_id,
 
34
  QueryParam,
35
  )
36
 
37
+ from .storage import (
38
+ JsonKVStorage,
39
+ NanoVectorDBStorage,
40
+ NetworkXStorage,
41
+ )
42
+
43
+ from .kg.neo4j_impl import Neo4JStorage
44
+
45
+ from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
46
+
47
+ # future KG integrations
48
+
49
+ # from .kg.ArangoDB_impl import (
50
+ # GraphStorage as ArangoDBStorage
51
+ # )
52
+
53
 
54
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
55
  try:
 
69
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
70
  )
71
 
72
+ kv_storage: str = field(default="JsonKVStorage")
73
+ vector_storage: str = field(default="NanoVectorDBStorage")
74
+ graph_storage: str = field(default="NetworkXStorage")
75
 
76
  current_log_level = logger.level
77
  log_level: str = field(default=current_log_level)
 
111
  llm_model_kwargs: dict = field(default_factory=dict)
112
 
113
  # storage
 
 
114
  vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
115
+
116
  enable_llm_cache: bool = True
117
 
118
  # extension
 
130
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
131
 
132
  # @TODO: should move all storage setup here to leverage initial start params attached to self.
133
+
134
+ self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
135
+ self._get_storage_class()[self.kv_storage]
136
+ )
137
+ self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
138
+ self.vector_storage
139
+ ]
140
  self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
141
+ self.graph_storage
142
  ]
143
 
144
  if not os.path.exists(self.working_dir):
145
  logger.info(f"Creating working directory {self.working_dir}")
146
  os.makedirs(self.working_dir)
147
 
 
 
 
 
 
 
 
 
148
  self.llm_response_cache = (
149
  self.key_string_value_json_storage_cls(
150
+ namespace="llm_response_cache",
151
+ global_config=asdict(self),
152
+ embedding_func=None,
153
  )
154
  if self.enable_llm_cache
155
  else None
156
  )
 
 
 
157
 
158
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
159
  self.embedding_func
160
  )
161
 
162
+ ####
163
+ # add embedding func by walter
164
+ ####
165
+ self.full_docs = self.key_string_value_json_storage_cls(
166
+ namespace="full_docs",
167
+ global_config=asdict(self),
168
+ embedding_func=self.embedding_func,
169
+ )
170
+ self.text_chunks = self.key_string_value_json_storage_cls(
171
+ namespace="text_chunks",
172
+ global_config=asdict(self),
173
+ embedding_func=self.embedding_func,
174
+ )
175
+ self.chunk_entity_relation_graph = self.graph_storage_cls(
176
+ namespace="chunk_entity_relation",
177
+ global_config=asdict(self),
178
+ embedding_func=self.embedding_func,
179
+ )
180
+ ####
181
+ # add embedding func by walter over
182
+ ####
183
+
184
  self.entities_vdb = self.vector_db_storage_cls(
185
  namespace="entities",
186
  global_config=asdict(self),
 
209
 
210
  def _get_storage_class(self) -> Type[BaseGraphStorage]:
211
  return {
212
+ # kv storage
213
+ "JsonKVStorage": JsonKVStorage,
214
+ "OracleKVStorage": OracleKVStorage,
215
+ # vector storage
216
+ "NanoVectorDBStorage": NanoVectorDBStorage,
217
+ "OracleVectorDBStorage": OracleVectorDBStorage,
218
+ # graph storage
219
  "NetworkXStorage": NetworkXStorage,
220
+ "Neo4JStorage": Neo4JStorage,
221
+ "OracleGraphStorage": OracleGraphStorage,
222
+ # "ArangoDBStorage": ArangoDBStorage
223
  }
224
 
225
  def insert(self, string_or_strings):
lightrag/operate.py CHANGED
@@ -16,6 +16,7 @@ from .utils import (
16
  split_string_by_multi_markers,
17
  truncate_list_by_token_size,
18
  process_combine_contexts,
 
19
  )
20
  from .base import (
21
  BaseGraphStorage,
@@ -403,9 +404,10 @@ async def local_query(
403
  kw_prompt_temp = PROMPTS["keywords_extraction"]
404
  kw_prompt = kw_prompt_temp.format(query=query)
405
  result = await use_model_func(kw_prompt)
 
406
 
407
  try:
408
- keywords_data = json.loads(result)
409
  keywords = keywords_data.get("low_level_keywords", [])
410
  keywords = ", ".join(keywords)
411
  except json.JSONDecodeError:
@@ -669,9 +671,10 @@ async def global_query(
669
  kw_prompt_temp = PROMPTS["keywords_extraction"]
670
  kw_prompt = kw_prompt_temp.format(query=query)
671
  result = await use_model_func(kw_prompt)
 
672
 
673
  try:
674
- keywords_data = json.loads(result)
675
  keywords = keywords_data.get("high_level_keywords", [])
676
  keywords = ", ".join(keywords)
677
  except json.JSONDecodeError:
@@ -910,8 +913,9 @@ async def hybrid_query(
910
  kw_prompt = kw_prompt_temp.format(query=query)
911
 
912
  result = await use_model_func(kw_prompt)
 
913
  try:
914
- keywords_data = json.loads(result)
915
  hl_keywords = keywords_data.get("high_level_keywords", [])
916
  ll_keywords = keywords_data.get("low_level_keywords", [])
917
  hl_keywords = ", ".join(hl_keywords)
 
16
  split_string_by_multi_markers,
17
  truncate_list_by_token_size,
18
  process_combine_contexts,
19
+ locate_json_string_body_from_string,
20
  )
21
  from .base import (
22
  BaseGraphStorage,
 
404
  kw_prompt_temp = PROMPTS["keywords_extraction"]
405
  kw_prompt = kw_prompt_temp.format(query=query)
406
  result = await use_model_func(kw_prompt)
407
+ json_text = locate_json_string_body_from_string(result)
408
 
409
  try:
410
+ keywords_data = json.loads(json_text)
411
  keywords = keywords_data.get("low_level_keywords", [])
412
  keywords = ", ".join(keywords)
413
  except json.JSONDecodeError:
 
671
  kw_prompt_temp = PROMPTS["keywords_extraction"]
672
  kw_prompt = kw_prompt_temp.format(query=query)
673
  result = await use_model_func(kw_prompt)
674
+ json_text = locate_json_string_body_from_string(result)
675
 
676
  try:
677
+ keywords_data = json.loads(json_text)
678
  keywords = keywords_data.get("high_level_keywords", [])
679
  keywords = ", ".join(keywords)
680
  except json.JSONDecodeError:
 
913
  kw_prompt = kw_prompt_temp.format(query=query)
914
 
915
  result = await use_model_func(kw_prompt)
916
+ json_text = locate_json_string_body_from_string(result)
917
  try:
918
+ keywords_data = json.loads(json_text)
919
  hl_keywords = keywords_data.get("high_level_keywords", [])
920
  ll_keywords = keywords_data.get("low_level_keywords", [])
921
  hl_keywords = ", ".join(hl_keywords)
lightrag/prompt.py CHANGED
@@ -14,7 +14,7 @@ Given a text document that is potentially relevant to this activity and a list o
14
 
15
  -Steps-
16
  1. Identify all entities. For each identified entity, extract the following information:
17
- - entity_name: Name of the entity, capitalized
18
  - entity_type: One of the following types: [{entity_types}]
19
  - entity_description: Comprehensive description of the entity's attributes and activities
20
  Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>
@@ -222,14 +222,24 @@ Output:
222
 
223
  """
224
 
225
- PROMPTS["naive_rag_response"] = """You're a helpful assistant
226
- Below are the knowledge you know:
227
- {content_data}
228
- ---
229
- If you don't know the answer or if the provided knowledge do not contain sufficient information to provide an answer, just say so. Do not make anything up.
 
 
230
  Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
231
  If you don't know the answer, just say so. Do not make anything up.
232
  Do not include information where the supporting evidence for it is not provided.
 
233
  ---Target response length and format---
 
234
  {response_type}
 
 
 
 
 
 
235
  """
 
14
 
15
  -Steps-
16
  1. Identify all entities. For each identified entity, extract the following information:
17
+ - entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
18
  - entity_type: One of the following types: [{entity_types}]
19
  - entity_description: Comprehensive description of the entity's attributes and activities
20
  Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>
 
222
 
223
  """
224
 
225
+ PROMPTS["naive_rag_response"] = """---Role---
226
+
227
+ You are a helpful assistant responding to questions about documents provided.
228
+
229
+
230
+ ---Goal---
231
+
232
  Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
233
  If you don't know the answer, just say so. Do not make anything up.
234
  Do not include information where the supporting evidence for it is not provided.
235
+
236
  ---Target response length and format---
237
+
238
  {response_type}
239
+
240
+ ---Documents---
241
+
242
+ {content_data}
243
+
244
+ Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
245
  """
requirements.txt CHANGED
@@ -1,6 +1,8 @@
1
  accelerate
2
  aioboto3
3
  aiohttp
 
 
4
  graspologic
5
  hnswlib
6
  nano-vectordb
@@ -8,10 +10,13 @@ neo4j
8
  networkx
9
  ollama
10
  openai
 
11
  pyvis
12
  tenacity
 
 
 
13
  tiktoken
14
  torch
15
  transformers
16
  xxhash
17
- # lmdeploy[all]
 
1
  accelerate
2
  aioboto3
3
  aiohttp
4
+
5
+ # database packages
6
  graspologic
7
  hnswlib
8
  nano-vectordb
 
10
  networkx
11
  ollama
12
  openai
13
+ oracledb
14
  pyvis
15
  tenacity
16
+ # lmdeploy[all]
17
+
18
+ # LLM packages
19
  tiktoken
20
  torch
21
  transformers
22
  xxhash
 
test.py CHANGED
@@ -18,7 +18,7 @@ rag = LightRAG(
18
  # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
19
  )
20
 
21
- with open("./book.txt") as f:
22
  rag.insert(f.read())
23
 
24
  # Perform naive search
 
18
  # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
19
  )
20
 
21
+ with open("./dickens/book.txt", "r", encoding="utf-8") as f:
22
  rag.insert(f.read())
23
 
24
  # Perform naive search