LarFii commited on
Commit
4460ba5
·
1 Parent(s): 76a313b

Add huggingface model support

Browse files
README.md CHANGED
@@ -59,8 +59,8 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
59
  # Perform global search
60
  print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
61
 
62
- # Perform hybird search
63
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybird")))
64
  ```
65
  Batch Insert
66
  ```python
@@ -287,8 +287,8 @@ def extract_queries(file_path):
287
  ├── examples
288
  │ ├── batch_eval.py
289
  │ ├── generate_query.py
290
- │ ├── insert.py
291
- │ └── query.py
292
  ├── lightrag
293
  │ ├── __init__.py
294
  │ ├── base.py
 
59
  # Perform global search
60
  print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
61
 
62
+ # Perform hybrid search
63
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
64
  ```
65
  Batch Insert
66
  ```python
 
287
  ├── examples
288
  │ ├── batch_eval.py
289
  │ ├── generate_query.py
290
+ │ ├── lightrag_openai_demo.py
291
+ │ └── lightrag_hf_demo.py
292
  ├── lightrag
293
  │ ├── __init__.py
294
  │ ├── base.py
examples/insert.py DELETED
@@ -1,18 +0,0 @@
1
- import os
2
- import sys
3
-
4
- from lightrag import LightRAG
5
-
6
- # os.environ["OPENAI_API_KEY"] = ""
7
-
8
- WORKING_DIR = ""
9
-
10
- if not os.path.exists(WORKING_DIR):
11
- os.mkdir(WORKING_DIR)
12
-
13
- rag = LightRAG(working_dir=WORKING_DIR)
14
-
15
- with open('./text.txt', 'r') as f:
16
- text = f.read()
17
-
18
- rag.insert(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_hf_demo.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from lightrag import LightRAG, QueryParam
5
+ from lightrag.llm import hf_model_complete, hf_embedding
6
+ from transformers import AutoModel,AutoTokenizer
7
+
8
+ WORKING_DIR = "./dickens"
9
+
10
+ if not os.path.exists(WORKING_DIR):
11
+ os.mkdir(WORKING_DIR)
12
+
13
+ rag = LightRAG(
14
+ working_dir=WORKING_DIR,
15
+ llm_model_func=hf_model_complete,
16
+ llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
17
+ embedding_func=hf_embedding,
18
+ tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
19
+ embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
20
+ )
21
+
22
+
23
+ with open("./book.txt") as f:
24
+ rag.insert(f.read())
25
+
26
+ # Perform naive search
27
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
28
+
29
+ # Perform local search
30
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
31
+
32
+ # Perform global search
33
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
34
+
35
+ # Perform hybrid search
36
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
examples/lightrag_openai_demo.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from lightrag import LightRAG, QueryParam
5
+ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
6
+ from transformers import AutoModel,AutoTokenizer
7
+
8
+ WORKING_DIR = "./dickens"
9
+
10
+ if not os.path.exists(WORKING_DIR):
11
+ os.mkdir(WORKING_DIR)
12
+
13
+ rag = LightRAG(
14
+ working_dir=WORKING_DIR,
15
+ llm_model_func=gpt_4o_complete
16
+ # llm_model_func=gpt_4o_mini_complete
17
+ )
18
+
19
+
20
+ with open("./book.txt") as f:
21
+ rag.insert(f.read())
22
+
23
+ # Perform naive search
24
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
25
+
26
+ # Perform local search
27
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
28
+
29
+ # Perform global search
30
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
31
+
32
+ # Perform hybrid search
33
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
examples/query.py DELETED
@@ -1,16 +0,0 @@
1
- import os
2
- import sys
3
-
4
- from lightrag import LightRAG, QueryParam
5
-
6
- # os.environ["OPENAI_API_KEY"] = ""
7
-
8
- WORKING_DIR = ""
9
-
10
- rag = LightRAG(working_dir=WORKING_DIR)
11
-
12
- mode = 'global'
13
- query_param = QueryParam(mode=mode)
14
-
15
- result = rag.query("", param=query_param)
16
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .lightrag import LightRAG, QueryParam
2
 
3
- __version__ = "0.0.3"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
 
1
  from .lightrag import LightRAG, QueryParam
2
 
3
+ __version__ = "0.0.4"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
lightrag/base.py CHANGED
@@ -14,7 +14,7 @@ T = TypeVar("T")
14
 
15
  @dataclass
16
  class QueryParam:
17
- mode: Literal["local", "global", "hybird", "naive"] = "global"
18
  only_need_context: bool = False
19
  response_type: str = "Multiple Paragraphs"
20
  top_k: int = 60
 
14
 
15
  @dataclass
16
  class QueryParam:
17
+ mode: Literal["local", "global", "hybrid", "naive"] = "global"
18
  only_need_context: bool = False
19
  response_type: str = "Multiple Paragraphs"
20
  top_k: int = 60
lightrag/lightrag.py CHANGED
@@ -3,7 +3,8 @@ import os
3
  from dataclasses import asdict, dataclass, field
4
  from datetime import datetime
5
  from functools import partial
6
- from typing import Type, cast
 
7
 
8
  from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
9
  from .operate import (
@@ -11,7 +12,7 @@ from .operate import (
11
  extract_entities,
12
  local_query,
13
  global_query,
14
- hybird_query,
15
  naive_query,
16
  )
17
 
@@ -38,15 +39,14 @@ from .base import (
38
 
39
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
40
  try:
41
- # If there is already an event loop, use it.
42
- loop = asyncio.get_event_loop()
43
  except RuntimeError:
44
- # If in a sub-thread, create a new event loop.
45
  logger.info("Creating a new event loop in a sub-thread.")
46
  loop = asyncio.new_event_loop()
47
  asyncio.set_event_loop(loop)
48
  return loop
49
 
 
50
  @dataclass
51
  class LightRAG:
52
  working_dir: str = field(
@@ -77,6 +77,9 @@ class LightRAG:
77
  )
78
 
79
  # text embedding
 
 
 
80
  # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
81
  embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
82
  embedding_batch_num: int = 32
@@ -100,6 +103,13 @@ class LightRAG:
100
  convert_response_to_json_func: callable = convert_response_to_json
101
 
102
  def __post_init__(self):
 
 
 
 
 
 
 
103
  log_file = os.path.join(self.working_dir, "lightrag.log")
104
  set_logger(log_file)
105
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
@@ -130,8 +140,11 @@ class LightRAG:
130
  namespace="chunk_entity_relation", global_config=asdict(self)
131
  )
132
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
133
- self.embedding_func
 
 
134
  )
 
135
  self.entities_vdb = (
136
  self.vector_db_storage_cls(
137
  namespace="entities",
@@ -267,8 +280,8 @@ class LightRAG:
267
  param,
268
  asdict(self),
269
  )
270
- elif param.mode == "hybird":
271
- response = await hybird_query(
272
  query,
273
  self.chunk_entity_relation_graph,
274
  self.entities_vdb,
 
3
  from dataclasses import asdict, dataclass, field
4
  from datetime import datetime
5
  from functools import partial
6
+ from typing import Type, cast, Any
7
+ from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
8
 
9
  from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
10
  from .operate import (
 
12
  extract_entities,
13
  local_query,
14
  global_query,
15
+ hybrid_query,
16
  naive_query,
17
  )
18
 
 
39
 
40
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
41
  try:
42
+ loop = asyncio.get_running_loop()
 
43
  except RuntimeError:
 
44
  logger.info("Creating a new event loop in a sub-thread.")
45
  loop = asyncio.new_event_loop()
46
  asyncio.set_event_loop(loop)
47
  return loop
48
 
49
+
50
  @dataclass
51
  class LightRAG:
52
  working_dir: str = field(
 
77
  )
78
 
79
  # text embedding
80
+ tokenizer: Any = None
81
+ embed_model: Any = None
82
+
83
  # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
84
  embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
85
  embedding_batch_num: int = 32
 
103
  convert_response_to_json_func: callable = convert_response_to_json
104
 
105
  def __post_init__(self):
106
+ if callable(self.embedding_func) and self.embedding_func.__name__ == 'hf_embedding':
107
+ if self.tokenizer is None:
108
+ self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
109
+ if self.embed_model is None:
110
+ self.embed_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
111
+
112
+
113
  log_file = os.path.join(self.working_dir, "lightrag.log")
114
  set_logger(log_file)
115
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
 
140
  namespace="chunk_entity_relation", global_config=asdict(self)
141
  )
142
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
143
+ lambda texts: self.embedding_func(texts, self.tokenizer, self.embed_model)
144
+ if callable(self.embedding_func) and self.embedding_func.__name__ == 'hf_embedding'
145
+ else self.embedding_func(texts)
146
  )
147
+
148
  self.entities_vdb = (
149
  self.vector_db_storage_cls(
150
  namespace="entities",
 
280
  param,
281
  asdict(self),
282
  )
283
+ elif param.mode == "hybrid":
284
+ response = await hybrid_query(
285
  query,
286
  self.chunk_entity_relation_graph,
287
  self.entities_vdb,
lightrag/llm.py CHANGED
@@ -142,18 +142,14 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
142
 
143
 
144
 
145
- global EMBED_MODEL
146
- global tokenizer
147
- EMBED_MODEL = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
148
- tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
149
  @wrap_embedding_func_with_attrs(
150
  embedding_dim=384,
151
  max_token_size=5000,
152
  )
153
- async def hf_embedding(texts: list[str]) -> np.ndarray:
154
  input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
155
  with torch.no_grad():
156
- outputs = EMBED_MODEL(input_ids)
157
  embeddings = outputs.last_hidden_state.mean(dim=1)
158
  return embeddings.detach().numpy()
159
 
 
142
 
143
 
144
 
 
 
 
 
145
  @wrap_embedding_func_with_attrs(
146
  embedding_dim=384,
147
  max_token_size=5000,
148
  )
149
+ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
150
  input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
151
  with torch.no_grad():
152
+ outputs = embed_model(input_ids)
153
  embeddings = outputs.last_hidden_state.mean(dim=1)
154
  return embeddings.detach().numpy()
155
 
lightrag/operate.py CHANGED
@@ -827,7 +827,7 @@ async def _find_related_text_unit_from_relationships(
827
 
828
  return all_text_units
829
 
830
- async def hybird_query(
831
  query,
832
  knowledge_graph_inst: BaseGraphStorage,
833
  entities_vdb: BaseVectorStorage,
 
827
 
828
  return all_text_units
829
 
830
+ async def hybrid_query(
831
  query,
832
  knowledge_graph_inst: BaseGraphStorage,
833
  entities_vdb: BaseVectorStorage,
reproduce/Step_3.py CHANGED
@@ -52,7 +52,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
52
 
53
  if __name__ == "__main__":
54
  cls = "agriculture"
55
- mode = "hybird"
56
  WORKING_DIR = "../{cls}"
57
 
58
  rag = LightRAG(working_dir=WORKING_DIR)
 
52
 
53
  if __name__ == "__main__":
54
  cls = "agriculture"
55
+ mode = "hybrid"
56
  WORKING_DIR = "../{cls}"
57
 
58
  rag = LightRAG(working_dir=WORKING_DIR)