jin commited on
Commit
8d66028
·
1 Parent(s): 58f78e8

Oracle Database support

Browse files

Add oracle 23ai database as the KV/vector/graph storage

examples/lightrag_oracle_demo.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import sys, os
4
+ print(os.getcwd())
5
+ from pathlib import Path
6
+ script_directory = Path(__file__).resolve().parent.parent
7
+ sys.path.append(os.path.abspath(script_directory))
8
+
9
+ import asyncio
10
+ from lightrag import LightRAG, QueryParam
11
+ from lightrag.llm import openai_complete_if_cache, openai_embedding
12
+ from lightrag.utils import EmbeddingFunc
13
+ import numpy as np
14
+ from datetime import datetime
15
+
16
+ from lightrag.kg.oracle_impl import OracleDB
17
+
18
+
19
+ WORKING_DIR = "./dickens"
20
+
21
+ # We use OpenAI compatible API to call LLM on Oracle Cloud
22
+ # More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
23
+ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
24
+ APIKEY = "ocigenerativeai"
25
+ CHATMODEL = "cohere.command-r-plus"
26
+ EMBEDMODEL = "cohere.embed-multilingual-v3.0"
27
+
28
+
29
+ if not os.path.exists(WORKING_DIR):
30
+ os.mkdir(WORKING_DIR)
31
+
32
+ async def llm_model_func(
33
+ prompt, system_prompt=None, history_messages=[], **kwargs
34
+ ) -> str:
35
+ return await openai_complete_if_cache(
36
+ CHATMODEL,
37
+ prompt,
38
+ system_prompt=system_prompt,
39
+ history_messages=history_messages,
40
+ api_key=APIKEY,
41
+ base_url=BASE_URL,
42
+ **kwargs,
43
+ )
44
+
45
+
46
+ async def embedding_func(texts: list[str]) -> np.ndarray:
47
+ return await openai_embedding(
48
+ texts,
49
+ model=EMBEDMODEL,
50
+ api_key=APIKEY,
51
+ base_url=BASE_URL,
52
+ )
53
+
54
+
55
+ async def get_embedding_dim():
56
+ test_text = ["This is a test sentence."]
57
+ embedding = await embedding_func(test_text)
58
+ embedding_dim = embedding.shape[1]
59
+ return embedding_dim
60
+
61
+
62
+ async def main():
63
+ try:
64
+ # Detect embedding dimension
65
+ embedding_dimension = await get_embedding_dim()
66
+ print(f"Detected embedding dimension: {embedding_dimension}")
67
+
68
+ # Create Oracle DB connection
69
+ # The `config` parameter is the connection configuration of Oracle DB
70
+ # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
71
+ # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
72
+ # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
73
+ oracle_db = OracleDB(config={
74
+ "user":"RAG",
75
+ "password":"xxxxxxxxx",
76
+ "dsn":"xxxxxxx_medium",
77
+ "config_dir":"dir/path/to/oracle/config",
78
+ "wallet_location":"dir/path/to/oracle/wallet",
79
+ "wallet_password":"xxxxxxxxx",
80
+ "workspace":"company" # specify which docs we want to store and query
81
+ }
82
+ )
83
+
84
+
85
+ # Check if Oracle DB tables exist, if not, tables will be created
86
+ await oracle_db.check_tables()
87
+
88
+
89
+ # Initialize LightRAG
90
+ # We use Oracle DB as the KV/vector/graph storage
91
+ rag = LightRAG(
92
+ enable_llm_cache=False,
93
+ working_dir=WORKING_DIR,
94
+ chunk_token_size=512,
95
+ llm_model_func=llm_model_func,
96
+ embedding_func=EmbeddingFunc(
97
+ embedding_dim=embedding_dimension,
98
+ max_token_size=512,
99
+ func=embedding_func,
100
+ ),
101
+ graph_storage = "OracleGraphStorage",
102
+ kv_storage="OracleKVStorage",
103
+ vector_storage="OracleVectorDBStorage"
104
+ )
105
+
106
+ # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
107
+ rag.graph_storage_cls.db = oracle_db
108
+ rag.key_string_value_json_storage_cls.db = oracle_db
109
+ rag.vector_db_storage_cls.db = oracle_db
110
+
111
+ # Extract and Insert into LightRAG storage
112
+ with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
113
+ await rag.ainsert(f.read())
114
+
115
+ # Perform search in different modes
116
+ modes = ["naive", "local", "global", "hybrid"]
117
+ for mode in modes:
118
+ print("="*20, mode, "="*20)
119
+ print(await rag.aquery("这个文章讲了什么?", param=QueryParam(mode=mode)))
120
+ print("-"*100, "\n")
121
+
122
+ except Exception as e:
123
+ print(f"An error occurred: {e}")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ asyncio.run(main())
lightrag/base.py CHANGED
@@ -59,6 +59,7 @@ class BaseVectorStorage(StorageNameSpace):
59
 
60
  @dataclass
61
  class BaseKVStorage(Generic[T], StorageNameSpace):
 
62
  async def all_keys(self) -> list[str]:
63
  raise NotImplementedError
64
 
@@ -83,6 +84,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
83
 
84
  @dataclass
85
  class BaseGraphStorage(StorageNameSpace):
 
86
  async def has_node(self, node_id: str) -> bool:
87
  raise NotImplementedError
88
 
 
59
 
