zrguo commited on
Commit
76a313b
·
2 Parent(s): 4a14cb4 2d1e6e5

Merge pull request #8 from TianyuFan0504/main

Browse files
Files changed (3) hide show
  1. lightrag/lightrag.py +5 -3
  2. lightrag/llm.py +82 -2
  3. lightrag/operate.py +60 -14
lightrag/lightrag.py CHANGED
@@ -5,7 +5,7 @@ from datetime import datetime
5
  from functools import partial
6
  from typing import Type, cast
7
 
8
- from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding
9
  from .operate import (
10
  chunking_by_token_size,
11
  extract_entities,
@@ -77,12 +77,14 @@ class LightRAG:
77
  )
78
 
79
  # text embedding
80
- embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
 
81
  embedding_batch_num: int = 32
82
  embedding_func_max_async: int = 16
83
 
84
  # LLM
85
- llm_model_func: callable = gpt_4o_mini_complete
 
86
  llm_model_max_token_size: int = 32768
87
  llm_model_max_async: int = 16
88
 
 
5
  from functools import partial
6
  from typing import Type, cast
7
 
8
+ from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
9
  from .operate import (
10
  chunking_by_token_size,
11
  extract_entities,
 
77
  )
78
 
79
  # text embedding
80
+ # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
81
+ embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
82
  embedding_batch_num: int = 32
83
  embedding_func_max_async: int = 16
84
 
85
  # LLM
86
+ llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
87
+ llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
88
  llm_model_max_token_size: int = 32768
89
  llm_model_max_async: int = 16
90
 
lightrag/llm.py CHANGED
@@ -7,10 +7,12 @@ from tenacity import (
7
  wait_exponential,
8
  retry_if_exception_type,
9
  )
10
-
 
11
  from .base import BaseKVStorage
12
  from .utils import compute_args_hash, wrap_embedding_func_with_attrs
13
-
 
14
  @retry(
15
  stop=stop_after_attempt(3),
16
  wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -42,6 +44,52 @@ async def openai_complete_if_cache(
42
  )
43
  return response.choices[0].message.content
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  async def gpt_4o_complete(
46
  prompt, system_prompt=None, history_messages=[], **kwargs
47
  ) -> str:
@@ -65,6 +113,20 @@ async def gpt_4o_mini_complete(
65
  **kwargs,
66
  )
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
69
  @retry(
70
  stop=stop_after_attempt(3),
@@ -78,6 +140,24 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
78
  )
79
  return np.array([dp.embedding for dp in response.data])
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  if __name__ == "__main__":
82
  import asyncio
83
 
 
7
  wait_exponential,
8
  retry_if_exception_type,
9
  )
10
+ from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
11
+ import torch
12
  from .base import BaseKVStorage
13
  from .utils import compute_args_hash, wrap_embedding_func_with_attrs
14
+ import copy
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
  @retry(
17
  stop=stop_after_attempt(3),
18
  wait=wait_exponential(multiplier=1, min=4, max=10),
 
44
  )
45
  return response.choices[0].message.content
46
 
47
+ async def hf_model_if_cache(
48
+ model, prompt, system_prompt=None, history_messages=[], **kwargs
49
+ ) -> str:
50
+ model_name = model
51
+ hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
52
+ if hf_tokenizer.pad_token == None:
53
+ # print("use eos token")
54
+ hf_tokenizer.pad_token = hf_tokenizer.eos_token
55
+ hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
56
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
57
+ messages = []
58
+ if system_prompt:
59
+ messages.append({"role": "system", "content": system_prompt})
60
+ messages.extend(history_messages)
61
+ messages.append({"role": "user", "content": prompt})
62
+
63
+ if hashing_kv is not None:
64
+ args_hash = compute_args_hash(model, messages)
65
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
66
+ if if_cache_return is not None:
67
+ return if_cache_return["return"]
68
+ input_prompt = ''
69
+ try:
70
+ input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
71
+ except:
72
+ try:
73
+ ori_message = copy.deepcopy(messages)
74
+ if messages[0]['role'] == "system":
75
+ messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
76
+ messages = messages[1:]
77
+ input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
78
+ except:
79
+ len_message = len(ori_message)
80
+ for msgid in range(len_message):
81
+ input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
82
+
83
+ input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
84
+ output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
85
+ response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
86
+ if hashing_kv is not None:
87
+ await hashing_kv.upsert(
88
+ {args_hash: {"return": response_text, "model": model}}
89
+ )
90
+ return response_text
91
+
92
+
93
  async def gpt_4o_complete(
94
  prompt, system_prompt=None, history_messages=[], **kwargs
95
  ) -> str:
 
