Ken Wiltshire
commited on
Commit
·
e7da6e0
1
Parent(s):
45c89eb
set kg by start param, defaults to networkx
Browse files- .gitignore +0 -1
- lightrag/kg/__init__.py +1 -1
- lightrag/kg/{neo4j.py → neo4j_impl.py} +0 -0
- lightrag/lightrag.py +18 -9
- lightrag/operate.py +25 -1
- lightrag/storage.py +2 -0
- testkg.py +2 -1
.gitignore
CHANGED
@@ -7,5 +7,4 @@ lightrag-dev/
|
|
7 |
dist/
|
8 |
env/
|
9 |
local_neo4jWorkDir/
|
10 |
-
local_neo4jWorkDir.bak/
|
11 |
neo4jWorkDir/
|
|
|
7 |
dist/
|
8 |
env/
|
9 |
local_neo4jWorkDir/
|
|
|
10 |
neo4jWorkDir/
|
lightrag/kg/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
print ("init package vars here. ......")
|
2 |
-
from .neo4j import GraphStorage as Neo4JStorage
|
3 |
|
4 |
|
5 |
# import sys
|
|
|
1 |
print ("init package vars here. ......")
|
2 |
+
# from .neo4j import GraphStorage as Neo4JStorage
|
3 |
|
4 |
|
5 |
# import sys
|
lightrag/kg/{neo4j.py → neo4j_impl.py}
RENAMED
File without changes
|
lightrag/lightrag.py
CHANGED
@@ -25,7 +25,7 @@ from .storage import (
|
|
25 |
NetworkXStorage,
|
26 |
)
|
27 |
|
28 |
-
from .kg.
|
29 |
GraphStorage as Neo4JStorage
|
30 |
)
|
31 |
|
@@ -58,10 +58,14 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
58 |
|
59 |
@dataclass
|
60 |
class LightRAG:
|
|
|
61 |
working_dir: str = field(
|
62 |
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
63 |
)
|
64 |
|
|
|
|
|
|
|
65 |
# text chunking
|
66 |
chunk_token_size: int = 1200
|
67 |
chunk_overlap_token_size: int = 100
|
@@ -99,20 +103,15 @@ class LightRAG:
|
|
99 |
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
100 |
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
|
101 |
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
102 |
-
|
103 |
-
# module = importlib.import_module('kg.neo4j')
|
104 |
-
# Neo4JStorage = getattr(module, 'GraphStorage')
|
105 |
-
if True==True:
|
106 |
-
print ("using KG")
|
107 |
-
graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
|
108 |
-
else:
|
109 |
-
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
|
110 |
enable_llm_cache: bool = True
|
111 |
|
112 |
# extension
|
113 |
addon_params: dict = field(default_factory=dict)
|
114 |
convert_response_to_json_func: callable = convert_response_to_json
|
115 |
|
|
|
|
|
|
|
116 |
def __post_init__(self):
|
117 |
log_file = os.path.join(self.working_dir, "lightrag.log")
|
118 |
set_logger(log_file)
|
@@ -121,6 +120,10 @@ class LightRAG:
|
|
121 |
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
122 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
123 |
|
|
|
|
|
|
|
|
|
124 |
if not os.path.exists(self.working_dir):
|
125 |
logger.info(f"Creating working directory {self.working_dir}")
|
126 |
os.makedirs(self.working_dir)
|
@@ -169,6 +172,12 @@ class LightRAG:
|
|
169 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
170 |
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
171 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
def insert(self, string_or_strings):
|
174 |
loop = always_get_an_event_loop()
|
|
|
25 |
NetworkXStorage,
|
26 |
)
|
27 |
|
28 |
+
from .kg.neo4j_impl import (
|
29 |
GraphStorage as Neo4JStorage
|
30 |
)
|
31 |
|
|
|
58 |
|
59 |
@dataclass
|
60 |
class LightRAG:
|
61 |
+
|
62 |
working_dir: str = field(
|
63 |
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
64 |
)
|
65 |
|
66 |
+
kg: str = field(default="NetworkXStorage")
|
67 |
+
|
68 |
+
|
69 |
# text chunking
|
70 |
chunk_token_size: int = 1200
|
71 |
chunk_overlap_token_size: int = 100
|
|
|
103 |
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
104 |
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
|
105 |
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
enable_llm_cache: bool = True
|
107 |
|
108 |
# extension
|
109 |
addon_params: dict = field(default_factory=dict)
|
110 |
convert_response_to_json_func: callable = convert_response_to_json
|
111 |
|
112 |
+
# def get_configured_KG(self):
|
113 |
+
# return self.kg
|
114 |
+
|
115 |
def __post_init__(self):
|
116 |
log_file = os.path.join(self.working_dir, "lightrag.log")
|
117 |
set_logger(log_file)
|
|
|
120 |
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
121 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
122 |
|
123 |
+
#should move all storage setup here to leverage initial start params attached to self.
|
124 |
+
print (f"self.kg set to: {self.kg}")
|
125 |
+
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
|
126 |
+
|
127 |
if not os.path.exists(self.working_dir):
|
128 |
logger.info(f"Creating working directory {self.working_dir}")
|
129 |
os.makedirs(self.working_dir)
|
|
|
172 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
173 |
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
174 |
)
|
175 |
+
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
176 |
+
return {
|
177 |
+
"Neo4JStorage": Neo4JStorage,
|
178 |
+
"NetworkXStorage": NetworkXStorage,
|
179 |
+
# "new_kg_here": KGClass
|
180 |
+
}
|
181 |
|
182 |
def insert(self, string_or_strings):
|
183 |
loop = always_get_an_event_loop()
|
lightrag/operate.py
CHANGED
@@ -71,6 +71,7 @@ async def _handle_entity_relation_summary(
|
|
71 |
use_prompt = prompt_template.format(**context_base)
|
72 |
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
73 |
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
|
|
74 |
return summary
|
75 |
|
76 |
|
@@ -78,6 +79,7 @@ async def _handle_single_entity_extraction(
|
|
78 |
record_attributes: list[str],
|
79 |
chunk_key: str,
|
80 |
):
|
|
|
81 |
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
82 |
return None
|
83 |
# add this record as a node in the G
|
@@ -263,6 +265,7 @@ async def extract_entities(
|
|
263 |
|
264 |
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
265 |
nonlocal already_processed, already_entities, already_relations
|
|
|
266 |
chunk_key = chunk_key_dp[0]
|
267 |
chunk_dp = chunk_key_dp[1]
|
268 |
content = chunk_dp["content"]
|
@@ -432,6 +435,7 @@ async def local_query(
|
|
432 |
text_chunks_db,
|
433 |
query_param,
|
434 |
)
|
|
|
435 |
if query_param.only_need_context:
|
436 |
return context
|
437 |
if context is None:
|
@@ -440,6 +444,7 @@ async def local_query(
|
|
440 |
sys_prompt = sys_prompt_temp.format(
|
441 |
context_data=context, response_type=query_param.response_type
|
442 |
)
|
|
|
443 |
response = await use_model_func(
|
444 |
query,
|
445 |
system_prompt=sys_prompt,
|
@@ -465,14 +470,20 @@ async def _build_local_query_context(
|
|
465 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
466 |
query_param: QueryParam,
|
467 |
):
|
|
|
|
|
468 |
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
|
|
|
|
469 |
if not len(results):
|
470 |
return None
|
|
|
471 |
node_datas = await asyncio.gather(
|
472 |
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
473 |
)
|
474 |
if not all([n is not None for n in node_datas]):
|
475 |
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
|
|
476 |
node_degrees = await asyncio.gather(
|
477 |
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
478 |
)
|
@@ -480,7 +491,7 @@ async def _build_local_query_context(
|
|
480 |
{**n, "entity_name": k["entity_name"], "rank": d}
|
481 |
for k, n, d in zip(results, node_datas, node_degrees)
|
482 |
if n is not None
|
483 |
-
]
|
484 |
use_text_units = await _find_most_related_text_unit_from_entities(
|
485 |
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
486 |
)
|
@@ -718,6 +729,7 @@ async def _build_global_query_context(
|
|
718 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
719 |
query_param: QueryParam,
|
720 |
):
|
|
|
721 |
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
722 |
|
723 |
if not len(results):
|
@@ -883,12 +895,14 @@ async def hybrid_query(
|
|
883 |
query_param: QueryParam,
|
884 |
global_config: dict,
|
885 |
) -> str:
|
|
|
886 |
low_level_context = None
|
887 |
high_level_context = None
|
888 |
use_model_func = global_config["llm_model_func"]
|
889 |
|
890 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
891 |
kw_prompt = kw_prompt_temp.format(query=query)
|
|
|
892 |
|
893 |
result = await use_model_func(kw_prompt)
|
894 |
try:
|
@@ -897,6 +911,8 @@ async def hybrid_query(
|
|
897 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
898 |
hl_keywords = ", ".join(hl_keywords)
|
899 |
ll_keywords = ", ".join(ll_keywords)
|
|
|
|
|
900 |
except json.JSONDecodeError:
|
901 |
try:
|
902 |
result = (
|
@@ -926,6 +942,8 @@ async def hybrid_query(
|
|
926 |
query_param,
|
927 |
)
|
928 |
|
|
|
|
|
929 |
if hl_keywords:
|
930 |
high_level_context = await _build_global_query_context(
|
931 |
hl_keywords,
|
@@ -935,6 +953,8 @@ async def hybrid_query(
|
|
935 |
text_chunks_db,
|
936 |
query_param,
|
937 |
)
|
|
|
|
|
938 |
|
939 |
context = combine_contexts(high_level_context, low_level_context)
|
940 |
|
@@ -951,6 +971,7 @@ async def hybrid_query(
|
|
951 |
query,
|
952 |
system_prompt=sys_prompt,
|
953 |
)
|
|
|
954 |
if len(response) > len(sys_prompt):
|
955 |
response = (
|
956 |
response.replace(sys_prompt, "")
|
@@ -1044,10 +1065,13 @@ async def naive_query(
|
|
1044 |
):
|
1045 |
use_model_func = global_config["llm_model_func"]
|
1046 |
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
|
|
1047 |
if not len(results):
|
1048 |
return PROMPTS["fail_response"]
|
1049 |
chunks_ids = [r["id"] for r in results]
|
1050 |
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
|
|
|
|
1051 |
|
1052 |
maybe_trun_chunks = truncate_list_by_token_size(
|
1053 |
chunks,
|
|
|
71 |
use_prompt = prompt_template.format(**context_base)
|
72 |
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
73 |
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
74 |
+
print ("Summarized: {context_base} for entity relationship {} ")
|
75 |
return summary
|
76 |
|
77 |
|
|
|
79 |
record_attributes: list[str],
|
80 |
chunk_key: str,
|
81 |
):
|
82 |
+
print (f"_handle_single_entity_extraction {record_attributes} chunk_key {chunk_key}")
|
83 |
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
84 |
return None
|
85 |
# add this record as a node in the G
|
|
|
265 |
|
266 |
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
267 |
nonlocal already_processed, already_entities, already_relations
|
268 |
+
print (f"kw: processing a single chunk, {chunk_key_dp}")
|
269 |
chunk_key = chunk_key_dp[0]
|
270 |
chunk_dp = chunk_key_dp[1]
|
271 |
content = chunk_dp["content"]
|
|
|
435 |
text_chunks_db,
|
436 |
query_param,
|
437 |
)
|
438 |
+
print (f"got the following context {context} based on prompt keywords {keywords}")
|
439 |
if query_param.only_need_context:
|
440 |
return context
|
441 |
if context is None:
|
|
|
444 |
sys_prompt = sys_prompt_temp.format(
|
445 |
context_data=context, response_type=query_param.response_type
|
446 |
)
|
447 |
+
print (f"local query:{query} local sysprompt:{sys_prompt}")
|
448 |
response = await use_model_func(
|
449 |
query,
|
450 |
system_prompt=sys_prompt,
|
|
|
470 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
471 |
query_param: QueryParam,
|
472 |
):
|
473 |
+
print ("kw1: ENTITIES VDB QUERY**********************************")
|
474 |
+
|
475 |
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
476 |
+
print (f"kw2: ENTITIES VDB QUERY, RESULTS {results}**********************************")
|
477 |
+
|
478 |
if not len(results):
|
479 |
return None
|
480 |
+
print ("kw3: using entities to get_nodes returned in above vdb query. search results from embedding your query keywords")
|
481 |
node_datas = await asyncio.gather(
|
482 |
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
483 |
)
|
484 |
if not all([n is not None for n in node_datas]):
|
485 |
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
486 |
+
print ("kw4: getting node degrees next for the same entities/nodes")
|
487 |
node_degrees = await asyncio.gather(
|
488 |
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
489 |
)
|
|
|
491 |
{**n, "entity_name": k["entity_name"], "rank": d}
|
492 |
for k, n, d in zip(results, node_datas, node_degrees)
|
493 |
if n is not None
|
494 |
+
]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
495 |
use_text_units = await _find_most_related_text_unit_from_entities(
|
496 |
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
497 |
)
|
|
|
729 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
730 |
query_param: QueryParam,
|
731 |
):
|
732 |
+
print ("RELATIONSHIPS VDB QUERY**********************************")
|
733 |
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
734 |
|
735 |
if not len(results):
|
|
|
895 |
query_param: QueryParam,
|
896 |
global_config: dict,
|
897 |
) -> str:
|
898 |
+
print ("HYBRID QUERY *********")
|
899 |
low_level_context = None
|
900 |
high_level_context = None
|
901 |
use_model_func = global_config["llm_model_func"]
|
902 |
|
903 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
904 |
kw_prompt = kw_prompt_temp.format(query=query)
|
905 |
+
print ( f"kw:kw_prompt: {kw_prompt}")
|
906 |
|
907 |
result = await use_model_func(kw_prompt)
|
908 |
try:
|
|
|
911 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
912 |
hl_keywords = ", ".join(hl_keywords)
|
913 |
ll_keywords = ", ".join(ll_keywords)
|
914 |
+
print (f"High level key words: {hl_keywords}")
|
915 |
+
print (f"Low level key words: {ll_keywords}")
|
916 |
except json.JSONDecodeError:
|
917 |
try:
|
918 |
result = (
|
|
|
942 |
query_param,
|
943 |
)
|
944 |
|
945 |
+
print (f"low_level_context: {low_level_context}")
|
946 |
+
|
947 |
if hl_keywords:
|
948 |
high_level_context = await _build_global_query_context(
|
949 |
hl_keywords,
|
|
|
953 |
text_chunks_db,
|
954 |
query_param,
|
955 |
)
|
956 |
+
print (f"high_level_context: {high_level_context}")
|
957 |
+
|
958 |
|
959 |
context = combine_contexts(high_level_context, low_level_context)
|
960 |
|
|
|
971 |
query,
|
972 |
system_prompt=sys_prompt,
|
973 |
)
|
974 |
+
print (f"kw: got system prompt: {sys_prompt}. got response for that prompt: {response}")
|
975 |
if len(response) > len(sys_prompt):
|
976 |
response = (
|
977 |
response.replace(sys_prompt, "")
|
|
|
1065 |
):
|
1066 |
use_model_func = global_config["llm_model_func"]
|
1067 |
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
1068 |
+
print (f"raw chunks from chunks_vdb.query {results}")
|
1069 |
if not len(results):
|
1070 |
return PROMPTS["fail_response"]
|
1071 |
chunks_ids = [r["id"] for r in results]
|
1072 |
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
1073 |
+
print (f"raw chunks from text_chunks_db {chunks} retrieved by id using the above chunk ids from prev chunks_vdb ")
|
1074 |
+
|
1075 |
|
1076 |
maybe_trun_chunks = truncate_list_by_token_size(
|
1077 |
chunks,
|
lightrag/storage.py
CHANGED
@@ -95,6 +95,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
95 |
embeddings = np.concatenate(embeddings_list)
|
96 |
for i, d in enumerate(list_data):
|
97 |
d["__vector__"] = embeddings[i]
|
|
|
98 |
results = self._client.upsert(datas=list_data)
|
99 |
return results
|
100 |
|
@@ -109,6 +110,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
109 |
results = [
|
110 |
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
111 |
]
|
|
|
112 |
return results
|
113 |
|
114 |
async def index_done_callback(self):
|
|
|
95 |
embeddings = np.concatenate(embeddings_list)
|
96 |
for i, d in enumerate(list_data):
|
97 |
d["__vector__"] = embeddings[i]
|
98 |
+
print (f"Upserting to vector: {list_data}")
|
99 |
results = self._client.upsert(datas=list_data)
|
100 |
return results
|
101 |
|
|
|
110 |
results = [
|
111 |
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
112 |
]
|
113 |
+
print (f"vector db results {results} for query {query}")
|
114 |
return results
|
115 |
|
116 |
async def index_done_callback(self):
|
testkg.py
CHANGED
@@ -15,7 +15,8 @@ if not os.path.exists(WORKING_DIR):
|
|
15 |
|
16 |
rag = LightRAG(
|
17 |
working_dir=WORKING_DIR,
|
18 |
-
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
|
|
|
19 |
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
|
20 |
)
|
21 |
|
|
|
15 |
|
16 |
rag = LightRAG(
|
17 |
working_dir=WORKING_DIR,
|
18 |
+
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
|
19 |
+
kg="Neo4JStorage"
|
20 |
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
|
21 |
)
|
22 |
|