jin commited on
Commit
5dcb28f
·
1 Parent(s): 4d468e5

fix pre commit

Browse files
examples/lightrag_api_oracle_demo..py CHANGED
@@ -1,10 +1,10 @@
1
-
2
  from fastapi import FastAPI, HTTPException, File, UploadFile
3
  from contextlib import asynccontextmanager
4
  from pydantic import BaseModel
5
  from typing import Optional
6
 
7
- import sys, os
 
8
  from pathlib import Path
9
 
10
  import asyncio
@@ -13,7 +13,6 @@ from lightrag import LightRAG, QueryParam
13
  from lightrag.llm import openai_complete_if_cache, openai_embedding
14
  from lightrag.utils import EmbeddingFunc
15
  import numpy as np
16
- from datetime import datetime
17
 
18
  from lightrag.kg.oracle_impl import OracleDB
19
 
@@ -24,8 +23,6 @@ script_directory = Path(__file__).resolve().parent.parent
24
  sys.path.append(os.path.abspath(script_directory))
25
 
26
 
27
-
28
-
29
  # Apply nest_asyncio to solve event loop issues
30
  nest_asyncio.apply()
31
 
@@ -51,6 +48,7 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
51
  if not os.path.exists(WORKING_DIR):
52
  os.mkdir(WORKING_DIR)
53
 
 
54
  async def llm_model_func(
55
  prompt, system_prompt=None, history_messages=[], **kwargs
56
  ) -> str:
@@ -80,8 +78,8 @@ async def get_embedding_dim():
80
  embedding_dim = embedding.shape[1]
81
  return embedding_dim
82
 
 
83
  async def init():
84
-
85
  # Detect embedding dimension
86
  embedding_dimension = await get_embedding_dim()
87
  print(f"Detected embedding dimension: {embedding_dimension}")
@@ -91,36 +89,36 @@ async def init():
91
  # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
92
  # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
93
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- oracle_db = OracleDB(config={
96
- "user":"",
97
- "password":"",
98
- "dsn":"",
99
- "config_dir":"",
100
- "wallet_location":"",
101
- "wallet_password":"",
102
- "workspace":""
103
- } # specify which docs you want to store and query
104
- )
105
-
106
  # Check if Oracle DB tables exist, if not, tables will be created
107
  await oracle_db.check_tables()
108
  # Initialize LightRAG
109
- # We use Oracle DB as the KV/vector/graph storage
110
  rag = LightRAG(
111
- enable_llm_cache=False,
112
- working_dir=WORKING_DIR,
113
- chunk_token_size=512,
114
- llm_model_func=llm_model_func,
115
- embedding_func=EmbeddingFunc(
116
- embedding_dim=embedding_dimension,
117
- max_token_size=512,
118
- func=embedding_func,
119
- ),
120
- graph_storage = "OracleGraphStorage",
121
- kv_storage="OracleKVStorage",
122
- vector_storage="OracleVectorDBStorage"
123
- )
124
 
125
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
126
  rag.graph_storage_cls.db = oracle_db
@@ -129,6 +127,7 @@ async def init():
129
 
130
  return rag
131
 
 
132
  # Data models
133
 
134
 
@@ -152,6 +151,7 @@ class Response(BaseModel):
152
 
153
  rag = None # 定义为全局对象
154
 
 
155
  @asynccontextmanager
156
  async def lifespan(app: FastAPI):
157
  global rag
@@ -160,18 +160,21 @@ async def lifespan(app: FastAPI):
160
  yield
161
 
162
 
163
- app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan)
 
 
 
164
 
165
  @app.post("/query", response_model=Response)
166
  async def query_endpoint(request: QueryRequest):
167
  try:
168
  # loop = asyncio.get_event_loop()
169
  result = await rag.aquery(
170
- request.query,
171
- param=QueryParam(
172
- mode=request.mode, only_need_context=request.only_need_context
173
- ),
174
- )
175
  return Response(status="success", data=result)
176
  except Exception as e:
177
  raise HTTPException(status_code=500, detail=str(e))
@@ -234,4 +237,4 @@ if __name__ == "__main__":
234
  # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
235
 
236
  # 4. Health check:
237
- # curl -X GET "http://127.0.0.1:8020/health"
 
 
1
  from fastapi import FastAPI, HTTPException, File, UploadFile
2
  from contextlib import asynccontextmanager
3
  from pydantic import BaseModel
4
  from typing import Optional
5
 
6
+ import sys
7
+ import os
8
  from pathlib import Path
9
 
10
  import asyncio
 
13
  from lightrag.llm import openai_complete_if_cache, openai_embedding
14
  from lightrag.utils import EmbeddingFunc
15
  import numpy as np
 
16
 
17
  from lightrag.kg.oracle_impl import OracleDB
18
 
 
23
  sys.path.append(os.path.abspath(script_directory))
24
 
25
 
 
 
26
  # Apply nest_asyncio to solve event loop issues
27
  nest_asyncio.apply()
28
 
 
48
  if not os.path.exists(WORKING_DIR):
49
  os.mkdir(WORKING_DIR)
50
 
51
+
52
  async def llm_model_func(
53
  prompt, system_prompt=None, history_messages=[], **kwargs
54
  ) -> str:
 
78
  embedding_dim = embedding.shape[1]
79
  return embedding_dim
80
 
81
+
82
  async def init():
 
83
  # Detect embedding dimension
84
  embedding_dimension = await get_embedding_dim()
85
  print(f"Detected embedding dimension: {embedding_dimension}")
 
89
  # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
90
  # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
91
 
92
+ oracle_db = OracleDB(
93
+ config={
94
+ "user": "",
95
+ "password": "",
96
+ "dsn": "",
97
+ "config_dir": "",
98
+ "wallet_location": "",
99
+ "wallet_password": "",
100
+ "workspace": "",
101
+ } # specify which docs you want to store and query
102
+ )
103
 
 
 
 
 
 
 
 
 
 
 
 
104
  # Check if Oracle DB tables exist, if not, tables will be created
105
  await oracle_db.check_tables()
106
  # Initialize LightRAG
107
+ # We use Oracle DB as the KV/vector/graph storage
108
  rag = LightRAG(
109
+ enable_llm_cache=False,
110
+ working_dir=WORKING_DIR,
111
+ chunk_token_size=512,
112
+ llm_model_func=llm_model_func,
113
+ embedding_func=EmbeddingFunc(
114
+ embedding_dim=embedding_dimension,
115
+ max_token_size=512,
116
+ func=embedding_func,
117
+ ),
118
+ graph_storage="OracleGraphStorage",
119
+ kv_storage="OracleKVStorage",
120
+ vector_storage="OracleVectorDBStorage",
121
+ )
122
 
123
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
124
  rag.graph_storage_cls.db = oracle_db
 
127
 
128
  return rag
129
 
130
+
131
  # Data models
132
 
133
 
 
151
 
152
  rag = None # 定义为全局对象
153
 
154
+
155
  @asynccontextmanager
156
  async def lifespan(app: FastAPI):
157
  global rag
 
160
  yield
161
 
162
 
163
+ app = FastAPI(
164
+ title="LightRAG API", description="API for RAG operations", lifespan=lifespan
165
+ )
166
+
167
 
168
  @app.post("/query", response_model=Response)
169
  async def query_endpoint(request: QueryRequest):
170
  try:
171
  # loop = asyncio.get_event_loop()
172
  result = await rag.aquery(
173
+ request.query,
174
+ param=QueryParam(
175
+ mode=request.mode, only_need_context=request.only_need_context
176
+ ),
177
+ )
178
  return Response(status="success", data=result)
179
  except Exception as e:
180
  raise HTTPException(status_code=500, detail=str(e))
 
237
  # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
238
 
239
  # 4. Health check:
240
+ # curl -X GET "http://127.0.0.1:8020/health"
examples/lightrag_oracle_demo.py CHANGED
@@ -1,11 +1,11 @@
1
- import sys, os
 