113
  **kwargs,
114
  )
115
 
116
+
117
+
118
+ async def hf_model_complete(
119
+ prompt, system_prompt=None, history_messages=[], **kwargs
120
+ ) -> str:
121
+ input_string = kwargs['hashing_kv'].global_config['llm_model_name']
122
+ return await hf_model_if_cache(
123
+ input_string,
124
+ prompt,
125
+ system_prompt=system_prompt,
126
+ history_messages=history_messages,
127
+ **kwargs,
128
+ )
129
+
130
  @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
131
  @retry(
132
  stop=stop_after_attempt(3),
 
140
  )
141
  return np.array([dp.embedding for dp in response.data])
142
 
143
+
144
+
145
+ global EMBED_MODEL
146
+ global tokenizer
147
+ EMBED_MODEL = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
148
+ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
149
+ @wrap_embedding_func_with_attrs(
150
+ embedding_dim=384,
151
+ max_token_size=5000,
152
+ )
153
+ async def hf_embedding(texts: list[str]) -> np.ndarray:
154
+ input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
155
+ with torch.no_grad():
156
+ outputs = EMBED_MODEL(input_ids)
157
+ embeddings = outputs.last_hidden_state.mean(dim=1)
158
+ return embeddings.detach().numpy()
159
+
160
+
161
  if __name__ == "__main__":
162
  import asyncio
163
 
lightrag/operate.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  import re
4
  from typing import Union
5
  from collections import Counter, defaultdict
