Merge pull request #8 from TianyuFan0504/main
Browse files- lightrag/lightrag.py +5 -3
- lightrag/llm.py +82 -2
- lightrag/operate.py +60 -14
lightrag/lightrag.py
CHANGED
@@ -5,7 +5,7 @@ from datetime import datetime
|
|
5 |
from functools import partial
|
6 |
from typing import Type, cast
|
7 |
|
8 |
-
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding
|
9 |
from .operate import (
|
10 |
chunking_by_token_size,
|
11 |
extract_entities,
|
@@ -77,12 +77,14 @@ class LightRAG:
|
|
77 |
)
|
78 |
|
79 |
# text embedding
|
80 |
-
embedding_func: EmbeddingFunc = field(default_factory=lambda:
|
|
|
81 |
embedding_batch_num: int = 32
|
82 |
embedding_func_max_async: int = 16
|
83 |
|
84 |
# LLM
|
85 |
-
llm_model_func: callable = gpt_4o_mini_complete
|
|
|
86 |
llm_model_max_token_size: int = 32768
|
87 |
llm_model_max_async: int = 16
|
88 |
|
|
|
5 |
from functools import partial
|
6 |
from typing import Type, cast
|
7 |
|
8 |
+
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
|
9 |
from .operate import (
|
10 |
chunking_by_token_size,
|
11 |
extract_entities,
|
|
|
77 |
)
|
78 |
|
79 |
# text embedding
|
80 |
+
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
81 |
+
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
|
82 |
embedding_batch_num: int = 32
|
83 |
embedding_func_max_async: int = 16
|
84 |
|
85 |
# LLM
|
86 |
+
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
|
87 |
+
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
88 |
llm_model_max_token_size: int = 32768
|
89 |
llm_model_max_async: int = 16
|
90 |
|
lightrag/llm.py
CHANGED
@@ -7,10 +7,12 @@ from tenacity import (
|
|
7 |
wait_exponential,
|
8 |
retry_if_exception_type,
|
9 |
)
|
10 |
-
|
|
|
11 |
from .base import BaseKVStorage
|
12 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
13 |
-
|
|
|
14 |
@retry(
|
15 |
stop=stop_after_attempt(3),
|
16 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
@@ -42,6 +44,52 @@ async def openai_complete_if_cache(
|
|
42 |
)
|
43 |
return response.choices[0].message.content
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
async def gpt_4o_complete(
|
46 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
47 |
) -> str:
|
@@ -65,6 +113,20 @@ async def gpt_4o_mini_complete(
|
|
65 |
**kwargs,
|
66 |
)
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
69 |
@retry(
|
70 |
stop=stop_after_attempt(3),
|
@@ -78,6 +140,24 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
|
|
78 |
)
|
79 |
return np.array([dp.embedding for dp in response.data])
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
if __name__ == "__main__":
|
82 |
import asyncio
|
83 |
|
|
|
7 |
wait_exponential,
|
8 |
retry_if_exception_type,
|
9 |
)
|
10 |
+
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
11 |
+
import torch
|
12 |
from .base import BaseKVStorage
|
13 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
14 |
+
import copy
|
15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
16 |
@retry(
|
17 |
stop=stop_after_attempt(3),
|
18 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
|
44 |
)
|
45 |
return response.choices[0].message.content
|
46 |
|
47 |
+
async def hf_model_if_cache(
|
48 |
+
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
49 |
+
) -> str:
|
50 |
+
model_name = model
|
51 |
+
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
|
52 |
+
if hf_tokenizer.pad_token == None:
|
53 |
+
# print("use eos token")
|
54 |
+
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
55 |
+
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
|
56 |
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
57 |
+
messages = []
|
58 |
+
if system_prompt:
|
59 |
+
messages.append({"role": "system", "content": system_prompt})
|
60 |
+
messages.extend(history_messages)
|
61 |
+
messages.append({"role": "user", "content": prompt})
|
62 |
+
|
63 |
+
if hashing_kv is not None:
|
64 |
+
args_hash = compute_args_hash(model, messages)
|
65 |
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
66 |
+
if if_cache_return is not None:
|
67 |
+
return if_cache_return["return"]
|
68 |
+
input_prompt = ''
|
69 |
+
try:
|
70 |
+
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
71 |
+
except:
|
72 |
+
try:
|
73 |
+
ori_message = copy.deepcopy(messages)
|
74 |
+
if messages[0]['role'] == "system":
|
75 |
+
messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
|
76 |
+
messages = messages[1:]
|
77 |
+
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
78 |
+
except:
|
79 |
+
len_message = len(ori_message)
|
80 |
+
for msgid in range(len_message):
|
81 |
+
input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
|
82 |
+
|
83 |
+
input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
|
84 |
+
output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
|
85 |
+
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
86 |
+
if hashing_kv is not None:
|
87 |
+
await hashing_kv.upsert(
|
88 |
+
{args_hash: {"return": response_text, "model": model}}
|
89 |
+
)
|
90 |
+
return response_text
|
91 |
+
|
92 |
+
|
93 |
async def gpt_4o_complete(
|
94 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
95 |
) -> str:
|
|
|
113 |
**kwargs,
|
114 |
)
|
115 |
|
116 |
+
|
117 |
+
|
118 |
+
async def hf_model_complete(
|
119 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
120 |
+
) -> str:
|
121 |
+
input_string = kwargs['hashing_kv'].global_config['llm_model_name']
|
122 |
+
return await hf_model_if_cache(
|
123 |
+
input_string,
|
124 |
+
prompt,
|
125 |
+
system_prompt=system_prompt,
|
126 |
+
history_messages=history_messages,
|
127 |
+
**kwargs,
|
128 |
+
)
|
129 |
+
|
130 |
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
131 |
@retry(
|
132 |
stop=stop_after_attempt(3),
|
|
|
140 |
)
|
141 |
return np.array([dp.embedding for dp in response.data])
|
142 |
|
143 |
+
|
144 |
+
|
145 |
+
global EMBED_MODEL
|
146 |
+
global tokenizer
|
147 |
+
EMBED_MODEL = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
148 |
+
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
149 |
+
@wrap_embedding_func_with_attrs(
|
150 |
+
embedding_dim=384,
|
151 |
+
max_token_size=5000,
|
152 |
+
)
|
153 |
+
async def hf_embedding(texts: list[str]) -> np.ndarray:
|
154 |
+
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
|
155 |
+
with torch.no_grad():
|
156 |
+
outputs = EMBED_MODEL(input_ids)
|
157 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
158 |
+
return embeddings.detach().numpy()
|
159 |
+
|
160 |
+
|
161 |
if __name__ == "__main__":
|
162 |
import asyncio
|
163 |
|
lightrag/operate.py
CHANGED
@@ -3,7 +3,7 @@ import json
|
|
3 |
import re
|
4 |
from typing import Union
|
5 |
from collections import Counter, defaultdict
|
6 |
-
|
7 |
from .utils import (
|
8 |
logger,
|
9 |
clean_str,
|
@@ -398,10 +398,15 @@ async def local_query(
|
|
398 |
keywords = keywords_data.get("low_level_keywords", [])
|
399 |
keywords = ', '.join(keywords)
|
400 |
except json.JSONDecodeError as e:
|
|
|
|
|
|
|
|
|
|
|
401 |
# Handle parsing error
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
context = await _build_local_query_context(
|
406 |
keywords,
|
407 |
knowledge_graph_inst,
|
@@ -421,6 +426,9 @@ async def local_query(
|
|
421 |
query,
|
422 |
system_prompt=sys_prompt,
|
423 |
)
|
|
|
|
|
|
|
424 |
return response
|
425 |
|
426 |
async def _build_local_query_context(
|
@@ -617,9 +625,16 @@ async def global_query(
|
|
617 |
keywords = keywords_data.get("high_level_keywords", [])
|
618 |
keywords = ', '.join(keywords)
|
619 |
except json.JSONDecodeError as e:
|
620 |
-
|
621 |
-
|
622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
|
624 |
context = await _build_global_query_context(
|
625 |
keywords,
|
@@ -643,6 +658,9 @@ async def global_query(
|
|
643 |
query,
|
644 |
system_prompt=sys_prompt,
|
645 |
)
|
|
|
|
|
|
|
646 |
return response
|
647 |
|
648 |
async def _build_global_query_context(
|
@@ -822,8 +840,8 @@ async def hybird_query(
|
|
822 |
|
823 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
824 |
kw_prompt = kw_prompt_temp.format(query=query)
|
|
|
825 |
result = await use_model_func(kw_prompt)
|
826 |
-
|
827 |
try:
|
828 |
keywords_data = json.loads(result)
|
829 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
@@ -831,10 +849,18 @@ async def hybird_query(
|
|
831 |
hl_keywords = ', '.join(hl_keywords)
|
832 |
ll_keywords = ', '.join(ll_keywords)
|
833 |
except json.JSONDecodeError as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
834 |
# Handle parsing error
|
835 |
-
|
836 |
-
|
837 |
-
|
|
|
838 |
low_level_context = await _build_local_query_context(
|
839 |
ll_keywords,
|
840 |
knowledge_graph_inst,
|
@@ -851,7 +877,7 @@ async def hybird_query(
|
|
851 |
text_chunks_db,
|
852 |
query_param,
|
853 |
)
|
854 |
-
|
855 |
context = combine_contexts(high_level_context, low_level_context)
|
856 |
|
857 |
if query_param.only_need_context:
|
@@ -867,10 +893,13 @@ async def hybird_query(
|
|
867 |
query,
|
868 |
system_prompt=sys_prompt,
|
869 |
)
|
|
|
|
|
870 |
return response
|
871 |
|
872 |
def combine_contexts(high_level_context, low_level_context):
|
873 |
# Function to extract entities, relationships, and sources from context strings
|
|
|
874 |
def extract_sections(context):
|
875 |
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
876 |
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
@@ -883,8 +912,21 @@ def combine_contexts(high_level_context, low_level_context):
|
|
883 |
return entities, relationships, sources
|
884 |
|
885 |
# Extract sections from both contexts
|
886 |
-
|
887 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
888 |
|
889 |
# Combine and deduplicate the entities
|
890 |
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
@@ -940,5 +982,9 @@ async def naive_query(
|
|
940 |
query,
|
941 |
system_prompt=sys_prompt,
|
942 |
)
|
|
|
|
|
|
|
|
|
943 |
return response
|
944 |
|
|
|
3 |
import re
|
4 |
from typing import Union
|
5 |
from collections import Counter, defaultdict
|
6 |
+
import warnings
|
7 |
from .utils import (
|
8 |
logger,
|
9 |
clean_str,
|
|
|
398 |
keywords = keywords_data.get("low_level_keywords", [])
|
399 |
keywords = ', '.join(keywords)
|
400 |
except json.JSONDecodeError as e:
|
401 |
+
try:
|
402 |
+
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
403 |
+
keywords_data = json.loads(result)
|
404 |
+
keywords = keywords_data.get("low_level_keywords", [])
|
405 |
+
keywords = ', '.join(keywords)
|
406 |
# Handle parsing error
|
407 |
+
except json.JSONDecodeError as e:
|
408 |
+
print(f"JSON parsing error: {e}")
|
409 |
+
return PROMPTS["fail_response"]
|
410 |
context = await _build_local_query_context(
|
411 |
keywords,
|
412 |
knowledge_graph_inst,
|
|
|
426 |
query,
|
427 |
system_prompt=sys_prompt,
|
428 |
)
|
429 |
+
if len(response)>len(sys_prompt):
|
430 |
+
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
431 |
+
|
432 |
return response
|
433 |
|
434 |
async def _build_local_query_context(
|
|
|
625 |
keywords = keywords_data.get("high_level_keywords", [])
|
626 |
keywords = ', '.join(keywords)
|
627 |
except json.JSONDecodeError as e:
|
628 |
+
try:
|
629 |
+
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
630 |
+
keywords_data = json.loads(result)
|
631 |
+
keywords = keywords_data.get("high_level_keywords", [])
|
632 |
+
keywords = ', '.join(keywords)
|
633 |
+
|
634 |
+
except json.JSONDecodeError as e:
|
635 |
+
# Handle parsing error
|
636 |
+
print(f"JSON parsing error: {e}")
|
637 |
+
return PROMPTS["fail_response"]
|
638 |
|
639 |
context = await _build_global_query_context(
|
640 |
keywords,
|
|
|
658 |
query,
|
659 |
system_prompt=sys_prompt,
|
660 |
)
|
661 |
+
if len(response)>len(sys_prompt):
|
662 |
+
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
663 |
+
|
664 |
return response
|
665 |
|
666 |
async def _build_global_query_context(
|
|
|
840 |
|
841 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
842 |
kw_prompt = kw_prompt_temp.format(query=query)
|
843 |
+
|
844 |
result = await use_model_func(kw_prompt)
|
|
|
845 |
try:
|
846 |
keywords_data = json.loads(result)
|
847 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
|
|
849 |
hl_keywords = ', '.join(hl_keywords)
|
850 |
ll_keywords = ', '.join(ll_keywords)
|
851 |
except json.JSONDecodeError as e:
|
852 |
+
try:
|
853 |
+
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
854 |
+
keywords_data = json.loads(result)
|
855 |
+
hl_keywords = keywords_data.get("high_level_keywords", [])
|
856 |
+
ll_keywords = keywords_data.get("low_level_keywords", [])
|
857 |
+
hl_keywords = ', '.join(hl_keywords)
|
858 |
+
ll_keywords = ', '.join(ll_keywords)
|
859 |
# Handle parsing error
|
860 |
+
except json.JSONDecodeError as e:
|
861 |
+
print(f"JSON parsing error: {e}")
|
862 |
+
return PROMPTS["fail_response"]
|
863 |
+
|
864 |
low_level_context = await _build_local_query_context(
|
865 |
ll_keywords,
|
866 |
knowledge_graph_inst,
|
|
|
877 |
text_chunks_db,
|
878 |
query_param,
|
879 |
)
|
880 |
+
|
881 |
context = combine_contexts(high_level_context, low_level_context)
|
882 |
|
883 |
if query_param.only_need_context:
|
|
|
893 |
query,
|
894 |
system_prompt=sys_prompt,
|
895 |
)
|
896 |
+
if len(response)>len(sys_prompt):
|
897 |
+
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
898 |
return response
|
899 |
|
900 |
def combine_contexts(high_level_context, low_level_context):
|
901 |
# Function to extract entities, relationships, and sources from context strings
|
902 |
+
|
903 |
def extract_sections(context):
|
904 |
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
905 |
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
|
|
912 |
return entities, relationships, sources
|
913 |
|
914 |
# Extract sections from both contexts
|
915 |
+
|
916 |
+
if high_level_context==None:
|
917 |
+
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
918 |
+
hl_entities, hl_relationships, hl_sources = '','',''
|
919 |
+
else:
|
920 |
+
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
921 |
+
|
922 |
+
|
923 |
+
if low_level_context==None:
|
924 |
+
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
|
925 |
+
ll_entities, ll_relationships, ll_sources = '','',''
|
926 |
+
else:
|
927 |
+
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
928 |
+
|
929 |
+
|
930 |
|
931 |
# Combine and deduplicate the entities
|
932 |
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
|
|
982 |
query,
|
983 |
system_prompt=sys_prompt,
|
984 |
)
|
985 |
+
|
986 |
+
if len(response)>len(sys_prompt):
|
987 |
+
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
988 |
+
|
989 |
return response
|
990 |
|