eko9989 commited on
Commit
e7ef3b9
·
1 Parent(s): 9518296

Update utils.py

Browse files
Files changed (1) hide show
  1. lightrag/utils.py +26 -14
lightrag/utils.py CHANGED
@@ -1,5 +1,7 @@
1
  import asyncio
2
  import html
 
 
3
  import json
4
  import logging
5
  import os
@@ -7,7 +9,7 @@ import re
7
  from dataclasses import dataclass
8
  from functools import wraps
9
  from hashlib import md5
10
- from typing import Any, Union
11
  import xml.etree.ElementTree as ET
12
 
13
  import numpy as np
@@ -175,10 +177,21 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
175
  return list_data
176
 
177
 
178
- def list_of_list_to_csv(data: list[list]):
179
- return "\n".join(
180
- [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
181
- )
 
 
 
 
 
 
 
 
 
 
 
182
 
183
 
184
  def save_data_to_file(data, file_name):
@@ -248,8 +261,8 @@ def xml_to_json(xml_file):
248
  #混合检索中的合并函数
249
  def process_combine_contexts(hl, ll):
250
  header = None
251
- list_hl = hl.strip().split("\n")
252
- list_ll = ll.strip().split("\n")
253
  # 去掉第一个元素(如果不为空)
254
  if list_hl:
255
  header=list_hl[0]
@@ -259,12 +272,11 @@ def process_combine_contexts(hl, ll):
259
  list_ll = list_ll[1:]
260
  if header is None:
261
  return ""
262
-
263
- # 去掉每个子元素中逗号分隔后的第一个元素(如果不为空)
264
  if list_hl:
265
- list_hl = [','.join(item.split(',')[1:]) for item in list_hl if item]
266
  if list_ll:
267
- list_ll = [','.join(item.split(',')[1:]) for item in list_ll if item]
268
 
269
  # 合并并去重
270
  combined_sources_set = set(
@@ -272,12 +284,12 @@ def process_combine_contexts(hl, ll):
272
  )
273
 
274
  # 创建包含头部的新列表
275
- combined_sources = [header]
276
  # 为 combined_sources_set 中的每个元素添加自增数字
277
  for i, item in enumerate(combined_sources_set, start=1):
278
  combined_sources.append(f"{i},\t{item}")
279
-
280
- # 将列表转换为字符串,元素之间用换行符分隔
281
  combined_sources = "\n".join(combined_sources)
282
 
283
  return combined_sources
 
1
  import asyncio
2
  import html
3
+ import io
4
+ import csv
5
  import json
6
  import logging
7
  import os
 
9
  from dataclasses import dataclass
10
  from functools import wraps
11
  from hashlib import md5
12
+ from typing import Any, Union,List
13
  import xml.etree.ElementTree as ET
14
 
15
  import numpy as np
 
177
  return list_data
178
 
179
 
180
+ # def list_of_list_to_csv(data: list[list]):
181
+ # return "\n".join(
182
+ # [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
183
+ # )
184
+ def list_of_list_to_csv(data: List[List[str]]) -> str:
185
+ output = io.StringIO()
186
+ writer = csv.writer(output)
187
+ writer.writerows(data)
188
+ return output.getvalue()
189
+ def csv_string_to_list(csv_string: str) -> List[List[str]]:
190
+ output = io.StringIO(csv_string)
191
+ reader = csv.reader(output)
192
+ return [row for row in reader]
193
+
194
+
195
 
196
 
197
  def save_data_to_file(data, file_name):
 
261
  #混合检索中的合并函数
262
  def process_combine_contexts(hl, ll):
263
  header = None
264
+ list_hl = csv_string_to_list(hl.strip())
265
+ list_ll = csv_string_to_list(ll.strip())
266
  # 去掉第一个元素(如果不为空)
267
  if list_hl:
268
  header=list_hl[0]
 
272
  list_ll = list_ll[1:]
273
  if header is None:
274
  return ""
275
+ # 去掉每个子元素中的第一个元素(如果不为空),再转为一维数组,用于合并去重
 
276
  if list_hl:
277
+ list_hl = [','.join(item[1:]) for item in list_hl if item]
278
  if list_ll:
279
+ list_ll = [','.join(item[1:]) for item in list_ll if item]
280
 
281
  # 合并并去重
282
  combined_sources_set = set(
 
284
  )
285
 
286
  # 创建包含头部的新列表
287
+ combined_sources = [",\t".join(header)]
288
  # 为 combined_sources_set 中的每个元素添加自增数字
289
  for i, item in enumerate(combined_sources_set, start=1):
290
  combined_sources.append(f"{i},\t{item}")
291
+
292
+ # 将列表转换为字符串,子元素之间用换行符分隔
293
  combined_sources = "\n".join(combined_sources)
294
 
295
  return combined_sources