2
  from pathlib import Path
3
  import asyncio
4
  from lightrag import LightRAG, QueryParam
5
  from lightrag.llm import openai_complete_if_cache, openai_embedding
6
  from lightrag.utils import EmbeddingFunc
7
  import numpy as np
8
- from datetime import datetime
9
  from lightrag.kg.oracle_impl import OracleDB
10
 
11
  print(os.getcwd())
@@ -25,6 +25,7 @@ EMBEDMODEL = "cohere.embed-multilingual-v3.0"
25
  if not os.path.exists(WORKING_DIR):
26
  os.mkdir(WORKING_DIR)
27
 
 
28
  async def llm_model_func(
29
  prompt, system_prompt=None, history_messages=[], **kwargs
30
  ) -> str:
@@ -66,22 +67,21 @@ async def main():
66
  # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
67
  # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
68
  # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
69
- oracle_db = OracleDB(config={
70
- "user":"username",
71
- "password":"xxxxxxxxx",
72
- "dsn":"xxxxxxx_medium",
73
- "config_dir":"dir/path/to/oracle/config",
74
- "wallet_location":"dir/path/to/oracle/wallet",
75
- "wallet_password":"xxxxxxxxx",
76
- "workspace":"company" # specify which docs you want to store and query
 
77
  }
78
- )
79
-
80
 
81
  # Check if Oracle DB tables exist, if not, tables will be created
82
  await oracle_db.check_tables()
83
 
84
-
85
  # Initialize LightRAG
86
  # We use Oracle DB as the KV/vector/graph storage
87
  rag = LightRAG(
@@ -93,10 +93,10 @@ async def main():
93
  embedding_dim=embedding_dimension,
94
  max_token_size=512,
95
  func=embedding_func,
96
- ),
97
- graph_storage = "OracleGraphStorage",
98
- kv_storage="OracleKVStorage",
99
- vector_storage="OracleVectorDBStorage"
100
  )
101
 
102
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
@@ -106,18 +106,23 @@ async def main():
106
 
107
  # Extract and Insert into LightRAG storage
108
  with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
109
- await rag.ainsert(f.read())
110
 
111
  # Perform search in different modes
112
  modes = ["naive", "local", "global", "hybrid"]
113
  for mode in modes:
114
- print("="*20, mode, "="*20)
115
- print(await rag.aquery("What are the top themes in this story?", param=QueryParam(mode=mode)))
116
- print("-"*100, "\n")
 
 
 
 
 
117
 
118
  except Exception as e:
119
  print(f"An error occurred: {e}")
120
 
121
 
122
  if __name__ == "__main__":
123
- asyncio.run(main())
 
1
+ import sys
2
+ import os
3
  from pathlib import Path
4
  import asyncio
5
  from lightrag import LightRAG, QueryParam
6
  from lightrag.llm import openai_complete_if_cache, openai_embedding
7
  from lightrag.utils import EmbeddingFunc
8
  import numpy as np
 
9
  from lightrag.kg.oracle_impl import OracleDB
10
 
11
  print(os.getcwd())
 
25
  if not os.path.exists(WORKING_DIR):
26
  os.mkdir(WORKING_DIR)
27
 
28
+
29
  async def llm_model_func(
30
  prompt, system_prompt=None, history_messages=[], **kwargs
31
  ) -> str:
 
67
  # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
68
  # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
69
  # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
70
+ oracle_db = OracleDB(
71
+ config={
72
+ "user": "username",
73
+ "password": "xxxxxxxxx",
74
+ "dsn": "xxxxxxx_medium",
75
+ "config_dir": "dir/path/to/oracle/config",
76
+ "wallet_location": "dir/path/to/oracle/wallet",
77
+ "wallet_password": "xxxxxxxxx",
78
+ "workspace": "company", # specify which docs you want to store and query
79
  }
80
+ )
 
81
 
82
  # Check if Oracle DB tables exist, if not, tables will be created
83
  await oracle_db.check_tables()
84
 
 
85
  # Initialize LightRAG
86
  # We use Oracle DB as the KV/vector/graph storage
87
  rag = LightRAG(
 
93
  embedding_dim=embedding_dimension,
94
  max_token_size=512,
95
  func=embedding_func,
96
+ ),
97
+ graph_storage="OracleGraphStorage",
98
+ kv_storage="OracleKVStorage",
99
+ vector_storage="OracleVectorDBStorage",
100
  )
101
 
102
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
 
106
 
107
  # Extract and Insert into LightRAG storage
108
  with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
109
+ await rag.ainsert(f.read())
110
 
111
  # Perform search in different modes
112
  modes = ["naive", "local", "global", "hybrid"]
113
  for mode in modes:
114
+ print("=" * 20, mode, "=" * 20)
115
+ print(
116
+ await rag.aquery(
117
+ "What are the top themes in this story?",
118
+ param=QueryParam(mode=mode),
119
+ )
120
+ )
121
+ print("-" * 100, "\n")
122
 
123
  except Exception as e:
124
  print(f"An error occurred: {e}")
125
 
126
 
127
  if __name__ == "__main__":
128
+ asyncio.run(main())
lightrag/base.py CHANGED
@@ -60,6 +60,7 @@ class BaseVectorStorage(StorageNameSpace):
60
  @dataclass
61
  class BaseKVStorage(Generic[T], StorageNameSpace):
62
  embedding_func: EmbeddingFunc
 
63
  async def all_keys(self) -> list[str]:
64
  raise NotImplementedError
65
 
@@ -85,6 +86,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
85
  @dataclass
86
  class BaseGraphStorage(StorageNameSpace):
87
  embedding_func: EmbeddingFunc = None
 
88
  async def has_node(self, node_id: str) -> bool:
89
  raise NotImplementedError
90
 
 
60
  @dataclass
61
  class BaseKVStorage(Generic[T], StorageNameSpace):
62
  embedding_func: EmbeddingFunc
63
+
64
  async def all_keys(self) -> list[str]:
65
  raise NotImplementedError
66
 
 
86
  @dataclass
87
  class BaseGraphStorage(StorageNameSpace):
88
  embedding_func: EmbeddingFunc = None
89
+
90
  async def has_node(self, node_id: str) -> bool:
91
  raise NotImplementedError
92
 
lightrag/kg/oracle_impl.py CHANGED
@@ -1,9 +1,9 @@
1
  import asyncio
2
- #import html
3
- #import os
 
4
  from dataclasses import dataclass
5
- from typing import Any, Union, cast
6
- import networkx as nx
7
  import numpy as np
8
  import array
9
 
@@ -16,8 +16,9 @@ from ..base import (
16
 
17
  import oracledb
18
 
 
19
  class OracleDB:
20
- def __init__(self,config,**kwargs):
21
  self.host = config.get("host", None)
22
  self.port = config.get("port", None)
23
  self.user = config.get("user", None)
@@ -32,21 +33,21 @@ class OracleDB:
32
  logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
33
  if self.user is None or self.password is None:
34
  raise ValueError("Missing database user or password in addon_params")
35
-
36
  try:
37
  oracledb.defaults.fetch_lobs = False
38
 
39
  self.pool = oracledb.create_pool_async(
40
- user = self.user,
41
- password = self.password,
42
- dsn = self.dsn,
43
- config_dir = self.config_dir,
44
- wallet_location = self.wallet_location,
45
- wallet_password = self.wallet_password,
46
- min = 1,
47
- max = self.max,
48
- increment = self.increment
49
- )
50
  logger.info(f"Connected to Oracle database at {self.dsn}")
51
  except Exception as e:
52
  logger.error(f"Failed to connect to Oracle database at {self.dsn}")
@@ -90,12 +91,14 @@ class OracleDB:
90
  arraysize=cursor.arraysize,
91
  outconverter=self.numpy_converter_out,
92
  )
