Ken Wiltshire commited on
Commit
e7da6e0
·
1 Parent(s): 45c89eb

set kg by start param, defaults to networkx

Browse files
.gitignore CHANGED
@@ -7,5 +7,4 @@ lightrag-dev/
7
  dist/
8
  env/
9
  local_neo4jWorkDir/
10
- local_neo4jWorkDir.bak/
11
  neo4jWorkDir/
 
7
  dist/
8
  env/
9
  local_neo4jWorkDir/
 
10
  neo4jWorkDir/
lightrag/kg/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  print ("init package vars here. ......")
2
- from .neo4j import GraphStorage as Neo4JStorage
3
 
4
 
5
  # import sys
 
1
  print ("init package vars here. ......")
2
+ # from .neo4j import GraphStorage as Neo4JStorage
3
 
4
 
5
  # import sys
lightrag/kg/{neo4j.py → neo4j_impl.py} RENAMED
File without changes
lightrag/lightrag.py CHANGED
@@ -25,7 +25,7 @@ from .storage import (
25
  NetworkXStorage,
26
  )
27
 
28
- from .kg.neo4j import (
29
  GraphStorage as Neo4JStorage
30
  )
31
 
@@ -58,10 +58,14 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
58
 
59
  @dataclass
60
  class LightRAG:
 
61
  working_dir: str = field(
62
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
63
  )
64
 
 
 
 
65
  # text chunking
66
  chunk_token_size: int = 1200
67
  chunk_overlap_token_size: int = 100
@@ -99,20 +103,15 @@ class LightRAG:
99
  key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
100
  vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
101
  vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
102
-
103
- # module = importlib.import_module('kg.neo4j')
104
- # Neo4JStorage = getattr(module, 'GraphStorage')
105
- if True==True:
106
- print ("using KG")
107
- graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
108
- else:
109
- graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
110
  enable_llm_cache: bool = True
111
 
112
  # extension
113
  addon_params: dict = field(default_factory=dict)
114
  convert_response_to_json_func: callable = convert_response_to_json
115
 
 
 
 
116
  def __post_init__(self):
117
  log_file = os.path.join(self.working_dir, "lightrag.log")
118
  set_logger(log_file)
@@ -121,6 +120,10 @@ class LightRAG:
121
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
122
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
123
 
 
 
 
 
124
  if not os.path.exists(self.working_dir):
125
  logger.info(f"Creating working directory {self.working_dir}")
126
  os.makedirs(self.working_dir)
@@ -169,6 +172,12 @@ class LightRAG:
169
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
170
  partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
171
  )
 
 
 
 
 
 
172
 
173
  def insert(self, string_or_strings):
174
  loop = always_get_an_event_loop()
 
25
  NetworkXStorage,
26
  )
27
 
28
+ from .kg.neo4j_impl import (
29
  GraphStorage as Neo4JStorage
30
  )
31
 
 
58
 
59
  @dataclass
60
  class LightRAG:
61
+
62
  working_dir: str = field(
63
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
64
  )
65
 
66
+ kg: str = field(default="NetworkXStorage")
67
+
68
+
69
  # text chunking
70
  chunk_token_size: int = 1200
71
  chunk_overlap_token_size: int = 100
 
103
  key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
104
  vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
105
  vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
 
 
 
 
 
 
 
 
106
  enable_llm_cache: bool = True
107
 
108
  # extension
109
  addon_params: dict = field(default_factory=dict)
110
  convert_response_to_json_func: callable = convert_response_to_json
111
 
112
+ # def get_configured_KG(self):
113
+ # return self.kg
114
+
115
  def __post_init__(self):
116
  log_file = os.path.join(self.working_dir, "lightrag.log")
117
  set_logger(log_file)
 
120
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
121
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
122
 
123
+ #should move all storage setup here to leverage initial start params attached to self.
124
+ print (f"self.kg set to: {self.kg}")
125
+ self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
126
+
127
  if not os.path.exists(self.working_dir):
