Update utils.py
Browse files- lightrag/utils.py +26 -14
lightrag/utils.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import asyncio
|
2 |
import html
|
|
|
|
|
3 |
import json
|
4 |
import logging
|
5 |
import os
|
@@ -7,7 +9,7 @@ import re
|
|
7 |
from dataclasses import dataclass
|
8 |
from functools import wraps
|
9 |
from hashlib import md5
|
10 |
-
from typing import Any, Union
|
11 |
import xml.etree.ElementTree as ET
|
12 |
|
13 |
import numpy as np
|
@@ -175,10 +177,21 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
|
175 |
return list_data
|
176 |
|
177 |
|
178 |
-
def list_of_list_to_csv(data: list[list]):
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
|
184 |
def save_data_to_file(data, file_name):
|
@@ -248,8 +261,8 @@ def xml_to_json(xml_file):
|
|
248 |
#混合检索中的合并函数
|
249 |
def process_combine_contexts(hl, ll):
|
250 |
header = None
|
251 |
-
list_hl = hl.strip()
|
252 |
-
list_ll = ll.strip()
|
253 |
# 去掉第一个元素(如果不为空)
|
254 |
if list_hl:
|
255 |
header=list_hl[0]
|
@@ -259,12 +272,11 @@ def process_combine_contexts(hl, ll):
|
|
259 |
list_ll = list_ll[1:]
|
260 |
if header is None:
|
261 |
return ""
|
262 |
-
|
263 |
-
# 去掉每个子元素中逗号分隔后的第一个元素(如果不为空)
|
264 |
if list_hl:
|
265 |
-
list_hl = [','.join(item
|
266 |
if list_ll:
|
267 |
-
list_ll = [','.join(item
|
268 |
|
269 |
# 合并并去重
|
270 |
combined_sources_set = set(
|
@@ -272,12 +284,12 @@ def process_combine_contexts(hl, ll):
|
|
272 |
)
|
273 |
|
274 |
# 创建包含头部的新列表
|
275 |
-
combined_sources = [header]
|
276 |
# 为 combined_sources_set 中的每个元素添加自增数字
|
277 |
for i, item in enumerate(combined_sources_set, start=1):
|
278 |
combined_sources.append(f"{i},\t{item}")
|
279 |
-
|
280 |
-
#
|
281 |
combined_sources = "\n".join(combined_sources)
|
282 |
|
283 |
return combined_sources
|
|
|
1 |
import asyncio
|
2 |
import html
|
3 |
+
import io
|
4 |
+
import csv
|
5 |
import json
|
6 |
import logging
|
7 |
import os
|
|
|
9 |
from dataclasses import dataclass
|
10 |
from functools import wraps
|
11 |
from hashlib import md5
|
12 |
+
from typing import Any, Union,List
|
13 |
import xml.etree.ElementTree as ET
|
14 |
|
15 |
import numpy as np
|
|
|
177 |
return list_data
|
178 |
|
179 |
|
180 |
+
# def list_of_list_to_csv(data: list[list]):
|
181 |
+
# return "\n".join(
|
182 |
+
# [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
183 |
+
# )
|
184 |
+
def list_of_list_to_csv(data: List[List[str]]) -> str:
|
185 |
+
output = io.StringIO()
|
186 |
+
writer = csv.writer(output)
|
187 |
+
writer.writerows(data)
|
188 |
+
return output.getvalue()
|
189 |
+
def csv_string_to_list(csv_string: str) -> List[List[str]]:
|
190 |
+
output = io.StringIO(csv_string)
|
191 |
+
reader = csv.reader(output)
|
192 |
+
return [row for row in reader]
|
193 |
+
|
194 |
+
|
195 |
|
196 |
|
197 |
def save_data_to_file(data, file_name):
|
|
|
261 |
#混合检索中的合并函数
|
262 |
def process_combine_contexts(hl, ll):
|
263 |
header = None
|
264 |
+
list_hl = csv_string_to_list(hl.strip())
|
265 |
+
list_ll = csv_string_to_list(ll.strip())
|
266 |
# 去掉第一个元素(如果不为空)
|
267 |
if list_hl:
|
268 |
header=list_hl[0]
|
|
|
272 |
list_ll = list_ll[1:]
|
273 |
if header is None:
|
274 |
return ""
|
275 |
+
# 去掉每个子元素中的第一个元素(如果不为空),再转为一维数组,用于合并去重
|
|
|
276 |
if list_hl:
|
277 |
+
list_hl = [','.join(item[1:]) for item in list_hl if item]
|
278 |
if list_ll:
|
279 |
+
list_ll = [','.join(item[1:]) for item in list_ll if item]
|
280 |
|
281 |
# 合并并去重
|
282 |
combined_sources_set = set(
|
|
|
284 |
)
|
285 |
|
286 |
# 创建包含头部的新列表
|
287 |
+
combined_sources = [",\t".join(header)]
|
288 |
# 为 combined_sources_set 中的每个元素添加自增数字
|
289 |
for i, item in enumerate(combined_sources_set, start=1):
|
290 |
combined_sources.append(f"{i},\t{item}")
|
291 |
+
|
292 |
+
# 将列表转换为字符串,子元素之间用换行符分隔
|
293 |
combined_sources = "\n".join(combined_sources)
|
294 |
|
295 |
return combined_sources
|