yangdx commited on
Commit
353b669
·
2 Parent(s): 3cef7fa 44e2f81

Merge branch 'context_format_csv_to_json'

Browse files
Files changed (2) hide show
  1. lightrag/operate.py +19 -14
  2. lightrag/utils.py +29 -61
lightrag/operate.py CHANGED
@@ -14,7 +14,6 @@ from .utils import (
14
  compute_mdhash_id,
15
  Tokenizer,
16
  is_float_regex,
17
- list_of_list_to_csv,
18
  normalize_extracted_info,
19
  pack_user_ass_to_openai_messages,
20
  split_string_by_multi_markers,
@@ -26,6 +25,7 @@ from .utils import (
26
  CacheData,
27
  get_conversation_turns,
28
  use_llm_func_with_cache,
 
29
  )
30
  from .base import (
31
  BaseGraphStorage,
@@ -1333,21 +1333,26 @@ async def _build_query_context(
1333
  [hl_text_units_context, ll_text_units_context],
1334
  )
1335
  # not necessary to use LLM to generate a response
1336
- if not entities_context.strip() and not relations_context.strip():
1337
  return None
1338
 
 
 
 
 
 
1339
  result = f"""
1340
  -----Entities-----
1341
- ```csv
1342
- {entities_context}
1343
  ```
1344
  -----Relationships-----
1345
- ```csv
1346
- {relations_context}
1347
  ```
1348
  -----Sources-----
1349
- ```csv
1350
- {text_units_context}
1351
  ```
1352
  """.strip()
1353
  return result
@@ -1453,7 +1458,7 @@ async def _get_node_data(
1453
  file_path,
1454
  ]
1455
  )
1456
- entities_context = list_of_list_to_csv(entites_section_list)
1457
 
1458
  relations_section_list = [
1459
  [
@@ -1490,14 +1495,14 @@ async def _get_node_data(
1490
  file_path,
1491
  ]
1492
  )
1493
- relations_context = list_of_list_to_csv(relations_section_list)
1494
 
1495
  text_units_section_list = [["id", "content", "file_path"]]
1496
  for i, t in enumerate(use_text_units):
1497
  text_units_section_list.append(
1498
  [i, t["content"], t.get("file_path", "unknown_source")]
1499
  )
1500
- text_units_context = list_of_list_to_csv(text_units_section_list)
1501
  return entities_context, relations_context, text_units_context
1502
 
1503
 
@@ -1775,7 +1780,7 @@ async def _get_edge_data(
1775
  file_path,
1776
  ]
1777
  )
1778
- relations_context = list_of_list_to_csv(relations_section_list)
1779
 
1780
  entites_section_list = [
1781
  ["id", "entity", "type", "description", "rank", "created_at", "file_path"]
@@ -1800,12 +1805,12 @@ async def _get_edge_data(
1800
  file_path,
1801
  ]
1802
  )
1803
- entities_context = list_of_list_to_csv(entites_section_list)
1804
 
1805
  text_units_section_list = [["id", "content", "file_path"]]
1806
  for i, t in enumerate(use_text_units):
1807
  text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")])
1808
- text_units_context = list_of_list_to_csv(text_units_section_list)
1809
  return entities_context, relations_context, text_units_context
1810
 
1811
 
 
14
  compute_mdhash_id,
15
  Tokenizer,
16
  is_float_regex,
 
17
  normalize_extracted_info,
18
  pack_user_ass_to_openai_messages,
19
  split_string_by_multi_markers,
 
25
  CacheData,
26
  get_conversation_turns,
27
  use_llm_func_with_cache,
28
+ list_of_list_to_json,
29
  )
30
  from .base import (
31
  BaseGraphStorage,
 
1333
  [hl_text_units_context, ll_text_units_context],
1334
  )
1335
  # not necessary to use LLM to generate a response
1336
+ if not entities_context and not relations_context:
1337
  return None
1338
 
1339
+ # 转换为 JSON 字符串
1340
+ entities_str = json.dumps(entities_context, ensure_ascii=False)
1341
+ relations_str = json.dumps(relations_context, ensure_ascii=False)
1342
+ text_units_str = json.dumps(text_units_context, ensure_ascii=False)
1343
+
1344
  result = f"""
1345
  -----Entities-----
1346
+ ```json
1347
+ {entities_str}
1348
  ```
1349
  -----Relationships-----
1350
+ ```json
1351
+ {relations_str}
1352
  ```
1353
  -----Sources-----
1354
+ ```json
1355
+ {text_units_str}
1356
  ```
1357
  """.strip()
1358
  return result
 
1458
  file_path,
1459
  ]
1460
  )
1461
+ entities_context = list_of_list_to_json(entites_section_list)
1462
 
1463
  relations_section_list = [
1464
  [
 
1495
  file_path,
1496
  ]
1497
  )
1498
+ relations_context = list_of_list_to_json(relations_section_list)
1499
 
1500
  text_units_section_list = [["id", "content", "file_path"]]
1501
  for i, t in enumerate(use_text_units):
1502
  text_units_section_list.append(
1503
  [i, t["content"], t.get("file_path", "unknown_source")]
1504
  )
1505
+ text_units_context = list_of_list_to_json(text_units_section_list)
1506
  return entities_context, relations_context, text_units_context
