Merge branch 'drahnreb/add-custom-tokenizer'
Browse files- README-zh.md +2 -1
- README.md +2 -1
- examples/lightrag_gemini_demo_no_tiktoken.py +230 -0
- lightrag/api/routers/ollama_api.py +2 -2
- lightrag/lightrag.py +38 -21
- lightrag/operate.py +52 -27
- lightrag/utils.py +87 -19
README-zh.md
CHANGED
@@ -1090,7 +1090,8 @@ rag.clear_cache(modes=["local"])
|
|
1090 |
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
1091 |
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
|
1092 |
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
|
1093 |
-
| **
|
|
|
1094 |
| **entity_extract_max_gleaning** | `int` | 实体提取过程中的循环次数,附加历史消息 | `1` |
|
1095 |
| **entity_summary_to_max_tokens** | `int` | 每个实体摘要的最大令牌大小 | `500` |
|
1096 |
| **node_embedding_algorithm** | `str` | 节点嵌入算法(当前未使用) | `node2vec` |
|
|
|
1090 |
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
1091 |
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
|
1092 |
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
|
1093 |
+
| **tokenizer** | `Tokenizer` | 用于将文本转换为 tokens(数字)以及使用遵循 TokenizerInterface 协议的 .encode() 和 .decode() 函数将 tokens 转换回文本的函数。 如果您不指定,它将使用默认的 Tiktoken tokenizer。 | `TiktokenTokenizer` |
|
1094 |
+
| **tiktoken_model_name** | `str` | 如果您使用的是默认的 Tiktoken tokenizer,那么这是要使用的特定 Tiktoken 模型的名称。如果您提供自己的 tokenizer,则忽略此设置。 | `gpt-4o-mini` |
|
1095 |
| **entity_extract_max_gleaning** | `int` | 实体提取过程中的循环次数,附加历史消息 | `1` |
|
1096 |
| **entity_summary_to_max_tokens** | `int` | 每个实体摘要的最大令牌大小 | `500` |
|
1097 |
| **node_embedding_algorithm** | `str` | 节点嵌入算法(当前未使用) | `node2vec` |
|
README.md
CHANGED
@@ -1156,7 +1156,8 @@ Valid modes are:
|
|
1156 |
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
1157 |
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
|
1158 |
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
|
1159 |
-
| **
|
|
|
1160 |
| **entity_extract_max_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
|
1161 |
| **entity_summary_to_max_tokens** | `int` | Maximum token size for each entity summary | `500` |
|
1162 |
| **node_embedding_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
|
|
|
1156 |
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
1157 |
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
|
1158 |
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
|
1159 |
+
| **tokenizer** | `Tokenizer` | The function used to convert text into tokens (numbers) and back using .encode() and .decode() functions following `TokenizerInterface` protocol. If you don't specify one, it will use the default Tiktoken tokenizer. | `TiktokenTokenizer` |
|
1160 |
+
| **tiktoken_model_name** | `str` | If you're using the default Tiktoken tokenizer, this is the name of the specific Tiktoken model to use. This setting is ignored if you provide your own tokenizer. | `gpt-4o-mini` |
|
1161 |
| **entity_extract_max_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
|
1162 |
| **entity_summary_to_max_tokens** | `int` | Maximum token size for each entity summary | `500` |
|
1163 |
| **node_embedding_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
|
examples/lightrag_gemini_demo_no_tiktoken.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install -q -U google-genai to use gemini as a client
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import Optional
|
5 |
+
import dataclasses
|
6 |
+
from pathlib import Path
|
7 |
+
import hashlib
|
8 |
+
import numpy as np
|
9 |
+
from google import genai
|
10 |
+
from google.genai import types
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
from lightrag.utils import EmbeddingFunc, Tokenizer
|
13 |
+
from lightrag import LightRAG, QueryParam
|
14 |
+
from sentence_transformers import SentenceTransformer
|
15 |
+
from lightrag.kg.shared_storage import initialize_pipeline_status
|
16 |
+
import sentencepiece as spm
|
17 |
+
import requests
|
18 |
+
|
19 |
+
import asyncio
|
20 |
+
import nest_asyncio
|
21 |
+
|
22 |
+
# Apply nest_asyncio to solve event loop issues
|
23 |
+
nest_asyncio.apply()
|
24 |
+
|
25 |
+
load_dotenv()
|
26 |
+
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
27 |
+
|
28 |
+
WORKING_DIR = "./dickens"
|
29 |
+
|
30 |
+
if os.path.exists(WORKING_DIR):
|
31 |
+
import shutil
|
32 |
+
|
33 |
+
shutil.rmtree(WORKING_DIR)
|
34 |
+
|
35 |
+
os.mkdir(WORKING_DIR)
|
36 |
+
|
37 |
+
|
38 |
+
class GemmaTokenizer(Tokenizer):
|
39 |
+
# adapted from google-cloud-aiplatform[tokenization]
|
40 |
+
|
41 |
+
@dataclasses.dataclass(frozen=True)
|
42 |
+
class _TokenizerConfig:
|
43 |
+
tokenizer_model_url: str
|
44 |
+
tokenizer_model_hash: str
|
45 |
+
|
46 |
+
_TOKENIZERS = {
|
47 |
+
"google/gemma2": _TokenizerConfig(
|
48 |
+
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
|
49 |
+
tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
|
50 |
+
),
|
51 |
+
"google/gemma3": _TokenizerConfig(
|
52 |
+
tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
|
53 |
+
tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
|
54 |
+
),
|
55 |
+
}
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
|
59 |
+
):
|
60 |
+
# https://github.com/google/gemma_pytorch/tree/main/tokenizer
|
61 |
+
if "1.5" in model_name or "1.0" in model_name:
|
62 |
+
# up to gemini 1.5 gemma2 is a comparable local tokenizer
|
63 |
+
# https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py
|
64 |
+
tokenizer_name = "google/gemma2"
|
65 |
+
else:
|
66 |
+
# for gemini > 2.0 gemma3 was used
|
67 |
+
tokenizer_name = "google/gemma3"
|
68 |
+
|
69 |
+
file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url
|
70 |
+
tokenizer_model_name = file_url.rsplit("/", 1)[1]
|
71 |
+
expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash
|
72 |
+
|
73 |
+
tokenizer_dir = Path(tokenizer_dir)
|
74 |
+
if tokenizer_dir.is_dir():
|
75 |
+
file_path = tokenizer_dir / tokenizer_model_name
|
76 |
+
model_data = self._maybe_load_from_cache(
|
77 |
+
file_path=file_path, expected_hash=expected_hash
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
model_data = None
|
81 |
+
if not model_data:
|
82 |
+
model_data = self._load_from_url(
|
83 |
+
file_url=file_url, expected_hash=expected_hash
|
84 |
+
)
|
85 |
+
self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
|
86 |
+
|
87 |
+
tokenizer = spm.SentencePieceProcessor()
|
88 |
+
tokenizer.LoadFromSerializedProto(model_data)
|
89 |
+
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
90 |
+
|
91 |
+
def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool:
|
92 |
+
"""Returns true if the content is valid by checking the hash."""
|
93 |
+
return hashlib.sha256(model_data).hexdigest() == expected_hash
|
94 |
+
|
95 |
+
def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes:
|
96 |
+
"""Loads the model data from the cache path."""
|
97 |
+
if not file_path.is_file():
|
98 |
+
return
|
99 |
+
with open(file_path, "rb") as f:
|
100 |
+
content = f.read()
|
101 |
+
if self._is_valid_model(model_data=content, expected_hash=expected_hash):
|
102 |
+
return content
|
103 |
+
|
104 |
+
# Cached file corrupted.
|
105 |
+
self._maybe_remove_file(file_path)
|
106 |
+
|
107 |
+
def _load_from_url(self, file_url: str, expected_hash: str) -> bytes:
|
108 |
+
"""Loads model bytes from the given file url."""
|
109 |
+
resp = requests.get(file_url)
|
110 |
+
resp.raise_for_status()
|
111 |
+
content = resp.content
|
112 |
+
|
113 |
+
if not self._is_valid_model(model_data=content, expected_hash=expected_hash):
|
114 |
+
actual_hash = hashlib.sha256(content).hexdigest()
|
115 |
+
raise ValueError(
|
116 |
+
f"Downloaded model file is corrupted."
|
117 |
+
f" Expected hash {expected_hash}. Got file hash {actual_hash}."
|
118 |
+
)
|
119 |
+
return content
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None:
|
123 |
+
"""Saves the model data to the cache path."""
|
124 |
+
try:
|
125 |
+
if not cache_path.is_file():
|
126 |
+
cache_dir = cache_path.parent
|
127 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
128 |
+
with open(cache_path, "wb") as f:
|
129 |
+
f.write(model_data)
|
130 |
+
except OSError:
|
131 |
+
# Don't raise if we cannot write file.
|
132 |
+
pass
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def _maybe_remove_file(file_path: Path) -> None:
|
136 |
+
"""Removes the file if exists."""
|
137 |
+
if not file_path.is_file():
|
138 |
+
return
|
139 |
+
try:
|
140 |
+
file_path.unlink()
|
141 |
+
except OSError:
|
142 |
+
# Don't raise if we cannot remove file.
|
143 |
+
pass
|
144 |
+
|
145 |
+
# def encode(self, content: str) -> list[int]:
|
146 |
+
# return self.tokenizer.encode(content)
|
147 |
+
|
148 |
+
# def decode(self, tokens: list[int]) -> str:
|
149 |
+
# return self.tokenizer.decode(tokens)
|
150 |
+
|
151 |
+
|
152 |
+
async def llm_model_func(
|
153 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
154 |
+
) -> str:
|
155 |
+
# 1. Initialize the GenAI Client with your Gemini API Key
|
156 |
+
client = genai.Client(api_key=gemini_api_key)
|
157 |
+
|
158 |
+
# 2. Combine prompts: system prompt, history, and user prompt
|
159 |
+
if history_messages is None:
|
160 |
+
history_messages = []
|
161 |
+
|
162 |
+
combined_prompt = ""
|
163 |
+
if system_prompt:
|
164 |
+
combined_prompt += f"{system_prompt}\n"
|
165 |
+
|
166 |
+
for msg in history_messages:
|
167 |
+
# Each msg is expected to be a dict: {"role": "...", "content": "..."}
|
168 |
+
combined_prompt += f"{msg['role']}: {msg['content']}\n"
|
169 |
+
|
170 |
+
# Finally, add the new user prompt
|
171 |
+
combined_prompt += f"user: {prompt}"
|
172 |
+
|
173 |
+
# 3. Call the Gemini model
|
174 |
+
response = client.models.generate_content(
|
175 |
+
model="gemini-1.5-flash",
|
176 |
+
contents=[combined_prompt],
|
177 |
+
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
|
178 |
+
)
|
179 |
+
|
180 |
+
# 4. Return the response text
|
181 |
+
return response.text
|
182 |
+
|
183 |
+
|
184 |
+
async def embedding_func(texts: list[str]) -> np.ndarray:
|
185 |
+
model = SentenceTransformer("all-MiniLM-L6-v2")
|
186 |
+
embeddings = model.encode(texts, convert_to_numpy=True)
|
187 |
+
return embeddings
|
188 |
+
|
189 |
+
|
190 |
+
async def initialize_rag():
|
191 |
+
rag = LightRAG(
|
192 |
+
working_dir=WORKING_DIR,
|
193 |
+
# tiktoken_model_name="gpt-4o-mini",
|
194 |
+
tokenizer=GemmaTokenizer(
|
195 |
+
tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
|
196 |
+
model_name="gemini-2.0-flash",
|
197 |
+
),
|
198 |
+
llm_model_func=llm_model_func,
|
199 |
+
embedding_func=EmbeddingFunc(
|
200 |
+
embedding_dim=384,
|
201 |
+
max_token_size=8192,
|
202 |
+
func=embedding_func,
|
203 |
+
),
|
204 |
+
)
|
205 |
+
|
206 |
+
await rag.initialize_storages()
|
207 |
+
await initialize_pipeline_status()
|
208 |
+
|
209 |
+
return rag
|
210 |
+
|
211 |
+
|
212 |
+
def main():
|
213 |
+
# Initialize RAG instance
|
214 |
+
rag = asyncio.run(initialize_rag())
|
215 |
+
file_path = "story.txt"
|
216 |
+
with open(file_path, "r") as file:
|
217 |
+
text = file.read()
|
218 |
+
|
219 |
+
rag.insert(text)
|
220 |
+
|
221 |
+
response = rag.query(
|
222 |
+
query="What is the main theme of the story?",
|
223 |
+
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
|
224 |
+
)
|
225 |
+
|
226 |
+
print(response)
|
227 |
+
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
main()
|
lightrag/api/routers/ollama_api.py
CHANGED
@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
|
|
10 |
import asyncio
|
11 |
from ascii_colors import trace_exception
|
12 |
from lightrag import LightRAG, QueryParam
|
13 |
-
from lightrag.utils import
|
14 |
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
15 |
from fastapi import Depends
|
16 |
|
@@ -97,7 +97,7 @@ class OllamaTagResponse(BaseModel):
|
|
97 |
|
98 |
def estimate_tokens(text: str) -> int:
|
99 |
"""Estimate the number of tokens in text using tiktoken"""
|
100 |
-
tokens =
|
101 |
return len(tokens)
|
102 |
|
103 |
|
|
|
10 |
import asyncio
|
11 |
from ascii_colors import trace_exception
|
12 |
from lightrag import LightRAG, QueryParam
|
13 |
+
from lightrag.utils import TiktokenTokenizer
|
14 |
from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
|
15 |
from fastapi import Depends
|
16 |
|
|
|
97 |
|
98 |
def estimate_tokens(text: str) -> int:
|
99 |
"""Estimate the number of tokens in text using tiktoken"""
|
100 |
+
tokens = TiktokenTokenizer().encode(text)
|
101 |
return len(tokens)
|
102 |
|
103 |
|
lightrag/lightrag.py
CHANGED
@@ -7,7 +7,18 @@ import warnings
|
|
7 |
from dataclasses import asdict, dataclass, field
|
8 |
from datetime import datetime
|
9 |
from functools import partial
|
10 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
from lightrag.kg import (
|
13 |
STORAGES,
|
@@ -41,11 +52,12 @@ from .operate import (
|
|
41 |
)
|
42 |
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
43 |
from .utils import (
|
|
|
|
|
44 |
EmbeddingFunc,
|
45 |
always_get_an_event_loop,
|
46 |
compute_mdhash_id,
|
47 |
convert_response_to_json,
|
48 |
-
encode_string_by_tiktoken,
|
49 |
lazy_external_import,
|
50 |
limit_async_func_call,
|
51 |
get_content_summary,
|
@@ -122,33 +134,38 @@ class LightRAG:
|
|
122 |
)
|
123 |
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
124 |
|
125 |
-
|
126 |
-
"""
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
|
|
|
129 |
|
130 |
chunking_func: Callable[
|
131 |
[
|
|
|
132 |
str,
|
133 |
-
str
|
134 |
bool,
|
135 |
int,
|
136 |
int,
|
137 |
-
str,
|
138 |
],
|
139 |
-
|
140 |
] = field(default_factory=lambda: chunking_by_token_size)
|
141 |
"""
|
142 |
Custom chunking function for splitting text into chunks before processing.
|
143 |
|
144 |
The function should take the following parameters:
|
145 |
|
|
|
146 |
- `content`: The text to be split into chunks.
|
147 |
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
|
148 |
- `split_by_character_only`: If True, the text is split only on the specified character.
|
149 |
- `chunk_token_size`: The maximum number of tokens per chunk.
|
150 |
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
151 |
-
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
|
152 |
|
153 |
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
154 |
- `tokens`: The number of tokens in the chunk.
|
@@ -310,7 +327,15 @@ class LightRAG:
|
|
310 |
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
311 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
312 |
|
313 |
-
# Init
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
315 |
self.embedding_func
|
316 |
)
|
@@ -603,11 +628,7 @@ class LightRAG:
|
|
603 |
inserting_chunks: dict[str, Any] = {}
|
604 |
for index, chunk_text in enumerate(text_chunks):
|
605 |
chunk_key = compute_mdhash_id(chunk_text, prefix="chunk-")
|
606 |
-
tokens = len(
|
607 |
-
encode_string_by_tiktoken(
|
608 |
-
chunk_text, model_name=self.tiktoken_model_name
|
609 |
-
)
|
610 |
-
)
|
611 |
inserting_chunks[chunk_key] = {
|
612 |
"content": chunk_text,
|
613 |
"full_doc_id": doc_key,
|
@@ -900,12 +921,12 @@ class LightRAG:
|
|
900 |
"file_path": file_path, # Add file path to each chunk
|
901 |
}
|
902 |
for dp in self.chunking_func(
|
|
|
903 |
status_doc.content,
|
904 |
split_by_character,
|
905 |
split_by_character_only,
|
906 |
self.chunk_overlap_token_size,
|
907 |
self.chunk_token_size,
|
908 |
-
self.tiktoken_model_name,
|
909 |
)
|
910 |
}
|
911 |
|
@@ -1133,11 +1154,7 @@ class LightRAG:
|
|
1133 |
for chunk_data in custom_kg.get("chunks", []):
|
1134 |
chunk_content = clean_text(chunk_data["content"])
|
1135 |
source_id = chunk_data["source_id"]
|
1136 |
-
tokens = len(
|
1137 |
-
encode_string_by_tiktoken(
|
1138 |
-
chunk_content, model_name=self.tiktoken_model_name
|
1139 |
-
)
|
1140 |
-
)
|
1141 |
chunk_order_index = (
|
1142 |
0
|
1143 |
if "chunk_order_index" not in chunk_data.keys()
|
|
|
7 |
from dataclasses import asdict, dataclass, field
|
8 |
from datetime import datetime
|
9 |
from functools import partial
|
10 |
+
from typing import (
|
11 |
+
Any,
|
12 |
+
AsyncIterator,
|
13 |
+
Callable,
|
14 |
+
Iterator,
|
15 |
+
cast,
|
16 |
+
final,
|
17 |
+
Literal,
|
18 |
+
Optional,
|
19 |
+
List,
|
20 |
+
Dict,
|
21 |
+
)
|
22 |
|
23 |
from lightrag.kg import (
|
24 |
STORAGES,
|
|
|
52 |
)
|
53 |
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
54 |
from .utils import (
|
55 |
+
Tokenizer,
|
56 |
+
TiktokenTokenizer,
|
57 |
EmbeddingFunc,
|
58 |
always_get_an_event_loop,
|
59 |
compute_mdhash_id,
|
60 |
convert_response_to_json,
|
|
|
61 |
lazy_external_import,
|
62 |
limit_async_func_call,
|
63 |
get_content_summary,
|
|
|
134 |
)
|
135 |
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
|
136 |
|
137 |
+
tokenizer: Optional[Tokenizer] = field(default=None)
|
138 |
+
"""
|
139 |
+
A function that returns a Tokenizer instance.
|
140 |
+
If None, and a `tiktoken_model_name` is provided, a TiktokenTokenizer will be created.
|
141 |
+
If both are None, the default TiktokenTokenizer is used.
|
142 |
+
"""
|
143 |
|
144 |
+
tiktoken_model_name: str = field(default="gpt-4o-mini")
|
145 |
+
"""Model name used for tokenization when chunking text with tiktoken. Defaults to `gpt-4o-mini`."""
|
146 |
|
147 |
chunking_func: Callable[
|
148 |
[
|
149 |
+
Tokenizer,
|
150 |
str,
|
151 |
+
Optional[str],
|
152 |
bool,
|
153 |
int,
|
154 |
int,
|
|
|
155 |
],
|
156 |
+
List[Dict[str, Any]],
|
157 |
] = field(default_factory=lambda: chunking_by_token_size)
|
158 |
"""
|
159 |
Custom chunking function for splitting text into chunks before processing.
|
160 |
|
161 |
The function should take the following parameters:
|
162 |
|
163 |
+
- `tokenizer`: A Tokenizer instance to use for tokenization.
|
164 |
- `content`: The text to be split into chunks.
|
165 |
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
|
166 |
- `split_by_character_only`: If True, the text is split only on the specified character.
|
167 |
- `chunk_token_size`: The maximum number of tokens per chunk.
|
168 |
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
|
|
169 |
|
170 |
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
171 |
- `tokens`: The number of tokens in the chunk.
|
|
|
327 |
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
328 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
329 |
|
330 |
+
# Init Tokenizer
|
331 |
+
# Post-initialization hook to handle backward compatabile tokenizer initialization based on provided parameters
|
332 |
+
if self.tokenizer is None:
|
333 |
+
if self.tiktoken_model_name:
|
334 |
+
self.tokenizer = TiktokenTokenizer(self.tiktoken_model_name)
|
335 |
+
else:
|
336 |
+
self.tokenizer = TiktokenTokenizer()
|
337 |
+
|
338 |
+
# Init Embedding
|
339 |
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
340 |
self.embedding_func
|
341 |
)
|
|
|
628 |
inserting_chunks: dict[str, Any] = {}
|
629 |
for index, chunk_text in enumerate(text_chunks):
|
630 |
chunk_key = compute_mdhash_id(chunk_text, prefix="chunk-")
|
631 |
+
tokens = len(self.tokenizer.encode(chunk_text))
|
|
|
|
|
|
|
|
|
632 |
inserting_chunks[chunk_key] = {
|
633 |
"content": chunk_text,
|
634 |
"full_doc_id": doc_key,
|
|
|
921 |
"file_path": file_path, # Add file path to each chunk
|
922 |
}
|
923 |
for dp in self.chunking_func(
|
924 |
+
self.tokenizer,
|
925 |
status_doc.content,
|
926 |
split_by_character,
|
927 |
split_by_character_only,
|
928 |
self.chunk_overlap_token_size,
|
929 |
self.chunk_token_size,
|
|
|
930 |
)
|
931 |
}
|
932 |
|
|
|
1154 |
for chunk_data in custom_kg.get("chunks", []):
|
1155 |
chunk_content = clean_text(chunk_data["content"])
|
1156 |
source_id = chunk_data["source_id"]
|
1157 |
+
tokens = len(self.tokenizer.encode(chunk_content))
|
|
|
|
|
|
|
|
|
1158 |
chunk_order_index = (
|
1159 |
0
|
1160 |
if "chunk_order_index" not in chunk_data.keys()
|
lightrag/operate.py
CHANGED
@@ -12,8 +12,7 @@ from .utils import (
|
|
12 |
logger,
|
13 |
clean_str,
|
14 |
compute_mdhash_id,
|
15 |
-
|
16 |
-
encode_string_by_tiktoken,
|
17 |
is_float_regex,
|
18 |
list_of_list_to_csv,
|
19 |
normalize_extracted_info,
|
@@ -46,32 +45,31 @@ load_dotenv(dotenv_path=".env", override=False)
|
|
46 |
|
47 |
|
48 |
def chunking_by_token_size(
|
|
|
49 |
content: str,
|
50 |
split_by_character: str | None = None,
|
51 |
split_by_character_only: bool = False,
|
52 |
overlap_token_size: int = 128,
|
53 |
max_token_size: int = 1024,
|
54 |
-
tiktoken_model: str = "gpt-4o",
|
55 |
) -> list[dict[str, Any]]:
|
56 |
-
tokens =
|
57 |
results: list[dict[str, Any]] = []
|
58 |
if split_by_character:
|
59 |
raw_chunks = content.split(split_by_character)
|
60 |
new_chunks = []
|
61 |
if split_by_character_only:
|
62 |
for chunk in raw_chunks:
|
63 |
-
_tokens =
|
64 |
new_chunks.append((len(_tokens), chunk))
|
65 |
else:
|
66 |
for chunk in raw_chunks:
|
67 |
-
_tokens =
|
68 |
if len(_tokens) > max_token_size:
|
69 |
for start in range(
|
70 |
0, len(_tokens), max_token_size - overlap_token_size
|
71 |
):
|
72 |
-
chunk_content =
|
73 |
-
_tokens[start : start + max_token_size]
|
74 |
-
model_name=tiktoken_model,
|
75 |
)
|
76 |
new_chunks.append(
|
77 |
(min(max_token_size, len(_tokens) - start), chunk_content)
|
@@ -90,9 +88,7 @@ def chunking_by_token_size(
|
|
90 |
for index, start in enumerate(
|
91 |
range(0, len(tokens), max_token_size - overlap_token_size)
|
92 |
):
|
93 |
-
chunk_content =
|
94 |
-
tokens[start : start + max_token_size], model_name=tiktoken_model
|
95 |
-
)
|
96 |
results.append(
|
97 |
{
|
98 |
"tokens": min(max_token_size, len(tokens) - start),
|
@@ -116,19 +112,19 @@ async def _handle_entity_relation_summary(
|
|
116 |
If too long, use LLM to summarize.
|
117 |
"""
|
118 |
use_llm_func: callable = global_config["llm_model_func"]
|
|
|
119 |
llm_max_tokens = global_config["llm_model_max_token_size"]
|
120 |
-
tiktoken_model_name = global_config["tiktoken_model_name"]
|
121 |
summary_max_tokens = global_config["summary_to_max_tokens"]
|
122 |
|
123 |
language = global_config["addon_params"].get(
|
124 |
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
125 |
)
|
126 |
|
127 |
-
tokens =
|
|
|
|
|
128 |
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
129 |
-
use_description =
|
130 |
-
tokens[:llm_max_tokens], model_name=tiktoken_model_name
|
131 |
-
)
|
132 |
context_base = dict(
|
133 |
entity_name=entity_or_relation_name,
|
134 |
description_list=use_description.split(GRAPH_FIELD_SEP),
|
@@ -865,7 +861,8 @@ async def kg_query(
|
|
865 |
if query_param.only_need_prompt:
|
866 |
return sys_prompt
|
867 |
|
868 |
-
|
|
|
869 |
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
870 |
|
871 |
response = await use_model_func(
|
@@ -987,7 +984,8 @@ async def extract_keywords_only(
|
|
987 |
query=text, examples=examples, language=language, history=history_context
|
988 |
)
|
989 |
|
990 |
-
|
|
|
991 |
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
992 |
|
993 |
# 5. Call the LLM for keyword extraction
|
@@ -1054,6 +1052,8 @@ async def mix_kg_vector_query(
|
|
1054 |
2. Retrieving relevant text chunks through vector similarity
|
1055 |
3. Combining both results for comprehensive answer generation
|
1056 |
"""
|
|
|
|
|
1057 |
# 1. Cache handling
|
1058 |
use_model_func = (
|
1059 |
query_param.model_func
|
@@ -1153,6 +1153,7 @@ async def mix_kg_vector_query(
|
|
1153 |
valid_chunks,
|
1154 |
key=lambda x: x["content"],
|
1155 |
max_token_size=query_param.max_token_for_text_unit,
|
|
|
1156 |
)
|
1157 |
|
1158 |
if not maybe_trun_chunks:
|
@@ -1210,7 +1211,7 @@ async def mix_kg_vector_query(
|
|
1210 |
if query_param.only_need_prompt:
|
1211 |
return sys_prompt
|
1212 |
|
1213 |
-
len_of_prompts = len(
|
1214 |
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
1215 |
|
1216 |
# 6. Generate response
|
@@ -1373,17 +1374,24 @@ async def _get_node_data(
|
|
1373 |
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
1374 |
# get entitytext chunk
|
1375 |
use_text_units = await _find_most_related_text_unit_from_entities(
|
1376 |
-
node_datas,
|
|
|
|
|
|
|
1377 |
)
|
1378 |
use_relations = await _find_most_related_edges_from_entities(
|
1379 |
-
node_datas,
|
|
|
|
|
1380 |
)
|
1381 |
|
|
|
1382 |
len_node_datas = len(node_datas)
|
1383 |
node_datas = truncate_list_by_token_size(
|
1384 |
node_datas,
|
1385 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1386 |
max_token_size=query_param.max_token_for_local_context,
|
|
|
1387 |
)
|
1388 |
logger.debug(
|
1389 |
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
@@ -1558,14 +1566,15 @@ async def _find_most_related_text_unit_from_entities(
|
|
1558 |
logger.warning("No valid text units found")
|
1559 |
return []
|
1560 |
|
|
|
1561 |
all_text_units = sorted(
|
1562 |
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
1563 |
)
|
1564 |
-
|
1565 |
all_text_units = truncate_list_by_token_size(
|
1566 |
all_text_units,
|
1567 |
key=lambda x: x["data"]["content"],
|
1568 |
max_token_size=query_param.max_token_for_text_unit,
|
|
|
1569 |
)
|
1570 |
|
1571 |
logger.debug(
|
@@ -1619,6 +1628,7 @@ async def _find_most_related_edges_from_entities(
|
|
1619 |
}
|
1620 |
all_edges_data.append(combined)
|
1621 |
|
|
|
1622 |
all_edges_data = sorted(
|
1623 |
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1624 |
)
|
@@ -1626,6 +1636,7 @@ async def _find_most_related_edges_from_entities(
|
|
1626 |
all_edges_data,
|
1627 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1628 |
max_token_size=query_param.max_token_for_global_context,
|
|
|
1629 |
)
|
1630 |
|
1631 |
logger.debug(
|
@@ -1681,6 +1692,7 @@ async def _get_edge_data(
|
|
1681 |
}
|
1682 |
edge_datas.append(combined)
|
1683 |
|
|
|
1684 |
edge_datas = sorted(
|
1685 |
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1686 |
)
|
@@ -1688,13 +1700,19 @@ async def _get_edge_data(
|
|
1688 |
edge_datas,
|
1689 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1690 |
max_token_size=query_param.max_token_for_global_context,
|
|
|
1691 |
)
|
1692 |
use_entities, use_text_units = await asyncio.gather(
|
1693 |
_find_most_related_entities_from_relationships(
|
1694 |
-
edge_datas,
|
|
|
|
|
1695 |
),
|
1696 |
_find_related_text_unit_from_relationships(
|
1697 |
-
edge_datas,
|
|
|
|
|
|
|
1698 |
),
|
1699 |
)
|
1700 |
logger.info(
|
@@ -1804,11 +1822,13 @@ async def _find_most_related_entities_from_relationships(
|
|
1804 |
combined = {**node, "entity_name": entity_name, "rank": degree}
|
1805 |
node_datas.append(combined)
|
1806 |
|
|
|
1807 |
len_node_datas = len(node_datas)
|
1808 |
node_datas = truncate_list_by_token_size(
|
1809 |
node_datas,
|
1810 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1811 |
max_token_size=query_param.max_token_for_local_context,
|
|
|
1812 |
)
|
1813 |
logger.debug(
|
1814 |
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
@@ -1863,10 +1883,12 @@ async def _find_related_text_unit_from_relationships(
|
|
1863 |
logger.warning("No valid text chunks after filtering")
|
1864 |
return []
|
1865 |
|
|
|
1866 |
truncated_text_units = truncate_list_by_token_size(
|
1867 |
valid_text_units,
|
1868 |
key=lambda x: x["data"]["content"],
|
1869 |
max_token_size=query_param.max_token_for_text_unit,
|
|
|
1870 |
)
|
1871 |
|
1872 |
logger.debug(
|
@@ -1937,10 +1959,12 @@ async def naive_query(
|
|
1937 |
logger.warning("No valid chunks found after filtering")
|
1938 |
return PROMPTS["fail_response"]
|
1939 |
|
|
|
1940 |
maybe_trun_chunks = truncate_list_by_token_size(
|
1941 |
valid_chunks,
|
1942 |
key=lambda x: x["content"],
|
1943 |
max_token_size=query_param.max_token_for_text_unit,
|
|
|
1944 |
)
|
1945 |
|
1946 |
if not maybe_trun_chunks:
|
@@ -1978,7 +2002,7 @@ async def naive_query(
|
|
1978 |
if query_param.only_need_prompt:
|
1979 |
return sys_prompt
|
1980 |
|
1981 |
-
len_of_prompts = len(
|
1982 |
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
1983 |
|
1984 |
response = await use_model_func(
|
@@ -2125,7 +2149,8 @@ async def kg_query_with_keywords(
|
|
2125 |
if query_param.only_need_prompt:
|
2126 |
return sys_prompt
|
2127 |
|
2128 |
-
|
|
|
2129 |
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
|
2130 |
|
2131 |
# 6. Generate response
|
|
|
12 |
logger,
|
13 |
clean_str,
|
14 |
compute_mdhash_id,
|
15 |
+
Tokenizer,
|
|
|
16 |
is_float_regex,
|
17 |
list_of_list_to_csv,
|
18 |
normalize_extracted_info,
|
|
|
45 |
|
46 |
|
47 |
def chunking_by_token_size(
|
48 |
+
tokenizer: Tokenizer,
|
49 |
content: str,
|
50 |
split_by_character: str | None = None,
|
51 |
split_by_character_only: bool = False,
|
52 |
overlap_token_size: int = 128,
|
53 |
max_token_size: int = 1024,
|
|
|
54 |
) -> list[dict[str, Any]]:
|
55 |
+
tokens = tokenizer.encode(content)
|
56 |
results: list[dict[str, Any]] = []
|
57 |
if split_by_character:
|
58 |
raw_chunks = content.split(split_by_character)
|
59 |
new_chunks = []
|
60 |
if split_by_character_only:
|
61 |
for chunk in raw_chunks:
|
62 |
+
_tokens = tokenizer.encode(chunk)
|
63 |
new_chunks.append((len(_tokens), chunk))
|
64 |
else:
|
65 |
for chunk in raw_chunks:
|
66 |
+
_tokens = tokenizer.encode(chunk)
|
67 |
if len(_tokens) > max_token_size:
|
68 |
for start in range(
|
69 |
0, len(_tokens), max_token_size - overlap_token_size
|
70 |
):
|
71 |
+
chunk_content = tokenizer.decode(
|
72 |
+
_tokens[start : start + max_token_size]
|
|
|
73 |
)
|
74 |
new_chunks.append(
|
75 |
(min(max_token_size, len(_tokens) - start), chunk_content)
|
|
|
88 |
for index, start in enumerate(
|
89 |
range(0, len(tokens), max_token_size - overlap_token_size)
|
90 |
):
|
91 |
+
chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
|
|
|
|
|
92 |
results.append(
|
93 |
{
|
94 |
"tokens": min(max_token_size, len(tokens) - start),
|
|
|
112 |
If too long, use LLM to summarize.
|
113 |
"""
|
114 |
use_llm_func: callable = global_config["llm_model_func"]
|
115 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
116 |
llm_max_tokens = global_config["llm_model_max_token_size"]
|
|
|
117 |
summary_max_tokens = global_config["summary_to_max_tokens"]
|
118 |
|
119 |
language = global_config["addon_params"].get(
|
120 |
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
121 |
)
|
122 |
|
123 |
+
tokens = tokenizer.encode(description)
|
124 |
+
if len(tokens) < summary_max_tokens: # No need for summary
|
125 |
+
return description
|
126 |
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
127 |
+
use_description = tokenizer.decode(tokens[:llm_max_tokens])
|
|
|
|
|
128 |
context_base = dict(
|
129 |
entity_name=entity_or_relation_name,
|
130 |
description_list=use_description.split(GRAPH_FIELD_SEP),
|
|
|
861 |
if query_param.only_need_prompt:
|
862 |
return sys_prompt
|
863 |
|
864 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
865 |
+
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
866 |
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
867 |
|
868 |
response = await use_model_func(
|
|
|
984 |
query=text, examples=examples, language=language, history=history_context
|
985 |
)
|
986 |
|
987 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
988 |
+
len_of_prompts = len(tokenizer.encode(kw_prompt))
|
989 |
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
990 |
|
991 |
# 5. Call the LLM for keyword extraction
|
|
|
1052 |
2. Retrieving relevant text chunks through vector similarity
|
1053 |
3. Combining both results for comprehensive answer generation
|
1054 |
"""
|
1055 |
+
# get tokenizer
|
1056 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
1057 |
# 1. Cache handling
|
1058 |
use_model_func = (
|
1059 |
query_param.model_func
|
|
|
1153 |
valid_chunks,
|
1154 |
key=lambda x: x["content"],
|
1155 |
max_token_size=query_param.max_token_for_text_unit,
|
1156 |
+
tokenizer=tokenizer,
|
1157 |
)
|
1158 |
|
1159 |
if not maybe_trun_chunks:
|
|
|
1211 |
if query_param.only_need_prompt:
|
1212 |
return sys_prompt
|
1213 |
|
1214 |
+
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
1215 |
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
1216 |
|
1217 |
# 6. Generate response
|
|
|
1374 |
] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
|
1375 |
# get entitytext chunk
|
1376 |
use_text_units = await _find_most_related_text_unit_from_entities(
|
1377 |
+
node_datas,
|
1378 |
+
query_param,
|
1379 |
+
text_chunks_db,
|
1380 |
+
knowledge_graph_inst,
|
1381 |
)
|
1382 |
use_relations = await _find_most_related_edges_from_entities(
|
1383 |
+
node_datas,
|
1384 |
+
query_param,
|
1385 |
+
knowledge_graph_inst,
|
1386 |
)
|
1387 |
|
1388 |
+
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
1389 |
len_node_datas = len(node_datas)
|
1390 |
node_datas = truncate_list_by_token_size(
|
1391 |
node_datas,
|
1392 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1393 |
max_token_size=query_param.max_token_for_local_context,
|
1394 |
+
tokenizer=tokenizer,
|
1395 |
)
|
1396 |
logger.debug(
|
1397 |
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
|
|
1566 |
logger.warning("No valid text units found")
|
1567 |
return []
|
1568 |
|
1569 |
+
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
1570 |
all_text_units = sorted(
|
1571 |
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
1572 |
)
|
|
|
1573 |
all_text_units = truncate_list_by_token_size(
|
1574 |
all_text_units,
|
1575 |
key=lambda x: x["data"]["content"],
|
1576 |
max_token_size=query_param.max_token_for_text_unit,
|
1577 |
+
tokenizer=tokenizer,
|
1578 |
)
|
1579 |
|
1580 |
logger.debug(
|
|
|
1628 |
}
|
1629 |
all_edges_data.append(combined)
|
1630 |
|
1631 |
+
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
|
1632 |
all_edges_data = sorted(
|
1633 |
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1634 |
)
|
|
|
1636 |
all_edges_data,
|
1637 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1638 |
max_token_size=query_param.max_token_for_global_context,
|
1639 |
+
tokenizer=tokenizer,
|
1640 |
)
|
1641 |
|
1642 |
logger.debug(
|
|
|
1692 |
}
|
1693 |
edge_datas.append(combined)
|
1694 |
|
1695 |
+
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
1696 |
edge_datas = sorted(
|
1697 |
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1698 |
)
|
|
|
1700 |
edge_datas,
|
1701 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1702 |
max_token_size=query_param.max_token_for_global_context,
|
1703 |
+
tokenizer=tokenizer,
|
1704 |
)
|
1705 |
use_entities, use_text_units = await asyncio.gather(
|
1706 |
_find_most_related_entities_from_relationships(
|
1707 |
+
edge_datas,
|
1708 |
+
query_param,
|
1709 |
+
knowledge_graph_inst,
|
1710 |
),
|
1711 |
_find_related_text_unit_from_relationships(
|
1712 |
+
edge_datas,
|
1713 |
+
query_param,
|
1714 |
+
text_chunks_db,
|
1715 |
+
knowledge_graph_inst,
|
1716 |
),
|
1717 |
)
|
1718 |
logger.info(
|
|
|
1822 |
combined = {**node, "entity_name": entity_name, "rank": degree}
|
1823 |
node_datas.append(combined)
|
1824 |
|
1825 |
+
tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
|
1826 |
len_node_datas = len(node_datas)
|
1827 |
node_datas = truncate_list_by_token_size(
|
1828 |
node_datas,
|
1829 |
key=lambda x: x["description"] if x["description"] is not None else "",
|
1830 |
max_token_size=query_param.max_token_for_local_context,
|
1831 |
+
tokenizer=tokenizer,
|
1832 |
)
|
1833 |
logger.debug(
|
1834 |
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
|
|
|
1883 |
logger.warning("No valid text chunks after filtering")
|
1884 |
return []
|
1885 |
|
1886 |
+
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
|
1887 |
truncated_text_units = truncate_list_by_token_size(
|
1888 |
valid_text_units,
|
1889 |
key=lambda x: x["data"]["content"],
|
1890 |
max_token_size=query_param.max_token_for_text_unit,
|
1891 |
+
tokenizer=tokenizer,
|
1892 |
)
|
1893 |
|
1894 |
logger.debug(
|
|
|
1959 |
logger.warning("No valid chunks found after filtering")
|
1960 |
return PROMPTS["fail_response"]
|
1961 |
|
1962 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
1963 |
maybe_trun_chunks = truncate_list_by_token_size(
|
1964 |
valid_chunks,
|
1965 |
key=lambda x: x["content"],
|
1966 |
max_token_size=query_param.max_token_for_text_unit,
|
1967 |
+
tokenizer=tokenizer,
|
1968 |
)
|
1969 |
|
1970 |
if not maybe_trun_chunks:
|
|
|
2002 |
if query_param.only_need_prompt:
|
2003 |
return sys_prompt
|
2004 |
|
2005 |
+
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
2006 |
logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
|
2007 |
|
2008 |
response = await use_model_func(
|
|
|
2149 |
if query_param.only_need_prompt:
|
2150 |
return sys_prompt
|
2151 |
|
2152 |
+
tokenizer: Tokenizer = global_config["tokenizer"]
|
2153 |
+
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
2154 |
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
|
2155 |
|
2156 |
# 6. Generate response
|
lightrag/utils.py
CHANGED
@@ -12,10 +12,9 @@ import re
|
|
12 |
from dataclasses import dataclass
|
13 |
from functools import wraps
|
14 |
from hashlib import md5
|
15 |
-
from typing import Any, Callable, TYPE_CHECKING
|
16 |
import xml.etree.ElementTree as ET
|
17 |
import numpy as np
|
18 |
-
import tiktoken
|
19 |
from lightrag.prompt import PROMPTS
|
20 |
from dotenv import load_dotenv
|
21 |
|
@@ -193,9 +192,6 @@ class UnlimitedSemaphore:
|
|
193 |
pass
|
194 |
|
195 |
|
196 |
-
ENCODER = None
|
197 |
-
|
198 |
-
|
199 |
@dataclass
|
200 |
class EmbeddingFunc:
|
201 |
embedding_dim: int
|
@@ -311,20 +307,89 @@ def write_json(json_obj, file_name):
|
|
311 |
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
312 |
|
313 |
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
|
|
|
|
|
|
|
|
321 |
|
322 |
-
def
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
|
330 |
def pack_user_ass_to_openai_messages(*args: str):
|
@@ -361,14 +426,17 @@ def is_float_regex(value: str) -> bool:
|
|
361 |
|
362 |
|
363 |
def truncate_list_by_token_size(
|
364 |
-
list_data: list[Any],
|
|
|
|
|
|
|
365 |
) -> list[int]:
|
366 |
"""Truncate a list of data by token size"""
|
367 |
if max_token_size <= 0:
|
368 |
return []
|
369 |
tokens = 0
|
370 |
for i, data in enumerate(list_data):
|
371 |
-
tokens += len(
|
372 |
if tokens > max_token_size:
|
373 |
return list_data[:i]
|
374 |
return list_data
|
|
|
12 |
from dataclasses import dataclass
|
13 |
from functools import wraps
|
14 |
from hashlib import md5
|
15 |
+
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
16 |
import xml.etree.ElementTree as ET
|
17 |
import numpy as np
|
|
|
18 |
from lightrag.prompt import PROMPTS
|
19 |
from dotenv import load_dotenv
|
20 |
|
|
|
192 |
pass
|
193 |
|
194 |
|
|
|
|
|
|
|
195 |
@dataclass
|
196 |
class EmbeddingFunc:
|
197 |
embedding_dim: int
|
|
|
307 |
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
308 |
|
309 |
|
310 |
+
class TokenizerInterface(Protocol):
|
311 |
+
"""
|
312 |
+
Defines the interface for a tokenizer, requiring encode and decode methods.
|
313 |
+
"""
|
314 |
+
|
315 |
+
def encode(self, content: str) -> List[int]:
|
316 |
+
"""Encodes a string into a list of tokens."""
|
317 |
+
...
|
318 |
+
|
319 |
+
def decode(self, tokens: List[int]) -> str:
|
320 |
+
"""Decodes a list of tokens into a string."""
|
321 |
+
...
|
322 |
+
|
323 |
+
|
324 |
+
class Tokenizer:
|
325 |
+
"""
|
326 |
+
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
|
327 |
+
"""
|
328 |
+
|
329 |
+
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
|
330 |
+
"""
|
331 |
+
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
model_name: The associated model name for the tokenizer.
|
335 |
+
tokenizer: An instance of a class implementing the TokenizerInterface.
|
336 |
+
"""
|
337 |
+
self.model_name: str = model_name
|
338 |
+
self.tokenizer: TokenizerInterface = tokenizer
|
339 |
+
|
340 |
+
def encode(self, content: str) -> List[int]:
|
341 |
+
"""
|
342 |
+
Encodes a string into a list of tokens using the underlying tokenizer.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
content: The string to encode.
|
346 |
|
347 |
+
Returns:
|
348 |
+
A list of integer tokens.
|
349 |
+
"""
|
350 |
+
return self.tokenizer.encode(content)
|
351 |
|
352 |
+
def decode(self, tokens: List[int]) -> str:
|
353 |
+
"""
|
354 |
+
Decodes a list of tokens into a string using the underlying tokenizer.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
tokens: A list of integer tokens to decode.
|
358 |
+
|
359 |
+
Returns:
|
360 |
+
The decoded string.
|
361 |
+
"""
|
362 |
+
return self.tokenizer.decode(tokens)
|
363 |
+
|
364 |
+
|
365 |
+
class TiktokenTokenizer(Tokenizer):
|
366 |
+
"""
|
367 |
+
A Tokenizer implementation using the tiktoken library.
|
368 |
+
"""
|
369 |
+
|
370 |
+
def __init__(self, model_name: str = "gpt-4o-mini"):
|
371 |
+
"""
|
372 |
+
Initializes the TiktokenTokenizer with a specified model name.
|
373 |
+
|
374 |
+
Args:
|
375 |
+
model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini".
|
376 |
+
|
377 |
+
Raises:
|
378 |
+
ImportError: If tiktoken is not installed.
|
379 |
+
ValueError: If the model_name is invalid.
|
380 |
+
"""
|
381 |
+
try:
|
382 |
+
import tiktoken
|
383 |
+
except ImportError:
|
384 |
+
raise ImportError(
|
385 |
+
"tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`."
|
386 |
+
)
|
387 |
+
|
388 |
+
try:
|
389 |
+
tokenizer = tiktoken.encoding_for_model(model_name)
|
390 |
+
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
391 |
+
except KeyError:
|
392 |
+
raise ValueError(f"Invalid model_name: {model_name}.")
|
393 |
|
394 |
|
395 |
def pack_user_ass_to_openai_messages(*args: str):
|
|
|
426 |
|
427 |
|
428 |
def truncate_list_by_token_size(
|
429 |
+
list_data: list[Any],
|
430 |
+
key: Callable[[Any], str],
|
431 |
+
max_token_size: int,
|
432 |
+
tokenizer: Tokenizer,
|
433 |
) -> list[int]:
|
434 |
"""Truncate a list of data by token size"""
|
435 |
if max_token_size <= 0:
|
436 |
return []
|
437 |
tokens = 0
|
438 |
for i, data in enumerate(list_data):
|
439 |
+
tokens += len(tokenizer.encode(key(data)))
|
440 |
if tokens > max_token_size:
|
441 |
return list_data[:i]
|
442 |
return list_data
|