6
-
7
  from .utils import (
8
  logger,
9
  clean_str,
@@ -398,10 +398,15 @@ async def local_query(
398
  keywords = keywords_data.get("low_level_keywords", [])
399
  keywords = ', '.join(keywords)
400
  except json.JSONDecodeError as e:
 
 
 
 
 
401
  # Handle parsing error
402
- print(f"JSON parsing error: {e}")
403
- return PROMPTS["fail_response"]
404
-
405
  context = await _build_local_query_context(
406
  keywords,
407
  knowledge_graph_inst,
@@ -421,6 +426,9 @@ async def local_query(
421
  query,
422
  system_prompt=sys_prompt,
423
  )
 
 
 
424
  return response
425
 
426
  async def _build_local_query_context(
@@ -617,9 +625,16 @@ async def global_query(
617
  keywords = keywords_data.get("high_level_keywords", [])
618
  keywords = ', '.join(keywords)
619
  except json.JSONDecodeError as e:
620
- # Handle parsing error
621
- print(f"JSON parsing error: {e}")
622
- return PROMPTS["fail_response"]
 
 
 
 
 
 
 
623
 
624
  context = await _build_global_query_context(
625
  keywords,
@@ -643,6 +658,9 @@ async def global_query(
643
  query,
644
  system_prompt=sys_prompt,
645
  )
 
 
 
646
  return response
647
 
648
  async def _build_global_query_context(
@@ -822,8 +840,8 @@ async def hybird_query(
822
 
823
  kw_prompt_temp = PROMPTS["keywords_extraction"]
824
  kw_prompt = kw_prompt_temp.format(query=query)
 
825
  result = await use_model_func(kw_prompt)
826
-
827
  try:
828
  keywords_data = json.loads(result)
829
  hl_keywords = keywords_data.get("high_level_keywords", [])
@@ -831,10 +849,18 @@ async def hybird_query(
831
  hl_keywords = ', '.join(hl_keywords)
832
  ll_keywords = ', '.join(ll_keywords)
833
  except json.JSONDecodeError as e:
 
 
 
 
 
 
 
834
  # Handle parsing error
835
- print(f"JSON parsing error: {e}")
836
- return PROMPTS["fail_response"]
837
-
 
838
  low_level_context = await _build_local_query_context(
839
  ll_keywords,
840
  knowledge_graph_inst,
@@ -851,7 +877,7 @@ async def hybird_query(
851
  text_chunks_db,
852
  query_param,
853
  )
854
-
855
  context = combine_contexts(high_level_context, low_level_context)
856
 
857
  if query_param.only_need_context:
@@ -867,10 +893,13 @@ async def hybird_query(
867
  query,
868
  system_prompt=sys_prompt,
869
  )
 
 
870
  return response
871
 
872
  def combine_contexts(high_level_context, low_level_context):
873
  # Function to extract entities, relationships, and sources from context strings
 
874
  def extract_sections(context):
875
  entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
876
  relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
@@ -883,8 +912,21 @@ def combine_contexts(high_level_context, low_level_context):
883
  return entities, relationships, sources
884
 
885
  # Extract sections from both contexts
886
- hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
887
- ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
 
 
 
 
 
 
 
 
 
 
 
 
 
888
 
889
  # Combine and deduplicate the entities
890
  combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
@@ -940,5 +982,9 @@ async def naive_query(
940
  query,
941
  system_prompt=sys_prompt,
942
  )
 
 
 
 
943
  return response
944
 
 
3
  import re
4
  from typing import Union
5
  from collections import Counter, defaultdict
6
+ import warnings
7
  from .utils import (
8
  logger,
9
  clean_str,
 
398
  keywords = keywords_data.get("low_level_keywords", [])
399
  keywords = ', '.join(keywords)
400
  except json.JSONDecodeError as e:
401
+ try:
402
+ result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
403
+ keywords_data = json.loads(result)
404
+ keywords = keywords_data.get("low_level_keywords", [])
405
+ keywords = ', '.join(keywords)
406
  # Handle parsing error
407
+ except json.JSONDecodeError as e:
408
+ print(f"JSON parsing error: {e}")
409
+ return PROMPTS["fail_response"]
410
  context = await _build_local_query_context(
411
  keywords,
412
  knowledge_graph_inst,
 
426
  query,
427
  system_prompt=sys_prompt,
428
  )
429
+ if len(response)>len(sys_prompt):
430
+ response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
431
+
432
  return response
433
 
434
  async def _build_local_query_context(
 
625
  keywords = keywords_data.get("high_level_keywords", [])
626
  keywords = ', '.join(keywords)
627
  except json.JSONDecodeError as e:
628
+ try:
629
+ result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
630
+ keywords_data = json.loads(result)
631
+ keywords = keywords_data.get("high_level_keywords", [])
632
+ keywords = ', '.join(keywords)
633
+
634
+ except json.JSONDecodeError as e:
635
+ # Handle parsing error
636
+ print(f"JSON parsing error: {e}")
637
+ return PROMPTS["fail_response"]
638
 
639
  context = await _build_global_query_context(
640
  keywords,
 
658
  query,
659
  system_prompt=sys_prompt,
660
  )
661
+ if len(response)>len(sys_prompt):
662
+ response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
663
+
664
  return response
665
 
666
  async def _build_global_query_context(
 
840
 
841
  kw_prompt_temp = PROMPTS["keywords_extraction"]
842
  kw_prompt = kw_prompt_temp.format(query=query)
843
+
844
  result = await use_model_func(kw_prompt)
 
845
  try:
846
  keywords_data = json.loads(result)
847
  hl_keywords = keywords_data.get("high_level_keywords", [])
 
849
  hl_keywords = ', '.join(hl_keywords)
850
  ll_keywords = ', '.join(ll_keywords)
851
  except json.JSONDecodeError as e:
852
+ try:
853
+ result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
854
+ keywords_data = json.loads(result)
855
+ hl_keywords = keywords_data.get("high_level_keywords", [])
856
+ ll_keywords = keywords_data.get("low_level_keywords", [])
857
+ hl_keywords = ', '.join(hl_keywords)
858
+ ll_keywords = ', '.join(ll_keywords)
859
  # Handle parsing error
860
+ except json.JSONDecodeError as e:
861
+ print(f"JSON parsing error: {e}")
862
+ return PROMPTS["fail_response"]
863
+
864
  low_level_context = await _build_local_query_context(
865
  ll_keywords,
866
  knowledge_graph_inst,
 
877
  text_chunks_db,
878
  query_param,
879
  )
880
+
881
  context = combine_contexts(high_level_context, low_level_context)
882
 
883
  if query_param.only_need_context:
 
893
  query,
894
  system_prompt=sys_prompt,
895
  )
896
+ if len(response)>len(sys_prompt):
897
+ response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
898
  return response
899
 
900
  def combine_contexts(high_level_context, low_level_context):
901
  # Function to extract entities, relationships, and sources from context strings
902
+
903
  def extract_sections(context):
904
  entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
905
  relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
 
912
  return entities, relationships, sources
913
 
914
  # Extract sections from both contexts
915
+
916
+ if high_level_context==None:
917
+ warnings.warn("High Level context is None. Return empty High entity/relationship/source")
918
+ hl_entities, hl_relationships, hl_sources = '','',''
919
+ else:
920
+ hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
921
+
922
+
923
+ if low_level_context==None:
924
+ warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
925
+ ll_entities, ll_relationships, ll_sources = '','',''
926
+ else:
927
+ ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
928
+
929
+
930
 
931
  # Combine and deduplicate the entities
932
  combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
 
982
  query,
983
  system_prompt=sys_prompt,
984
  )
985
+
986
+ if len(response)>len(sys_prompt):
987
+ response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
988
+
989
  return response
990