Ken Wiltshire commited on
Commit
b2bf632
·
1 Parent(s): e7da6e0

securing for production with env vars for creds

Browse files
.gitignore CHANGED
@@ -7,4 +7,5 @@ lightrag-dev/
7
  dist/
8
  env/
9
  local_neo4jWorkDir/
10
- neo4jWorkDir/
 
 
7
  dist/
8
  env/
9
  local_neo4jWorkDir/
10
+ neo4jWorkDir/
11
+ ignore_this.txt
lightrag/kg/neo4j_impl.py CHANGED
@@ -5,7 +5,6 @@ from dataclasses import dataclass
5
  from typing import Any, Union, cast
6
  import numpy as np
7
  import inspect
8
- # import package.common.utils as utils
9
  from lightrag.utils import load_json, logger, write_json
10
  from ..base import (
11
  BaseGraphStorage
@@ -22,27 +21,6 @@ from tenacity import (
22
 
23
 
24
 
25
-
26
-
27
- # @TODO: catch and retry "ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address"
28
- # during indexing.
29
-
30
-
31
-
32
-
33
- # Replace with your actual URI, username, and password
34
- #local
35
- URI = "neo4j://localhost:7687"
36
- USERNAME = "neo4j"
37
- PASSWORD = "password"
38
-
39
- #aura
40
- # URI = "neo4j+s://91fbae6c.databases.neo4j.io"
41
- # USERNAME = "neo4j"
42
- # PASSWORD = "KWKPXfXcClDbUlmDdGgIQhU5mL1N4E_2CJp2BDFbEbw"
43
- # Create a driver object
44
-
45
-
46
  @dataclass
47
  class GraphStorage(BaseGraphStorage):
48
  @staticmethod
@@ -51,6 +29,15 @@ class GraphStorage(BaseGraphStorage):
51
 
52
  def __post_init__(self):
53
  # self._graph = preloaded_graph or nx.Graph()
 
 
 
 
 
 
 
 
 
54
  self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
55
  self._node_embed_algorithms = {
56
  "node2vec": self._node2vec_embed,
@@ -65,7 +52,7 @@ class GraphStorage(BaseGraphStorage):
65
  query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
66
  result = tx.run(query)
67
  single_result = result.single()
68
- logger.info(
69
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
70
  )
71
 
@@ -90,7 +77,7 @@ class GraphStorage(BaseGraphStorage):
90
  # if result.single() == None:
91
  # print (f"this should not happen: ---- {label1}/{label2} {query}")
92
 
93
- logger.info(
94
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
95
  )
96
 
@@ -111,7 +98,7 @@ class GraphStorage(BaseGraphStorage):
111
  result = session.run(query)
112
  for record in result:
113
  result = record["n"]
114
- logger.info(
115
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
116
  )
117
  return result
@@ -133,7 +120,7 @@ class GraphStorage(BaseGraphStorage):
133
  record = result.single()
134
  if record:
135
  edge_count = record["totalEdgeCount"]
136
- logger.info(
137
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
138
  )
139
  return edge_count
@@ -154,7 +141,7 @@ class GraphStorage(BaseGraphStorage):
154
  RETURN count(r) AS degree"""
155
  result = session.run(query)
156
  record = result.single()
157
- logger.info(
158
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
159
  )
160
  return record["degree"]
@@ -183,7 +170,7 @@ class GraphStorage(BaseGraphStorage):
183
  record = result.single()
184
  if record:
185
  result = dict(record["edge_properties"])
186
- logger.info(
187
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
188
  )
189
  return result
@@ -254,7 +241,7 @@ class GraphStorage(BaseGraphStorage):
254
  # if source_label and target_label:
255
  # connections.append((source_label, target_label))
256
 
257
- # logger.info(
258
  # f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
259
  # )
260
  # return connections
@@ -308,7 +295,7 @@ class GraphStorage(BaseGraphStorage):
308
  result = tx.run(query, properties=properties)
309
  record = result.single()
310
  if record:
311
- logger.info(
312
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
313
  )
314
  return dict(record["n"])
@@ -364,7 +351,7 @@ class GraphStorage(BaseGraphStorage):
364
  """
365
 
366
  result = tx.run(query, properties=edge_properties)
367
- logger.info(
368
  f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
369
  )
370
  return result.single()
@@ -385,7 +372,7 @@ class GraphStorage(BaseGraphStorage):
385
  with self._driver.session() as session:
386
  #Define the Cypher query
387
  options = self.global_config["node2vec_params"]
388
- logger.info(f"building embeddings with options {options}")
389
  query = f"""CALL gds.node2vec.write('91fbae6c', {
390
  options
391
  })
 
5
  from typing import Any, Union, cast
6
  import numpy as np
7
  import inspect
 
8
  from lightrag.utils import load_json, logger, write_json
9
  from ..base import (
10
  BaseGraphStorage
 
21
 
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @dataclass
25
  class GraphStorage(BaseGraphStorage):
26
  @staticmethod
 
29
 
30
  def __post_init__(self):
31
  # self._graph = preloaded_graph or nx.Graph()
32
+ credetial_parts = ['URI', 'USERNAME','PASSWORD']
33
+ credentials_set = all(x in os.environ for x in credetial_parts )
34
+ if credentials_set:
35
+ URI = os.environ["URI"]
36
+ USERNAME = os.environ["USERNAME"]
37
+ PASSWORD = os.environ["PASSWORD"]
38
+ else:
39
+ raise Exception (f"One or more Neo4J Credentials, {credetial_parts}, not found in the environment")
40
+
41
  self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
42
  self._node_embed_algorithms = {
43
  "node2vec": self._node2vec_embed,
 
52
  query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
53
  result = tx.run(query)
54
  single_result = result.single()
55
+ logger.debug(
56
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
57
  )
58
 
 
77
  # if result.single() == None:
78
  # print (f"this should not happen: ---- {label1}/{label2} {query}")
79
 
80
+ logger.debug(
81
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
82
  )
83
 
 
98
  result = session.run(query)
99
  for record in result:
100
  result = record["n"]
101
+ logger.debug(
102
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
103
  )
104
  return result
 
120
  record = result.single()
121
  if record:
122
  edge_count = record["totalEdgeCount"]
123
+ logger.debug(
124
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
125
  )
126
  return edge_count
 
141
  RETURN count(r) AS degree"""
142
  result = session.run(query)
143
  record = result.single()
144
+ logger.debug(
145
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
146
  )
147
  return record["degree"]
 
170
  record = result.single()
171
  if record:
172
  result = dict(record["edge_properties"])
173
+ logger.debug(
174
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
175
  )
176
  return result
 
241
  # if source_label and target_label:
242
  # connections.append((source_label, target_label))
243
 
244
+ # logger.debug(
245
  # f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{connections}'
246
  # )
247
  # return connections
 
295
  result = tx.run(query, properties=properties)
296
  record = result.single()
297
  if record:
298
+ logger.debug(
299
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
300
  )
301
  return dict(record["n"])
 
351
  """
352
 
353
  result = tx.run(query, properties=edge_properties)
354
+ logger.debug(
355
  f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
356
  )
357
  return result.single()
 
372
  with self._driver.session() as session:
373
  #Define the Cypher query
374
  options = self.global_config["node2vec_params"]
375
+ logger.debug(f"building embeddings with options {options}")
376
  query = f"""CALL gds.node2vec.write('91fbae6c', {
377
  options
378
  })
lightrag/lightrag.py CHANGED
@@ -28,6 +28,13 @@ from .storage import (
28
  from .kg.neo4j_impl import (
29
  GraphStorage as Neo4JStorage
30
  )
 
 
 
 
 
 
 
31
 
32
  from .utils import (
33
  EmbeddingFunc,
@@ -64,7 +71,11 @@ class LightRAG:
64
  )
65
 
66
  kg: str = field(default="NetworkXStorage")
67
-
 
 
 
 
68
 
69
  # text chunking
70
  chunk_token_size: int = 1200
@@ -115,13 +126,14 @@ class LightRAG:
115
  def __post_init__(self):
116
  log_file = os.path.join(self.working_dir, "lightrag.log")
117
  set_logger(log_file)
 
 
118
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
119
 
120
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
121
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
122
 
123
  #should move all storage setup here to leverage initial start params attached to self.
124
- print (f"self.kg set to: {self.kg}")
125
  self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
126
 
127
  if not os.path.exists(self.working_dir):
@@ -176,7 +188,7 @@ class LightRAG:
176
  return {
177
  "Neo4JStorage": Neo4JStorage,
178
  "NetworkXStorage": NetworkXStorage,
179
- # "new_kg_here": KGClass
180
  }
181
 
182
  def insert(self, string_or_strings):
 
28
  from .kg.neo4j_impl import (
29
  GraphStorage as Neo4JStorage
30
  )
31
+ #future KG integrations
32
+
33
+ # from .kg.ArangoDB_impl import (
34
+ # GraphStorage as ArangoDBStorage
35
+ # )
36
+
37
+
38
 
39
  from .utils import (
40
  EmbeddingFunc,
 
71
  )
72
 
73
  kg: str = field(default="NetworkXStorage")
74
+
75
+ current_log_level = logger.level
76
+ log_level: str = field(default=current_log_level)
77
+
78
+
79
 
80
  # text chunking
81
  chunk_token_size: int = 1200
 
126
  def __post_init__(self):
127
  log_file = os.path.join(self.working_dir, "lightrag.log")
128
  set_logger(log_file)
129
+ logger.setLevel(self.log_level)
130
+
131
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
132
 
133
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
134
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
135
 
136
  #should move all storage setup here to leverage initial start params attached to self.
 
137
  self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
138
 
139
  if not os.path.exists(self.working_dir):
 
188
  return {
189
  "Neo4JStorage": Neo4JStorage,
190
  "NetworkXStorage": NetworkXStorage,
191
+ # "ArangoDBStorage": ArangoDBStorage
192
  }
193
 
194
  def insert(self, string_or_strings):
lightrag/operate.py CHANGED
@@ -71,7 +71,6 @@ async def _handle_entity_relation_summary(
71
  use_prompt = prompt_template.format(**context_base)
72
  logger.debug(f"Trigger summary: {entity_or_relation_name}")
73
  summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
74
- print ("Summarized: {context_base} for entity relationship {} ")
75
  return summary
76
 
77
 
@@ -79,7 +78,6 @@ async def _handle_single_entity_extraction(
79
  record_attributes: list[str],
80
  chunk_key: str,
81
  ):
82
- print (f"_handle_single_entity_extraction {record_attributes} chunk_key {chunk_key}")
83
  if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
84
  return None
85
  # add this record as a node in the G
@@ -265,7 +263,6 @@ async def extract_entities(
265
 
266
  async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
267
  nonlocal already_processed, already_entities, already_relations
268
- print (f"kw: processing a single chunk, {chunk_key_dp}")
269
  chunk_key = chunk_key_dp[0]
270
  chunk_dp = chunk_key_dp[1]
271
  content = chunk_dp["content"]
@@ -435,7 +432,6 @@ async def local_query(
435
  text_chunks_db,
436
  query_param,
437
  )
438
- print (f"got the following context {context} based on prompt keywords {keywords}")
439
  if query_param.only_need_context:
440
  return context
441
  if context is None:
@@ -444,7 +440,6 @@ async def local_query(
444
  sys_prompt = sys_prompt_temp.format(
445
  context_data=context, response_type=query_param.response_type
446
  )
447
- print (f"local query:{query} local sysprompt:{sys_prompt}")
448
  response = await use_model_func(
449
  query,
450
  system_prompt=sys_prompt,
@@ -470,20 +465,16 @@ async def _build_local_query_context(
470
  text_chunks_db: BaseKVStorage[TextChunkSchema],
471
  query_param: QueryParam,
472
  ):
473
- print ("kw1: ENTITIES VDB QUERY**********************************")
474
 
475
  results = await entities_vdb.query(query, top_k=query_param.top_k)
476
- print (f"kw2: ENTITIES VDB QUERY, RESULTS {results}**********************************")
477
 
478
  if not len(results):
479
  return None
480
- print ("kw3: using entities to get_nodes returned in above vdb query. search results from embedding your query keywords")
481
  node_datas = await asyncio.gather(
482
  *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
483
  )
484
  if not all([n is not None for n in node_datas]):
485
  logger.warning("Some nodes are missing, maybe the storage is damaged")
486
- print ("kw4: getting node degrees next for the same entities/nodes")
487
  node_degrees = await asyncio.gather(
488
  *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
489
  )
@@ -729,7 +720,6 @@ async def _build_global_query_context(
729
  text_chunks_db: BaseKVStorage[TextChunkSchema],
730
  query_param: QueryParam,
731
  ):
732
- print ("RELATIONSHIPS VDB QUERY**********************************")
733
  results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
734
 
735
  if not len(results):
@@ -895,14 +885,12 @@ async def hybrid_query(
895
  query_param: QueryParam,
896
  global_config: dict,
897
  ) -> str:
898
- print ("HYBRID QUERY *********")
899
  low_level_context = None
900
  high_level_context = None
901
  use_model_func = global_config["llm_model_func"]
902
 
903
  kw_prompt_temp = PROMPTS["keywords_extraction"]
904
  kw_prompt = kw_prompt_temp.format(query=query)
905
- print ( f"kw:kw_prompt: {kw_prompt}")
906
 
907
  result = await use_model_func(kw_prompt)
908
  try:
@@ -911,8 +899,6 @@ async def hybrid_query(
911
  ll_keywords = keywords_data.get("low_level_keywords", [])
912
  hl_keywords = ", ".join(hl_keywords)
913
  ll_keywords = ", ".join(ll_keywords)
914
- print (f"High level key words: {hl_keywords}")
915
- print (f"Low level key words: {ll_keywords}")
916
  except json.JSONDecodeError:
917
  try:
918
  result = (
@@ -942,7 +928,6 @@ async def hybrid_query(
942
  query_param,
943
  )
944
 
945
- print (f"low_level_context: {low_level_context}")
946
 
947
  if hl_keywords:
948
  high_level_context = await _build_global_query_context(
@@ -953,7 +938,6 @@ async def hybrid_query(
953
  text_chunks_db,
954
  query_param,
955
  )
956
- print (f"high_level_context: {high_level_context}")
957
 
958
 
959
  context = combine_contexts(high_level_context, low_level_context)
@@ -971,7 +955,6 @@ async def hybrid_query(
971
  query,
972
  system_prompt=sys_prompt,
973
  )
974
- print (f"kw: got system prompt: {sys_prompt}. got response for that prompt: {response}")
975
  if len(response) > len(sys_prompt):
976
  response = (
977
  response.replace(sys_prompt, "")
@@ -1065,12 +1048,10 @@ async def naive_query(
1065
  ):
1066
  use_model_func = global_config["llm_model_func"]
1067
  results = await chunks_vdb.query(query, top_k=query_param.top_k)
1068
- print (f"raw chunks from chunks_vdb.query {results}")
1069
  if not len(results):
1070
  return PROMPTS["fail_response"]
1071
  chunks_ids = [r["id"] for r in results]
1072
  chunks = await text_chunks_db.get_by_ids(chunks_ids)
1073
- print (f"raw chunks from text_chunks_db {chunks} retrieved by id using the above chunk ids from prev chunks_vdb ")
1074
 
1075
 
1076
  maybe_trun_chunks = truncate_list_by_token_size(
 
71
  use_prompt = prompt_template.format(**context_base)
72
  logger.debug(f"Trigger summary: {entity_or_relation_name}")
73
  summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
 
74
  return summary
75
 
76
 
 
78
  record_attributes: list[str],
79
  chunk_key: str,
80
  ):
 
81
  if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
82
  return None
83
  # add this record as a node in the G
 
263
 
264
  async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
265
  nonlocal already_processed, already_entities, already_relations
 
266
  chunk_key = chunk_key_dp[0]
267
  chunk_dp = chunk_key_dp[1]
268
  content = chunk_dp["content"]
 
432
  text_chunks_db,
433
  query_param,
434
  )
 
435
  if query_param.only_need_context:
436
  return context
437
  if context is None:
 
440
  sys_prompt = sys_prompt_temp.format(
441
  context_data=context, response_type=query_param.response_type
442
  )
 
443
  response = await use_model_func(
444
  query,
445
  system_prompt=sys_prompt,
 
465
  text_chunks_db: BaseKVStorage[TextChunkSchema],
466
  query_param: QueryParam,
467
  ):
 
468
 
469
  results = await entities_vdb.query(query, top_k=query_param.top_k)
 
470
 
471
  if not len(results):
472
  return None
 
473
  node_datas = await asyncio.gather(
474
  *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
475
  )
476
  if not all([n is not None for n in node_datas]):
477
  logger.warning("Some nodes are missing, maybe the storage is damaged")
 
478
  node_degrees = await asyncio.gather(
479
  *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
480
  )
 
720
  text_chunks_db: BaseKVStorage[TextChunkSchema],
721
  query_param: QueryParam,
722
  ):
 
723
  results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
724
 
725
  if not len(results):
 
885
  query_param: QueryParam,
886
  global_config: dict,
887
  ) -> str:
 
888
  low_level_context = None
889
  high_level_context = None
890
  use_model_func = global_config["llm_model_func"]
891
 
892
  kw_prompt_temp = PROMPTS["keywords_extraction"]
893
  kw_prompt = kw_prompt_temp.format(query=query)
 
894
 
895
  result = await use_model_func(kw_prompt)
896
  try:
 
899
  ll_keywords = keywords_data.get("low_level_keywords", [])
900
  hl_keywords = ", ".join(hl_keywords)
901
  ll_keywords = ", ".join(ll_keywords)
 
 
902
  except json.JSONDecodeError:
903
  try:
904
  result = (
 
928
  query_param,
929
  )
930
 
 
931
 
932
  if hl_keywords:
933
  high_level_context = await _build_global_query_context(
 
938
  text_chunks_db,
939
  query_param,
940
  )
 
941
 
942
 
943
  context = combine_contexts(high_level_context, low_level_context)
 
955
  query,
956
  system_prompt=sys_prompt,
957
  )
 
958
  if len(response) > len(sys_prompt):
959
  response = (
960
  response.replace(sys_prompt, "")
 
1048
  ):
1049
  use_model_func = global_config["llm_model_func"]
1050
  results = await chunks_vdb.query(query, top_k=query_param.top_k)
 
1051
  if not len(results):
1052
  return PROMPTS["fail_response"]
1053
  chunks_ids = [r["id"] for r in results]
1054
  chunks = await text_chunks_db.get_by_ids(chunks_ids)
 
1055
 
1056
 
1057
  maybe_trun_chunks = truncate_list_by_token_size(
testkg.py CHANGED
@@ -16,12 +16,13 @@ if not os.path.exists(WORKING_DIR):
16
  rag = LightRAG(
17
  working_dir=WORKING_DIR,
18
  llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
19
- kg="Neo4JStorage"
 
20
  # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
21
  )
22
 
23
- with open("./book.txt") as f:
24
- rag.insert(f.read())
25
 
26
  # Perform naive search
27
  print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
16
  rag = LightRAG(
17
  working_dir=WORKING_DIR,
18
  llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
19
+ kg="Neo4JStorage",
20
+ log_level="INFO"
21
  # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
22
  )
23
 
24
+ # with open("./book.txt") as f:
25
+ # rag.insert(f.read())
26
 
27
  # Perform naive search
28
  print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))