128
  logger.info(f"Creating working directory {self.working_dir}")
129
  os.makedirs(self.working_dir)
 
172
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
173
  partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
174
  )
175
+ def _get_storage_class(self) -> Type[BaseGraphStorage]:
176
+ return {
177
+ "Neo4JStorage": Neo4JStorage,
178
+ "NetworkXStorage": NetworkXStorage,
179
+ # "new_kg_here": KGClass
180
+ }
181
 
182
  def insert(self, string_or_strings):
183
  loop = always_get_an_event_loop()
lightrag/operate.py CHANGED
@@ -71,6 +71,7 @@ async def _handle_entity_relation_summary(
71
  use_prompt = prompt_template.format(**context_base)
72
  logger.debug(f"Trigger summary: {entity_or_relation_name}")
73
  summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
 
74
  return summary
75
 
76
 
@@ -78,6 +79,7 @@ async def _handle_single_entity_extraction(
78
  record_attributes: list[str],
79
  chunk_key: str,
80
  ):
 
81
  if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
82
  return None
83
  # add this record as a node in the G
@@ -263,6 +265,7 @@ async def extract_entities(
263
 
264
  async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
265
  nonlocal already_processed, already_entities, already_relations
 
266
  chunk_key = chunk_key_dp[0]
267
  chunk_dp = chunk_key_dp[1]
268
  content = chunk_dp["content"]
@@ -432,6 +435,7 @@ async def local_query(
432
  text_chunks_db,
433
  query_param,
434
  )
 
435
  if query_param.only_need_context:
436
  return context
437
  if context is None:
@@ -440,6 +444,7 @@ async def local_query(
440
  sys_prompt = sys_prompt_temp.format(
441
  context_data=context, response_type=query_param.response_type
442
  )
 
443
  response = await use_model_func(
444
  query,
445
  system_prompt=sys_prompt,
@@ -465,14 +470,20 @@ async def _build_local_query_context(
465
  text_chunks_db: BaseKVStorage[TextChunkSchema],
466
  query_param: QueryParam,
467
  ):
 
 
468
  results = await entities_vdb.query(query, top_k=query_param.top_k)
 
 
469
  if not len(results):
470
  return None
 
471
  node_datas = await asyncio.gather(
472
  *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
473
  )
474
  if not all([n is not None for n in node_datas]):
475
  logger.warning("Some nodes are missing, maybe the storage is damaged")
 
476
  node_degrees = await asyncio.gather(
477
  *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
478
  )
@@ -480,7 +491,7 @@ async def _build_local_query_context(
480
  {**n, "entity_name": k["entity_name"], "rank": d}
481
  for k, n, d in zip(results, node_datas, node_degrees)
482
  if n is not None
483
- ]
484
  use_text_units = await _find_most_related_text_unit_from_entities(
485
  node_datas, query_param, text_chunks_db, knowledge_graph_inst
486
  )
@@ -718,6 +729,7 @@ async def _build_global_query_context(
718
  text_chunks_db: BaseKVStorage[TextChunkSchema],
719
  query_param: QueryParam,
720
  ):
 
721
  results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
722
 
723
  if not len(results):
@@ -883,12 +895,14 @@ async def hybrid_query(
883
  query_param: QueryParam,
884
  global_config: dict,
885
  ) -> str:
 
886
  low_level_context = None
887
  high_level_context = None
888
  use_model_func = global_config["llm_model_func"]
889
 
890
  kw_prompt_temp = PROMPTS["keywords_extraction"]
891
  kw_prompt = kw_prompt_temp.format(query=query)
 
892
 
893
  result = await use_model_func(kw_prompt)
894
  try:
@@ -897,6 +911,8 @@ async def hybrid_query(
897
  ll_keywords = keywords_data.get("low_level_keywords", [])
898
  hl_keywords = ", ".join(hl_keywords)
899
  ll_keywords = ", ".join(ll_keywords)
 
 
900
  except json.JSONDecodeError:
901
  try:
902
  result = (
@@ -926,6 +942,8 @@ async def hybrid_query(
926
  query_param,
927
  )
928
 
 
 
929
  if hl_keywords:
930
  high_level_context = await _build_global_query_context(
931
  hl_keywords,
@@ -935,6 +953,8 @@ async def hybrid_query(
935
  text_chunks_db,
936
  query_param,
937
  )
 
 
938
 
939
  context = combine_contexts(high_level_context, low_level_context)
940
 
@@ -951,6 +971,7 @@ async def hybrid_query(
951
  query,
952
  system_prompt=sys_prompt,
953
  )
 
954
  if len(response) > len(sys_prompt):
955
  response = (
956
  response.replace(sys_prompt, "")
@@ -1044,10 +1065,13 @@ async def naive_query(
1044
  ):
1045
  use_model_func = global_config["llm_model_func"]
1046
  results = await chunks_vdb.query(query, top_k=query_param.top_k)
 
1047
  if not len(results):
1048
  return PROMPTS["fail_response"]
1049
  chunks_ids = [r["id"] for r in results]
1050
  chunks = await text_chunks_db.get_by_ids(chunks_ids)
 
 
1051
 
1052
  maybe_trun_chunks = truncate_list_by_token_size(
1053
  chunks,
 
71
  use_prompt = prompt_template.format(**context_base)
72
  logger.debug(f"Trigger summary: {entity_or_relation_name}")
73
  summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
74
+ print ("Summarized: {context_base} for entity relationship {} ")
75
  return summary
76
 
77
 
 
79
  record_attributes: list[str],
80
  chunk_key: str,
81
  ):
82
+ print (f"_handle_single_entity_extraction {record_attributes} chunk_key {chunk_key}")
83
  if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
84
  return None
85
  # add this record as a node in the G
 
265
 
266
  async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
267
  nonlocal already_processed, already_entities, already_relations
268
+ print (f"kw: processing a single chunk, {chunk_key_dp}")
269
  chunk_key = chunk_key_dp[0]
270
  chunk_dp = chunk_key_dp[1]
271
  content = chunk_dp["content"]
 
435
  text_chunks_db,
436
  query_param,
437
  )
438
+ print (f"got the following context {context} based on prompt keywords {keywords}")
439
  if query_param.only_need_context:
440
  return context
441
  if context is None:
 
444
  sys_prompt = sys_prompt_temp.format(
445
  context_data=context, response_type=query_param.response_type
446
  )
447
+ print (f"local query:{query} local sysprompt:{sys_prompt}")
448
  response = await use_model_func(
449
  query,
450
  system_prompt=sys_prompt,
 
470
  text_chunks_db: BaseKVStorage[TextChunkSchema],
471
  query_param: QueryParam,
472
  ):
473
+ print ("kw1: ENTITIES VDB QUERY**********************************")
474
+
475
  results = await entities_vdb.query(query, top_k=query_param.top_k)
476
+ print (f"kw2: ENTITIES VDB QUERY, RESULTS {results}**********************************")
477
+
478
  if not len(results):
479
  return None
480
+ print ("kw3: using entities to get_nodes returned in above vdb query. search results from embedding your query keywords")
481
  node_datas = await asyncio.gather(
482
  *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
483
  )
484
  if not all([n is not None for n in node_datas]):
485
  logger.warning("Some nodes are missing, maybe the storage is damaged")
486
+ print ("kw4: getting node degrees next for the same entities/nodes")
487
  node_degrees = await asyncio.gather(
488
  *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
489
  )
 
491
  {**n, "entity_name": k["entity_name"], "rank": d}
492
  for k, n, d in zip(results, node_datas, node_degrees)
493
  if n is not None
494
+ ]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
495
  use_text_units = await _find_most_related_text_unit_from_entities(
496
  node_datas, query_param, text_chunks_db, knowledge_graph_inst
497
  )
 
729
  text_chunks_db: BaseKVStorage[TextChunkSchema],
730
  query_param: QueryParam,
731
  ):
732
+ print ("RELATIONSHIPS VDB QUERY**********************************")
733
  results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
734
 
735
  if not len(results):
 
895
  query_param: QueryParam,
896
  global_config: dict,
897
  ) -> str:
898
+ print ("HYBRID QUERY *********")
899
  low_level_context = None
900
  high_level_context = None
901
  use_model_func = global_config["llm_model_func"]
902
 
903
  kw_prompt_temp = PROMPTS["keywords_extraction"]
904
  kw_prompt = kw_prompt_temp.format(query=query)
905
+ print ( f"kw:kw_prompt: {kw_prompt}")
906
 
907
  result = await use_model_func(kw_prompt)
908
  try:
 
911
  ll_keywords = keywords_data.get("low_level_keywords", [])
912
  hl_keywords = ", ".join(hl_keywords)
913
  ll_keywords = ", ".join(ll_keywords)
914
+ print (f"High level key words: {hl_keywords}")
915
+ print (f"Low level key words: {ll_keywords}")
916
  except json.JSONDecodeError:
917
  try:
918
  result = (
 
942
  query_param,
943
  )
944
 
945
+ print (f"low_level_context: {low_level_context}")
946
+
947
  if hl_keywords:
948
  high_level_context = await _build_global_query_context(
949
  hl_keywords,
 
953
  text_chunks_db,
954
  query_param,
955
  )
956
+ print (f"high_level_context: {high_level_context}")
957
+
958
 
959
  context = combine_contexts(high_level_context, low_level_context)
960
 
 
971
  query,
972
  system_prompt=sys_prompt,
973
  )
974
+ print (f"kw: got system prompt: {sys_prompt}. got response for that prompt: {response}")
975
  if len(response) > len(sys_prompt):
976
  response = (
977
  response.replace(sys_prompt, "")
 
1065
  ):
1066
  use_model_func = global_config["llm_model_func"]
1067
  results = await chunks_vdb.query(query, top_k=query_param.top_k)
1068
+ print (f"raw chunks from chunks_vdb.query {results}")
1069
  if not len(results):
1070
  return PROMPTS["fail_response"]
1071
  chunks_ids = [r["id"] for r in results]
1072
  chunks = await text_chunks_db.get_by_ids(chunks_ids)
1073
+ print (f"raw chunks from text_chunks_db {chunks} retrieved by id using the above chunk ids from prev chunks_vdb ")
1074
+
1075
 
1076
  maybe_trun_chunks = truncate_list_by_token_size(
1077
  chunks,
lightrag/storage.py CHANGED
@@ -95,6 +95,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
95
  embeddings = np.concatenate(embeddings_list)
96
  for i, d in enumerate(list_data):
97
  d["__vector__"] = embeddings[i]
 
98
  results = self._client.upsert(datas=list_data)
99
  return results
100
 
@@ -109,6 +110,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
109
  results = [
110
  {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
111
  ]
 
112
  return results
113
 
114
  async def index_done_callback(self):
 
95
  embeddings = np.concatenate(embeddings_list)
96
  for i, d in enumerate(list_data):
97
  d["__vector__"] = embeddings[i]
98
+ print (f"Upserting to vector: {list_data}")
99
  results = self._client.upsert(datas=list_data)
100
  return results
101
 
 
110
  results = [
111
  {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
112
  ]
113
+ print (f"vector db results {results} for query {query}")
114
  return results
115
 
116
  async def index_done_callback(self):
testkg.py CHANGED
@@ -15,7 +15,8 @@ if not os.path.exists(WORKING_DIR):
15
 
16
  rag = LightRAG(
17
  working_dir=WORKING_DIR,
18
- llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
 
19
  # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
20
  )
21
 
 
15
 
16
  rag = LightRAG(
17
  working_dir=WORKING_DIR,
18
+ llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
19
+ kg="Neo4JStorage"
20
  # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
21
  )
22