Ken Wiltshire
commited on
Commit
·
b2bf632
1
Parent(s):
e7da6e0
securing for production with env vars for creds
Browse files- .gitignore +2 -1
- lightrag/kg/neo4j_impl.py +19 -32
- lightrag/lightrag.py +15 -3
- lightrag/operate.py +0 -19
- testkg.py +4 -3
.gitignore
CHANGED
@@ -7,4 +7,5 @@ lightrag-dev/
|
|
7 |
dist/
|
8 |
env/
|
9 |
local_neo4jWorkDir/
|
10 |
-
neo4jWorkDir/
|
|
|
|
7 |
dist/
|
8 |
env/
|
9 |
local_neo4jWorkDir/
|
10 |
+
neo4jWorkDir/
|
11 |
+
ignore_this.txt
|
lightrag/kg/neo4j_impl.py
CHANGED
@@ -5,7 +5,6 @@ from dataclasses import dataclass
|
|
5 |
from typing import Any, Union, cast
|
6 |
import numpy as np
|
7 |
import inspect
|
8 |
-
# import package.common.utils as utils
|
9 |
from lightrag.utils import load_json, logger, write_json
|
10 |
from ..base import (
|
11 |
BaseGraphStorage
|
@@ -22,27 +21,6 @@ from tenacity import (
|
|
22 |
|
23 |
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
# @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address"
|
28 |
-
# during indexing.
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
# Replace with your actual URI, username, and password
|
34 |
-
#local
|
35 |
-
URI = "neo4j://localhost:7687"
|
36 |
-
USERNAME = "neo4j"
|
37 |
-
PASSWORD = "password"
|
38 |
-
|
39 |
-
#aura
|
40 |
-
# URI = "neo4j+s://91fbae6c.databases.neo4j.io"
|
41 |
-
# USERNAME = "neo4j"
|
42 |
-
# PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"
|
43 |
-
# Create a driver object
|
44 |
-
|
45 |
-
|
46 |
@dataclass
|
47 |
class GraphStorage(BaseGraphStorage):
|
48 |
@staticmethod
|
@@ -51,6 +29,15 @@ class GraphStorage(BaseGraphStorage):
|
|
51 |
|
52 |
def __post_init__(self):
|
53 |
# self._graph = preloaded_graph or nx.Graph()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
|
55 |
self._node_embed_algorithms = {
|
56 |
"node2vec": self._node2vec_embed,
|
@@ -65,7 +52,7 @@ class GraphStorage(BaseGraphStorage):
|
|
65 |
query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
|
66 |
result = tx.run(query)
|
67 |
single_result = result.single()
|
68 |
-
logger.
|
69 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
|
70 |
)
|
71 |
|
@@ -90,7 +77,7 @@ class GraphStorage(BaseGraphStorage):
|
|
90 |
# if result.single() == None:
|
91 |
# print (f"this should not happen: ---- {label1}/{label2} {query}")
|
92 |
|
93 |
-
logger.
|
94 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
|
95 |
)
|
96 |
|
@@ -111,7 +98,7 @@ class GraphStorage(BaseGraphStorage):
|
|
111 |
result = session.run(query)
|
112 |
for record in result:
|
113 |
result = record["n"]
|
114 |
-
logger.
|
115 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
116 |
)
|
117 |
return result
|
@@ -133,7 +120,7 @@ class GraphStorage(BaseGraphStorage):
|
|
133 |
record = result.single()
|
134 |
if record:
|
135 |
edge_count = record["totalEdgeCount"]
|
136 |
-
logger.
|
137 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
|
138 |
)
|
139 |
return edge_count
|
@@ -154,7 +141,7 @@ class GraphStorage(BaseGraphStorage):
|
|
154 |
RETURN count(r) AS degree"""
|
155 |
result = session.run(query)
|
156 |
record = result.single()
|
157 |
-
logger.
|
158 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
|
159 |
)
|
160 |
return record["degree"]
|
@@ -183,7 +170,7 @@ class GraphStorage(BaseGraphStorage):
|
|
183 |
record = result.single()
|
184 |
if record:
|
185 |
result = dict(record["edge_properties"])
|
186 |
-
logger.
|
187 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
188 |
)
|
189 |
return result
|
@@ -254,7 +241,7 @@ class GraphStorage(BaseGraphStorage):
|
|
254 |
# if source_label and target_label:
|
255 |
# connections.append((source_label, target_label))
|
256 |
|
257 |
-
# logger.
|
258 |
# f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
|
259 |
# )
|
260 |
# return connections
|
@@ -308,7 +295,7 @@ class GraphStorage(BaseGraphStorage):
|
|
308 |
result = tx.run(query, properties=properties)
|
309 |
record = result.single()
|
310 |
if record:
|
311 |
-
logger.
|
312 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
|
313 |
)
|
314 |
return dict(record["n"])
|
@@ -364,7 +351,7 @@ class GraphStorage(BaseGraphStorage):
|
|
364 |
"""
|
365 |
|
366 |
result = tx.run(query, properties=edge_properties)
|
367 |
-
logger.
|
368 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
|
369 |
)
|
370 |
return result.single()
|
@@ -385,7 +372,7 @@ class GraphStorage(BaseGraphStorage):
|
|
385 |
with self._driver.session() as session:
|
386 |
#Define the Cypher query
|
387 |
options = self.global_config["node2vec_params"]
|
388 |
-
logger.
|
389 |
query = f"""CALL gds.node2vec.write('91fbae6c', {
|
390 |
options
|
391 |
})
|
|
|
5 |
from typing import Any, Union, cast
|
6 |
import numpy as np
|
7 |
import inspect
|
|
|
8 |
from lightrag.utils import load_json, logger, write_json
|
9 |
from ..base import (
|
10 |
BaseGraphStorage
|
|
|
21 |
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
@dataclass
|
25 |
class GraphStorage(BaseGraphStorage):
|
26 |
@staticmethod
|
|
|
29 |
|
30 |
def __post_init__(self):
|
31 |
# self._graph = preloaded_graph or nx.Graph()
|
32 |
+
credetial_parts = ['URI', 'USERNAME','PASSWORD']
|
33 |
+
credentials_set = all(x in os.environ for x in credetial_parts )
|
34 |
+
if credentials_set:
|
35 |
+
URI = os.environ["URI"]
|
36 |
+
USERNAME = os.environ["USERNAME"]
|
37 |
+
PASSWORD = os.environ["PASSWORD"]
|
38 |
+
else:
|
39 |
+
raise Exception (f"One or more Neo4J Credentials, {credetial_parts}, not found in the environment")
|
40 |
+
|
41 |
self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
|
42 |
self._node_embed_algorithms = {
|
43 |
"node2vec": self._node2vec_embed,
|
|
|
52 |
query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
|
53 |
result = tx.run(query)
|
54 |
single_result = result.single()
|
55 |
+
logger.debug(
|
56 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
|
57 |
)
|
58 |
|
|
|
77 |
# if result.single() == None:
|
78 |
# print (f"this should not happen: ---- {label1}/{label2} {query}")
|
79 |
|
80 |
+
logger.debug(
|
81 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
|
82 |
)
|
83 |
|
|
|
98 |
result = session.run(query)
|
99 |
for record in result:
|
100 |
result = record["n"]
|
101 |
+
logger.debug(
|
102 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
103 |
)
|
104 |
return result
|
|
|
120 |
record = result.single()
|
121 |
if record:
|
122 |
edge_count = record["totalEdgeCount"]
|
123 |
+
logger.debug(
|
124 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
|
125 |
)
|
126 |
return edge_count
|
|
|
141 |
RETURN count(r) AS degree"""
|
142 |
result = session.run(query)
|
143 |
record = result.single()
|
144 |
+
logger.debug(
|
145 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
|
146 |
)
|
147 |
return record["degree"]
|
|
|
170 |
record = result.single()
|
171 |
if record:
|
172 |
result = dict(record["edge_properties"])
|
173 |
+
logger.debug(
|
174 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
|
175 |
)
|
176 |
return result
|
|
|
241 |
# if source_label and target_label:
|
242 |
# connections.append((source_label, target_label))
|
243 |
|
244 |
+
# logger.debug(
|
245 |
# f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
|
246 |
# )
|
247 |
# return connections
|
|
|
295 |
result = tx.run(query, properties=properties)
|
296 |
record = result.single()
|
297 |
if record:
|
298 |
+
logger.debug(
|
299 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
|
300 |
)
|
301 |
return dict(record["n"])
|
|
|
351 |
"""
|
352 |
|
353 |
result = tx.run(query, properties=edge_properties)
|
354 |
+
logger.debug(
|
355 |
f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
|
356 |
)
|
357 |
return result.single()
|
|
|
372 |
with self._driver.session() as session:
|
373 |
#Define the Cypher query
|
374 |
options = self.global_config["node2vec_params"]
|
375 |
+
logger.debug(f"building embeddings with options {options}")
|
376 |
query = f"""CALL gds.node2vec.write('91fbae6c', {
|
377 |
options
|
378 |
})
|
lightrag/lightrag.py
CHANGED
@@ -28,6 +28,13 @@ from .storage import (
|
|
28 |
from .kg.neo4j_impl import (
|
29 |
GraphStorage as Neo4JStorage
|
30 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
from .utils import (
|
33 |
EmbeddingFunc,
|
@@ -64,7 +71,11 @@ class LightRAG:
|
|
64 |
)
|
65 |
|
66 |
kg: str = field(default="NetworkXStorage")
|
67 |
-
|
|
|
|
|
|
|
|
|
68 |
|
69 |
# text chunking
|
70 |
chunk_token_size: int = 1200
|
@@ -115,13 +126,14 @@ class LightRAG:
|
|
115 |
def __post_init__(self):
|
116 |
log_file = os.path.join(self.working_dir, "lightrag.log")
|
117 |
set_logger(log_file)
|
|
|
|
|
118 |
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
119 |
|
120 |
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
121 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
122 |
|
123 |
#should move all storage setup here to leverage initial start params attached to self.
|
124 |
-
print (f"self.kg set to: {self.kg}")
|
125 |
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
|
126 |
|
127 |
if not os.path.exists(self.working_dir):
|
@@ -176,7 +188,7 @@ class LightRAG:
|
|
176 |
return {
|
177 |
"Neo4JStorage": Neo4JStorage,
|
178 |
"NetworkXStorage": NetworkXStorage,
|
179 |
-
# "
|
180 |
}
|
181 |
|
182 |
def insert(self, string_or_strings):
|
|
|
28 |
from .kg.neo4j_impl import (
|
29 |
GraphStorage as Neo4JStorage
|
30 |
)
|
31 |
+
#future KG integrations
|
32 |
+
|
33 |
+
# from .kg.ArangoDB_impl import (
|
34 |
+
# GraphStorage as ArangoDBStorage
|
35 |
+
# )
|
36 |
+
|
37 |
+
|
38 |
|
39 |
from .utils import (
|
40 |
EmbeddingFunc,
|
|
|
71 |
)
|
72 |
|
73 |
kg: str = field(default="NetworkXStorage")
|
74 |
+
|
75 |
+
current_log_level = logger.level
|
76 |
+
log_level: str = field(default=current_log_level)
|
77 |
+
|
78 |
+
|
79 |
|
80 |
# text chunking
|
81 |
chunk_token_size: int = 1200
|
|
|
126 |
def __post_init__(self):
|
127 |
log_file = os.path.join(self.working_dir, "lightrag.log")
|
128 |
set_logger(log_file)
|
129 |
+
logger.setLevel(self.log_level)
|
130 |
+
|
131 |
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
132 |
|
133 |
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
134 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
135 |
|
136 |
#should move all storage setup here to leverage initial start params attached to self.
|
|
|
137 |
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
|
138 |
|
139 |
if not os.path.exists(self.working_dir):
|
|
|
188 |
return {
|
189 |
"Neo4JStorage": Neo4JStorage,
|
190 |
"NetworkXStorage": NetworkXStorage,
|
191 |
+
# "ArangoDBStorage": ArangoDBStorage
|
192 |
}
|
193 |
|
194 |
def insert(self, string_or_strings):
|
lightrag/operate.py
CHANGED
@@ -71,7 +71,6 @@ async def _handle_entity_relation_summary(
|
|
71 |
use_prompt = prompt_template.format(**context_base)
|
72 |
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
73 |
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
74 |
-
print ("Summarized: {context_base} for entity relationship {} ")
|
75 |
return summary
|
76 |
|
77 |
|
@@ -79,7 +78,6 @@ async def _handle_single_entity_extraction(
|
|
79 |
record_attributes: list[str],
|
80 |
chunk_key: str,
|
81 |
):
|
82 |
-
print (f"_handle_single_entity_extraction {record_attributes} chunk_key {chunk_key}")
|
83 |
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
84 |
return None
|
85 |
# add this record as a node in the G
|
@@ -265,7 +263,6 @@ async def extract_entities(
|
|
265 |
|
266 |
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
267 |
nonlocal already_processed, already_entities, already_relations
|
268 |
-
print (f"kw: processing a single chunk, {chunk_key_dp}")
|
269 |
chunk_key = chunk_key_dp[0]
|
270 |
chunk_dp = chunk_key_dp[1]
|
271 |
content = chunk_dp["content"]
|
@@ -435,7 +432,6 @@ async def local_query(
|
|
435 |
text_chunks_db,
|
436 |
query_param,
|
437 |
)
|
438 |
-
print (f"got the following context {context} based on prompt keywords {keywords}")
|
439 |
if query_param.only_need_context:
|
440 |
return context
|
441 |
if context is None:
|
@@ -444,7 +440,6 @@ async def local_query(
|
|
444 |
sys_prompt = sys_prompt_temp.format(
|
445 |
context_data=context, response_type=query_param.response_type
|
446 |
)
|
447 |
-
print (f"local query:{query} local sysprompt:{sys_prompt}")
|
448 |
response = await use_model_func(
|
449 |
query,
|
450 |
system_prompt=sys_prompt,
|
@@ -470,20 +465,16 @@ async def _build_local_query_context(
|
|
470 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
471 |
query_param: QueryParam,
|
472 |
):
|
473 |
-
print ("kw1: ENTITIES VDB QUERY**********************************")
|
474 |
|
475 |
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
476 |
-
print (f"kw2: ENTITIES VDB QUERY, RESULTS {results}**********************************")
|
477 |
|
478 |
if not len(results):
|
479 |
return None
|
480 |
-
print ("kw3: using entities to get_nodes returned in above vdb query. search results from embedding your query keywords")
|
481 |
node_datas = await asyncio.gather(
|
482 |
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
483 |
)
|
484 |
if not all([n is not None for n in node_datas]):
|
485 |
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
486 |
-
print ("kw4: getting node degrees next for the same entities/nodes")
|
487 |
node_degrees = await asyncio.gather(
|
488 |
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
489 |
)
|
@@ -729,7 +720,6 @@ async def _build_global_query_context(
|
|
729 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
730 |
query_param: QueryParam,
|
731 |
):
|
732 |
-
print ("RELATIONSHIPS VDB QUERY**********************************")
|
733 |
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
734 |
|
735 |
if not len(results):
|
@@ -895,14 +885,12 @@ async def hybrid_query(
|
|
895 |
query_param: QueryParam,
|
896 |
global_config: dict,
|
897 |
) -> str:
|
898 |
-
print ("HYBRID QUERY *********")
|
899 |
low_level_context = None
|
900 |
high_level_context = None
|
901 |
use_model_func = global_config["llm_model_func"]
|
902 |
|
903 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
904 |
kw_prompt = kw_prompt_temp.format(query=query)
|
905 |
-
print ( f"kw:kw_prompt: {kw_prompt}")
|
906 |
|
907 |
result = await use_model_func(kw_prompt)
|
908 |
try:
|
@@ -911,8 +899,6 @@ async def hybrid_query(
|
|
911 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
912 |
hl_keywords = ", ".join(hl_keywords)
|
913 |
ll_keywords = ", ".join(ll_keywords)
|
914 |
-
print (f"High level key words: {hl_keywords}")
|
915 |
-
print (f"Low level key words: {ll_keywords}")
|
916 |
except json.JSONDecodeError:
|
917 |
try:
|
918 |
result = (
|
@@ -942,7 +928,6 @@ async def hybrid_query(
|
|
942 |
query_param,
|
943 |
)
|
944 |
|
945 |
-
print (f"low_level_context: {low_level_context}")
|
946 |
|
947 |
if hl_keywords:
|
948 |
high_level_context = await _build_global_query_context(
|
@@ -953,7 +938,6 @@ async def hybrid_query(
|
|
953 |
text_chunks_db,
|
954 |
query_param,
|
955 |
)
|
956 |
-
print (f"high_level_context: {high_level_context}")
|
957 |
|
958 |
|
959 |
context = combine_contexts(high_level_context, low_level_context)
|
@@ -971,7 +955,6 @@ async def hybrid_query(
|
|
971 |
query,
|
972 |
system_prompt=sys_prompt,
|
973 |
)
|
974 |
-
print (f"kw: got system prompt: {sys_prompt}. got response for that prompt: {response}")
|
975 |
if len(response) > len(sys_prompt):
|
976 |
response = (
|
977 |
response.replace(sys_prompt, "")
|
@@ -1065,12 +1048,10 @@ async def naive_query(
|
|
1065 |
):
|
1066 |
use_model_func = global_config["llm_model_func"]
|
1067 |
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
1068 |
-
print (f"raw chunks from chunks_vdb.query {results}")
|
1069 |
if not len(results):
|
1070 |
return PROMPTS["fail_response"]
|
1071 |
chunks_ids = [r["id"] for r in results]
|
1072 |
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
1073 |
-
print (f"raw chunks from text_chunks_db {chunks} retrieved by id using the above chunk ids from prev chunks_vdb ")
|
1074 |
|
1075 |
|
1076 |
maybe_trun_chunks = truncate_list_by_token_size(
|
|
|
71 |
use_prompt = prompt_template.format(**context_base)
|
72 |
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
73 |
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
|
|
74 |
return summary
|
75 |
|
76 |
|
|
|
78 |
record_attributes: list[str],
|
79 |
chunk_key: str,
|
80 |
):
|
|
|
81 |
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
82 |
return None
|
83 |
# add this record as a node in the G
|
|
|
263 |
|
264 |
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
265 |
nonlocal already_processed, already_entities, already_relations
|
|
|
266 |
chunk_key = chunk_key_dp[0]
|
267 |
chunk_dp = chunk_key_dp[1]
|
268 |
content = chunk_dp["content"]
|
|
|
432 |
text_chunks_db,
|
433 |
query_param,
|
434 |
)
|
|
|
435 |
if query_param.only_need_context:
|
436 |
return context
|
437 |
if context is None:
|
|
|
440 |
sys_prompt = sys_prompt_temp.format(
|
441 |
context_data=context, response_type=query_param.response_type
|
442 |
)
|
|
|
443 |
response = await use_model_func(
|
444 |
query,
|
445 |
system_prompt=sys_prompt,
|
|
|
465 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
466 |
query_param: QueryParam,
|
467 |
):
|
|
|
468 |
|
469 |
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
|
|
470 |
|
471 |
if not len(results):
|
472 |
return None
|
|
|
473 |
node_datas = await asyncio.gather(
|
474 |
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
475 |
)
|
476 |
if not all([n is not None for n in node_datas]):
|
477 |
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
|
|
478 |
node_degrees = await asyncio.gather(
|
479 |
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
480 |
)
|
|
|
720 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
721 |
query_param: QueryParam,
|
722 |
):
|
|
|
723 |
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
724 |
|
725 |
if not len(results):
|
|
|
885 |
query_param: QueryParam,
|
886 |
global_config: dict,
|
887 |
) -> str:
|
|
|
888 |
low_level_context = None
|
889 |
high_level_context = None
|
890 |
use_model_func = global_config["llm_model_func"]
|
891 |
|
892 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
893 |
kw_prompt = kw_prompt_temp.format(query=query)
|
|
|
894 |
|
895 |
result = await use_model_func(kw_prompt)
|
896 |
try:
|
|
|
899 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
900 |
hl_keywords = ", ".join(hl_keywords)
|
901 |
ll_keywords = ", ".join(ll_keywords)
|
|
|
|
|
902 |
except json.JSONDecodeError:
|
903 |
try:
|
904 |
result = (
|
|
|
928 |
query_param,
|
929 |
)
|
930 |
|
|
|
931 |
|
932 |
if hl_keywords:
|
933 |
high_level_context = await _build_global_query_context(
|
|
|
938 |
text_chunks_db,
|
939 |
query_param,
|
940 |
)
|
|
|
941 |
|
942 |
|
943 |
context = combine_contexts(high_level_context, low_level_context)
|
|
|
955 |
query,
|
956 |
system_prompt=sys_prompt,
|
957 |
)
|
|
|
958 |
if len(response) > len(sys_prompt):
|
959 |
response = (
|
960 |
response.replace(sys_prompt, "")
|
|
|
1048 |
):
|
1049 |
use_model_func = global_config["llm_model_func"]
|
1050 |
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
|
|
1051 |
if not len(results):
|
1052 |
return PROMPTS["fail_response"]
|
1053 |
chunks_ids = [r["id"] for r in results]
|
1054 |
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
|
|
1055 |
|
1056 |
|
1057 |
maybe_trun_chunks = truncate_list_by_token_size(
|
testkg.py
CHANGED
@@ -16,12 +16,13 @@ if not os.path.exists(WORKING_DIR):
|
|
16 |
rag = LightRAG(
|
17 |
working_dir=WORKING_DIR,
|
18 |
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
|
19 |
-
kg="Neo4JStorage"
|
|
|
20 |
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
|
21 |
)
|
22 |
|
23 |
-
with open("./book.txt") as f:
|
24 |
-
|
25 |
|
26 |
# Perform naive search
|
27 |
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
|
|
16 |
rag = LightRAG(
|
17 |
working_dir=WORKING_DIR,
|
18 |
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
|
19 |
+
kg="Neo4JStorage",
|
20 |
+
log_level="INFO"
|
21 |
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
|
22 |
)
|
23 |
|
24 |
+
# with open("./book.txt") as f:
|
25 |
+
# rag.insert(f.read())
|
26 |
|
27 |
# Perform naive search
|
28 |
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|