93
-
94
  async def check_tables(self):
95
- for k,v in TABLES.items():
96
  try:
97
  if k.lower() == "lightrag_graph":
98
- await self.query("SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only")
 
 
99
  else:
100
  await self.query("SELECT 1 FROM {k}".format(k=k))
101
  except Exception as e:
@@ -108,12 +111,11 @@ class OracleDB:
108
  except Exception as e:
109
  logger.error(f"Failed to create table {k} in Oracle database")
110
  logger.error(f"Oracle database error: {e}")
111
-
112
- logger.info(f"Finished check all tables in Oracle database")
113
-
114
-
115
- async def query(self,sql: str, multirows: bool = False) -> Union[dict, None]:
116
- async with self.pool.acquire() as connection:
117
  connection.inputtypehandler = self.input_type_handler
118
  connection.outputtypehandler = self.output_type_handler
119
  with connection.cursor() as cursor:
@@ -136,9 +138,9 @@ class OracleDB:
136
  data = dict(zip(columns, row))
137
  else:
138
  data = None
139
- return data
140
 
141
- async def execute(self,sql: str, data: list = None):
142
  # logger.info("go into OracleDB execute method")
143
  try:
144
  async with self.pool.acquire() as connection:
@@ -148,58 +150,63 @@ class OracleDB:
148
  if data is None:
149
  await cursor.execute(sql)
150
  else:
151
- #print(data)
152
- #print(sql)
153
- await cursor.execute(sql,data)
154
  await connection.commit()
155
  except Exception as e:
156
- logger.error(f"Oracle database error: {e}")
157
  print(sql)
158
  print(data)
159
  raise
160
 
 
161
  @dataclass
162
  class OracleKVStorage(BaseKVStorage):
163
-
164
  # should pass db object to self.db
165
  def __post_init__(self):
166
  self._data = {}
167
- self._max_batch_size = self.global_config["embedding_batch_num"]
168
-
169
  ################ QUERY METHODS ################
170
 
171
  async def get_by_id(self, id: str) -> Union[dict, None]:
172
  """根据 id 获取 doc_full 数据."""
173
- SQL = SQL_TEMPLATES["get_by_id_"+self.namespace].format(workspace=self.db.workspace,id=id)
174
- #print("get_by_id:"+SQL)
175
- res = await self.db.query(SQL)
 
 
176
  if res:
177
- data = res #{"data":res}
178
- #print (data)
179
  return data
180
  else:
181
  return None
182
 
183
  # Query by id
184
- async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict],None]:
185
  """根据 id 获取 doc_chunks 数据"""
186
- SQL = SQL_TEMPLATES["get_by_ids_"+self.namespace].format(workspace=self.db.workspace,
187
- ids=",".join([f"'{id}'" for id in ids]))
188
- #print("get_by_ids:"+SQL)
189
- res = await self.db.query(SQL,multirows=True)
 
190
  if res:
191
- data = res # [{"data":i} for i in res]
192
- #print(data)
193
  return data
194
  else:
195
  return None
196
-
197
  async def filter_keys(self, keys: list[str]) -> set[str]:
198
  """过滤掉重复内容"""
199
- SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
200
- workspace=self.db.workspace,
201
- ids=",".join([f"'{k}'" for k in keys]))
202
- res = await self.db.query(SQL,multirows=True)
 
 
203
  data = None
204
  if res:
205
  exist_keys = [key["id"] for key in res]
@@ -208,14 +215,13 @@ class OracleKVStorage(BaseKVStorage):
208
  exist_keys = []
209
  data = set([s for s in keys if s not in exist_keys])
210
  return data
211
-
212
-
213
  ################ INSERT METHODS ################
214
  async def upsert(self, data: dict[str, dict]):
215
  left_data = {k: v for k, v in data.items() if k not in self._data}
216
  self._data.update(left_data)
217
- #print(self._data)
218
- #values = []
219
  if self.namespace == "text_chunks":
220
  list_data = [
221
  {
@@ -226,7 +232,7 @@ class OracleKVStorage(BaseKVStorage):
226
  ]
227
  contents = [v["content"] for v in data.values()]
228
  batches = [
229
- contents[i: i + self._max_batch_size]
230
  for i in range(0, len(contents), self._max_batch_size)
231
  ]
232
  embeddings_list = await asyncio.gather(
@@ -235,42 +241,45 @@ class OracleKVStorage(BaseKVStorage):
235
  embeddings = np.concatenate(embeddings_list)
236
  for i, d in enumerate(list_data):
237
  d["__vector__"] = embeddings[i]
238
- #print(list_data)
239
  for item in list_data:
240
- merge_sql = SQL_TEMPLATES["merge_chunk"].format(
241
- check_id=item["__id__"]
242
- )
243
-
244
- values = [item["__id__"], item["content"], self.db.workspace, item["tokens"],
245
- item["chunk_order_index"], item["full_doc_id"], item["__vector__"]]
246
- #print(merge_sql)
 
 
 
 
 
247
  await self.db.execute(merge_sql, values)
248
 
249
  if self.namespace == "full_docs":
250
  for k, v in self._data.items():
251
- #values.clear()
252
  merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
253
  check_id=k,
254
  )
255
  values = [k, self._data[k]["content"], self.db.workspace]
256
- #print(merge_sql)
257
  await self.db.execute(merge_sql, values)
258
  return left_data
259
 
260
-
261
  async def index_done_callback(self):
262
  if self.namespace in ["full_docs", "text_chunks"]:
263
  logger.info("full doc and chunk data had been saved into oracle db!")
264
 
265
 
266
-
267
  @dataclass
268
  class OracleVectorDBStorage(BaseVectorStorage):
269
  cosine_better_than_threshold: float = 0.2
270
 
271
  def __post_init__(self):
272
  pass
273
-
274
  async def upsert(self, data: dict[str, dict]):
275
  """向向量数据库中插入数据"""
276
  pass
@@ -278,53 +287,51 @@ class OracleVectorDBStorage(BaseVectorStorage):
278
  async def index_done_callback(self):
279
  pass
280
 
281
-
282
  #################### query method ###############
283
  async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
284
- """从向量数据库中查询数据"""
285
  embeddings = await self.embedding_func([query])
286
  embedding = embeddings[0]
287
  # 转换精度
288
  dtype = str(embedding.dtype).upper()
289
  dimension = embedding.shape[0]
290
- embedding_string = ', '.join(map(str, embedding.tolist()))
291
 
292
  SQL = SQL_TEMPLATES[self.namespace].format(
293
- embedding_string=embedding_string,
294
- dimension=dimension,
295
- dtype=dtype,
296
- workspace=self.db.workspace,
297
- top_k=top_k,
298
- better_than_threshold=self.cosine_better_than_threshold,
299
- )
300
  # print(SQL)
301
  results = await self.db.query(SQL, multirows=True)
302
- #print("vector search result:",results)
303
  return results
304
 
305
 
306
  @dataclass
307
- class OracleGraphStorage(BaseGraphStorage):
308
  """基于Oracle的图存储模块"""
309
-
310
  def __post_init__(self):
311
  """从graphml文件加载图"""
312
  self._max_batch_size = self.global_config["embedding_batch_num"]
313
 
314
-
315
  #################### insert method ################
316
-
317
  async def upsert_node(self, node_id: str, node_data: dict[str, str]):
318
  """插入或更新节点"""
319
- #print("go into upsert node method")
320
  entity_name = node_id
321
  entity_type = node_data["entity_type"]
322
  description = node_data["description"]
323
- source_id = node_data["source_id"]
324
- content = entity_name+description
325
  contents = [content]
326
  batches = [
327
- contents[i: i + self._max_batch_size]
328
  for i in range(0, len(contents), self._max_batch_size)
329
  ]
330
  embeddings_list = await asyncio.gather(
@@ -333,27 +340,38 @@ class OracleGraphStorage(BaseGraphStorage):
333
  embeddings = np.concatenate(embeddings_list)
334
  content_vector = embeddings[0]
335
  merge_sql = SQL_TEMPLATES["merge_node"].format(
336
- workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  )
338
- #print(merge_sql)
339
- await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector])
340
- #self._graph.add_node(node_id, **node_data)
341
 