1507
 
1508
 
 
1780
  file_path,
1781
  ]
1782
  )
1783
+ relations_context = list_of_list_to_json(relations_section_list)
1784
 
1785
  entites_section_list = [
1786
  ["id", "entity", "type", "description", "rank", "created_at", "file_path"]
 
1805
  file_path,
1806
  ]
1807
  )
1808
+ entities_context = list_of_list_to_json(entites_section_list)
1809
 
1810
  text_units_section_list = [["id", "content", "file_path"]]
1811
  for i, t in enumerate(use_text_units):
1812
  text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")])
1813
+ text_units_context = list_of_list_to_json(text_units_section_list)
1814
  return entities_context, relations_context, text_units_context
1815
 
1816
 
lightrag/utils.py CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
2
 
3
  import asyncio
4
  import html
5
- import io
6
  import csv
7
  import json
8
  import logging
@@ -442,37 +441,24 @@ def truncate_list_by_token_size(
442
  return list_data
443
 
444
 
445
- def list_of_list_to_csv(data: list[list[str]]) -> str:
446
- output = io.StringIO()
447
- writer = csv.writer(
448
- output,
449
- quoting=csv.QUOTE_ALL, # Quote all fields
450
- escapechar="\\", # Use backslash as escape character
451
- quotechar='"', # Use double quotes
452
- lineterminator="\n", # Explicit line terminator
453
- )
454
- writer.writerows(data)
455
- return output.getvalue()
456
-
457
 
458
- def csv_string_to_list(csv_string: str) -> list[list[str]]:
459
- # Clean the string by removing NUL characters
460
- cleaned_string = csv_string.replace("\0", "")
461
 
462
- output = io.StringIO(cleaned_string)
463
- reader = csv.reader(
464
- output,
465
- quoting=csv.QUOTE_ALL, # Match the writer configuration
466
- escapechar="\\", # Use backslash as escape character
467
- quotechar='"', # Use double quotes
468
- )
 
 
469
 
470
- try:
471
- return [row for row in reader]
472
- except csv.Error as e:
473
- raise ValueError(f"Failed to parse CSV string: {str(e)}")
474
- finally:
475
- output.close()
476
 
477
 
478
  def save_data_to_file(data, file_name):
@@ -540,41 +526,23 @@ def xml_to_json(xml_file):
540
  return None
541
 
542
 
543
- def process_combine_contexts(hl: str, ll: str):
544
- header = None
545
- list_hl = csv_string_to_list(hl.strip())
546
- list_ll = csv_string_to_list(ll.strip())
547
-
548
- if list_hl:
549
- header = list_hl[0]
550
- list_hl = list_hl[1:]
551
- if list_ll:
552
- header = list_ll[0]
553
- list_ll = list_ll[1:]
554
- if header is None:
555
- return ""
556
-
557
- if list_hl:
558
- list_hl = [",".join(item[1:]) for item in list_hl if item]
559
- if list_ll:
560
- list_ll = [",".join(item[1:]) for item in list_ll if item]
561
-
562
- combined_sources = []
563
- seen = set()
564
-
565
- for item in list_hl + list_ll:
566
- if item and item not in seen:
567
- combined_sources.append(item)
568
- seen.add(item)
569
-
570
- combined_sources_result = [",\t".join(header)]
571
 
572
- for i, item in enumerate(combined_sources, start=1):
573
- combined_sources_result.append(f"{i},\t{item}")
 
 
 
 
574
 
575
- combined_sources_result = "\n".join(combined_sources_result)
 
576
 
577
- return combined_sources_result
578
 
579
 
580
  async def get_best_cached_response(
 
2
 
3
  import asyncio
4
  import html
 
5
  import csv
6
  import json
7
  import logging
 
441
  return list_data
442
 
443
 
444
+ def list_of_list_to_json(data: list[list[str]]) -> list[dict[str, str]]:
445
+ if not data or len(data) <= 1:
446
+ return []
 
 
 
 
 
 
 
 
 
447
 
448
+ header = data[0]
449
+ result = []
 
450
 
451
+ for row in data[1:]:
452
+ if len(row) >= 2:
453
+ item = {}
454
+ for i, field_name in enumerate(header):
455
+ if i < len(row):
456
+ item[field_name] = str(row[i])
457
+ else:
458
+ item[field_name] = ""
459
+ result.append(item)
460
 
461
+ return result
 
 
 
 
 
462
 
463
 
464
  def save_data_to_file(data, file_name):
 
526
  return None
527
 
528
 
529
+ def process_combine_contexts(
530
+ hl_context: list[dict[str, str]], ll_context: list[dict[str, str]]
531
+ ):
532
+ seen_content = {}
533
+ combined_data = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
+ for item in hl_context + ll_context:
536
+ content_dict = {k: v for k, v in item.items() if k != "id"}
537
+ content_key = tuple(sorted(content_dict.items()))
538
+ if content_key not in seen_content:
539
+ seen_content[content_key] = item
540
+ combined_data.append(item)
541
 
542
+ for i, item in enumerate(combined_data):
543
+ item["id"] = str(i)
544
 
545
+ return combined_data
546
 
547
 
548
  async def get_best_cached_response(