jin
commited on
Commit
·
5dcb28f
1
Parent(s):
4d468e5
fix pre commit
Browse files- examples/lightrag_api_oracle_demo..py +41 -38
- examples/lightrag_oracle_demo.py +27 -22
- lightrag/base.py +2 -0
- lightrag/kg/oracle_impl.py +233 -216
- lightrag/lightrag.py +30 -22
- lightrag/operate.py +1 -1
- requirements.txt +12 -12
examples/lightrag_api_oracle_demo..py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
|
2 |
from fastapi import FastAPI, HTTPException, File, UploadFile
|
3 |
from contextlib import asynccontextmanager
|
4 |
from pydantic import BaseModel
|
5 |
from typing import Optional
|
6 |
|
7 |
-
import sys
|
|
|
8 |
from pathlib import Path
|
9 |
|
10 |
import asyncio
|
@@ -13,7 +13,6 @@ from lightrag import LightRAG, QueryParam
|
|
13 |
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
14 |
from lightrag.utils import EmbeddingFunc
|
15 |
import numpy as np
|
16 |
-
from datetime import datetime
|
17 |
|
18 |
from lightrag.kg.oracle_impl import OracleDB
|
19 |
|
@@ -24,8 +23,6 @@ script_directory = Path(__file__).resolve().parent.parent
|
|
24 |
sys.path.append(os.path.abspath(script_directory))
|
25 |
|
26 |
|
27 |
-
|
28 |
-
|
29 |
# Apply nest_asyncio to solve event loop issues
|
30 |
nest_asyncio.apply()
|
31 |
|
@@ -51,6 +48,7 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
|
51 |
if not os.path.exists(WORKING_DIR):
|
52 |
os.mkdir(WORKING_DIR)
|
53 |
|
|
|
54 |
async def llm_model_func(
|
55 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
56 |
) -> str:
|
@@ -80,8 +78,8 @@ async def get_embedding_dim():
|
|
80 |
embedding_dim = embedding.shape[1]
|
81 |
return embedding_dim
|
82 |
|
|
|
83 |
async def init():
|
84 |
-
|
85 |
# Detect embedding dimension
|
86 |
embedding_dimension = await get_embedding_dim()
|
87 |
print(f"Detected embedding dimension: {embedding_dimension}")
|
@@ -91,36 +89,36 @@ async def init():
|
|
91 |
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
92 |
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
oracle_db = OracleDB(config={
|
96 |
-
"user":"",
|
97 |
-
"password":"",
|
98 |
-
"dsn":"",
|
99 |
-
"config_dir":"",
|
100 |
-
"wallet_location":"",
|
101 |
-
"wallet_password":"",
|
102 |
-
"workspace":""
|
103 |
-
} # specify which docs you want to store and query
|
104 |
-
)
|
105 |
-
|
106 |
# Check if Oracle DB tables exist, if not, tables will be created
|
107 |
await oracle_db.check_tables()
|
108 |
# Initialize LightRAG
|
109 |
-
|
110 |
rag = LightRAG(
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
|
125 |
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
126 |
rag.graph_storage_cls.db = oracle_db
|
@@ -129,6 +127,7 @@ async def init():
|
|
129 |
|
130 |
return rag
|
131 |
|
|
|
132 |
# Data models
|
133 |
|
134 |
|
@@ -152,6 +151,7 @@ class Response(BaseModel):
|
|
152 |
|
153 |
rag = None # 定义为全局对象
|
154 |
|
|
|
155 |
@asynccontextmanager
|
156 |
async def lifespan(app: FastAPI):
|
157 |
global rag
|
@@ -160,18 +160,21 @@ async def lifespan(app: FastAPI):
|
|
160 |
yield
|
161 |
|
162 |
|
163 |
-
app = FastAPI(
|
|
|
|
|
|
|
164 |
|
165 |
@app.post("/query", response_model=Response)
|
166 |
async def query_endpoint(request: QueryRequest):
|
167 |
try:
|
168 |
# loop = asyncio.get_event_loop()
|
169 |
result = await rag.aquery(
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
return Response(status="success", data=result)
|
176 |
except Exception as e:
|
177 |
raise HTTPException(status_code=500, detail=str(e))
|
@@ -234,4 +237,4 @@ if __name__ == "__main__":
|
|
234 |
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
|
235 |
|
236 |
# 4. Health check:
|
237 |
-
# curl -X GET "http://127.0.0.1:8020/health"
|
|
|
|
|
1 |
from fastapi import FastAPI, HTTPException, File, UploadFile
|
2 |
from contextlib import asynccontextmanager
|
3 |
from pydantic import BaseModel
|
4 |
from typing import Optional
|
5 |
|
6 |
+
import sys
|
7 |
+
import os
|
8 |
from pathlib import Path
|
9 |
|
10 |
import asyncio
|
|
|
13 |
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
14 |
from lightrag.utils import EmbeddingFunc
|
15 |
import numpy as np
|
|
|
16 |
|
17 |
from lightrag.kg.oracle_impl import OracleDB
|
18 |
|
|
|
23 |
sys.path.append(os.path.abspath(script_directory))
|
24 |
|
25 |
|
|
|
|
|
26 |
# Apply nest_asyncio to solve event loop issues
|
27 |
nest_asyncio.apply()
|
28 |
|
|
|
48 |
if not os.path.exists(WORKING_DIR):
|
49 |
os.mkdir(WORKING_DIR)
|
50 |
|
51 |
+
|
52 |
async def llm_model_func(
|
53 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
54 |
) -> str:
|
|
|
78 |
embedding_dim = embedding.shape[1]
|
79 |
return embedding_dim
|
80 |
|
81 |
+
|
82 |
async def init():
|
|
|
83 |
# Detect embedding dimension
|
84 |
embedding_dimension = await get_embedding_dim()
|
85 |
print(f"Detected embedding dimension: {embedding_dimension}")
|
|
|
89 |
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
90 |
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
91 |
|
92 |
+
oracle_db = OracleDB(
|
93 |
+
config={
|
94 |
+
"user": "",
|
95 |
+
"password": "",
|
96 |
+
"dsn": "",
|
97 |
+
"config_dir": "",
|
98 |
+
"wallet_location": "",
|
99 |
+
"wallet_password": "",
|
100 |
+
"workspace": "",
|
101 |
+
} # specify which docs you want to store and query
|
102 |
+
)
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
# Check if Oracle DB tables exist, if not, tables will be created
|
105 |
await oracle_db.check_tables()
|
106 |
# Initialize LightRAG
|
107 |
+
# We use Oracle DB as the KV/vector/graph storage
|
108 |
rag = LightRAG(
|
109 |
+
enable_llm_cache=False,
|
110 |
+
working_dir=WORKING_DIR,
|
111 |
+
chunk_token_size=512,
|
112 |
+
llm_model_func=llm_model_func,
|
113 |
+
embedding_func=EmbeddingFunc(
|
114 |
+
embedding_dim=embedding_dimension,
|
115 |
+
max_token_size=512,
|
116 |
+
func=embedding_func,
|
117 |
+
),
|
118 |
+
graph_storage="OracleGraphStorage",
|
119 |
+
kv_storage="OracleKVStorage",
|
120 |
+
vector_storage="OracleVectorDBStorage",
|
121 |
+
)
|
122 |
|
123 |
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
124 |
rag.graph_storage_cls.db = oracle_db
|
|
|
127 |
|
128 |
return rag
|
129 |
|
130 |
+
|
131 |
# Data models
|
132 |
|
133 |
|
|
|
151 |
|
152 |
rag = None # 定义为全局对象
|
153 |
|
154 |
+
|
155 |
@asynccontextmanager
|
156 |
async def lifespan(app: FastAPI):
|
157 |
global rag
|
|
|
160 |
yield
|
161 |
|
162 |
|
163 |
+
app = FastAPI(
|
164 |
+
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
165 |
+
)
|
166 |
+
|
167 |
|
168 |
@app.post("/query", response_model=Response)
|
169 |
async def query_endpoint(request: QueryRequest):
|
170 |
try:
|
171 |
# loop = asyncio.get_event_loop()
|
172 |
result = await rag.aquery(
|
173 |
+
request.query,
|
174 |
+
param=QueryParam(
|
175 |
+
mode=request.mode, only_need_context=request.only_need_context
|
176 |
+
),
|
177 |
+
)
|
178 |
return Response(status="success", data=result)
|
179 |
except Exception as e:
|
180 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
237 |
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
|
238 |
|
239 |
# 4. Health check:
|
240 |
+
# curl -X GET "http://127.0.0.1:8020/health"
|
examples/lightrag_oracle_demo.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
-
import sys
|
|
|
2 |
from pathlib import Path
|
3 |
import asyncio
|
4 |
from lightrag import LightRAG, QueryParam
|
5 |
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
6 |
from lightrag.utils import EmbeddingFunc
|
7 |
import numpy as np
|
8 |
-
from datetime import datetime
|
9 |
from lightrag.kg.oracle_impl import OracleDB
|
10 |
|
11 |
print(os.getcwd())
|
@@ -25,6 +25,7 @@ EMBEDMODEL = "cohere.embed-multilingual-v3.0"
|
|
25 |
if not os.path.exists(WORKING_DIR):
|
26 |
os.mkdir(WORKING_DIR)
|
27 |
|
|
|
28 |
async def llm_model_func(
|
29 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
30 |
) -> str:
|
@@ -66,22 +67,21 @@ async def main():
|
|
66 |
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
|
67 |
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
68 |
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
69 |
-
oracle_db = OracleDB(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
77 |
}
|
78 |
-
|
79 |
-
|
80 |
|
81 |
# Check if Oracle DB tables exist, if not, tables will be created
|
82 |
await oracle_db.check_tables()
|
83 |
|
84 |
-
|
85 |
# Initialize LightRAG
|
86 |
# We use Oracle DB as the KV/vector/graph storage
|
87 |
rag = LightRAG(
|
@@ -93,10 +93,10 @@ async def main():
|
|
93 |
embedding_dim=embedding_dimension,
|
94 |
max_token_size=512,
|
95 |
func=embedding_func,
|
96 |
-
|
97 |
-
graph_storage
|
98 |
-
kv_storage="OracleKVStorage",
|
99 |
-
vector_storage="OracleVectorDBStorage"
|
100 |
)
|
101 |
|
102 |
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
@@ -106,18 +106,23 @@ async def main():
|
|
106 |
|
107 |
# Extract and Insert into LightRAG storage
|
108 |
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
|
109 |
-
|
110 |
|
111 |
# Perform search in different modes
|
112 |
modes = ["naive", "local", "global", "hybrid"]
|
113 |
for mode in modes:
|
114 |
-
print("="*20, mode, "="*20)
|
115 |
-
print(
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
except Exception as e:
|
119 |
print(f"An error occurred: {e}")
|
120 |
|
121 |
|
122 |
if __name__ == "__main__":
|
123 |
-
asyncio.run(main())
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
from pathlib import Path
|
4 |
import asyncio
|
5 |
from lightrag import LightRAG, QueryParam
|
6 |
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
7 |
from lightrag.utils import EmbeddingFunc
|
8 |
import numpy as np
|
|
|
9 |
from lightrag.kg.oracle_impl import OracleDB
|
10 |
|
11 |
print(os.getcwd())
|
|
|
25 |
if not os.path.exists(WORKING_DIR):
|
26 |
os.mkdir(WORKING_DIR)
|
27 |
|
28 |
+
|
29 |
async def llm_model_func(
|
30 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
31 |
) -> str:
|
|
|
67 |
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
|
68 |
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
69 |
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
70 |
+
oracle_db = OracleDB(
|
71 |
+
config={
|
72 |
+
"user": "username",
|
73 |
+
"password": "xxxxxxxxx",
|
74 |
+
"dsn": "xxxxxxx_medium",
|
75 |
+
"config_dir": "dir/path/to/oracle/config",
|
76 |
+
"wallet_location": "dir/path/to/oracle/wallet",
|
77 |
+
"wallet_password": "xxxxxxxxx",
|
78 |
+
"workspace": "company", # specify which docs you want to store and query
|
79 |
}
|
80 |
+
)
|
|
|
81 |
|
82 |
# Check if Oracle DB tables exist, if not, tables will be created
|
83 |
await oracle_db.check_tables()
|
84 |
|
|
|
85 |
# Initialize LightRAG
|
86 |
# We use Oracle DB as the KV/vector/graph storage
|
87 |
rag = LightRAG(
|
|
|
93 |
embedding_dim=embedding_dimension,
|
94 |
max_token_size=512,
|
95 |
func=embedding_func,
|
96 |
+
),
|
97 |
+
graph_storage="OracleGraphStorage",
|
98 |
+
kv_storage="OracleKVStorage",
|
99 |
+
vector_storage="OracleVectorDBStorage",
|
100 |
)
|
101 |
|
102 |
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
|
|
106 |
|
107 |
# Extract and Insert into LightRAG storage
|
108 |
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
|
109 |
+
await rag.ainsert(f.read())
|
110 |
|
111 |
# Perform search in different modes
|
112 |
modes = ["naive", "local", "global", "hybrid"]
|
113 |
for mode in modes:
|
114 |
+
print("=" * 20, mode, "=" * 20)
|
115 |
+
print(
|
116 |
+
await rag.aquery(
|
117 |
+
"What are the top themes in this story?",
|
118 |
+
param=QueryParam(mode=mode),
|
119 |
+
)
|
120 |
+
)
|
121 |
+
print("-" * 100, "\n")
|
122 |
|
123 |
except Exception as e:
|
124 |
print(f"An error occurred: {e}")
|
125 |
|
126 |
|
127 |
if __name__ == "__main__":
|
128 |
+
asyncio.run(main())
|
lightrag/base.py
CHANGED
@@ -60,6 +60,7 @@ class BaseVectorStorage(StorageNameSpace):
|
|
60 |
@dataclass
|
61 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
62 |
embedding_func: EmbeddingFunc
|
|
|
63 |
async def all_keys(self) -> list[str]:
|
64 |
raise NotImplementedError
|
65 |
|
@@ -85,6 +86,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
|
|
85 |
@dataclass
|
86 |
class BaseGraphStorage(StorageNameSpace):
|
87 |
embedding_func: EmbeddingFunc = None
|
|
|
88 |
async def has_node(self, node_id: str) -> bool:
|
89 |
raise NotImplementedError
|
90 |
|
|
|
60 |
@dataclass
|
61 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
62 |
embedding_func: EmbeddingFunc
|
63 |
+
|
64 |
async def all_keys(self) -> list[str]:
|
65 |
raise NotImplementedError
|
66 |
|
|
|
86 |
@dataclass
|
87 |
class BaseGraphStorage(StorageNameSpace):
|
88 |
embedding_func: EmbeddingFunc = None
|
89 |
+
|
90 |
async def has_node(self, node_id: str) -> bool:
|
91 |
raise NotImplementedError
|
92 |
|
lightrag/kg/oracle_impl.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import asyncio
|
2 |
-
|
3 |
-
#import
|
|
|
4 |
from dataclasses import dataclass
|
5 |
-
from typing import
|
6 |
-
import networkx as nx
|
7 |
import numpy as np
|
8 |
import array
|
9 |
|
@@ -16,8 +16,9 @@ from ..base import (
|
|
16 |
|
17 |
import oracledb
|
18 |
|
|
|
19 |
class OracleDB:
|
20 |
-
def __init__(self,config
|
21 |
self.host = config.get("host", None)
|
22 |
self.port = config.get("port", None)
|
23 |
self.user = config.get("user", None)
|
@@ -32,21 +33,21 @@ class OracleDB:
|
|
32 |
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
33 |
if self.user is None or self.password is None:
|
34 |
raise ValueError("Missing database user or password in addon_params")
|
35 |
-
|
36 |
try:
|
37 |
oracledb.defaults.fetch_lobs = False
|
38 |
|
39 |
self.pool = oracledb.create_pool_async(
|
40 |
-
user
|
41 |
-
password
|
42 |
-
dsn
|
43 |
-
config_dir
|
44 |
-
wallet_location
|
45 |
-
wallet_password
|
46 |
-
min
|
47 |
-
max
|
48 |
-
increment
|
49 |
-
|
50 |
logger.info(f"Connected to Oracle database at {self.dsn}")
|
51 |
except Exception as e:
|
52 |
logger.error(f"Failed to connect to Oracle database at {self.dsn}")
|
@@ -90,12 +91,14 @@ class OracleDB:
|
|
90 |
arraysize=cursor.arraysize,
|
91 |
outconverter=self.numpy_converter_out,
|
92 |
)
|
93 |
-
|
94 |
async def check_tables(self):
|
95 |
-
for k,v in TABLES.items():
|
96 |
try:
|
97 |
if k.lower() == "lightrag_graph":
|
98 |
-
await self.query(
|
|
|
|
|
99 |
else:
|
100 |
await self.query("SELECT 1 FROM {k}".format(k=k))
|
101 |
except Exception as e:
|
@@ -108,12 +111,11 @@ class OracleDB:
|
|
108 |
except Exception as e:
|
109 |
logger.error(f"Failed to create table {k} in Oracle database")
|
110 |
logger.error(f"Oracle database error: {e}")
|
111 |
-
|
112 |
-
logger.info(
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
async with self.pool.acquire() as connection:
|
117 |
connection.inputtypehandler = self.input_type_handler
|
118 |
connection.outputtypehandler = self.output_type_handler
|
119 |
with connection.cursor() as cursor:
|
@@ -136,9 +138,9 @@ class OracleDB:
|
|
136 |
data = dict(zip(columns, row))
|
137 |
else:
|
138 |
data = None
|
139 |
-
return data
|
140 |
|
141 |
-
async def execute(self,sql: str, data: list = None):
|
142 |
# logger.info("go into OracleDB execute method")
|
143 |
try:
|
144 |
async with self.pool.acquire() as connection:
|
@@ -148,58 +150,63 @@ class OracleDB:
|
|
148 |
if data is None:
|
149 |
await cursor.execute(sql)
|
150 |
else:
|
151 |
-
#print(data)
|
152 |
-
#print(sql)
|
153 |
-
await cursor.execute(sql,data)
|
154 |
await connection.commit()
|
155 |
except Exception as e:
|
156 |
-
logger.error(f"Oracle database error: {e}")
|
157 |
print(sql)
|
158 |
print(data)
|
159 |
raise
|
160 |
|
|
|
161 |
@dataclass
|
162 |
class OracleKVStorage(BaseKVStorage):
|
163 |
-
|
164 |
# should pass db object to self.db
|
165 |
def __post_init__(self):
|
166 |
self._data = {}
|
167 |
-
self._max_batch_size = self.global_config["embedding_batch_num"]
|
168 |
-
|
169 |
################ QUERY METHODS ################
|
170 |
|
171 |
async def get_by_id(self, id: str) -> Union[dict, None]:
|
172 |
"""根据 id 获取 doc_full 数据."""
|
173 |
-
SQL = SQL_TEMPLATES["get_by_id_"+self.namespace].format(
|
174 |
-
|
175 |
-
|
|
|
|
|
176 |
if res:
|
177 |
-
data = res
|
178 |
-
#print (data)
|
179 |
return data
|
180 |
else:
|
181 |
return None
|
182 |
|
183 |
# Query by id
|
184 |
-
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict],None]:
|
185 |
"""根据 id 获取 doc_chunks 数据"""
|
186 |
-
SQL = SQL_TEMPLATES["get_by_ids_"+self.namespace].format(
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
190 |
if res:
|
191 |
-
data = res
|
192 |
-
#print(data)
|
193 |
return data
|
194 |
else:
|
195 |
return None
|
196 |
-
|
197 |
async def filter_keys(self, keys: list[str]) -> set[str]:
|
198 |
"""过滤掉重复内容"""
|
199 |
-
SQL = SQL_TEMPLATES["filter_keys"].format(
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
203 |
data = None
|
204 |
if res:
|
205 |
exist_keys = [key["id"] for key in res]
|
@@ -208,14 +215,13 @@ class OracleKVStorage(BaseKVStorage):
|
|
208 |
exist_keys = []
|
209 |
data = set([s for s in keys if s not in exist_keys])
|
210 |
return data
|
211 |
-
|
212 |
-
|
213 |
################ INSERT METHODS ################
|
214 |
async def upsert(self, data: dict[str, dict]):
|
215 |
left_data = {k: v for k, v in data.items() if k not in self._data}
|
216 |
self._data.update(left_data)
|
217 |
-
#print(self._data)
|
218 |
-
#values = []
|
219 |
if self.namespace == "text_chunks":
|
220 |
list_data = [
|
221 |
{
|
@@ -226,7 +232,7 @@ class OracleKVStorage(BaseKVStorage):
|
|
226 |
]
|
227 |
contents = [v["content"] for v in data.values()]
|
228 |
batches = [
|
229 |
-
contents[i: i + self._max_batch_size]
|
230 |
for i in range(0, len(contents), self._max_batch_size)
|
231 |
]
|
232 |
embeddings_list = await asyncio.gather(
|
@@ -235,42 +241,45 @@ class OracleKVStorage(BaseKVStorage):
|
|
235 |
embeddings = np.concatenate(embeddings_list)
|
236 |
for i, d in enumerate(list_data):
|
237 |
d["__vector__"] = embeddings[i]
|
238 |
-
#print(list_data)
|
239 |
for item in list_data:
|
240 |
-
merge_sql = SQL_TEMPLATES["merge_chunk"].format(
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
247 |
await self.db.execute(merge_sql, values)
|
248 |
|
249 |
if self.namespace == "full_docs":
|
250 |
for k, v in self._data.items():
|
251 |
-
#values.clear()
|
252 |
merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
|
253 |
check_id=k,
|
254 |
)
|
255 |
values = [k, self._data[k]["content"], self.db.workspace]
|
256 |
-
#print(merge_sql)
|
257 |
await self.db.execute(merge_sql, values)
|
258 |
return left_data
|
259 |
|
260 |
-
|
261 |
async def index_done_callback(self):
|
262 |
if self.namespace in ["full_docs", "text_chunks"]:
|
263 |
logger.info("full doc and chunk data had been saved into oracle db!")
|
264 |
|
265 |
|
266 |
-
|
267 |
@dataclass
|
268 |
class OracleVectorDBStorage(BaseVectorStorage):
|
269 |
cosine_better_than_threshold: float = 0.2
|
270 |
|
271 |
def __post_init__(self):
|
272 |
pass
|
273 |
-
|
274 |
async def upsert(self, data: dict[str, dict]):
|
275 |
"""向向量数据库中插入数据"""
|
276 |
pass
|
@@ -278,53 +287,51 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|
278 |
async def index_done_callback(self):
|
279 |
pass
|
280 |
|
281 |
-
|
282 |
#################### query method ###############
|
283 |
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
284 |
-
"""从向量数据库中查询数据"""
|
285 |
embeddings = await self.embedding_func([query])
|
286 |
embedding = embeddings[0]
|
287 |
# 转换精度
|
288 |
dtype = str(embedding.dtype).upper()
|
289 |
dimension = embedding.shape[0]
|
290 |
-
embedding_string =
|
291 |
|
292 |
SQL = SQL_TEMPLATES[self.namespace].format(
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
# print(SQL)
|
301 |
results = await self.db.query(SQL, multirows=True)
|
302 |
-
#print("vector search result:",results)
|
303 |
return results
|
304 |
|
305 |
|
306 |
@dataclass
|
307 |
-
class OracleGraphStorage(BaseGraphStorage):
|
308 |
"""基于Oracle的图存储模块"""
|
309 |
-
|
310 |
def __post_init__(self):
|
311 |
"""从graphml文件加载图"""
|
312 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
313 |
|
314 |
-
|
315 |
#################### insert method ################
|
316 |
-
|
317 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
318 |
"""插入或更新节点"""
|
319 |
-
#print("go into upsert node method")
|
320 |
entity_name = node_id
|
321 |
entity_type = node_data["entity_type"]
|
322 |
description = node_data["description"]
|
323 |
-
source_id
|
324 |
-
content = entity_name+description
|
325 |
contents = [content]
|
326 |
batches = [
|
327 |
-
contents[i: i + self._max_batch_size]
|
328 |
for i in range(0, len(contents), self._max_batch_size)
|
329 |
]
|
330 |
embeddings_list = await asyncio.gather(
|
@@ -333,27 +340,38 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
333 |
embeddings = np.concatenate(embeddings_list)
|
334 |
content_vector = embeddings[0]
|
335 |
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
336 |
-
workspace=self.db.workspace,name=entity_name, source_chunk_id=source_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
)
|
338 |
-
#
|
339 |
-
await self.db.execute(merge_sql, [self.db.workspace,entity_name,entity_type,description,source_id,content,content_vector])
|
340 |
-
#self._graph.add_node(node_id, **node_data)
|
341 |
|
342 |
async def upsert_edge(
|
343 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
344 |
):
|
345 |
"""插入或更新边"""
|
346 |
-
#print("go into upsert edge method")
|
347 |
source_name = source_node_id
|
348 |
target_name = target_node_id
|
349 |
weight = edge_data["weight"]
|
350 |
keywords = edge_data["keywords"]
|
351 |
description = edge_data["description"]
|
352 |
source_chunk_id = edge_data["source_id"]
|
353 |
-
content = keywords+source_name+target_name+description
|
354 |
contents = [content]
|
355 |
batches = [
|
356 |
-
contents[i: i + self._max_batch_size]
|
357 |
for i in range(0, len(contents), self._max_batch_size)
|
358 |
]
|
359 |
embeddings_list = await asyncio.gather(
|
@@ -362,11 +380,27 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
362 |
embeddings = np.concatenate(embeddings_list)
|
363 |
content_vector = embeddings[0]
|
364 |
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
365 |
-
workspace=self.db.workspace,
|
|
|
|
|
|
|
366 |
)
|
367 |
-
#print(merge_sql)
|
368 |
-
await self.db.execute(
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
372 |
"""为节点生成向量"""
|
@@ -386,99 +420,109 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
386 |
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
387 |
return embeddings, nodes_ids
|
388 |
|
389 |
-
|
390 |
async def index_done_callback(self):
|
391 |
"""写入graphhml图文件"""
|
392 |
-
logger.info(
|
393 |
-
|
|
|
|
|
394 |
#################### query method #################
|
395 |
async def has_node(self, node_id: str) -> bool:
|
396 |
-
"""根据节点id检查节点是否存在"""
|
397 |
-
SQL = SQL_TEMPLATES["has_node"].format(
|
398 |
-
|
399 |
-
|
|
|
|
|
400 |
res = await self.db.query(SQL)
|
401 |
if res:
|
402 |
-
#print("Node exist!",res)
|
403 |
return True
|
404 |
else:
|
405 |
-
#print("Node not exist!")
|
406 |
return False
|
407 |
|
408 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
409 |
"""根据源和目标节点id检查边是否存在"""
|
410 |
-
SQL = SQL_TEMPLATES["has_edge"].format(
|
411 |
-
|
412 |
-
|
|
|
|
|
413 |
# print(SQL)
|
414 |
res = await self.db.query(SQL)
|
415 |
if res:
|
416 |
-
#print("Edge exist!",res)
|
417 |
return True
|
418 |
else:
|
419 |
-
#print("Edge not exist!")
|
420 |
return False
|
421 |
|
422 |
async def node_degree(self, node_id: str) -> int:
|
423 |
-
"""根据节点id获取节点的度"""
|
424 |
-
SQL = SQL_TEMPLATES["node_degree"].format(
|
|
|
|
|
425 |
# print(SQL)
|
426 |
res = await self.db.query(SQL)
|
427 |
if res:
|
428 |
-
#print("Node degree",res["degree"])
|
429 |
return res["degree"]
|
430 |
else:
|
431 |
-
#print("Edge not exist!")
|
432 |
return 0
|
433 |
|
434 |
-
|
435 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
436 |
"""根据源和目标节点id获取边的度"""
|
437 |
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
438 |
-
#print("Edge degree",degree)
|
439 |
return degree
|
440 |
|
441 |
-
|
442 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
443 |
"""根据节点id获取节点数据"""
|
444 |
-
SQL = SQL_TEMPLATES["get_node"].format(
|
|
|
|
|
445 |
# print(self.db.workspace, node_id)
|
446 |
# print(SQL)
|
447 |
res = await self.db.query(SQL)
|
448 |
if res:
|
449 |
-
#print("Get node!",self.db.workspace, node_id,res)
|
450 |
return res
|
451 |
else:
|
452 |
-
#print("Can't get node!",self.db.workspace, node_id)
|
453 |
return None
|
454 |
-
|
455 |
async def get_edge(
|
456 |
self, source_node_id: str, target_node_id: str
|
457 |
) -> Union[dict, None]:
|
458 |
"""根据源和目标节点id获取边"""
|
459 |
-
SQL = SQL_TEMPLATES["get_edge"].format(
|
460 |
-
|
461 |
-
|
|
|
|
|
462 |
res = await self.db.query(SQL)
|
463 |
if res:
|
464 |
-
#print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
465 |
return res
|
466 |
else:
|
467 |
-
#print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
468 |
return None
|
469 |
|
470 |
async def get_node_edges(self, source_node_id: str):
|
471 |
"""根据节点id获取节点的所有边"""
|
472 |
if await self.has_node(source_node_id):
|
473 |
-
SQL = SQL_TEMPLATES["get_node_edges"].format(
|
474 |
-
|
|
|
475 |
res = await self.db.query(sql=SQL, multirows=True)
|
476 |
if res:
|
477 |
-
data = [(i["source_name"],i["target_name"]) for i in res]
|
478 |
-
#print("Get node edge!",self.db.workspace, source_node_id,data)
|
479 |
return data
|
480 |
else:
|
481 |
-
#print("Node Edge not exist!",self.db.workspace, source_node_id)
|
482 |
return []
|
483 |
|
484 |
|
@@ -487,12 +531,12 @@ N_T = {
|
|
487 |
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
488 |
"chunks": "LIGHTRAG_DOC_CHUNKS",
|
489 |
"entities": "LIGHTRAG_GRAPH_NODES",
|
490 |
-
"relationships": "LIGHTRAG_GRAPH_EDGES"
|
491 |
}
|
492 |
|
493 |
TABLES = {
|
494 |
-
"LIGHTRAG_DOC_FULL":
|
495 |
-
|
496 |
id varchar(256)PRIMARY KEY,
|
497 |
workspace varchar(1024),
|
498 |
doc_name varchar(1024),
|
@@ -500,61 +544,63 @@ TABLES = {
|
|
500 |
meta JSON,
|
501 |
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
502 |
updatetime TIMESTAMP DEFAULT NULL
|
503 |
-
)"""
|
504 |
-
|
505 |
-
"LIGHTRAG_DOC_CHUNKS":
|
506 |
-
|
507 |
id varchar(256) PRIMARY KEY,
|
508 |
workspace varchar(1024),
|
509 |
full_doc_id varchar(256),
|
510 |
chunk_order_index NUMBER,
|
511 |
-
tokens NUMBER,
|
512 |
content CLOB,
|
513 |
content_vector VECTOR,
|
514 |
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
515 |
-
updatetime TIMESTAMP DEFAULT NULL
|
516 |
-
)"""
|
517 |
-
|
518 |
-
"LIGHTRAG_GRAPH_NODES":
|
519 |
-
|
520 |
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
521 |
workspace varchar(1024),
|
522 |
name varchar(2048),
|
523 |
-
entity_type varchar(1024),
|
524 |
description CLOB,
|
525 |
source_chunk_id varchar(256),
|
526 |
content CLOB,
|
527 |
content_vector VECTOR,
|
528 |
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
529 |
updatetime TIMESTAMP DEFAULT NULL
|
530 |
-
)"""
|
531 |
-
|
532 |
-
|
|
|
533 |
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
534 |
workspace varchar(1024),
|
535 |
source_name varchar(2048),
|
536 |
-
target_name varchar(2048),
|
537 |
weight NUMBER,
|
538 |
-
keywords CLOB,
|
539 |
description CLOB,
|
540 |
source_chunk_id varchar(256),
|
541 |
content CLOB,
|
542 |
content_vector VECTOR,
|
543 |
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
544 |
updatetime TIMESTAMP DEFAULT NULL
|
545 |
-
)"""
|
546 |
-
|
547 |
-
|
|
|
548 |
id varchar(256) PRIMARY KEY,
|
549 |
send clob,
|
550 |
return clob,
|
551 |
model varchar(1024),
|
552 |
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
553 |
updatetime TIMESTAMP DEFAULT NULL
|
554 |
-
)"""
|
555 |
-
|
556 |
-
"LIGHTRAG_GRAPH":
|
557 |
-
|
558 |
VERTEX TABLES (
|
559 |
lightrag_graph_nodes KEY (id)
|
560 |
LABEL entity
|
@@ -565,93 +611,67 @@ TABLES = {
|
|
565 |
SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
|
566 |
DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
|
567 |
LABEL has_relation
|
568 |
-
PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
|
569 |
-
) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""
|
570 |
-
}
|
|
|
571 |
|
572 |
|
573 |
SQL_TEMPLATES = {
|
574 |
# SQL for KVStorage
|
575 |
-
"get_by_id_full_docs":
|
576 |
-
|
577 |
-
|
578 |
-
"
|
579 |
-
|
580 |
-
|
581 |
-
"get_by_ids_full_docs":
|
582 |
-
"select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
|
583 |
-
|
584 |
-
"get_by_ids_text_chunks":
|
585 |
-
"select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
|
586 |
-
|
587 |
-
"filter_keys":
|
588 |
-
"select id from {table_name} where workspace='{workspace}' and id in ({ids})",
|
589 |
-
|
590 |
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
|
591 |
USING DUAL
|
592 |
ON (a.id = '{check_id}')
|
593 |
WHEN NOT MATCHED THEN
|
594 |
INSERT(id,content,workspace) values(:1,:2,:3)
|
595 |
""",
|
596 |
-
|
597 |
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
|
598 |
USING DUAL
|
599 |
ON (a.id = '{check_id}')
|
600 |
WHEN NOT MATCHED THEN
|
601 |
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
|
602 |
values (:1,:2,:3,:4,:5,:6,:7) """,
|
603 |
-
|
604 |
# SQL for VectorStorage
|
605 |
-
"entities":
|
606 |
-
|
607 |
-
|
608 |
-
FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
|
609 |
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
614 |
-
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
|
615 |
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
(SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
620 |
-
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
|
621 |
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
622 |
-
|
623 |
# SQL for GraphStorage
|
624 |
-
"has_node":
|
625 |
-
"""SELECT * FROM GRAPH_TABLE (lightrag_graph
|
626 |
MATCH (a)
|
627 |
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
628 |
COLUMNS (a.name))""",
|
629 |
-
|
630 |
-
"has_edge":
|
631 |
-
"""SELECT * FROM GRAPH_TABLE (lightrag_graph
|
632 |
MATCH (a) -[e]-> (b)
|
633 |
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
634 |
AND a.name='{source_node_id}' AND b.name='{target_node_id}'
|
635 |
COLUMNS (e.source_name,e.target_name) )""",
|
636 |
-
|
637 |
-
"node_degree":
|
638 |
-
"""SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
639 |
MATCH (a)-[e]->(b)
|
640 |
WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
641 |
AND a.name='{node_id}' or b.name = '{node_id}'
|
642 |
COLUMNS (a.name))""",
|
643 |
-
|
644 |
-
"get_node":
|
645 |
-
"""SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
646 |
FROM GRAPH_TABLE (lightrag_graph
|
647 |
-
MATCH (a)
|
648 |
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
649 |
COLUMNS (a.name)
|
650 |
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
651 |
WHERE t2.workspace='{workspace}'""",
|
652 |
-
|
653 |
-
"get_edge":
|
654 |
-
"""SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
655 |
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
656 |
FROM GRAPH_TABLE (lightrag_graph
|
657 |
MATCH (a)-[e]->(b)
|
@@ -659,15 +679,12 @@ SQL_TEMPLATES = {
|
|
659 |
AND a.name='{source_node_id}' and b.name = '{target_node_id}'
|
660 |
COLUMNS (e.id,a.name as source_id)
|
661 |
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
662 |
-
|
663 |
-
"get_node_edges":
|
664 |
-
"""SELECT source_name,target_name
|
665 |
FROM GRAPH_TABLE (lightrag_graph
|
666 |
MATCH (a)-[e]->(b)
|
667 |
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
668 |
AND a.name='{source_node_id}'
|
669 |
COLUMNS (a.name as source_name,b.name as target_name))""",
|
670 |
-
|
671 |
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
672 |
USING DUAL
|
673 |
ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
|
@@ -679,5 +696,5 @@ SQL_TEMPLATES = {
|
|
679 |
ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
|
680 |
WHEN NOT MATCHED THEN
|
681 |
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
682 |
-
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """
|
683 |
-
|
|
|
1 |
import asyncio
|
2 |
+
|
3 |
+
# import html
|
4 |
+
# import os
|
5 |
from dataclasses import dataclass
|
6 |
+
from typing import Union
|
|
|
7 |
import numpy as np
|
8 |
import array
|
9 |
|
|
|
16 |
|
17 |
import oracledb
|
18 |
|
19 |
+
|
20 |
class OracleDB:
|
21 |
+
def __init__(self, config, **kwargs):
|
22 |
self.host = config.get("host", None)
|
23 |
self.port = config.get("port", None)
|
24 |
self.user = config.get("user", None)
|
|
|
33 |
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
34 |
if self.user is None or self.password is None:
|
35 |
raise ValueError("Missing database user or password in addon_params")
|
36 |
+
|
37 |
try:
|
38 |
oracledb.defaults.fetch_lobs = False
|
39 |
|
40 |
self.pool = oracledb.create_pool_async(
|
41 |
+
user=self.user,
|
42 |
+
password=self.password,
|
43 |
+
dsn=self.dsn,
|
44 |
+
config_dir=self.config_dir,
|
45 |
+
wallet_location=self.wallet_location,
|
46 |
+
wallet_password=self.wallet_password,
|
47 |
+
min=1,
|
48 |
+
max=self.max,
|
49 |
+
increment=self.increment,
|
50 |
+
)
|
51 |
logger.info(f"Connected to Oracle database at {self.dsn}")
|
52 |
except Exception as e:
|
53 |
logger.error(f"Failed to connect to Oracle database at {self.dsn}")
|
|
|
91 |
arraysize=cursor.arraysize,
|
92 |
outconverter=self.numpy_converter_out,
|
93 |
)
|
94 |
+
|
95 |
async def check_tables(self):
|
96 |
+
for k, v in TABLES.items():
|
97 |
try:
|
98 |
if k.lower() == "lightrag_graph":
|
99 |
+
await self.query(
|
100 |
+
"SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
|
101 |
+
)
|
102 |
else:
|
103 |
await self.query("SELECT 1 FROM {k}".format(k=k))
|
104 |
except Exception as e:
|
|
|
111 |
except Exception as e:
|
112 |
logger.error(f"Failed to create table {k} in Oracle database")
|
113 |
logger.error(f"Oracle database error: {e}")
|
114 |
+
|
115 |
+
logger.info("Finished check all tables in Oracle database")
|
116 |
+
|
117 |
+
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
|
118 |
+
async with self.pool.acquire() as connection:
|
|
|
119 |
connection.inputtypehandler = self.input_type_handler
|
120 |
connection.outputtypehandler = self.output_type_handler
|
121 |
with connection.cursor() as cursor:
|
|
|
138 |
data = dict(zip(columns, row))
|
139 |
else:
|
140 |
data = None
|
141 |
+
return data
|
142 |
|
143 |
+
async def execute(self, sql: str, data: list = None):
|
144 |
# logger.info("go into OracleDB execute method")
|
145 |
try:
|
146 |
async with self.pool.acquire() as connection:
|
|
|
150 |
if data is None:
|
151 |
await cursor.execute(sql)
|
152 |
else:
|
153 |
+
# print(data)
|
154 |
+
# print(sql)
|
155 |
+
await cursor.execute(sql, data)
|
156 |
await connection.commit()
|
157 |
except Exception as e:
|
158 |
+
logger.error(f"Oracle database error: {e}")
|
159 |
print(sql)
|
160 |
print(data)
|
161 |
raise
|
162 |
|
163 |
+
|
164 |
@dataclass
|
165 |
class OracleKVStorage(BaseKVStorage):
|
|
|
166 |
# should pass db object to self.db
|
167 |
def __post_init__(self):
|
168 |
self._data = {}
|
169 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
170 |
+
|
171 |
################ QUERY METHODS ################
|
172 |
|
173 |
async def get_by_id(self, id: str) -> Union[dict, None]:
|
174 |
"""根据 id 获取 doc_full 数据."""
|
175 |
+
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
|
176 |
+
workspace=self.db.workspace, id=id
|
177 |
+
)
|
178 |
+
# print("get_by_id:"+SQL)
|
179 |
+
res = await self.db.query(SQL)
|
180 |
if res:
|
181 |
+
data = res # {"data":res}
|
182 |
+
# print (data)
|
183 |
return data
|
184 |
else:
|
185 |
return None
|
186 |
|
187 |
# Query by id
|
188 |
+
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
189 |
"""根据 id 获取 doc_chunks 数据"""
|
190 |
+
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
191 |
+
workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
|
192 |
+
)
|
193 |
+
# print("get_by_ids:"+SQL)
|
194 |
+
res = await self.db.query(SQL, multirows=True)
|
195 |
if res:
|
196 |
+
data = res # [{"data":i} for i in res]
|
197 |
+
# print(data)
|
198 |
return data
|
199 |
else:
|
200 |
return None
|
201 |
+
|
202 |
async def filter_keys(self, keys: list[str]) -> set[str]:
|
203 |
"""过滤掉重复内容"""
|
204 |
+
SQL = SQL_TEMPLATES["filter_keys"].format(
|
205 |
+
table_name=N_T[self.namespace],
|
206 |
+
workspace=self.db.workspace,
|
207 |
+
ids=",".join([f"'{k}'" for k in keys]),
|
208 |
+
)
|
209 |
+
res = await self.db.query(SQL, multirows=True)
|
210 |
data = None
|
211 |
if res:
|
212 |
exist_keys = [key["id"] for key in res]
|
|
|
215 |
exist_keys = []
|
216 |
data = set([s for s in keys if s not in exist_keys])
|
217 |
return data
|
218 |
+
|
|
|
219 |
################ INSERT METHODS ################
|
220 |
async def upsert(self, data: dict[str, dict]):
|
221 |
left_data = {k: v for k, v in data.items() if k not in self._data}
|
222 |
self._data.update(left_data)
|
223 |
+
# print(self._data)
|
224 |
+
# values = []
|
225 |
if self.namespace == "text_chunks":
|
226 |
list_data = [
|
227 |
{
|
|
|
232 |
]
|
233 |
contents = [v["content"] for v in data.values()]
|
234 |
batches = [
|
235 |
+
contents[i : i + self._max_batch_size]
|
236 |
for i in range(0, len(contents), self._max_batch_size)
|
237 |
]
|
238 |
embeddings_list = await asyncio.gather(
|
|
|
241 |
embeddings = np.concatenate(embeddings_list)
|
242 |
for i, d in enumerate(list_data):
|
243 |
d["__vector__"] = embeddings[i]
|
244 |
+
# print(list_data)
|
245 |
for item in list_data:
|
246 |
+
merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
|
247 |
+
|
248 |
+
values = [
|
249 |
+
item["__id__"],
|
250 |
+
item["content"],
|
251 |
+
self.db.workspace,
|
252 |
+
item["tokens"],
|
253 |
+
item["chunk_order_index"],
|
254 |
+
item["full_doc_id"],
|
255 |
+
item["__vector__"],
|
256 |
+
]
|
257 |
+
# print(merge_sql)
|
258 |
await self.db.execute(merge_sql, values)
|
259 |
|
260 |
if self.namespace == "full_docs":
|
261 |
for k, v in self._data.items():
|
262 |
+
# values.clear()
|
263 |
merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
|
264 |
check_id=k,
|
265 |
)
|
266 |
values = [k, self._data[k]["content"], self.db.workspace]
|
267 |
+
# print(merge_sql)
|
268 |
await self.db.execute(merge_sql, values)
|
269 |
return left_data
|
270 |
|
|
|
271 |
async def index_done_callback(self):
|
272 |
if self.namespace in ["full_docs", "text_chunks"]:
|
273 |
logger.info("full doc and chunk data had been saved into oracle db!")
|
274 |
|
275 |
|
|
|
276 |
@dataclass
|
277 |
class OracleVectorDBStorage(BaseVectorStorage):
|
278 |
cosine_better_than_threshold: float = 0.2
|
279 |
|
280 |
def __post_init__(self):
|
281 |
pass
|
282 |
+
|
283 |
async def upsert(self, data: dict[str, dict]):
|
284 |
"""向向量数据库中插入数据"""
|
285 |
pass
|
|
|
287 |
async def index_done_callback(self):
|
288 |
pass
|
289 |
|
|
|
290 |
#################### query method ###############
|
291 |
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
292 |
+
"""从向量数据库中查询数据"""
|
293 |
embeddings = await self.embedding_func([query])
|
294 |
embedding = embeddings[0]
|
295 |
# 转换精度
|
296 |
dtype = str(embedding.dtype).upper()
|
297 |
dimension = embedding.shape[0]
|
298 |
+
embedding_string = ", ".join(map(str, embedding.tolist()))
|
299 |
|
300 |
SQL = SQL_TEMPLATES[self.namespace].format(
|
301 |
+
embedding_string=embedding_string,
|
302 |
+
dimension=dimension,
|
303 |
+
dtype=dtype,
|
304 |
+
workspace=self.db.workspace,
|
305 |
+
top_k=top_k,
|
306 |
+
better_than_threshold=self.cosine_better_than_threshold,
|
307 |
+
)
|
308 |
# print(SQL)
|
309 |
results = await self.db.query(SQL, multirows=True)
|
310 |
+
# print("vector search result:",results)
|
311 |
return results
|
312 |
|
313 |
|
314 |
@dataclass
|
315 |
+
class OracleGraphStorage(BaseGraphStorage):
|
316 |
"""基于Oracle的图存储模块"""
|
317 |
+
|
318 |
def __post_init__(self):
|
319 |
"""从graphml文件加载图"""
|
320 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
321 |
|
|
|
322 |
#################### insert method ################
|
323 |
+
|
324 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
325 |
"""插入或更新节点"""
|
326 |
+
# print("go into upsert node method")
|
327 |
entity_name = node_id
|
328 |
entity_type = node_data["entity_type"]
|
329 |
description = node_data["description"]
|
330 |
+
source_id = node_data["source_id"]
|
331 |
+
content = entity_name + description
|
332 |
contents = [content]
|
333 |
batches = [
|
334 |
+
contents[i : i + self._max_batch_size]
|
335 |
for i in range(0, len(contents), self._max_batch_size)
|
336 |
]
|
337 |
embeddings_list = await asyncio.gather(
|
|
|
340 |
embeddings = np.concatenate(embeddings_list)
|
341 |
content_vector = embeddings[0]
|
342 |
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
343 |
+
workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
|
344 |
+
)
|
345 |
+
# print(merge_sql)
|
346 |
+
await self.db.execute(
|
347 |
+
merge_sql,
|
348 |
+
[
|
349 |
+
self.db.workspace,
|
350 |
+
entity_name,
|
351 |
+
entity_type,
|
352 |
+
description,
|
353 |
+
source_id,
|
354 |
+
content,
|
355 |
+
content_vector,
|
356 |
+
],
|
357 |
)
|
358 |
+
# self._graph.add_node(node_id, **node_data)
|
|
|
|
|
359 |
|
360 |
async def upsert_edge(
|
361 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
362 |
):
|
363 |
"""插入或更新边"""
|
364 |
+
# print("go into upsert edge method")
|
365 |
source_name = source_node_id
|
366 |
target_name = target_node_id
|
367 |
weight = edge_data["weight"]
|
368 |
keywords = edge_data["keywords"]
|
369 |
description = edge_data["description"]
|
370 |
source_chunk_id = edge_data["source_id"]
|
371 |
+
content = keywords + source_name + target_name + description
|
372 |
contents = [content]
|
373 |
batches = [
|
374 |
+
contents[i : i + self._max_batch_size]
|
375 |
for i in range(0, len(contents), self._max_batch_size)
|
376 |
]
|
377 |
embeddings_list = await asyncio.gather(
|
|
|
380 |
embeddings = np.concatenate(embeddings_list)
|
381 |
content_vector = embeddings[0]
|
382 |
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
383 |
+
workspace=self.db.workspace,
|
384 |
+
source_name=source_name,
|
385 |
+
target_name=target_name,
|
386 |
+
source_chunk_id=source_chunk_id,
|
387 |
)
|
388 |
+
# print(merge_sql)
|
389 |
+
await self.db.execute(
|
390 |
+
merge_sql,
|
391 |
+
[
|
392 |
+
self.db.workspace,
|
393 |
+
source_name,
|
394 |
+
target_name,
|
395 |
+
weight,
|
396 |
+
keywords,
|
397 |
+
description,
|
398 |
+
source_chunk_id,
|
399 |
+
content,
|
400 |
+
content_vector,
|
401 |
+
],
|
402 |
+
)
|
403 |
+
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
404 |
|
405 |
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
406 |
"""为节点生成向量"""
|
|
|
420 |
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
421 |
return embeddings, nodes_ids
|
422 |
|
|
|
423 |
async def index_done_callback(self):
|
424 |
"""写入graphhml图文件"""
|
425 |
+
logger.info(
|
426 |
+
"Node and edge data had been saved into oracle db already, so nothing to do here!"
|
427 |
+
)
|
428 |
+
|
429 |
#################### query method #################
|
430 |
async def has_node(self, node_id: str) -> bool:
|
431 |
+
"""根据节点id检查节点是否存在"""
|
432 |
+
SQL = SQL_TEMPLATES["has_node"].format(
|
433 |
+
workspace=self.db.workspace, node_id=node_id
|
434 |
+
)
|
435 |
+
# print(SQL)
|
436 |
+
# print(self.db.workspace, node_id)
|
437 |
res = await self.db.query(SQL)
|
438 |
if res:
|
439 |
+
# print("Node exist!",res)
|
440 |
return True
|
441 |
else:
|
442 |
+
# print("Node not exist!")
|
443 |
return False
|
444 |
|
445 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
446 |
"""根据源和目标节点id检查边是否存在"""
|
447 |
+
SQL = SQL_TEMPLATES["has_edge"].format(
|
448 |
+
workspace=self.db.workspace,
|
449 |
+
source_node_id=source_node_id,
|
450 |
+
target_node_id=target_node_id,
|
451 |
+
)
|
452 |
# print(SQL)
|
453 |
res = await self.db.query(SQL)
|
454 |
if res:
|
455 |
+
# print("Edge exist!",res)
|
456 |
return True
|
457 |
else:
|
458 |
+
# print("Edge not exist!")
|
459 |
return False
|
460 |
|
461 |
async def node_degree(self, node_id: str) -> int:
|
462 |
+
"""根据节点id获取节点的度"""
|
463 |
+
SQL = SQL_TEMPLATES["node_degree"].format(
|
464 |
+
workspace=self.db.workspace, node_id=node_id
|
465 |
+
)
|
466 |
# print(SQL)
|
467 |
res = await self.db.query(SQL)
|
468 |
if res:
|
469 |
+
# print("Node degree",res["degree"])
|
470 |
return res["degree"]
|
471 |
else:
|
472 |
+
# print("Edge not exist!")
|
473 |
return 0
|
474 |
|
|
|
475 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
476 |
"""根据源和目标节点id获取边的度"""
|
477 |
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
478 |
+
# print("Edge degree",degree)
|
479 |
return degree
|
480 |
|
|
|
481 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
482 |
"""根据节点id获取节点数据"""
|
483 |
+
SQL = SQL_TEMPLATES["get_node"].format(
|
484 |
+
workspace=self.db.workspace, node_id=node_id
|
485 |
+
)
|
486 |
# print(self.db.workspace, node_id)
|
487 |
# print(SQL)
|
488 |
res = await self.db.query(SQL)
|
489 |
if res:
|
490 |
+
# print("Get node!",self.db.workspace, node_id,res)
|
491 |
return res
|
492 |
else:
|
493 |
+
# print("Can't get node!",self.db.workspace, node_id)
|
494 |
return None
|
495 |
+
|
496 |
async def get_edge(
|
497 |
self, source_node_id: str, target_node_id: str
|
498 |
) -> Union[dict, None]:
|
499 |
"""根据源和目标节点id获取边"""
|
500 |
+
SQL = SQL_TEMPLATES["get_edge"].format(
|
501 |
+
workspace=self.db.workspace,
|
502 |
+
source_node_id=source_node_id,
|
503 |
+
target_node_id=target_node_id,
|
504 |
+
)
|
505 |
res = await self.db.query(SQL)
|
506 |
if res:
|
507 |
+
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
508 |
return res
|
509 |
else:
|
510 |
+
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
511 |
return None
|
512 |
|
513 |
async def get_node_edges(self, source_node_id: str):
|
514 |
"""根据节点id获取节点的所有边"""
|
515 |
if await self.has_node(source_node_id):
|
516 |
+
SQL = SQL_TEMPLATES["get_node_edges"].format(
|
517 |
+
workspace=self.db.workspace, source_node_id=source_node_id
|
518 |
+
)
|
519 |
res = await self.db.query(sql=SQL, multirows=True)
|
520 |
if res:
|
521 |
+
data = [(i["source_name"], i["target_name"]) for i in res]
|
522 |
+
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
523 |
return data
|
524 |
else:
|
525 |
+
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
526 |
return []
|
527 |
|
528 |
|
|
|
531 |
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
532 |
"chunks": "LIGHTRAG_DOC_CHUNKS",
|
533 |
"entities": "LIGHTRAG_GRAPH_NODES",
|
534 |
+
"relationships": "LIGHTRAG_GRAPH_EDGES",
|
535 |
}
|
536 |
|
537 |
TABLES = {
|
538 |
+
"LIGHTRAG_DOC_FULL": {
|
539 |
+
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
|
540 |
id varchar(256)PRIMARY KEY,
|
541 |
workspace varchar(1024),
|
542 |
doc_name varchar(1024),
|
|
|
544 |
meta JSON,
|
545 |
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
546 |
updatetime TIMESTAMP DEFAULT NULL
|
547 |
+
)"""
|
548 |
+
},
|
549 |
+
"LIGHTRAG_DOC_CHUNKS": {
|
550 |
+
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
551 |
id varchar(256) PRIMARY KEY,
|
552 |
workspace varchar(1024),
|
553 |
full_doc_id varchar(256),
|
554 |
chunk_order_index NUMBER,
|
555 |
+
tokens NUMBER,
|
556 |
content CLOB,
|
557 |
content_vector VECTOR,
|
558 |
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
559 |
+
updatetime TIMESTAMP DEFAULT NULL
|
560 |
+
)"""
|
561 |
+
},
|
562 |
+
"LIGHTRAG_GRAPH_NODES": {
|
563 |
+
"ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
564 |
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
565 |
workspace varchar(1024),
|
566 |
name varchar(2048),
|
567 |
+
entity_type varchar(1024),
|
568 |
description CLOB,
|
569 |
source_chunk_id varchar(256),
|
570 |
content CLOB,
|
571 |
content_vector VECTOR,
|
572 |
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
573 |
updatetime TIMESTAMP DEFAULT NULL
|
574 |
+
)"""
|
575 |
+
},
|
576 |
+
"LIGHTRAG_GRAPH_EDGES": {
|
577 |
+
"ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
578 |
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
579 |
workspace varchar(1024),
|
580 |
source_name varchar(2048),
|
581 |
+
target_name varchar(2048),
|
582 |
weight NUMBER,
|
583 |
+
keywords CLOB,
|
584 |
description CLOB,
|
585 |
source_chunk_id varchar(256),
|
586 |
content CLOB,
|
587 |
content_vector VECTOR,
|
588 |
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
589 |
updatetime TIMESTAMP DEFAULT NULL
|
590 |
+
)"""
|
591 |
+
},
|
592 |
+
"LIGHTRAG_LLM_CACHE": {
|
593 |
+
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
|
594 |
id varchar(256) PRIMARY KEY,
|
595 |
send clob,
|
596 |
return clob,
|
597 |
model varchar(1024),
|
598 |
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
599 |
updatetime TIMESTAMP DEFAULT NULL
|
600 |
+
)"""
|
601 |
+
},
|
602 |
+
"LIGHTRAG_GRAPH": {
|
603 |
+
"ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
|
604 |
VERTEX TABLES (
|
605 |
lightrag_graph_nodes KEY (id)
|
606 |
LABEL entity
|
|
|
611 |
SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
|
612 |
DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
|
613 |
LABEL has_relation
|
614 |
+
PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
|
615 |
+
) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""
|
616 |
+
},
|
617 |
+
}
|
618 |
|
619 |
|
620 |
SQL_TEMPLATES = {
|
621 |
# SQL for KVStorage
|
622 |
+
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
|
623 |
+
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
|
624 |
+
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
|
625 |
+
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
|
626 |
+
"filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
627 |
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
|
628 |
USING DUAL
|
629 |
ON (a.id = '{check_id}')
|
630 |
WHEN NOT MATCHED THEN
|
631 |
INSERT(id,content,workspace) values(:1,:2,:3)
|
632 |
""",
|
|
|
633 |
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
|
634 |
USING DUAL
|
635 |
ON (a.id = '{check_id}')
|
636 |
WHEN NOT MATCHED THEN
|
637 |
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
|
638 |
values (:1,:2,:3,:4,:5,:6,:7) """,
|
|
|
639 |
# SQL for VectorStorage
|
640 |
+
"entities": """SELECT name as entity_name FROM
|
641 |
+
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
642 |
+
FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
|
|
|
643 |
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
644 |
+
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
|
645 |
+
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
646 |
+
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
|
|
|
|
|
647 |
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
648 |
+
"chunks": """SELECT id FROM
|
649 |
+
(SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
650 |
+
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
|
|
|
|
|
651 |
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
|
|
652 |
# SQL for GraphStorage
|
653 |
+
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
|
|
654 |
MATCH (a)
|
655 |
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
656 |
COLUMNS (a.name))""",
|
657 |
+
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
|
|
|
|
658 |
MATCH (a) -[e]-> (b)
|
659 |
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
660 |
AND a.name='{source_node_id}' AND b.name='{target_node_id}'
|
661 |
COLUMNS (e.source_name,e.target_name) )""",
|
662 |
+
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
|
|
|
|
663 |
MATCH (a)-[e]->(b)
|
664 |
WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
665 |
AND a.name='{node_id}' or b.name = '{node_id}'
|
666 |
COLUMNS (a.name))""",
|
667 |
+
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
|
|
|
|
668 |
FROM GRAPH_TABLE (lightrag_graph
|
669 |
+
MATCH (a)
|
670 |
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
671 |
COLUMNS (a.name)
|
672 |
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
673 |
WHERE t2.workspace='{workspace}'""",
|
674 |
+
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
|
|
|
|
675 |
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
676 |
FROM GRAPH_TABLE (lightrag_graph
|
677 |
MATCH (a)-[e]->(b)
|
|
|
679 |
AND a.name='{source_node_id}' and b.name = '{target_node_id}'
|
680 |
COLUMNS (e.id,a.name as source_id)
|
681 |
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
682 |
+
"get_node_edges": """SELECT source_name,target_name
|
|
|
|
|
683 |
FROM GRAPH_TABLE (lightrag_graph
|
684 |
MATCH (a)-[e]->(b)
|
685 |
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
686 |
AND a.name='{source_node_id}'
|
687 |
COLUMNS (a.name as source_name,b.name as target_name))""",
|
|
|
688 |
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
689 |
USING DUAL
|
690 |
ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
|
|
|
696 |
ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
|
697 |
WHEN NOT MATCHED THEN
|
698 |
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
699 |
+
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
|
700 |
+
}
|
lightrag/lightrag.py
CHANGED
@@ -38,15 +38,11 @@ from .storage import (
|
|
38 |
JsonKVStorage,
|
39 |
NanoVectorDBStorage,
|
40 |
NetworkXStorage,
|
41 |
-
|
42 |
|
43 |
from .kg.neo4j_impl import Neo4JStorage
|
44 |
|
45 |
-
from .kg.oracle_impl import
|
46 |
-
OracleKVStorage,
|
47 |
-
OracleGraphStorage,
|
48 |
-
OracleVectorDBStorage
|
49 |
-
)
|
50 |
|
51 |
# future KG integrations
|
52 |
|
@@ -54,6 +50,7 @@ from .kg.oracle_impl import (
|
|
54 |
# GraphStorage as ArangoDBStorage
|
55 |
# )
|
56 |
|
|
|
57 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
58 |
try:
|
59 |
return asyncio.get_event_loop()
|
@@ -72,7 +69,7 @@ class LightRAG:
|
|
72 |
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
73 |
)
|
74 |
|
75 |
-
kv_storage
|
76 |
vector_storage: str = field(default="NanoVectorDBStorage")
|
77 |
graph_storage: str = field(default="NetworkXStorage")
|
78 |
|
@@ -115,7 +112,7 @@ class LightRAG:
|
|
115 |
|
116 |
# storage
|
117 |
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
118 |
-
|
119 |
enable_llm_cache: bool = True
|
120 |
|
121 |
# extension
|
@@ -134,18 +131,25 @@ class LightRAG:
|
|
134 |
|
135 |
# @TODO: should move all storage setup here to leverage initial start params attached to self.
|
136 |
|
137 |
-
self.key_string_value_json_storage_cls: Type[BaseKVStorage] =
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
if not os.path.exists(self.working_dir):
|
142 |
logger.info(f"Creating working directory {self.working_dir}")
|
143 |
os.makedirs(self.working_dir)
|
144 |
|
145 |
-
|
146 |
self.llm_response_cache = (
|
147 |
self.key_string_value_json_storage_cls(
|
148 |
-
namespace="llm_response_cache",
|
|
|
|
|
149 |
)
|
150 |
if self.enable_llm_cache
|
151 |
else None
|
@@ -159,13 +163,19 @@ class LightRAG:
|
|
159 |
# add embedding func by walter
|
160 |
####
|
161 |
self.full_docs = self.key_string_value_json_storage_cls(
|
162 |
-
namespace="full_docs",
|
|
|
|
|
163 |
)
|
164 |
self.text_chunks = self.key_string_value_json_storage_cls(
|
165 |
-
namespace="text_chunks",
|
|
|
|
|
166 |
)
|
167 |
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
168 |
-
namespace="chunk_entity_relation",
|
|
|
|
|
169 |
)
|
170 |
####
|
171 |
# add embedding func by walter over
|
@@ -200,13 +210,11 @@ class LightRAG:
|
|
200 |
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
201 |
return {
|
202 |
# kv storage
|
203 |
-
"JsonKVStorage":JsonKVStorage,
|
204 |
-
"OracleKVStorage":OracleKVStorage,
|
205 |
-
|
206 |
# vector storage
|
207 |
-
"NanoVectorDBStorage":NanoVectorDBStorage,
|
208 |
-
"OracleVectorDBStorage":OracleVectorDBStorage,
|
209 |
-
|
210 |
# graph storage
|
211 |
"NetworkXStorage": NetworkXStorage,
|
212 |
"Neo4JStorage": Neo4JStorage,
|
|
|
38 |
JsonKVStorage,
|
39 |
NanoVectorDBStorage,
|
40 |
NetworkXStorage,
|
41 |
+
)
|
42 |
|
43 |
from .kg.neo4j_impl import Neo4JStorage
|
44 |
|
45 |
+
from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
|
|
|
|
|
|
|
|
|
46 |
|
47 |
# future KG integrations
|
48 |
|
|
|
50 |
# GraphStorage as ArangoDBStorage
|
51 |
# )
|
52 |
|
53 |
+
|
54 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
55 |
try:
|
56 |
return asyncio.get_event_loop()
|
|
|
69 |
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
70 |
)
|
71 |
|
72 |
+
kv_storage: str = field(default="JsonKVStorage")
|
73 |
vector_storage: str = field(default="NanoVectorDBStorage")
|
74 |
graph_storage: str = field(default="NetworkXStorage")
|
75 |
|
|
|
112 |
|
113 |
# storage
|
114 |
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
115 |
+
|
116 |
enable_llm_cache: bool = True
|
117 |
|
118 |
# extension
|
|
|
131 |
|
132 |
# @TODO: should move all storage setup here to leverage initial start params attached to self.
|
133 |
|
134 |
+
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
135 |
+
self._get_storage_class()[self.kv_storage]
|
136 |
+
)
|
137 |
+
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
|
138 |
+
self.vector_storage
|
139 |
+
]
|
140 |
+
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
141 |
+
self.graph_storage
|
142 |
+
]
|
143 |
|
144 |
if not os.path.exists(self.working_dir):
|
145 |
logger.info(f"Creating working directory {self.working_dir}")
|
146 |
os.makedirs(self.working_dir)
|
147 |
|
|
|
148 |
self.llm_response_cache = (
|
149 |
self.key_string_value_json_storage_cls(
|
150 |
+
namespace="llm_response_cache",
|
151 |
+
global_config=asdict(self),
|
152 |
+
embedding_func=None,
|
153 |
)
|
154 |
if self.enable_llm_cache
|
155 |
else None
|
|
|
163 |
# add embedding func by walter
|
164 |
####
|
165 |
self.full_docs = self.key_string_value_json_storage_cls(
|
166 |
+
namespace="full_docs",
|
167 |
+
global_config=asdict(self),
|
168 |
+
embedding_func=self.embedding_func,
|
169 |
)
|
170 |
self.text_chunks = self.key_string_value_json_storage_cls(
|
171 |
+
namespace="text_chunks",
|
172 |
+
global_config=asdict(self),
|
173 |
+
embedding_func=self.embedding_func,
|
174 |
)
|
175 |
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
176 |
+
namespace="chunk_entity_relation",
|
177 |
+
global_config=asdict(self),
|
178 |
+
embedding_func=self.embedding_func,
|
179 |
)
|
180 |
####
|
181 |
# add embedding func by walter over
|
|
|
210 |
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
211 |
return {
|
212 |
# kv storage
|
213 |
+
"JsonKVStorage": JsonKVStorage,
|
214 |
+
"OracleKVStorage": OracleKVStorage,
|
|
|
215 |
# vector storage
|
216 |
+
"NanoVectorDBStorage": NanoVectorDBStorage,
|
217 |
+
"OracleVectorDBStorage": OracleVectorDBStorage,
|
|
|
218 |
# graph storage
|
219 |
"NetworkXStorage": NetworkXStorage,
|
220 |
"Neo4JStorage": Neo4JStorage,
|
lightrag/operate.py
CHANGED
@@ -16,7 +16,7 @@ from .utils import (
|
|
16 |
split_string_by_multi_markers,
|
17 |
truncate_list_by_token_size,
|
18 |
process_combine_contexts,
|
19 |
-
locate_json_string_body_from_string
|
20 |
)
|
21 |
from .base import (
|
22 |
BaseGraphStorage,
|
|
|
16 |
split_string_by_multi_markers,
|
17 |
truncate_list_by_token_size,
|
18 |
process_combine_contexts,
|
19 |
+
locate_json_string_body_from_string,
|
20 |
)
|
21 |
from .base import (
|
22 |
BaseGraphStorage,
|
requirements.txt
CHANGED
@@ -1,22 +1,22 @@
|
|
1 |
accelerate
|
|
|
2 |
aiohttp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
pyvis
|
4 |
tenacity
|
5 |
-
xxhash
|
6 |
# lmdeploy[all]
|
7 |
|
8 |
# LLM packages
|
9 |
tiktoken
|
10 |
torch
|
11 |
transformers
|
12 |
-
|
13 |
-
ollama
|
14 |
-
openai
|
15 |
-
|
16 |
-
# database packages
|
17 |
-
graspologic
|
18 |
-
hnswlib
|
19 |
-
networkx
|
20 |
-
oracledb
|
21 |
-
nano-vectordb
|
22 |
-
neo4j
|
|
|
1 |
accelerate
|
2 |
+
aioboto3
|
3 |
aiohttp
|
4 |
+
|
5 |
+
# database packages
|
6 |
+
graspologic
|
7 |
+
hnswlib
|
8 |
+
nano-vectordb
|
9 |
+
neo4j
|
10 |
+
networkx
|
11 |
+
ollama
|
12 |
+
openai
|
13 |
+
oracledb
|
14 |
pyvis
|
15 |
tenacity
|
|
|
16 |
# lmdeploy[all]
|
17 |
|
18 |
# LLM packages
|
19 |
tiktoken
|
20 |
torch
|
21 |
transformers
|
22 |
+
xxhash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|