60
  @dataclass
61
  class BaseKVStorage(Generic[T], StorageNameSpace):
62
+ embedding_func: EmbeddingFunc
63
  async def all_keys(self) -> list[str]:
64
  raise NotImplementedError
65
 
 
84
 
85
  @dataclass
86
  class BaseGraphStorage(StorageNameSpace):
87
+ embedding_func: EmbeddingFunc
88
  async def has_node(self, node_id: str) -> bool:
89
  raise NotImplementedError
90
 
lightrag/kg/oracle_impl.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ #import html
3
+ #import os
4
+ from dataclasses import dataclass
5
+ from typing import Any, Union, cast
6
+ import networkx as nx
7
+ import numpy as np
8
+ import array
9
+
10
+ from ..utils import logger
11
+ from ..base import (
12
+ BaseGraphStorage,
13
+ BaseKVStorage,
14
+ BaseVectorStorage,
15
+ )
16
+
17
+ import oracledb
18
+
19
+ class OracleDB:
20
+ def __init__(self,config,**kwargs):
21
+ self.host = config.get("host", None)
22
+ self.port = config.get("port", None)
23
+ self.user = config.get("user", None)
24
+ self.password = config.get("password", None)
25
+ self.dsn = config.get("dsn", None)
26
+ self.config_dir = config.get("config_dir", None)
27
+ self.wallet_location = config.get("wallet_location", None)
28
+ self.wallet_password = config.get("wallet_password", None)
29
+ self.workspace = config.get("workspace", None)
30
+ self.max = 12
31
+ self.increment = 1
32
+ logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
33
+ if self.user is None or self.password is None:
34
+ raise ValueError("Missing database user or password in addon_params")
35
+
36
+ try:
37
+ oracledb.defaults.fetch_lobs = False
38
+
39
+ self.pool = oracledb.create_pool_async(
40
+ user = self.user,
41
+ password = self.password,
42
+ dsn = self.dsn,
43
+ config_dir = self.config_dir,
44
+ wallet_location = self.wallet_location,
45
+ wallet_password = self.wallet_password,
46
+ min = 1,
47
+ max = self.max,
48
+ increment = self.increment
49
+ )
50
+ logger.info(f"Connected to Oracle database at {self.dsn}")
51
+ except Exception as e:
52
+ logger.error(f"Failed to connect to Oracle database at {self.dsn}")
53
+ logger.error(f"Oracle database error: {e}")
54
+ raise
55
+
56
+ def numpy_converter_in(self, value):
57
+ """Convert numpy array to array.array"""
58
+ if value.dtype == np.float64:
59
+ dtype = "d"
60
+ elif value.dtype == np.float32:
61
+ dtype = "f"
62
+ else:
63
+ dtype = "b"
64
+ return array.array(dtype, value)
65
+
66
+ def input_type_handler(self, cursor, value, arraysize):
67
+ """Set the type handler for the input data"""
68
+ if isinstance(value, np.ndarray):
69
+ return cursor.var(
70
+ oracledb.DB_TYPE_VECTOR,
71
+ arraysize=arraysize,
72
+ inconverter=self.numpy_converter_in,
73
+ )
74
+
75
+ def numpy_converter_out(self, value):
76
+ """Convert array.array to numpy array"""
77
+ if value.typecode == "b":
78
+ dtype = np.int8
79
+ elif value.typecode == "f":
80
+ dtype = np.float32
81
+ else:
82
+ dtype = np.float64
83
+ return np.array(value, copy=False, dtype=dtype)
84
+
85
+ def output_type_handler(self, cursor, metadata):
86
+ """Set the type handler for the output data"""
87
+ if metadata.type_code is oracledb.DB_TYPE_VECTOR:
88
+ return cursor.var(
89
+ metadata.type_code,
90
+ arraysize=cursor.arraysize,
91
+ outconverter=self.numpy_converter_out,
92
+ )
93
+
94
+ async def check_tables(self):
95
+ for k,v in TABLES.items():
96
+ try:
97
+ if k.lower() == "lightrag_graph":
98
+ await self.query("SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only")
99
+ else:
100
+ await self.query("SELECT 1 FROM {k}".format(k=k))
101
+ except Exception as e:
102
+ logger.error(f"Failed to check table {k} in Oracle database")
103
+ logger.error(f"Oracle database error: {e}")
104
+ try:
105
+ # print(v["ddl"])
106
+ await self.execute(v["ddl"])
107
+ logger.info(f"Created table {k} in Oracle database")
108
+ except Exception as e:
109
+ logger.error(f"Failed to create table {k} in Oracle database")
110
+ logger.error(f"Oracle database error: {e}")
111
+
112
+ logger.info(f"Finished check all tables in Oracle database")
113
+
114
+
115
+ async def query(self,sql: str, multirows: bool = False) -> Union[dict, None]:
116
+ async with self.pool.acquire() as connection:
117
+ connection.inputtypehandler = self.input_type_handler
118
+ connection.outputtypehandler = self.output_type_handler
119
+ with connection.cursor() as cursor:
120
+ try:
121
+ await cursor.execute(sql)
122
+ except Exception as e:
123
+ logger.error(f"Oracle database error: {e}")
124
+ print(sql)
125
+ raise
126
+ columns = [column[0].lower() for column in cursor.description]
127
+ if multirows:
128
+ rows = await cursor.fetchall()
129
+ if rows:
130
+ data = [dict(zip(columns, row)) for row in rows]
131
+ else:
132
+ data = []
133
+ else:
134
+ row = await cursor.fetchone()
135
+ if row:
136
+ data = dict(zip(columns, row))
137
+ else:
138
+ data = None
139
+ return data
140
+
141
+ async def execute(self,sql: str, data: list = None):
142
+ # logger.info("go into OracleDB execute method")
143
+ try:
144
+ async with self.pool.acquire() as connection:
145
+ connection.inputtypehandler = self.input_type_handler
146
+ connection.outputtypehandler = self.output_type_handler
147
+ with connection.cursor() as cursor:
148
+ if data is None:
149
+ await cursor.execute(sql)
150
+ else:
151
+ #print(data)
152
+ #print(sql)
153
+ await cursor.execute(sql,data)
154
+ await connection.commit()
155
+ except Exception as e:
156
+ logger.error(f"Oracle database error: {e}")
157
+ print(sql)
158
+ print(data)
159
+ raise
160
+
161
+ @dataclass
162
+ class OracleKVStorage(BaseKVStorage):
163
+
164
+ # should pass db object to self.db
165
+ def __post_init__(self):
166
+ self._data = {}
167
+ self._max_batch_size = self.global_config["embedding_batch_num"]
168
+
169
+ ################ QUERY METHODS ################
170
+
171
+ async def get_by_id(self, id: str) -> Union[dict, None]:
172
+ """根据 id 获取 doc_full 数据."""
173
+ SQL = SQL_TEMPLATES["get_by_id_"+self.namespace].format(workspace=self.db.workspace,id=id)
174
+ #print("get_by_id:"+SQL)
175
+ res = await self.db.query(SQL)
176
+ if res:
177
+ data = res #{"data":res}
178
+ #print (data)
179
+ return data
180
+ else:
181
+ return None
182
+
183
+ # Query by id
184
+ async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict],None]:
185
+ """根据 id 获取 doc_chunks 数据"""
186
+ SQL = SQL_TEMPLATES["get_by_ids_"+self.namespace].format(workspace=self.db.workspace,
187
+ ids=",".join([f"'{id}'" for id in ids]))
188
+ #print("get_by_ids:"+SQL)
189
+ res = await self.db.query(SQL,multirows=True)
190
+ if res:
191
+ data = res # [{"data":i} for i in res]
192
+ #print(data)
193
+ return data
194
+ else:
195
+ return None
196
+
197
+ async def filter_keys(self, keys: list[str]) -> set[str]:
198
+ """过滤掉重复内容"""
199
+ SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
200
+ workspace=self.db.workspace,
201
+ ids=",".join([f"'{k}'" for k in keys]))
202
+ res = await self.db.query(SQL,multirows=True)
203
+ data = None
204
+ if res:
205
+ exist_keys = [key["id"] for key in res]
206
+ data = set([s for s in keys if s not in exist_keys])
207
+ else:
208
+ exist_keys = []
209
+ data = set([s for s in keys if s not in exist_keys])
210
+ return data
211
+
212
+
213
+ ################ INSERT METHODS ################
214
+ async def upsert(self, data: dict[str, dict]):
215
+ left_data = {k: v for k, v in data.items() if k not in self._data}
216
+ self._data.update(left_data)
217
+ #print(self._data)
218
+ #values = []
219
+ if self.namespace == "text_chunks":
220
+ list_data = [
221
+ {
222
+ "__id__": k,
223
+ **{k1: v1 for k1, v1 in v.items()},
224
+ }
225
+ for k, v in data.items()
226
+ ]
227
+ contents = [v["content"] for v in data.values()]
228
+ batches = [
229
+ contents[i: i + self._max_batch_size]
230
+ for i in range(0, len(contents), self._max_batch_size)
231
+ ]
232
+ embeddings_list = await asyncio.gather(
233
+ *[self.embedding_func(batch) for batch in batches]
234
+ )
235
+ embeddings = np.concatenate(embeddings_list)
236
+ for i, d in enumerate(list_data):
237
+ d["__vector__"] = embeddings[i]
238
+ #print(list_data)
239
+ for item in list_data:
240
+ merge_sql = SQL_TEMPLATES["merge_chunk"].format(
241
+ check_id=item["__id__"]
242
+ )
243
+
244
+ values = [item["__id__"], item["content"], self.db.workspace, item["tokens"],
245
+ item["chunk_order_index"], item["full_doc_id"], item["__vector__"]]
246
+ #print(merge_sql)
247
+ await self.db.execute(merge_sql, values)
248
+
249
+ if self.namespace == "full_docs":
250
+ for k, v in self._data.items():
251
+ #values.clear()
252
+ merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
253
+ check_id=k,
254
+ )
255
+ values = [k, self._data[k]["content"], self.db.workspace]
256
+ #print(merge_sql)
257
+ await self.db.execute(merge_sql, values)
258
+ return left_data
259
+
260
+
261
+ async def index_done_callback(self):
262
+ if self.namespace in ["full_docs", "text_chunks"]:
263
+ logger.info("full doc and chunk data had been saved into oracle db!")
264
+
265
+
266
+
267
+ @dataclass
268
+ class OracleVectorDBStorage(BaseVectorStorage):
269
+ cosine_better_than_threshold: float = 0.2
270
+
271
+ def __post_init__(self):
272
+ pass
273
+
274
+ async def upsert(self, data: dict[str, dict]):
275
+ """向向量数据库中插入数据"""
276
+ pass
277
+
278
+ async def index_done_callback(self):
279
+ pass
280
+
281
+
282
+ #################### query method ################
283
+ async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
284
+ """从向量数据库中查询数据"""
285
+ embeddings = await self.embedding_func([query])
286
+ embedding = embeddings[0]
287
+ # 转换精度
288
+ dtype = str(embedding.dtype).upper()
289
+ dimension = embedding.shape[0]
290
+ embedding_string = ', '.join(map(str, embedding.tolist()))
291
+
292
+ SQL = SQL_TEMPLATES[self.namespace].format(
293
+ embedding_string=embedding_string,
294
+ dimension=dimension,
295
+ dtype=dtype,
296
+ workspace=self.db.workspace,
297
+ top_k=top_k,
298
+ better_than_threshold=self.cosine_better_than_threshold,
299
+ )
300
+ # print(SQL)
301
+ results = await self.db.query(SQL, multirows=True)
302
+ #print("vector search result:",results)
303
+ return results
304
+
305
+
306
+ @dataclass
307
+ class OracleGraphStorage(BaseGraphStorage):
308
+ """基于Oracle的图存储模块"""
309
+ # @staticmethod
310
+ # def load_graph(file_name) -> nx.Graph:
311
+ # """读取graphhml图文件"""
312
+
313
+ # @staticmethod
314
+ # def write_graph(graph: nx.Graph, file_name):
315
+ # # """写入graphhml图文件"""
316
+
317
+ # @staticmethod
318
+ # def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
319
+ # """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
320
+ # Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
321
+ # 用于产生稳定的最大连通分量的模块,即相同的输入图==相同的输出lcc。
322
+ # """
323
+
324
+
325
+ # @staticmethod
326
+ # def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
327
+ # """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
328
+ # Ensure an undirected graph with the same relationships will always be read the same way.
329
+ # 确保具有相同关系的无向图始终以相同的方式读取。
330
+ # """
331
+
332
+ def __post_init__(self):
333
+ """从graphml文件加载图"""
334
+ self._max_batch_size = self.global_config["embedding_batch_num"]
335
+
336
+
337
+ #################### insert method ################
338
+
339
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
340
+ """插入或更新节点"""
341
+ #print("go into upsert node method")
342
+ entity_name = node_id
343
+ entity_type = node_data["entity_type"]
344
+ description = node_data["description"]
345
+ source_id = node_data["source_id"]
346
+ content = entity_name+description
347
+ contents = [content]
348
+ batches = [
349
+ contents[i: i + self._max_batch_size]
350
+ for i in range(0, len(contents), self._max_batch_size)
351
+ ]
352
+ embeddings_list = await asyncio.gather(
353
+ *[self.embedding_func(batch) for batch in batches]
354
+ )
355
+ embeddings = np.concatenate(embeddings_list)
356
+ content_vector = embeddings[0]
357
+ merge_sql = SQL_TEMPLATES["merge_node"].format(
358
+ workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id
359
+ )
360
+ #print(merge_sql)
361
+ await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector])
362
+ #self._graph.add_node(node_id, **node_data)
363
+
364
+ async def upsert_edge(
365
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
366
+ ):
367
+ """插入或更新边"""
368
+ #print("go into upsert edge method")
369
+ source_name = source_node_id
370
+ target_name = target_node_id
371
+ weight = edge_data["weight"]
372
+ keywords = edge_data["keywords"]
373
+ description = edge_data["description"]
374
+ source_chunk_id = edge_data["source_id"]
375
+ content = keywords+source_name+target_name+description
376
+ contents = [content]
377
+ batches = [
378
+ contents[i: i + self._max_batch_size]
379
+ for i in range(0, len(contents), self._max_batch_size)
380
+ ]
381
+ embeddings_list = await asyncio.gather(
382
+ *[self.embedding_func(batch) for batch in batches]
383
+ )
384
+ embeddings = np.concatenate(embeddings_list)
385
+ content_vector = embeddings[0]
386
+ merge_sql = SQL_TEMPLATES["merge_edge"].format(
387
+ workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id
388
+ )
389
+ #print(merge_sql)
390
+ await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector])
391
+ #self._graph.add_edge(source_node_id, target_node_id, **edge_data)
392
+
393
+ async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
394
+ """为节点生成向量"""
395
+ if algorithm not in self._node_embed_algorithms:
396
+ raise ValueError(f"Node embedding algorithm {algorithm} not supported")
397
+ return await self._node_embed_algorithms[algorithm]()
398
+
399
+ async def _node2vec_embed(self):
400
+ """为节点生成向量"""
401
+ from graspologic import embed
402
+
403
+ embeddings, nodes = embed.node2vec_embed(
404
+ self._graph,
405
+ **self.config["node2vec_params"],
406
+ )
407
+
408
+ nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
409
+ return embeddings, nodes_ids
410
+
411
+
412
+ async def index_done_callback(self):
413
+ """写入graphhml图文件"""
414
+ logger.info("Node and edge data had been saved into oracle db already, so nothing to do here!")
415
+
416
+ #################### query method ################
417
+ async def has_node(self, node_id: str) -> bool:
418
+ """根据节点id检查节点是否存在"""
419
+ SQL = SQL_TEMPLATES["has_node"].format(workspace=self.db.workspace, node_id=node_id)
420
+ # print(SQL)
421
+ #print(self.db.workspace, node_id)
422
+ res = await self.db.query(SQL)
423
+ if res:
424
+ #print("Node exist!",res)
425
+ return True
426
+ else:
427
+ #print("Node not exist!")
428
+ return False
429
+
430
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
431
+ """根据源和目标节点id检查边是否存在"""
432
+ SQL = SQL_TEMPLATES["has_edge"].format(workspace=self.db.workspace,
433
+ source_node_id=source_node_id,
434
+ target_node_id=target_node_id)
435
+ # print(SQL)
436
+ res = await self.db.query(SQL)
437
+ if res:
438
+ #print("Edge exist!",res)
439
+ return True
440
+ else:
441
+ #print("Edge not exist!")
442
+ return False
443
+
444
+ async def node_degree(self, node_id: str) -> int:
445
+ """根据节点id获取节点的度"""
446
+ SQL = SQL_TEMPLATES["node_degree"].format(workspace=self.db.workspace, node_id=node_id)
447
+ # print(SQL)
448
+ res = await self.db.query(SQL)
449
+ if res:
450
+ #print("Node degree",res["degree"])
451
+ return res["degree"]
452
+ else:
453
+ #print("Edge not exist!")
454
+ return 0
455
+
456
+
457
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
458
+ """根据源和目标节点id获取边的度"""
459
+ degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
460
+ #print("Edge degree",degree)
461
+ return degree
462
+
463
+
464
+ async def get_node(self, node_id: str) -> Union[dict, None]:
465
+ """根据节点id获取节点数据"""
466
+ SQL = SQL_TEMPLATES["get_node"].format(workspace=self.db.workspace, node_id=node_id)
467
+ # print(self.db.workspace, node_id)
468
+ # print(SQL)
469
+ res = await self.db.query(SQL)
470
+ if res:
471
+ #print("Get node!",self.db.workspace, node_id,res)
472
+ return res
473
+ else:
474
+ #print("Can't get node!",self.db.workspace, node_id)
475
+ return None
476
+
477
+ async def get_edge(
478
+ self, source_node_id: str, target_node_id: str
479
+ ) -> Union[dict, None]:
480
+ """根据源和目标节点id获取边"""
481
+ SQL = SQL_TEMPLATES["get_edge"].format(workspace=self.db.workspace,
482
+ source_node_id=source_node_id,
483
+ target_node_id=target_node_id)
484
+ res = await self.db.query(SQL)
485
+ if res:
486
+ #print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
487
+ return res
488
+ else:
489
+ #print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
490
+ return None
491
+
492
+ async def get_node_edges(self, source_node_id: str):
493
+ """根据节点id获取节点的所有边"""
494
+ if await self.has_node(source_node_id):
495
+ SQL = SQL_TEMPLATES["get_node_edges"].format(workspace=self.db.workspace,
496
+ source_node_id=source_node_id)
497
+ res = await self.db.query(sql=SQL, multirows=True)
498
+ if res:
499
+ data = [(i["source_name"],i["target_name"]) for i in res]
500
+ #print("Get node edge!",self.db.workspace, source_node_id,data)
501
+ return data
502
+ else:
503
+ #print("Node Edge not exist!",self.db.workspace, source_node_id)
504
+ return []
505
+
506
+ #################### INSERT method ################
507
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
508
+ """插入或更新节点"""
509
+ #print("go into upsert node method")
510
+ entity_name = node_id
511
+ entity_type = node_data["entity_type"]
512
+ description = node_data["description"]
513
+ source_id = node_data["source_id"]
514
+ content = entity_name+description
515
+ contents = [content]
516
+ batches = [
517
+ contents[i: i + self._max_batch_size]
518
+ for i in range(0, len(contents), self._max_batch_size)
519
+ ]
520
+ embeddings_list = await asyncio.gather(
521
+ *[self.embedding_func(batch) for batch in batches]
522
+ )
523
+ embeddings = np.concatenate(embeddings_list)
524
+ content_vector = embeddings[0]
525
+ merge_sql = SQL_TEMPLATES["merge_node"].format(
526
+ workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id
527
+ )
528
+ #print(merge_sql)
529
+ await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector])
530
+ #self._graph.add_node(node_id, **node_data)
531
+
532
+ async def upsert_edge(
533
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
534
+ ):
535
+ """插入或更新边"""
536
+ #print("go into upsert edge method")
537
+ source_name = source_node_id
538
+ target_name = target_node_id
539
+ weight = edge_data["weight"]
540
+ keywords = edge_data["keywords"]
541
+ description = edge_data["description"]
542
+ source_chunk_id = edge_data["source_id"]
543
+ content = keywords+source_name+target_name+description
544
+ contents = [content]
545
+ batches = [
546
+ contents[i: i + self._max_batch_size]
547
+ for i in range(0, len(contents), self._max_batch_size)
548
+ ]
549
+ embeddings_list = await asyncio.gather(
550
+ *[self.embedding_func(batch) for batch in batches]
551
+ )
552
+ embeddings = np.concatenate(embeddings_list)
553
+ content_vector = embeddings[0]
554
+ merge_sql = SQL_TEMPLATES["merge_edge"].format(
555
+ workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id
556
+ )
557
+ #print(merge_sql)
558
+ await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector])
559
+ #self._graph.add_edge(source_node_id, target_node_id, **edge_data)
560
+
561
+ async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
562
+ """为节点生成向量"""
563
+ if algorithm not in self._node_embed_algorithms:
564
+ raise ValueError(f"Node embedding algorithm {algorithm} not supported")
565
+ return await self._node_embed_algorithms[algorithm]()
566
+
567
+ async def _node2vec_embed(self):
568
+ """为节点生成向量"""
569
+ from graspologic import embed
570
+
571
+ embeddings, nodes = embed.node2vec_embed(
572
+ self._graph,
573
+ **self.config["node2vec_params"],
574
+ )
575
+
576
+ nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
577
+ return embeddings, nodes_ids
578
+
579
+
580
+ N_T = {
581
+ "full_docs": "LIGHTRAG_DOC_FULL",
582
+ "text_chunks": "LIGHTRAG_DOC_CHUNKS",
583
+ "chunks": "LIGHTRAG_DOC_CHUNKS",
584
+ "entities": "LIGHTRAG_GRAPH_NODES",
585
+ "relationships": "LIGHTRAG_GRAPH_EDGES"
586
+ }
587
+
588
+ TABLES = {
589
+ "LIGHTRAG_DOC_FULL":
590
+ {"ddl":"""CREATE TABLE LIGHTRAG_DOC_FULL (
591
+ id varchar(256)PRIMARY KEY,
592
+ workspace varchar(1024),
593
+ doc_name varchar(1024),
594
+ content CLOB,
595
+ meta JSON
596
+ )"""},
597
+
598
+ "LIGHTRAG_DOC_CHUNKS":
599
+ {"ddl":"""CREATE TABLE LIGHTRAG_DOC_CHUNKS (
600
+ id varchar(256) PRIMARY KEY,
601
+ workspace varchar(1024),
602
+ full_doc_id varchar(256),
603
+ chunk_order_index NUMBER,
604
+ tokens NUMBER,
605
+ content CLOB,
606
+ content_vector VECTOR
607
+ )"""},
608
+
609
+ "LIGHTRAG_GRAPH_NODES":
610
+ {"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_NODES (
611
+ id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
612
+ workspace varchar(1024),
613
+ name varchar(2048),
614
+ entity_type varchar(1024),
615
+ description CLOB,
616
+ source_chunk_id varchar(256),
617
+ content CLOB,
618
+ content_vector VECTOR
619
+ )"""},
620
+ "LIGHTRAG_GRAPH_EDGES":
621
+ {"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_EDGES (
622
+ id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
623
+ workspace varchar(1024),
624
+ source_name varchar(2048),
625
+ target_name varchar(2048),
626
+ weight NUMBER,
627
+ keywords CLOB,
628
+ description CLOB,
629
+ source_chunk_id varchar(256),
630
+ content CLOB,
631
+ content_vector VECTOR
632
+ )"""},
633
+ "LIGHTRAG_LLM_CACHE":
634
+ {"ddl":"""CREATE TABLE LIGHTRAG_LLM_CACHE (
635
+ id varchar(256) PRIMARY KEY,
636
+ return clob,
637
+ model varchar(1024)
638
+ )"""},
639
+
640
+ "LIGHTRAG_GRAPH":
641
+ {"ddl":"""CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
642
+ VERTEX TABLES (
643
+ lightrag_graph_nodes KEY (id)
644
+ LABEL entity
645
+ PROPERTIES (id,workspace,name) -- ,entity_type,description,source_chunk_id)
646
+ )
647
+ EDGE TABLES (
648
+ lightrag_graph_edges KEY (id)
649
+ SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
650
+ DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
651
+ LABEL has_relation
652
+ PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
653
+ ) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""},
654
+ }
655
+
656
+
657
+ SQL_TEMPLATES = {
658
+ # SQL for KVStorage
659
+ "get_by_id_full_docs":
660
+ "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
661
+
662
+ "get_by_id_text_chunks":
663
+ "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
664
+
665
+ "get_by_ids_full_docs":
666
+ "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
667
+
668
+ "get_by_ids_text_chunks":
669
+ "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
670
+
671
+ "filter_keys":
672
+ "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
673
+
674
+ "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
675
+ USING DUAL
676
+ ON (a.id = '{check_id}')
677
+ WHEN NOT MATCHED THEN
678
+ INSERT(id,content,workspace) values(:1,:2,:3)
679
+ """,
680
+
681
+ "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
682
+ USING DUAL
683
+ ON (a.id = '{check_id}')
684
+ WHEN NOT MATCHED THEN
685
+ INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
686
+ values (:1,:2,:3,:4,:5,:6,:7) """,
687
+
688
+ # SQL for VectorStorage
689
+ "entities":
690
+ """SELECT name as entity_name FROM
691
+ (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
692
+ FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
693
+ WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
694
+
695
+ "relationships":
696
+ """SELECT source_name as src_id, target_name as tgt_id FROM
697
+ (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
698
+ FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
699
+ WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
700
+
701
+ "chunks":
702
+ """SELECT id FROM
703
+ (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
704
+ FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
705
+ WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
706
+
707
+ # SQL for GraphStorage
708
+ "has_node":
709
+ """SELECT * FROM GRAPH_TABLE (lightrag_graph
710
+ MATCH (a)
711
+ WHERE a.workspace='{workspace}' AND a.name='{node_id}'
712
+ COLUMNS (a.name))""",
713
+
714
+ "has_edge":
715
+ """SELECT * FROM GRAPH_TABLE (lightrag_graph
716
+ MATCH (a) -[e]-> (b)
717
+ WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
718
+ AND a.name='{source_node_id}' AND b.name='{target_node_id}'
719
+ COLUMNS (e.source_name,e.target_name) )""",
720
+
721
+ "node_degree":
722
+ """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
723
+ MATCH (a)-[e]->(b)
724
+ WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
725
+ AND a.name='{node_id}' or b.name = '{node_id}'
726
+ COLUMNS (a.name))""",
727
+
728
+ "get_node":
729
+ """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
730
+ FROM GRAPH_TABLE (lightrag_graph
731
+ MATCH (a)
732
+ WHERE a.workspace='{workspace}' AND a.name='{node_id}'
733
+ COLUMNS (a.name)
734
+ ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
735
+ WHERE t2.workspace='{workspace}'""",
736
+
737
+ "get_edge":
738
+ """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
739
+ NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
740
+ FROM GRAPH_TABLE (lightrag_graph
741
+ MATCH (a)-[e]->(b)
742
+ WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
743
+ AND a.name='{source_node_id}' and b.name = '{target_node_id}'
744
+ COLUMNS (e.id,a.name as source_id)
745
+ ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
746
+
747
+ "get_node_edges":
748
+ """SELECT source_name,target_name
749
+ FROM GRAPH_TABLE (lightrag_graph
750
+ MATCH (a)-[e]->(b)
751
+ WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
752
+ AND a.name='{source_node_id}'
753
+ COLUMNS (a.name as source_name,b.name as target_name))""",
754
+
755
+ "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
756
+ USING DUAL
757
+ ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
758
+ WHEN NOT MATCHED THEN
759
+ INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
760
+ values (:1,:2,:3,:4,:5,:6,:7) """,
761
+ "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
762
+ USING DUAL
763
+ ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
764
+ WHEN NOT MATCHED THEN
765
+ INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
766
+ values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """
767
+ }
lightrag/lightrag.py CHANGED
@@ -18,20 +18,6 @@ from .operate import (
18
  naive_query,
19
  )
20
 
21
- from .storage import (
22
- JsonKVStorage,
23
- NanoVectorDBStorage,
24
- NetworkXStorage,
25
- )
26
-
27
- from .kg.neo4j_impl import Neo4JStorage
28
- # future KG integrations
29
-
30
- # from .kg.ArangoDB_impl import (
31
- # GraphStorage as ArangoDBStorage
32
- # )
33
-
34
-
35
  from .utils import (
36
  EmbeddingFunc,
37
  compute_mdhash_id,
@@ -49,6 +35,26 @@ from .base import (
49
  )
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
53
  try:
54
  return asyncio.get_event_loop()
@@ -68,7 +74,9 @@ class LightRAG:
68
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
69
  )
70
 
71
- kg: str = field(default="NetworkXStorage")
 
 
72
 
73
  current_log_level = logger.level
74
  log_level: str = field(default=current_log_level)
@@ -108,9 +116,16 @@ class LightRAG:
108
  llm_model_kwargs: dict = field(default_factory=dict)
109
 
110
  # storage
111
- key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
112
- vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
113
  vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
 
 
 
 
 
 
 
 
114
  enable_llm_cache: bool = True
115
 
116
  # extension
@@ -128,21 +143,16 @@ class LightRAG:
128
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
129
 
130
  # @TODO: should move all storage setup here to leverage initial start params attached to self.
131
- self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
132
- self.kg
133
- ]
134
 
 
 
 
 
 
135
  if not os.path.exists(self.working_dir):
136
  logger.info(f"Creating working directory {self.working_dir}")
137
  os.makedirs(self.working_dir)
138
 
139
- self.full_docs = self.key_string_value_json_storage_cls(
140
- namespace="full_docs", global_config=asdict(self)
141
- )
142
-
143
- self.text_chunks = self.key_string_value_json_storage_cls(
144
- namespace="text_chunks", global_config=asdict(self)
145
- )
146
 
147
  self.llm_response_cache = (
148
  self.key_string_value_json_storage_cls(
@@ -151,14 +161,27 @@ class LightRAG:
151
  if self.enable_llm_cache
152
  else None
153
  )
154
- self.chunk_entity_relation_graph = self.graph_storage_cls(
155
- namespace="chunk_entity_relation", global_config=asdict(self)
156
- )
157
 
158
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
159
  self.embedding_func
160
  )
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  self.entities_vdb = self.vector_db_storage_cls(
163
  namespace="entities",
164
  global_config=asdict(self),
@@ -187,8 +210,15 @@ class LightRAG:
187
 
188
  def _get_storage_class(self) -> Type[BaseGraphStorage]:
189
  return {
 
 
 
 
 
 
190
  "Neo4JStorage": Neo4JStorage,
191
  "NetworkXStorage": NetworkXStorage,
 
192
  # "ArangoDBStorage": ArangoDBStorage
193
  }
194
 
 
18
  naive_query,
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  from .utils import (
22
  EmbeddingFunc,
23
  compute_mdhash_id,
 
35
  )
36
 
37
 
38
+ from .storage import (
39
+ JsonKVStorage,
40
+ NanoVectorDBStorage,
41
+ NetworkXStorage,
42
+ )
43
+
44
+ from .kg.neo4j_impl import Neo4JStorage
45
+
46
+ from .kg.oracle_impl import (
47
+ OracleKVStorage,
48
+ OracleGraphStorage,
49
+ OracleVectorDBStorage
50
+ )
51
+
52
+ # future KG integrations
53
+
54
+ # from .kg.ArangoDB_impl import (
55
+ # GraphStorage as ArangoDBStorage
56
+ # )
57
+
58
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
59
  try:
60
  return asyncio.get_event_loop()
 
74
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
75
  )
76
 
77
+ kv_storage : str = field(default="JsonKVStorage")
78
+ vector_storage: str = field(default="NanoVectorDBStorage")
79
+ graph_storage: str = field(default="NetworkXStorage")
80
 
81
  current_log_level = logger.level
82
  log_level: str = field(default=current_log_level)
 
116
  llm_model_kwargs: dict = field(default_factory=dict)
117
 
118
  # storage
119
+
 
120
  vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
121
+ # if DATABASE_TYPE is None:
122
+ # key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
123
+ # vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
124
+ # vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
125
+ # elif DATABASE_TYPE == "oracle":
126
+ # key_string_value_json_storage_cls: Type[BaseKVStorage] = OracleKVStorage,
127
+ # vector_db_storage_cls: Type[BaseVectorStorage] = OracleVectorDBStorage,
128
+
129
  enable_llm_cache: bool = True
130
 
131
  # extension
 
143
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
144
 
145
  # @TODO: should move all storage setup here to leverage initial start params attached to self.
 
 
 
146
 
147
+ self. key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class()[self.kv_storage]
148
+
149
+ self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.graph_storage]
150
+
151
+ self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[self.vector_storage]
152
  if not os.path.exists(self.working_dir):
153
  logger.info(f"Creating working directory {self.working_dir}")
154
  os.makedirs(self.working_dir)
155
 
 
 
 
 
 
 
 
156
 
157
  self.llm_response_cache = (
158
  self.key_string_value_json_storage_cls(
 
161
  if self.enable_llm_cache
162
  else None
163
  )
 
 
 
164
 
165
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
166
  self.embedding_func
167
  )
168
 
169
+ ####
170
+ # add embedding func by walter
171
+ ####
172
+ self.full_docs = self.key_string_value_json_storage_cls(
173
+ namespace="full_docs", global_config=asdict(self), embedding_func=self.embedding_func
174
+ )
175
+ self.text_chunks = self.key_string_value_json_storage_cls(
176
+ namespace="text_chunks", global_config=asdict(self), embedding_func=self.embedding_func
177
+ )
178
+ self.chunk_entity_relation_graph = self.graph_storage_cls(
179
+ namespace="chunk_entity_relation", global_config=asdict(self), embedding_func=self.embedding_func
180
+ )
181
+ ####
182
+ # add embedding func by walter over
183
+ ####
184
+
185
  self.entities_vdb = self.vector_db_storage_cls(
186
  namespace="entities",
187
  global_config=asdict(self),
 
210
 
211
  def _get_storage_class(self) -> Type[BaseGraphStorage]:
212
  return {
213
+ "JsonKVStorage":JsonKVStorage,
214
+ "OracleKVStorage":OracleKVStorage,
215
+
216
+ "NanoVectorDBStorage":NanoVectorDBStorage,
217
+ "OracleVectorDBStorage":OracleVectorDBStorage,
218
+
219
  "Neo4JStorage": Neo4JStorage,
220
  "NetworkXStorage": NetworkXStorage,
221
+ "OracleGraphStorage": OracleGraphStorage,
222
  # "ArangoDBStorage": ArangoDBStorage
223
  }
224
 
lightrag/prompt.py CHANGED
@@ -222,14 +222,24 @@ Output:
222
 
223
  """
224
 
225
- PROMPTS["naive_rag_response"] = """You're a helpful assistant
226
- Below are the knowledge you know:
227
- {content_data}
228
- ---
229
- If you don't know the answer or if the provided knowledge do not contain sufficient information to provide an answer, just say so. Do not make anything up.
 
 
230
  Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
231
  If you don't know the answer, just say so. Do not make anything up.
232
  Do not include information where the supporting evidence for it is not provided.
 
233
  ---Target response length and format---
 
234
  {response_type}
 
 
 
 
 
 
235
  """
 
222
 
223
  """
224
 
225
+ PROMPTS["naive_rag_response"] = """---Role---
226
+
227
+ You are a helpful assistant responding to questions about documents provided.
228
+
229
+
230
+ ---Goal---
231
+
232
  Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
233
  If you don't know the answer, just say so. Do not make anything up.
234
  Do not include information where the supporting evidence for it is not provided.
235
+
236
  ---Target response length and format---
237
+
238
  {response_type}
239
+
240
+ ---Documents---
241
+
242
+ {content_data}
243
+
244
+ Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
245
  """
requirements.txt CHANGED
@@ -15,3 +15,4 @@ torch
15
  transformers
16
  xxhash
17
  # lmdeploy[all]
 
 
15
  transformers
16
  xxhash
17
  # lmdeploy[all]
18
+ oracledb