342
  async def upsert_edge(
343
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
344
  ):
345
  """插入或更新边"""
346
- #print("go into upsert edge method")
347
  source_name = source_node_id
348
  target_name = target_node_id
349
  weight = edge_data["weight"]
350
  keywords = edge_data["keywords"]
351
  description = edge_data["description"]
352
  source_chunk_id = edge_data["source_id"]
353
- content = keywords+source_name+target_name+description
354
  contents = [content]
355
  batches = [
356
- contents[i: i + self._max_batch_size]
357
  for i in range(0, len(contents), self._max_batch_size)
358
  ]
359
  embeddings_list = await asyncio.gather(
@@ -362,11 +380,27 @@ class OracleGraphStorage(BaseGraphStorage):
362
  embeddings = np.concatenate(embeddings_list)
363
  content_vector = embeddings[0]
364
  merge_sql = SQL_TEMPLATES["merge_edge"].format(
365
- workspace=self.db.workspace,source_name=source_name, target_name=target_name, source_chunk_id=source_chunk_id
 
 
 
366
  )
367
- #print(merge_sql)
368
- await self.db.execute(merge_sql, [self.db.workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector])
369
- #self._graph.add_edge(source_node_id, target_node_id, **edge_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
372
  """为节点生成向量"""
@@ -386,99 +420,109 @@ class OracleGraphStorage(BaseGraphStorage):
386
  nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
387
  return embeddings, nodes_ids
388
 
389
-
390
  async def index_done_callback(self):
391
  """写入graphhml图文件"""
392
- logger.info("Node and edge data had been saved into oracle db already, so nothing to do here!")
393
-
 
 
394
  #################### query method #################
395
  async def has_node(self, node_id: str) -> bool:
396
- """根据节点id检查节点是否存在"""
397
- SQL = SQL_TEMPLATES["has_node"].format(workspace=self.db.workspace, node_id=node_id)
398
- # print(SQL)
399
- #print(self.db.workspace, node_id)
 
 
400
  res = await self.db.query(SQL)
401
  if res:
402
- #print("Node exist!",res)
403
  return True
404
  else:
405
- #print("Node not exist!")
406
  return False
407
 
408
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
409
  """根据源和目标节点id检查边是否存在"""
410
- SQL = SQL_TEMPLATES["has_edge"].format(workspace=self.db.workspace,
411
- source_node_id=source_node_id,
412
- target_node_id=target_node_id)
 
 
413
  # print(SQL)
414
  res = await self.db.query(SQL)
415
  if res:
416
- #print("Edge exist!",res)
417
  return True
418
  else:
419
- #print("Edge not exist!")
420
  return False
421
 
422
  async def node_degree(self, node_id: str) -> int:
423
- """根据节点id获取节点的度"""
424
- SQL = SQL_TEMPLATES["node_degree"].format(workspace=self.db.workspace, node_id=node_id)
 
 
425
  # print(SQL)
426
  res = await self.db.query(SQL)
427
  if res:
428
- #print("Node degree",res["degree"])
429
  return res["degree"]
430
  else:
431
- #print("Edge not exist!")
432
  return 0
433
 
434
-
435
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
436
  """根据源和目标节点id获取边的度"""
437
  degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
438
- #print("Edge degree",degree)
439
  return degree
440
 
441
-
442
  async def get_node(self, node_id: str) -> Union[dict, None]:
443
  """根据节点id获取节点数据"""
444
- SQL = SQL_TEMPLATES["get_node"].format(workspace=self.db.workspace, node_id=node_id)
 
 
445
  # print(self.db.workspace, node_id)
446
  # print(SQL)
447
  res = await self.db.query(SQL)
448
  if res:
449
- #print("Get node!",self.db.workspace, node_id,res)
450
  return res
451
  else:
452
- #print("Can't get node!",self.db.workspace, node_id)
453
  return None
454
-
455
  async def get_edge(
456
  self, source_node_id: str, target_node_id: str
457
  ) -> Union[dict, None]:
458
  """根据源和目标节点id获取边"""
459
- SQL = SQL_TEMPLATES["get_edge"].format(workspace=self.db.workspace,
460
- source_node_id=source_node_id,
461
- target_node_id=target_node_id)
 
 
462
  res = await self.db.query(SQL)
463
  if res:
464
- #print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
465
  return res
466
  else:
467
- #print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
468
  return None
469
 
470
  async def get_node_edges(self, source_node_id: str):
471
  """根据节点id获取节点的所有边"""
472
  if await self.has_node(source_node_id):
473
- SQL = SQL_TEMPLATES["get_node_edges"].format(workspace=self.db.workspace,
474
- source_node_id=source_node_id)
 
475
  res = await self.db.query(sql=SQL, multirows=True)
476
  if res:
477
- data = [(i["source_name"],i["target_name"]) for i in res]
478
- #print("Get node edge!",self.db.workspace, source_node_id,data)
479
  return data
480
  else:
481
- #print("Node Edge not exist!",self.db.workspace, source_node_id)
482
  return []
483
 
484
 
@@ -487,12 +531,12 @@ N_T = {
487
  "text_chunks": "LIGHTRAG_DOC_CHUNKS",
488
  "chunks": "LIGHTRAG_DOC_CHUNKS",
489
  "entities": "LIGHTRAG_GRAPH_NODES",
490
- "relationships": "LIGHTRAG_GRAPH_EDGES"
491
  }
492
 
493
  TABLES = {
494
- "LIGHTRAG_DOC_FULL":
495
- {"ddl":"""CREATE TABLE LIGHTRAG_DOC_FULL (
496
  id varchar(256)PRIMARY KEY,
497
  workspace varchar(1024),
498
  doc_name varchar(1024),
@@ -500,61 +544,63 @@ TABLES = {
500
  meta JSON,
501
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
502
  updatetime TIMESTAMP DEFAULT NULL
503
- )"""},
504
-
505
- "LIGHTRAG_DOC_CHUNKS":
506
- {"ddl":"""CREATE TABLE LIGHTRAG_DOC_CHUNKS (
507
  id varchar(256) PRIMARY KEY,
508
  workspace varchar(1024),
509
  full_doc_id varchar(256),
510
  chunk_order_index NUMBER,
511
- tokens NUMBER,
512
  content CLOB,
513
  content_vector VECTOR,
514
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
515
- updatetime TIMESTAMP DEFAULT NULL
516
- )"""},
517
-
518
- "LIGHTRAG_GRAPH_NODES":
519
- {"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_NODES (
520
  id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
521
  workspace varchar(1024),
522
  name varchar(2048),
523
- entity_type varchar(1024),
524
  description CLOB,
525
  source_chunk_id varchar(256),
526
  content CLOB,
527
  content_vector VECTOR,
528
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
529
  updatetime TIMESTAMP DEFAULT NULL
530
- )"""},
531
- "LIGHTRAG_GRAPH_EDGES":
532
- {"ddl":"""CREATE TABLE LIGHTRAG_GRAPH_EDGES (
 
533
  id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
534
  workspace varchar(1024),
535
  source_name varchar(2048),
536
- target_name varchar(2048),
537
  weight NUMBER,
538
- keywords CLOB,
539
  description CLOB,
540
  source_chunk_id varchar(256),
541
  content CLOB,
542
  content_vector VECTOR,
543
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
544
  updatetime TIMESTAMP DEFAULT NULL
545
- )"""},
546
- "LIGHTRAG_LLM_CACHE":
547
- {"ddl":"""CREATE TABLE LIGHTRAG_LLM_CACHE (
 
548
  id varchar(256) PRIMARY KEY,
549
  send clob,
550
  return clob,
551
  model varchar(1024),
552
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
553
  updatetime TIMESTAMP DEFAULT NULL
554
- )"""},
555
-
556
- "LIGHTRAG_GRAPH":
557
- {"ddl":"""CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
558
  VERTEX TABLES (
559
  lightrag_graph_nodes KEY (id)
560
  LABEL entity
@@ -565,93 +611,67 @@ TABLES = {
565
  SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
566
  DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
567
  LABEL has_relation
568
- PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
569
- ) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""},
570
- }
 
571
 
572
 
573
  SQL_TEMPLATES = {
574
  # SQL for KVStorage
575
- "get_by_id_full_docs":
576
- "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
577
-
578
- "get_by_id_text_chunks":
579
- "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
580
-
581
- "get_by_ids_full_docs":
582
- "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
583
-
584
- "get_by_ids_text_chunks":
585
- "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
586
-
587
- "filter_keys":
588
- "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
589
-
590
  "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
591
  USING DUAL
592
  ON (a.id = '{check_id}')
593
  WHEN NOT MATCHED THEN
594
  INSERT(id,content,workspace) values(:1,:2,:3)
595
  """,
596
-
597
  "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
598
  USING DUAL
599
  ON (a.id = '{check_id}')
600
  WHEN NOT MATCHED THEN
601
  INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
602
  values (:1,:2,:3,:4,:5,:6,:7) """,
603
-
604
  # SQL for VectorStorage
605
- "entities":
606
- """SELECT name as entity_name FROM
607
- (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
608
- FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
609
  WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
610
-
611
- "relationships":
612
- """SELECT source_name as src_id, target_name as tgt_id FROM
613
- (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
614
- FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
615
  WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
616
-
617
- "chunks":
618
- """SELECT id FROM
619
- (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
620
- FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
621
  WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
622
-
623
  # SQL for GraphStorage
624
- "has_node":
625
- """SELECT * FROM GRAPH_TABLE (lightrag_graph
626
  MATCH (a)
627
  WHERE a.workspace='{workspace}' AND a.name='{node_id}'
628
  COLUMNS (a.name))""",
629
-
630
- "has_edge":
631
- """SELECT * FROM GRAPH_TABLE (lightrag_graph
632
  MATCH (a) -[e]-> (b)
633
  WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
634
  AND a.name='{source_node_id}' AND b.name='{target_node_id}'
635
  COLUMNS (e.source_name,e.target_name) )""",
636
-
637
- "node_degree":
638
- """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
639
  MATCH (a)-[e]->(b)
640
  WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
641
  AND a.name='{node_id}' or b.name = '{node_id}'
642
  COLUMNS (a.name))""",
643
-
644
- "get_node":
645
- """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
646
  FROM GRAPH_TABLE (lightrag_graph
647
- MATCH (a)
648
  WHERE a.workspace='{workspace}' AND a.name='{node_id}'
649
  COLUMNS (a.name)
650
  ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
651
  WHERE t2.workspace='{workspace}'""",
652
-
653
- "get_edge":
654
- """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
655
  NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
656
  FROM GRAPH_TABLE (lightrag_graph
657
  MATCH (a)-[e]->(b)
@@ -659,15 +679,12 @@ SQL_TEMPLATES = {
659
  AND a.name='{source_node_id}' and b.name = '{target_node_id}'
660
  COLUMNS (e.id,a.name as source_id)
661
  ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
662
-
663
- "get_node_edges":
664
- """SELECT source_name,target_name
665
  FROM GRAPH_TABLE (lightrag_graph
666
  MATCH (a)-[e]->(b)
667
  WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
668
  AND a.name='{source_node_id}'
669
  COLUMNS (a.name as source_name,b.name as target_name))""",
670
-
671
  "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
672
  USING DUAL
673
  ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
@@ -679,5 +696,5 @@ SQL_TEMPLATES = {
679
  ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
680
  WHEN NOT MATCHED THEN
681
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
682
- values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """
683
- }
 
1
  import asyncio
2
+
3
+ # import html
4
+ # import os
5
  from dataclasses import dataclass
6
+ from typing import Union
 
7
  import numpy as np
8
  import array
9
 
 
16
 
17
  import oracledb
18
 
19
+
20
  class OracleDB:
21
+ def __init__(self, config, **kwargs):
22
  self.host = config.get("host", None)
23
  self.port = config.get("port", None)
24
  self.user = config.get("user", None)
 
33
  logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
34
  if self.user is None or self.password is None:
35
  raise ValueError("Missing database user or password in addon_params")
36
+
37
  try:
38
  oracledb.defaults.fetch_lobs = False
39
 
40
  self.pool = oracledb.create_pool_async(
41
+ user=self.user,
42
+ password=self.password,
43
+ dsn=self.dsn,
44
+ config_dir=self.config_dir,
45
+ wallet_location=self.wallet_location,
46
+ wallet_password=self.wallet_password,
47
+ min=1,
48
+ max=self.max,
49
+ increment=self.increment,
50
+ )
51
  logger.info(f"Connected to Oracle database at {self.dsn}")
52
  except Exception as e:
53
  logger.error(f"Failed to connect to Oracle database at {self.dsn}")
 
91
  arraysize=cursor.arraysize,
92
  outconverter=self.numpy_converter_out,
93
  )
94
+
95
  async def check_tables(self):
96
+ for k, v in TABLES.items():
97
  try:
98
  if k.lower() == "lightrag_graph":
99
+ await self.query(
100
+ "SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
101
+ )
102
  else:
103
  await self.query("SELECT 1 FROM {k}".format(k=k))
104
  except Exception as e:
 
111
  except Exception as e:
112
  logger.error(f"Failed to create table {k} in Oracle database")
113
  logger.error(f"Oracle database error: {e}")
114
+
115
+ logger.info("Finished check all tables in Oracle database")
116
+
117
+ async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
118
+ async with self.pool.acquire() as connection:
 
119
  connection.inputtypehandler = self.input_type_handler
120
  connection.outputtypehandler = self.output_type_handler
121
  with connection.cursor() as cursor:
 
138
  data = dict(zip(columns, row))
139
  else:
140
  data = None
141
+ return data
142
 
143
+ async def execute(self, sql: str, data: list = None):
144
  # logger.info("go into OracleDB execute method")
145
  try:
146
  async with self.pool.acquire() as connection:
 
150
  if data is None:
151
  await cursor.execute(sql)
152
  else:
153
+ # print(data)
154
+ # print(sql)
155
+ await cursor.execute(sql, data)
156
  await connection.commit()
157
  except Exception as e:
158
+ logger.error(f"Oracle database error: {e}")
159
  print(sql)
160
  print(data)
161
  raise
162
 
163
+
164
  @dataclass
165
  class OracleKVStorage(BaseKVStorage):
 
166
  # should pass db object to self.db
167
  def __post_init__(self):
168
  self._data = {}
169
+ self._max_batch_size = self.global_config["embedding_batch_num"]
170
+
171
  ################ QUERY METHODS ################
172
 
173
  async def get_by_id(self, id: str) -> Union[dict, None]:
174
  """根据 id 获取 doc_full 数据."""
175
+ SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
176
+ workspace=self.db.workspace, id=id
177
+ )
178
+ # print("get_by_id:"+SQL)
179
+ res = await self.db.query(SQL)
180
  if res:
181
+ data = res # {"data":res}
182
+ # print (data)
183
  return data
184
  else:
185
  return None
186
 
187
  # Query by id
188
+ async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
189
  """根据 id 获取 doc_chunks 数据"""
190
+ SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
191
+ workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
192
+ )
193
+ # print("get_by_ids:"+SQL)
194
+ res = await self.db.query(SQL, multirows=True)
195
  if res:
196
+ data = res # [{"data":i} for i in res]
197
+ # print(data)
198
  return data
199
  else:
200
  return None
201
+
202
  async def filter_keys(self, keys: list[str]) -> set[str]:
203
  """过滤掉重复内容"""
204
+ SQL = SQL_TEMPLATES["filter_keys"].format(
205
+ table_name=N_T[self.namespace],
206
+ workspace=self.db.workspace,
207
+ ids=",".join([f"'{k}'" for k in keys]),
208
+ )
209
+ res = await self.db.query(SQL, multirows=True)
210
  data = None
211
  if res:
212
  exist_keys = [key["id"] for key in res]
 
215
  exist_keys = []
216
  data = set([s for s in keys if s not in exist_keys])
217
  return data
218
+
 
219
  ################ INSERT METHODS ################
220
  async def upsert(self, data: dict[str, dict]):
221
  left_data = {k: v for k, v in data.items() if k not in self._data}
222
  self._data.update(left_data)
223
+ # print(self._data)
224
+ # values = []
225
  if self.namespace == "text_chunks":
226
  list_data = [
227
  {
 
232
  ]
233
  contents = [v["content"] for v in data.values()]
234
  batches = [
235
+ contents[i : i + self._max_batch_size]
236
  for i in range(0, len(contents), self._max_batch_size)
237
  ]
238
  embeddings_list = await asyncio.gather(
 
241
  embeddings = np.concatenate(embeddings_list)
242
  for i, d in enumerate(list_data):
243
  d["__vector__"] = embeddings[i]
244
+ # print(list_data)
245
  for item in list_data:
246
+ merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
247
+
248
+ values = [
249
+ item["__id__"],
250
+ item["content"],
251
+ self.db.workspace,
252
+ item["tokens"],
253
+ item["chunk_order_index"],
254
+ item["full_doc_id"],
255
+ item["__vector__"],
256
+ ]
257
+ # print(merge_sql)
258
  await self.db.execute(merge_sql, values)
259
 
260
  if self.namespace == "full_docs":
261
  for k, v in self._data.items():
262
+ # values.clear()
263
  merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
264
  check_id=k,
265
  )
266
  values = [k, self._data[k]["content"], self.db.workspace]
267
+ # print(merge_sql)
268
  await self.db.execute(merge_sql, values)
269
  return left_data
270
 
 
271
  async def index_done_callback(self):
272
  if self.namespace in ["full_docs", "text_chunks"]:
273
  logger.info("full doc and chunk data had been saved into oracle db!")
274
 
275
 
 
276
  @dataclass
277
  class OracleVectorDBStorage(BaseVectorStorage):
278
  cosine_better_than_threshold: float = 0.2
279
 
280
  def __post_init__(self):
281
  pass
282
+
283
  async def upsert(self, data: dict[str, dict]):
284
  """向向量数据库中插入数据"""
285
  pass
 
287
  async def index_done_callback(self):
288
  pass
289
 
 
290
  #################### query method ###############
291
  async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
292
+ """从向量数据库中查询数据"""
293
  embeddings = await self.embedding_func([query])
294
  embedding = embeddings[0]
295
  # 转换精度
296
  dtype = str(embedding.dtype).upper()
297
  dimension = embedding.shape[0]
298
+ embedding_string = ", ".join(map(str, embedding.tolist()))
299
 
300
  SQL = SQL_TEMPLATES[self.namespace].format(
301
+ embedding_string=embedding_string,
302
+ dimension=dimension,
303
+ dtype=dtype,
304
+ workspace=self.db.workspace,
305
+ top_k=top_k,
306
+ better_than_threshold=self.cosine_better_than_threshold,
307
+ )
308
  # print(SQL)
309
  results = await self.db.query(SQL, multirows=True)
310
+ # print("vector search result:",results)
311
  return results
312
 
313
 
314
  @dataclass
315
+ class OracleGraphStorage(BaseGraphStorage):
316
  """基于Oracle的图存储模块"""
317
+
318
  def __post_init__(self):
319
  """从graphml文件加载图"""
320
  self._max_batch_size = self.global_config["embedding_batch_num"]
321
 
 
322
  #################### insert method ################
323
+
324
  async def upsert_node(self, node_id: str, node_data: dict[str, str]):
325
  """插入或更新节点"""
326
+ # print("go into upsert node method")
327
  entity_name = node_id
328
  entity_type = node_data["entity_type"]
329
  description = node_data["description"]
330
+ source_id = node_data["source_id"]
331
+ content = entity_name + description
332
  contents = [content]
333
  batches = [
334
+ contents[i : i + self._max_batch_size]
335
  for i in range(0, len(contents), self._max_batch_size)
336
  ]
337
  embeddings_list = await asyncio.gather(
 
340
  embeddings = np.concatenate(embeddings_list)
341
  content_vector = embeddings[0]
342
  merge_sql = SQL_TEMPLATES["merge_node"].format(
343
+ workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
344
+ )
345
+ # print(merge_sql)
346
+ await self.db.execute(
347
+ merge_sql,
348
+ [
349
+ self.db.workspace,
350
+ entity_name,
351
+ entity_type,
352
+ description,
353
+ source_id,
354
+ content,
355
+ content_vector,
356
+ ],
357
  )
358
+ # self._graph.add_node(node_id, **node_data)
 
 
359
 
360
  async def upsert_edge(
361
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
362
  ):
363
  """插入或更新边"""
364
+ # print("go into upsert edge method")
365
  source_name = source_node_id
366
  target_name = target_node_id
367
  weight = edge_data["weight"]
368
  keywords = edge_data["keywords"]
369
  description = edge_data["description"]
370
  source_chunk_id = edge_data["source_id"]
371
+ content = keywords + source_name + target_name + description
372
  contents = [content]
373
  batches = [
374
+ contents[i : i + self._max_batch_size]
375
  for i in range(0, len(contents), self._max_batch_size)
376
  ]
377
  embeddings_list = await asyncio.gather(
 
380
  embeddings = np.concatenate(embeddings_list)
381
  content_vector = embeddings[0]
382
  merge_sql = SQL_TEMPLATES["merge_edge"].format(
383
+ workspace=self.db.workspace,
384
+ source_name=source_name,
385
+ target_name=target_name,
386
+ source_chunk_id=source_chunk_id,
387
  )
388
+ # print(merge_sql)
389
+ await self.db.execute(
390
+ merge_sql,
391
+ [
392
+ self.db.workspace,
393
+ source_name,
394
+ target_name,
395
+ weight,
396
+ keywords,
397
+ description,
398
+ source_chunk_id,
399
+ content,
400
+ content_vector,
401
+ ],
402
+ )
403
+ # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
404
 
405
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
406
  """为节点生成向量"""
 
420
  nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
421
  return embeddings, nodes_ids
422
 
 
423
  async def index_done_callback(self):
424
  """写入graphhml图文件"""
425
+ logger.info(
426
+ "Node and edge data had been saved into oracle db already, so nothing to do here!"
427
+ )
428
+
429
  #################### query method #################
430
  async def has_node(self, node_id: str) -> bool:
431
+ """根据节点id检查节点是否存在"""
432
+ SQL = SQL_TEMPLATES["has_node"].format(
433
+ workspace=self.db.workspace, node_id=node_id
434
+ )
435
+ # print(SQL)
436
+ # print(self.db.workspace, node_id)
437
  res = await self.db.query(SQL)
438
  if res:
439
+ # print("Node exist!",res)
440
  return True
441
  else:
442
+ # print("Node not exist!")
443
  return False
444
 
445
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
446
  """根据源和目标节点id检查边是否存在"""
447
+ SQL = SQL_TEMPLATES["has_edge"].format(
448
+ workspace=self.db.workspace,
449
+ source_node_id=source_node_id,
450
+ target_node_id=target_node_id,
451
+ )
452
  # print(SQL)
453
  res = await self.db.query(SQL)
454
  if res:
455
+ # print("Edge exist!",res)
456
  return True
457
  else:
458
+ # print("Edge not exist!")
459
  return False
460
 
461
  async def node_degree(self, node_id: str) -> int:
462
+ """根据节点id获取节点的度"""
463
+ SQL = SQL_TEMPLATES["node_degree"].format(
464
+ workspace=self.db.workspace, node_id=node_id
465
+ )
466
  # print(SQL)
467
  res = await self.db.query(SQL)
468
  if res:
469
+ # print("Node degree",res["degree"])
470
  return res["degree"]
471
  else:
472
+ # print("Edge not exist!")
473
  return 0
474
 
 
475
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
476
  """根据源和目标节点id获取边的度"""
477
  degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
478
+ # print("Edge degree",degree)
479
  return degree
480
 
 
481
  async def get_node(self, node_id: str) -> Union[dict, None]:
482
  """根据节点id获取节点数据"""
483
+ SQL = SQL_TEMPLATES["get_node"].format(
484
+ workspace=self.db.workspace, node_id=node_id
485
+ )
486
  # print(self.db.workspace, node_id)
487
  # print(SQL)
488
  res = await self.db.query(SQL)
489
  if res:
490
+ # print("Get node!",self.db.workspace, node_id,res)
491
  return res
492
  else:
493
+ # print("Can't get node!",self.db.workspace, node_id)
494
  return None
495
+
496
  async def get_edge(
497
  self, source_node_id: str, target_node_id: str
498
  ) -> Union[dict, None]:
499
  """根据源和目标节点id获取边"""
500
+ SQL = SQL_TEMPLATES["get_edge"].format(
501
+ workspace=self.db.workspace,
502
+ source_node_id=source_node_id,
503
+ target_node_id=target_node_id,
504
+ )
505
  res = await self.db.query(SQL)
506
  if res:
507
+ # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
508
  return res
509
  else:
510
+ # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
511
  return None
512
 
513
  async def get_node_edges(self, source_node_id: str):
514
  """根据节点id获取节点的所有边"""
515
  if await self.has_node(source_node_id):
516
+ SQL = SQL_TEMPLATES["get_node_edges"].format(
517
+ workspace=self.db.workspace, source_node_id=source_node_id
518
+ )
519
  res = await self.db.query(sql=SQL, multirows=True)
520
  if res:
521
+ data = [(i["source_name"], i["target_name"]) for i in res]
522
+ # print("Get node edge!",self.db.workspace, source_node_id,data)
523
  return data
524
  else:
525
+ # print("Node Edge not exist!",self.db.workspace, source_node_id)
526
  return []
527
 
528
 
 
531
  "text_chunks": "LIGHTRAG_DOC_CHUNKS",
532
  "chunks": "LIGHTRAG_DOC_CHUNKS",
533
  "entities": "LIGHTRAG_GRAPH_NODES",
534
+ "relationships": "LIGHTRAG_GRAPH_EDGES",
535
  }
536
 
537
  TABLES = {
538
+ "LIGHTRAG_DOC_FULL": {
539
+ "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
540
  id varchar(256)PRIMARY KEY,
541
  workspace varchar(1024),
542
  doc_name varchar(1024),
 
544
  meta JSON,
545
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
546
  updatetime TIMESTAMP DEFAULT NULL
547
+ )"""
548
+ },
549
+ "LIGHTRAG_DOC_CHUNKS": {
550
+ "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
551
  id varchar(256) PRIMARY KEY,
552
  workspace varchar(1024),
553
  full_doc_id varchar(256),
554
  chunk_order_index NUMBER,
555
+ tokens NUMBER,
556
  content CLOB,
557
  content_vector VECTOR,
558
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
559
+ updatetime TIMESTAMP DEFAULT NULL
560
+ )"""
561
+ },
562
+ "LIGHTRAG_GRAPH_NODES": {
563
+ "ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES (
564
  id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
565
  workspace varchar(1024),
566
  name varchar(2048),
567
+ entity_type varchar(1024),
568
  description CLOB,
569
  source_chunk_id varchar(256),
570
  content CLOB,
571
  content_vector VECTOR,
572
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
573
  updatetime TIMESTAMP DEFAULT NULL
574
+ )"""
575
+ },
576
+ "LIGHTRAG_GRAPH_EDGES": {
577
+ "ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES (
578
  id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
579
  workspace varchar(1024),
580
  source_name varchar(2048),
581
+ target_name varchar(2048),
582
  weight NUMBER,
583
+ keywords CLOB,
584
  description CLOB,
585
  source_chunk_id varchar(256),
586
  content CLOB,
587
  content_vector VECTOR,
588
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
589
  updatetime TIMESTAMP DEFAULT NULL
590
+ )"""
591
+ },
592
+ "LIGHTRAG_LLM_CACHE": {
593
+ "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
594
  id varchar(256) PRIMARY KEY,
595
  send clob,
596
  return clob,
597
  model varchar(1024),
598
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
599
  updatetime TIMESTAMP DEFAULT NULL
600
+ )"""
601
+ },
602
+ "LIGHTRAG_GRAPH": {
603
+ "ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
604
  VERTEX TABLES (
605
  lightrag_graph_nodes KEY (id)
606
  LABEL entity
 
611
  SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
612
  DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
613
  LABEL has_relation
614
+ PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
615
+ ) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""
616
+ },
617
+ }
618
 
619
 
620
  SQL_TEMPLATES = {
621
  # SQL for KVStorage
622
+ "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
623
+ "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
624
+ "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
625
+ "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
626
+ "filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
 
 
 
 
 
 
 
 
 
 
627
  "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
628
  USING DUAL
629
  ON (a.id = '{check_id}')
630
  WHEN NOT MATCHED THEN
631
  INSERT(id,content,workspace) values(:1,:2,:3)
632
  """,
 
633
  "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
634
  USING DUAL
635
  ON (a.id = '{check_id}')
636
  WHEN NOT MATCHED THEN
637
  INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
638
  values (:1,:2,:3,:4,:5,:6,:7) """,
 
639
  # SQL for VectorStorage
640
+ "entities": """SELECT name as entity_name FROM
641
+ (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
642
+ FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
 
643
  WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
644
+ "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
645
+ (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
646
+ FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
 
 
647
  WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
648
+ "chunks": """SELECT id FROM
649
+ (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
650
+ FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
 
 
651
  WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
 
652
  # SQL for GraphStorage
653
+ "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
 
654
  MATCH (a)
655
  WHERE a.workspace='{workspace}' AND a.name='{node_id}'
656
  COLUMNS (a.name))""",
657
+ "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
 
 
658
  MATCH (a) -[e]-> (b)
659
  WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
660
  AND a.name='{source_node_id}' AND b.name='{target_node_id}'
661
  COLUMNS (e.source_name,e.target_name) )""",
662
+ "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
 
 
663
  MATCH (a)-[e]->(b)
664
  WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
665
  AND a.name='{node_id}' or b.name = '{node_id}'
666
  COLUMNS (a.name))""",
667
+ "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
 
 
668
  FROM GRAPH_TABLE (lightrag_graph
669
+ MATCH (a)
670
  WHERE a.workspace='{workspace}' AND a.name='{node_id}'
671
  COLUMNS (a.name)
672
  ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
673
  WHERE t2.workspace='{workspace}'""",
674
+ "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
 
 
675
  NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
676
  FROM GRAPH_TABLE (lightrag_graph
677
  MATCH (a)-[e]->(b)
 
679
  AND a.name='{source_node_id}' and b.name = '{target_node_id}'
680
  COLUMNS (e.id,a.name as source_id)
681
  ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
682
+ "get_node_edges": """SELECT source_name,target_name
 
 
683
  FROM GRAPH_TABLE (lightrag_graph
684
  MATCH (a)-[e]->(b)
685
  WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
686
  AND a.name='{source_node_id}'
687
  COLUMNS (a.name as source_name,b.name as target_name))""",
 
688
  "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
689
  USING DUAL
690
  ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
 
696
  ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
697
  WHEN NOT MATCHED THEN
698
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
699
+ values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
700
+ }
lightrag/lightrag.py CHANGED
@@ -38,15 +38,11 @@ from .storage import (
38
  JsonKVStorage,
39
  NanoVectorDBStorage,
40
  NetworkXStorage,
41
- )
42
 
43
  from .kg.neo4j_impl import Neo4JStorage
44
 
45
- from .kg.oracle_impl import (
46
- OracleKVStorage,
47
- OracleGraphStorage,
48
- OracleVectorDBStorage
49
- )
50
 
51
  # future KG integrations
52
 
@@ -54,6 +50,7 @@ from .kg.oracle_impl import (
54
  # GraphStorage as ArangoDBStorage
55
  # )
56
 
 
57
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
58
  try:
59
  return asyncio.get_event_loop()
@@ -72,7 +69,7 @@ class LightRAG:
72
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
73
  )
74
 
75
- kv_storage : str = field(default="JsonKVStorage")
76
  vector_storage: str = field(default="NanoVectorDBStorage")
77
  graph_storage: str = field(default="NetworkXStorage")
78
 
@@ -115,7 +112,7 @@ class LightRAG:
115
 
116
  # storage
117
  vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
118
-
119
  enable_llm_cache: bool = True
120
 
121
  # extension
@@ -134,18 +131,25 @@ class LightRAG:
134
 
135
  # @TODO: should move all storage setup here to leverage initial start params attached to self.
136
 
137
- self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class()[self.kv_storage]
138
- self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[self.vector_storage]
139
- self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.graph_storage]
 
 
 
 
 
 
140
 
141
  if not os.path.exists(self.working_dir):
142
  logger.info(f"Creating working directory {self.working_dir}")
143
  os.makedirs(self.working_dir)
144
 
145
-
146
  self.llm_response_cache = (
147
  self.key_string_value_json_storage_cls(
148
- namespace="llm_response_cache", global_config=asdict(self),embedding_func=None
 
 
149
  )
150
  if self.enable_llm_cache
151
  else None
@@ -159,13 +163,19 @@ class LightRAG:
159
  # add embedding func by walter
160
  ####
161
  self.full_docs = self.key_string_value_json_storage_cls(
162
- namespace="full_docs", global_config=asdict(self), embedding_func=self.embedding_func
 
 
163
  )
164
  self.text_chunks = self.key_string_value_json_storage_cls(
165
- namespace="text_chunks", global_config=asdict(self), embedding_func=self.embedding_func
 
 
166
  )
167
  self.chunk_entity_relation_graph = self.graph_storage_cls(
168
- namespace="chunk_entity_relation", global_config=asdict(self), embedding_func=self.embedding_func
 
 
169
  )
170
  ####
171
  # add embedding func by walter over
@@ -200,13 +210,11 @@ class LightRAG:
200
  def _get_storage_class(self) -> Type[BaseGraphStorage]:
201
  return {
202
  # kv storage
203
- "JsonKVStorage":JsonKVStorage,
204
- "OracleKVStorage":OracleKVStorage,
205
-
206
  # vector storage
207
- "NanoVectorDBStorage":NanoVectorDBStorage,
208
- "OracleVectorDBStorage":OracleVectorDBStorage,
209
-
210
  # graph storage
211
  "NetworkXStorage": NetworkXStorage,
212
  "Neo4JStorage": Neo4JStorage,
 
38
  JsonKVStorage,
39
  NanoVectorDBStorage,
40
  NetworkXStorage,
41
+ )
42
 
43
  from .kg.neo4j_impl import Neo4JStorage
44
 
45
+ from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
 
 
 
 
46
 
47
  # future KG integrations
48
 
 
50
  # GraphStorage as ArangoDBStorage
51
  # )
52
 
53
+
54
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
55
  try:
56
  return asyncio.get_event_loop()
 
69
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
70
  )
71
 
72
+ kv_storage: str = field(default="JsonKVStorage")
73
  vector_storage: str = field(default="NanoVectorDBStorage")
74
  graph_storage: str = field(default="NetworkXStorage")
75
 
 
112
 
113
  # storage
114
  vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
115
+
116
  enable_llm_cache: bool = True
117
 
118
  # extension
 
131
 
132
  # @TODO: should move all storage setup here to leverage initial start params attached to self.
133
 
134
+ self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
135
+ self._get_storage_class()[self.kv_storage]
136
+ )
137
+ self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
138
+ self.vector_storage
139
+ ]
140
+ self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
141
+ self.graph_storage
142
+ ]
143
 
144
  if not os.path.exists(self.working_dir):
145
  logger.info(f"Creating working directory {self.working_dir}")
146
  os.makedirs(self.working_dir)
147
 
 
148
  self.llm_response_cache = (
149
  self.key_string_value_json_storage_cls(
150
+ namespace="llm_response_cache",
151
+ global_config=asdict(self),
152
+ embedding_func=None,
153
  )
154
  if self.enable_llm_cache
155
  else None
 
163
  # add embedding func by walter
164
  ####
165
  self.full_docs = self.key_string_value_json_storage_cls(
166
+ namespace="full_docs",
167
+ global_config=asdict(self),
168
+ embedding_func=self.embedding_func,
169
  )
170
  self.text_chunks = self.key_string_value_json_storage_cls(
171
+ namespace="text_chunks",
172
+ global_config=asdict(self),
173
+ embedding_func=self.embedding_func,
174
  )
175
  self.chunk_entity_relation_graph = self.graph_storage_cls(
176
+ namespace="chunk_entity_relation",
177
+ global_config=asdict(self),
178
+ embedding_func=self.embedding_func,
179
  )
180
  ####
181
  # add embedding func by walter over
 
210
  def _get_storage_class(self) -> Type[BaseGraphStorage]:
211
  return {
212
  # kv storage
213
+ "JsonKVStorage": JsonKVStorage,
214
+ "OracleKVStorage": OracleKVStorage,
 
215
  # vector storage
216
+ "NanoVectorDBStorage": NanoVectorDBStorage,
217
+ "OracleVectorDBStorage": OracleVectorDBStorage,
 
218
  # graph storage
219
  "NetworkXStorage": NetworkXStorage,
220
  "Neo4JStorage": Neo4JStorage,
lightrag/operate.py CHANGED
@@ -16,7 +16,7 @@ from .utils import (
16
  split_string_by_multi_markers,
17
  truncate_list_by_token_size,
18
  process_combine_contexts,
19
- locate_json_string_body_from_string
20
  )
21
  from .base import (
22
  BaseGraphStorage,
 
16
  split_string_by_multi_markers,
17
  truncate_list_by_token_size,
18
  process_combine_contexts,
19
+ locate_json_string_body_from_string,
20
  )
21
  from .base import (
22
  BaseGraphStorage,
requirements.txt CHANGED
@@ -1,22 +1,22 @@
1
  accelerate
 
2
  aiohttp
 
 
 
 
 
 
 
 
 
 
3
  pyvis
4
  tenacity
5
- xxhash
6
  # lmdeploy[all]
7
 
8
  # LLM packages
9
  tiktoken
10
  torch
11
  transformers
12
- aioboto3
13
- ollama
14
- openai
15
-
16
- # database packages
17
- graspologic
18
- hnswlib
19
- networkx
20
- oracledb
21
- nano-vectordb
22
- neo4j
 
1
  accelerate
2
+ aioboto3
3
  aiohttp
4
+
5
+ # database packages
6
+ graspologic
7
+ hnswlib
8
+ nano-vectordb
9
+ neo4j
10
+ networkx
11
+ ollama
12
+ openai
13
+ oracledb
14
  pyvis
15
  tenacity
 
16
  # lmdeploy[all]
17
 
18
  # LLM packages
19
  tiktoken
20
  torch
21
  transformers
22
+ xxhash