yangdx commited on
Commit
330ab9a
·
2 Parent(s): b59915a edf0ba0

Merge branch 'drahnreb/add-custom-tokenizer'

Browse files
README-zh.md CHANGED
@@ -1090,7 +1090,8 @@ rag.clear_cache(modes=["local"])
1090
  | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
1091
  | **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
1092
  | **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
1093
- | **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
 
1094
  | **entity_extract_max_gleaning** | `int` | 实体提取过程中的循环次数,附加历史消息 | `1` |
1095
  | **entity_summary_to_max_tokens** | `int` | 每个实体摘要的最大令牌大小 | `500` |
1096
  | **node_embedding_algorithm** | `str` | 节点嵌入算法(当前未使用) | `node2vec` |
 
1090
  | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
1091
  | **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
1092
  | **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
1093
+ | **tokenizer** | `Tokenizer` | 用于将文本转换为 tokens(数字)以及使用遵循 TokenizerInterface 协议的 .encode() 和 .decode() 函数将 tokens 转换回文本的函数。 如果您不指定,它将使用默认的 Tiktoken tokenizer。 | `TiktokenTokenizer` |
1094
+ | **tiktoken_model_name** | `str` | 如果您使用的是默认的 Tiktoken tokenizer,那么这是要使用的特定 Tiktoken 模型的名称。如果您提供自己的 tokenizer,则忽略此设置。 | `gpt-4o-mini` |
1095
  | **entity_extract_max_gleaning** | `int` | 实体提取过程中的循环次数,附加历史消息 | `1` |
1096
  | **entity_summary_to_max_tokens** | `int` | 每个实体摘要的最大令牌大小 | `500` |
1097
  | **node_embedding_algorithm** | `str` | 节点嵌入算法(当前未使用) | `node2vec` |
README.md CHANGED
@@ -1156,7 +1156,8 @@ Valid modes are:
1156
  | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
1157
  | **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
1158
  | **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
1159
- | **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
 
1160
  | **entity_extract_max_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
1161
  | **entity_summary_to_max_tokens** | `int` | Maximum token size for each entity summary | `500` |
1162
  | **node_embedding_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
 
1156
  | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
1157
  | **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
1158
  | **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
1159
+ | **tokenizer** | `Tokenizer` | The function used to convert text into tokens (numbers) and back using .encode() and .decode() functions following `TokenizerInterface` protocol. If you don't specify one, it will use the default Tiktoken tokenizer. | `TiktokenTokenizer` |
1160
+ | **tiktoken_model_name** | `str` | If you're using the default Tiktoken tokenizer, this is the name of the specific Tiktoken model to use. This setting is ignored if you provide your own tokenizer. | `gpt-4o-mini` |
1161
  | **entity_extract_max_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` |
1162
  | **entity_summary_to_max_tokens** | `int` | Maximum token size for each entity summary | `500` |
1163
  | **node_embedding_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
examples/lightrag_gemini_demo_no_tiktoken.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install -q -U google-genai to use gemini as a client
2
+
3
+ import os
4
+ from typing import Optional
5
+ import dataclasses
6
+ from pathlib import Path
7
+ import hashlib
8
+ import numpy as np
9
+ from google import genai
10
+ from google.genai import types
11
+ from dotenv import load_dotenv
12
+ from lightrag.utils import EmbeddingFunc, Tokenizer
13
+ from lightrag import LightRAG, QueryParam
14
+ from sentence_transformers import SentenceTransformer
15
+ from lightrag.kg.shared_storage import initialize_pipeline_status
16
+ import sentencepiece as spm
17
+ import requests
18
+
19
+ import asyncio
20
+ import nest_asyncio
21
+
22
+ # Apply nest_asyncio to solve event loop issues
23
+ nest_asyncio.apply()
24
+
25
+ load_dotenv()
26
+ gemini_api_key = os.getenv("GEMINI_API_KEY")
27
+
28
+ WORKING_DIR = "./dickens"
29
+
30
+ if os.path.exists(WORKING_DIR):
31
+ import shutil
32
+
33
+ shutil.rmtree(WORKING_DIR)
34
+
35
+ os.mkdir(WORKING_DIR)
36
+
37
+
38
+ class GemmaTokenizer(Tokenizer):
39
+ # adapted from google-cloud-aiplatform[tokenization]
40
+
41
+ @dataclasses.dataclass(frozen=True)
42
+ class _TokenizerConfig:
43
+ tokenizer_model_url: str
44
+ tokenizer_model_hash: str
45
+
46
+ _TOKENIZERS = {
47
+ "google/gemma2": _TokenizerConfig(
48
+ tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
49
+ tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
50
+ ),
51
+ "google/gemma3": _TokenizerConfig(
52
+ tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
53
+ tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
54
+ ),
55
+ }
56
+
57
+ def __init__(
58
+ self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None
59
+ ):
60
+ # https://github.com/google/gemma_pytorch/tree/main/tokenizer
61
+ if "1.5" in model_name or "1.0" in model_name:
62
+ # up to gemini 1.5 gemma2 is a comparable local tokenizer
63
+ # https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py
64
+ tokenizer_name = "google/gemma2"
65
+ else:
66
+ # for gemini > 2.0 gemma3 was used
67
+ tokenizer_name = "google/gemma3"
68
+
69
+ file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url
70
+ tokenizer_model_name = file_url.rsplit("/", 1)[1]
71
+ expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash
72
+
73
+ tokenizer_dir = Path(tokenizer_dir)
74
+ if tokenizer_dir.is_dir():
75
+ file_path = tokenizer_dir / tokenizer_model_name
76
+ model_data = self._maybe_load_from_cache(
77
+ file_path=file_path, expected_hash=expected_hash
78
+ )
79
+ else:
80
+ model_data = None
81
+ if not model_data:
82
+ model_data = self._load_from_url(
83
+ file_url=file_url, expected_hash=expected_hash
84
+ )
85
+ self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
86
+
87
+ tokenizer = spm.SentencePieceProcessor()
88
+ tokenizer.LoadFromSerializedProto(model_data)
89
+ super().__init__(model_name=model_name, tokenizer=tokenizer)
90
+
91
+ def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool:
92
+ """Returns true if the content is valid by checking the hash."""
93
+ return hashlib.sha256(model_data).hexdigest() == expected_hash
94
+
95
+ def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes:
96
+ """Loads the model data from the cache path."""
97
+ if not file_path.is_file():
98
+ return
99
+ with open(file_path, "rb") as f:
100
+ content = f.read()
101
+ if self._is_valid_model(model_data=content, expected_hash=expected_hash):
102
+ return content
103
+
104
+ # Cached file corrupted.
105
+ self._maybe_remove_file(file_path)
106
+
107
+ def _load_from_url(self, file_url: str, expected_hash: str) -> bytes:
108
+ """Loads model bytes from the given file url."""
109
+ resp = requests.get(file_url)
110
+ resp.raise_for_status()
111
+ content = resp.content
112
+
113
+ if not self._is_valid_model(model_data=content, expected_hash=expected_hash):
114
+ actual_hash = hashlib.sha256(content).hexdigest()
115
+ raise ValueError(
116
+ f"Downloaded model file is corrupted."
117
+ f" Expected hash {expected_hash}. Got file hash {actual_hash}."
118
+ )
119
+ return content
120
+
121
+ @staticmethod
122
+ def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None:
123
+ """Saves the model data to the cache path."""
124
+ try:
125
+ if not cache_path.is_file():
126
+ cache_dir = cache_path.parent
127
+ cache_dir.mkdir(parents=True, exist_ok=True)
128
+ with open(cache_path, "wb") as f:
129
+ f.write(model_data)
130
+ except OSError:
131
+ # Don't raise if we cannot write file.
132
+ pass
133
+
134
+ @staticmethod
135
+ def _maybe_remove_file(file_path: Path) -> None:
136
+ """Removes the file if exists."""
137
+ if not file_path.is_file():
138
+ return
139
+ try:
140
+ file_path.unlink()
141
+ except OSError:
142
+ # Don't raise if we cannot remove file.
143
+ pass
144
+
145
+ # def encode(self, content: str) -> list[int]:
146
+ # return self.tokenizer.encode(content)
147
+
148
+ # def decode(self, tokens: list[int]) -> str:
149
+ # return self.tokenizer.decode(tokens)
150
+
151
+
152
+ async def llm_model_func(
153
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
154
+ ) -> str:
155
+ # 1. Initialize the GenAI Client with your Gemini API Key
156
+ client = genai.Client(api_key=gemini_api_key)
157
+
158
+ # 2. Combine prompts: system prompt, history, and user prompt
159
+ if history_messages is None:
160
+ history_messages = []
161
+
162
+ combined_prompt = ""
163
+ if system_prompt:
164
+ combined_prompt += f"{system_prompt}\n"
165
+
166
+ for msg in history_messages:
167
+ # Each msg is expected to be a dict: {"role": "...", "content": "..."}
168
+ combined_prompt += f"{msg['role']}: {msg['content']}\n"
169
+
170
+ # Finally, add the new user prompt
171
+ combined_prompt += f"user: {prompt}"
172
+
173
+ # 3. Call the Gemini model
174
+ response = client.models.generate_content(
175
+ model="gemini-1.5-flash",
176
+ contents=[combined_prompt],
177
+ config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
178
+ )
179
+
180
+ # 4. Return the response text
181
+ return response.text
182
+
183
+
184
+ async def embedding_func(texts: list[str]) -> np.ndarray:
185
+ model = SentenceTransformer("all-MiniLM-L6-v2")
186
+ embeddings = model.encode(texts, convert_to_numpy=True)
187
+ return embeddings
188
+
189
+
190
+ async def initialize_rag():
191
+ rag = LightRAG(
192
+ working_dir=WORKING_DIR,
193
+ # tiktoken_model_name="gpt-4o-mini",
194
+ tokenizer=GemmaTokenizer(
195
+ tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"),
196
+ model_name="gemini-2.0-flash",
197
+ ),
198
+ llm_model_func=llm_model_func,
199
+ embedding_func=EmbeddingFunc(
200
+ embedding_dim=384,
201
+ max_token_size=8192,
202
+ func=embedding_func,
203
+ ),
204
+ )
205
+
206
+ await rag.initialize_storages()
207
+ await initialize_pipeline_status()
208
+
209
+ return rag
210
+
211
+
212
+ def main():
213
+ # Initialize RAG instance
214
+ rag = asyncio.run(initialize_rag())
215
+ file_path = "story.txt"
216
+ with open(file_path, "r") as file:
217
+ text = file.read()
218
+
219
+ rag.insert(text)
220
+
221
+ response = rag.query(
222
+ query="What is the main theme of the story?",
223
+ param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
224
+ )
225
+
226
+ print(response)
227
+
228
+
229
+ if __name__ == "__main__":
230
+ main()
lightrag/api/routers/ollama_api.py CHANGED
@@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
10
  import asyncio
11
  from ascii_colors import trace_exception
12
  from lightrag import LightRAG, QueryParam
13
- from lightrag.utils import encode_string_by_tiktoken
14
  from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
15
  from fastapi import Depends
16
 
@@ -97,7 +97,7 @@ class OllamaTagResponse(BaseModel):
97
 
98
  def estimate_tokens(text: str) -> int:
99
  """Estimate the number of tokens in text using tiktoken"""
100
- tokens = encode_string_by_tiktoken(text)
101
  return len(tokens)
102
 
103
 
 
10
  import asyncio
11
  from ascii_colors import trace_exception
12
  from lightrag import LightRAG, QueryParam
13
+ from lightrag.utils import TiktokenTokenizer
14
  from lightrag.api.utils_api import ollama_server_infos, get_combined_auth_dependency
15
  from fastapi import Depends
16
 
 
97
 
98
  def estimate_tokens(text: str) -> int:
99
  """Estimate the number of tokens in text using tiktoken"""
100
+ tokens = TiktokenTokenizer().encode(text)
101
  return len(tokens)
102
 
103
 
lightrag/lightrag.py CHANGED
@@ -7,7 +7,18 @@ import warnings
7
  from dataclasses import asdict, dataclass, field
8
  from datetime import datetime
9
  from functools import partial
10
- from typing import Any, AsyncIterator, Callable, Iterator, cast, final, Literal
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  from lightrag.kg import (
13
  STORAGES,
@@ -41,11 +52,12 @@ from .operate import (
41
  )
42
  from .prompt import GRAPH_FIELD_SEP, PROMPTS
43
  from .utils import (
 
 
44
  EmbeddingFunc,
45
  always_get_an_event_loop,
46
  compute_mdhash_id,
47
  convert_response_to_json,
48
- encode_string_by_tiktoken,
49
  lazy_external_import,
50
  limit_async_func_call,
51
  get_content_summary,
@@ -122,33 +134,38 @@ class LightRAG:
122
  )
123
  """Number of overlapping tokens between consecutive text chunks to preserve context."""
124
 
125
- tiktoken_model_name: str = field(default="gpt-4o-mini")
126
- """Model name used for tokenization when chunking text."""
 
 
 
 
127
 
128
- """Maximum number of tokens used for summarizing extracted entities."""
 
129
 
130
  chunking_func: Callable[
131
  [
 
132
  str,
133
- str | None,
134
  bool,
135
  int,
136
  int,
137
- str,
138
  ],
139
- list[dict[str, Any]],
140
  ] = field(default_factory=lambda: chunking_by_token_size)
141
  """
142
  Custom chunking function for splitting text into chunks before processing.
143
 
144
  The function should take the following parameters:
145
 
 
146
  - `content`: The text to be split into chunks.
147
  - `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
148
  - `split_by_character_only`: If True, the text is split only on the specified character.
149
  - `chunk_token_size`: The maximum number of tokens per chunk.
150
  - `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
151
- - `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
152
 
153
  The function should return a list of dictionaries, where each dictionary contains the following keys:
154
  - `tokens`: The number of tokens in the chunk.
@@ -310,7 +327,15 @@ class LightRAG:
310
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
311
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
312
 
313
- # Init LLM
 
 
 
 
 
 
 
 
314
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
315
  self.embedding_func
316
  )
@@ -603,11 +628,7 @@ class LightRAG:
603
  inserting_chunks: dict[str, Any] = {}
604
  for index, chunk_text in enumerate(text_chunks):
605
  chunk_key = compute_mdhash_id(chunk_text, prefix="chunk-")
606
- tokens = len(
607
- encode_string_by_tiktoken(
608
- chunk_text, model_name=self.tiktoken_model_name
609
- )
610
- )
611
  inserting_chunks[chunk_key] = {
612
  "content": chunk_text,
613
  "full_doc_id": doc_key,
@@ -900,12 +921,12 @@ class LightRAG:
900
  "file_path": file_path, # Add file path to each chunk
901
  }
902
  for dp in self.chunking_func(
 
903
  status_doc.content,
904
  split_by_character,
905
  split_by_character_only,
906
  self.chunk_overlap_token_size,
907
  self.chunk_token_size,
908
- self.tiktoken_model_name,
909
  )
910
  }
911
 
@@ -1133,11 +1154,7 @@ class LightRAG:
1133
  for chunk_data in custom_kg.get("chunks", []):
1134
  chunk_content = clean_text(chunk_data["content"])
1135
  source_id = chunk_data["source_id"]
1136
- tokens = len(
1137
- encode_string_by_tiktoken(
1138
- chunk_content, model_name=self.tiktoken_model_name
1139
- )
1140
- )
1141
  chunk_order_index = (
1142
  0
1143
  if "chunk_order_index" not in chunk_data.keys()
 
7
  from dataclasses import asdict, dataclass, field
8
  from datetime import datetime
9
  from functools import partial
10
+ from typing import (
11
+ Any,
12
+ AsyncIterator,
13
+ Callable,
14
+ Iterator,
15
+ cast,
16
+ final,
17
+ Literal,
18
+ Optional,
19
+ List,
20
+ Dict,
21
+ )
22
 
23
  from lightrag.kg import (
24
  STORAGES,
 
52
  )
53
  from .prompt import GRAPH_FIELD_SEP, PROMPTS
54
  from .utils import (
55
+ Tokenizer,
56
+ TiktokenTokenizer,
57
  EmbeddingFunc,
58
  always_get_an_event_loop,
59
  compute_mdhash_id,
60
  convert_response_to_json,
 
61
  lazy_external_import,
62
  limit_async_func_call,
63
  get_content_summary,
 
134
  )
135
  """Number of overlapping tokens between consecutive text chunks to preserve context."""
136
 
137
+ tokenizer: Optional[Tokenizer] = field(default=None)
138
+ """
139
+ A function that returns a Tokenizer instance.
140
+ If None, and a `tiktoken_model_name` is provided, a TiktokenTokenizer will be created.
141
+ If both are None, the default TiktokenTokenizer is used.
142
+ """
143
 
144
+ tiktoken_model_name: str = field(default="gpt-4o-mini")
145
+ """Model name used for tokenization when chunking text with tiktoken. Defaults to `gpt-4o-mini`."""
146
 
147
  chunking_func: Callable[
148
  [
149
+ Tokenizer,
150
  str,
151
+ Optional[str],
152
  bool,
153
  int,
154
  int,
 
155
  ],
156
+ List[Dict[str, Any]],
157
  ] = field(default_factory=lambda: chunking_by_token_size)
158
  """
159
  Custom chunking function for splitting text into chunks before processing.
160
 
161
  The function should take the following parameters:
162
 
163
+ - `tokenizer`: A Tokenizer instance to use for tokenization.
164
  - `content`: The text to be split into chunks.
165
  - `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
166
  - `split_by_character_only`: If True, the text is split only on the specified character.
167
  - `chunk_token_size`: The maximum number of tokens per chunk.
168
  - `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
 
169
 
170
  The function should return a list of dictionaries, where each dictionary contains the following keys:
171
  - `tokens`: The number of tokens in the chunk.
 
327
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
328
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
329
 
330
+ # Init Tokenizer
331
+ # Post-initialization hook to handle backward compatabile tokenizer initialization based on provided parameters
332
+ if self.tokenizer is None:
333
+ if self.tiktoken_model_name:
334
+ self.tokenizer = TiktokenTokenizer(self.tiktoken_model_name)
335
+ else:
336
+ self.tokenizer = TiktokenTokenizer()
337
+
338
+ # Init Embedding
339
  self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
340
  self.embedding_func
341
  )
 
628
  inserting_chunks: dict[str, Any] = {}
629
  for index, chunk_text in enumerate(text_chunks):
630
  chunk_key = compute_mdhash_id(chunk_text, prefix="chunk-")
631
+ tokens = len(self.tokenizer.encode(chunk_text))
 
 
 
 
632
  inserting_chunks[chunk_key] = {
633
  "content": chunk_text,
634
  "full_doc_id": doc_key,
 
921
  "file_path": file_path, # Add file path to each chunk
922
  }
923
  for dp in self.chunking_func(
924
+ self.tokenizer,
925
  status_doc.content,
926
  split_by_character,
927
  split_by_character_only,
928
  self.chunk_overlap_token_size,
929
  self.chunk_token_size,
 
930
  )
931
  }
932
 
 
1154
  for chunk_data in custom_kg.get("chunks", []):
1155
  chunk_content = clean_text(chunk_data["content"])
1156
  source_id = chunk_data["source_id"]
1157
+ tokens = len(self.tokenizer.encode(chunk_content))
 
 
 
 
1158
  chunk_order_index = (
1159
  0
1160
  if "chunk_order_index" not in chunk_data.keys()
lightrag/operate.py CHANGED
@@ -12,8 +12,7 @@ from .utils import (
12
  logger,
13
  clean_str,
14
  compute_mdhash_id,
15
- decode_tokens_by_tiktoken,
16
- encode_string_by_tiktoken,
17
  is_float_regex,
18
  list_of_list_to_csv,
19
  normalize_extracted_info,
@@ -46,32 +45,31 @@ load_dotenv(dotenv_path=".env", override=False)
46
 
47
 
48
  def chunking_by_token_size(
 
49
  content: str,
50
  split_by_character: str | None = None,
51
  split_by_character_only: bool = False,
52
  overlap_token_size: int = 128,
53
  max_token_size: int = 1024,
54
- tiktoken_model: str = "gpt-4o",
55
  ) -> list[dict[str, Any]]:
56
- tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
57
  results: list[dict[str, Any]] = []
58
  if split_by_character:
59
  raw_chunks = content.split(split_by_character)
60
  new_chunks = []
61
  if split_by_character_only:
62
  for chunk in raw_chunks:
63
- _tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
64
  new_chunks.append((len(_tokens), chunk))
65
  else:
66
  for chunk in raw_chunks:
67
- _tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
68
  if len(_tokens) > max_token_size:
69
  for start in range(
70
  0, len(_tokens), max_token_size - overlap_token_size
71
  ):
72
- chunk_content = decode_tokens_by_tiktoken(
73
- _tokens[start : start + max_token_size],
74
- model_name=tiktoken_model,
75
  )
76
  new_chunks.append(
77
  (min(max_token_size, len(_tokens) - start), chunk_content)
@@ -90,9 +88,7 @@ def chunking_by_token_size(
90
  for index, start in enumerate(
91
  range(0, len(tokens), max_token_size - overlap_token_size)
92
  ):
93
- chunk_content = decode_tokens_by_tiktoken(
94
- tokens[start : start + max_token_size], model_name=tiktoken_model
95
- )
96
  results.append(
97
  {
98
  "tokens": min(max_token_size, len(tokens) - start),
@@ -116,19 +112,19 @@ async def _handle_entity_relation_summary(
116
  If too long, use LLM to summarize.
117
  """
118
  use_llm_func: callable = global_config["llm_model_func"]
 
119
  llm_max_tokens = global_config["llm_model_max_token_size"]
120
- tiktoken_model_name = global_config["tiktoken_model_name"]
121
  summary_max_tokens = global_config["summary_to_max_tokens"]
122
 
123
  language = global_config["addon_params"].get(
124
  "language", PROMPTS["DEFAULT_LANGUAGE"]
125
  )
126
 
127
- tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
 
 
128
  prompt_template = PROMPTS["summarize_entity_descriptions"]
129
- use_description = decode_tokens_by_tiktoken(
130
- tokens[:llm_max_tokens], model_name=tiktoken_model_name
131
- )
132
  context_base = dict(
133
  entity_name=entity_or_relation_name,
134
  description_list=use_description.split(GRAPH_FIELD_SEP),
@@ -865,7 +861,8 @@ async def kg_query(
865
  if query_param.only_need_prompt:
866
  return sys_prompt
867
 
868
- len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
 
869
  logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
870
 
871
  response = await use_model_func(
@@ -987,7 +984,8 @@ async def extract_keywords_only(
987
  query=text, examples=examples, language=language, history=history_context
988
  )
989
 
990
- len_of_prompts = len(encode_string_by_tiktoken(kw_prompt))
 
991
  logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
992
 
993
  # 5. Call the LLM for keyword extraction
@@ -1054,6 +1052,8 @@ async def mix_kg_vector_query(
1054
  2. Retrieving relevant text chunks through vector similarity
1055
  3. Combining both results for comprehensive answer generation
1056
  """
 
 
1057
  # 1. Cache handling
1058
  use_model_func = (
1059
  query_param.model_func
@@ -1153,6 +1153,7 @@ async def mix_kg_vector_query(
1153
  valid_chunks,
1154
  key=lambda x: x["content"],
1155
  max_token_size=query_param.max_token_for_text_unit,
 
1156
  )
1157
 
1158
  if not maybe_trun_chunks:
@@ -1210,7 +1211,7 @@ async def mix_kg_vector_query(
1210
  if query_param.only_need_prompt:
1211
  return sys_prompt
1212
 
1213
- len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
1214
  logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
1215
 
1216
  # 6. Generate response
@@ -1373,17 +1374,24 @@ async def _get_node_data(
1373
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1374
  # get entitytext chunk
1375
  use_text_units = await _find_most_related_text_unit_from_entities(
1376
- node_datas, query_param, text_chunks_db, knowledge_graph_inst
 
 
 
1377
  )
1378
  use_relations = await _find_most_related_edges_from_entities(
1379
- node_datas, query_param, knowledge_graph_inst
 
 
1380
  )
1381
 
 
1382
  len_node_datas = len(node_datas)
1383
  node_datas = truncate_list_by_token_size(
1384
  node_datas,
1385
  key=lambda x: x["description"] if x["description"] is not None else "",
1386
  max_token_size=query_param.max_token_for_local_context,
 
1387
  )
1388
  logger.debug(
1389
  f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1558,14 +1566,15 @@ async def _find_most_related_text_unit_from_entities(
1558
  logger.warning("No valid text units found")
1559
  return []
1560
 
 
1561
  all_text_units = sorted(
1562
  all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
1563
  )
1564
-
1565
  all_text_units = truncate_list_by_token_size(
1566
  all_text_units,
1567
  key=lambda x: x["data"]["content"],
1568
  max_token_size=query_param.max_token_for_text_unit,
 
1569
  )
1570
 
1571
  logger.debug(
@@ -1619,6 +1628,7 @@ async def _find_most_related_edges_from_entities(
1619
  }
1620
  all_edges_data.append(combined)
1621
 
 
1622
  all_edges_data = sorted(
1623
  all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
1624
  )
@@ -1626,6 +1636,7 @@ async def _find_most_related_edges_from_entities(
1626
  all_edges_data,
1627
  key=lambda x: x["description"] if x["description"] is not None else "",
1628
  max_token_size=query_param.max_token_for_global_context,
 
1629
  )
1630
 
1631
  logger.debug(
@@ -1681,6 +1692,7 @@ async def _get_edge_data(
1681
  }
1682
  edge_datas.append(combined)
1683
 
 
1684
  edge_datas = sorted(
1685
  edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
1686
  )
@@ -1688,13 +1700,19 @@ async def _get_edge_data(
1688
  edge_datas,
1689
  key=lambda x: x["description"] if x["description"] is not None else "",
1690
  max_token_size=query_param.max_token_for_global_context,
 
1691
  )
1692
  use_entities, use_text_units = await asyncio.gather(
1693
  _find_most_related_entities_from_relationships(
1694
- edge_datas, query_param, knowledge_graph_inst
 
 
1695
  ),
1696
  _find_related_text_unit_from_relationships(
1697
- edge_datas, query_param, text_chunks_db, knowledge_graph_inst
 
 
 
1698
  ),
1699
  )
1700
  logger.info(
@@ -1804,11 +1822,13 @@ async def _find_most_related_entities_from_relationships(
1804
  combined = {**node, "entity_name": entity_name, "rank": degree}
1805
  node_datas.append(combined)
1806
 
 
1807
  len_node_datas = len(node_datas)
1808
  node_datas = truncate_list_by_token_size(
1809
  node_datas,
1810
  key=lambda x: x["description"] if x["description"] is not None else "",
1811
  max_token_size=query_param.max_token_for_local_context,
 
1812
  )
1813
  logger.debug(
1814
  f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
@@ -1863,10 +1883,12 @@ async def _find_related_text_unit_from_relationships(
1863
  logger.warning("No valid text chunks after filtering")
1864
  return []
1865
 
 
1866
  truncated_text_units = truncate_list_by_token_size(
1867
  valid_text_units,
1868
  key=lambda x: x["data"]["content"],
1869
  max_token_size=query_param.max_token_for_text_unit,
 
1870
  )
1871
 
1872
  logger.debug(
@@ -1937,10 +1959,12 @@ async def naive_query(
1937
  logger.warning("No valid chunks found after filtering")
1938
  return PROMPTS["fail_response"]
1939
 
 
1940
  maybe_trun_chunks = truncate_list_by_token_size(
1941
  valid_chunks,
1942
  key=lambda x: x["content"],
1943
  max_token_size=query_param.max_token_for_text_unit,
 
1944
  )
1945
 
1946
  if not maybe_trun_chunks:
@@ -1978,7 +2002,7 @@ async def naive_query(
1978
  if query_param.only_need_prompt:
1979
  return sys_prompt
1980
 
1981
- len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
1982
  logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
1983
 
1984
  response = await use_model_func(
@@ -2125,7 +2149,8 @@ async def kg_query_with_keywords(
2125
  if query_param.only_need_prompt:
2126
  return sys_prompt
2127
 
2128
- len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
 
2129
  logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
2130
 
2131
  # 6. Generate response
 
12
  logger,
13
  clean_str,
14
  compute_mdhash_id,
15
+ Tokenizer,
 
16
  is_float_regex,
17
  list_of_list_to_csv,
18
  normalize_extracted_info,
 
45
 
46
 
47
  def chunking_by_token_size(
48
+ tokenizer: Tokenizer,
49
  content: str,
50
  split_by_character: str | None = None,
51
  split_by_character_only: bool = False,
52
  overlap_token_size: int = 128,
53
  max_token_size: int = 1024,
 
54
  ) -> list[dict[str, Any]]:
55
+ tokens = tokenizer.encode(content)
56
  results: list[dict[str, Any]] = []
57
  if split_by_character:
58
  raw_chunks = content.split(split_by_character)
59
  new_chunks = []
60
  if split_by_character_only:
61
  for chunk in raw_chunks:
62
+ _tokens = tokenizer.encode(chunk)
63
  new_chunks.append((len(_tokens), chunk))
64
  else:
65
  for chunk in raw_chunks:
66
+ _tokens = tokenizer.encode(chunk)
67
  if len(_tokens) > max_token_size:
68
  for start in range(
69
  0, len(_tokens), max_token_size - overlap_token_size
70
  ):
71
+ chunk_content = tokenizer.decode(
72
+ _tokens[start : start + max_token_size]
 
73
  )
74
  new_chunks.append(
75
  (min(max_token_size, len(_tokens) - start), chunk_content)
 
88
  for index, start in enumerate(
89
  range(0, len(tokens), max_token_size - overlap_token_size)
90
  ):
91
+ chunk_content = tokenizer.decode(tokens[start : start + max_token_size])
 
 
92
  results.append(
93
  {
94
  "tokens": min(max_token_size, len(tokens) - start),
 
112
  If too long, use LLM to summarize.
113
  """
114
  use_llm_func: callable = global_config["llm_model_func"]
115
+ tokenizer: Tokenizer = global_config["tokenizer"]
116
  llm_max_tokens = global_config["llm_model_max_token_size"]
 
117
  summary_max_tokens = global_config["summary_to_max_tokens"]
118
 
119
  language = global_config["addon_params"].get(
120
  "language", PROMPTS["DEFAULT_LANGUAGE"]
121
  )
122
 
123
+ tokens = tokenizer.encode(description)
124
+ if len(tokens) < summary_max_tokens: # No need for summary
125
+ return description
126
  prompt_template = PROMPTS["summarize_entity_descriptions"]
127
+ use_description = tokenizer.decode(tokens[:llm_max_tokens])
 
 
128
  context_base = dict(
129
  entity_name=entity_or_relation_name,
130
  description_list=use_description.split(GRAPH_FIELD_SEP),
 
861
  if query_param.only_need_prompt:
862
  return sys_prompt
863
 
864
+ tokenizer: Tokenizer = global_config["tokenizer"]
865
+ len_of_prompts = len(tokenizer.encode(query + sys_prompt))
866
  logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
867
 
868
  response = await use_model_func(
 
984
  query=text, examples=examples, language=language, history=history_context
985
  )
986
 
987
+ tokenizer: Tokenizer = global_config["tokenizer"]
988
+ len_of_prompts = len(tokenizer.encode(kw_prompt))
989
  logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
990
 
991
  # 5. Call the LLM for keyword extraction
 
1052
  2. Retrieving relevant text chunks through vector similarity
1053
  3. Combining both results for comprehensive answer generation
1054
  """
1055
+ # get tokenizer
1056
+ tokenizer: Tokenizer = global_config["tokenizer"]
1057
  # 1. Cache handling
1058
  use_model_func = (
1059
  query_param.model_func
 
1153
  valid_chunks,
1154
  key=lambda x: x["content"],
1155
  max_token_size=query_param.max_token_for_text_unit,
1156
+ tokenizer=tokenizer,
1157
  )
1158
 
1159
  if not maybe_trun_chunks:
 
1211
  if query_param.only_need_prompt:
1212
  return sys_prompt
1213
 
1214
+ len_of_prompts = len(tokenizer.encode(query + sys_prompt))
1215
  logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
1216
 
1217
  # 6. Generate response
 
1374
  ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
1375
  # get entitytext chunk
1376
  use_text_units = await _find_most_related_text_unit_from_entities(
1377
+ node_datas,
1378
+ query_param,
1379
+ text_chunks_db,
1380
+ knowledge_graph_inst,
1381
  )
1382
  use_relations = await _find_most_related_edges_from_entities(
1383
+ node_datas,
1384
+ query_param,
1385
+ knowledge_graph_inst,
1386
  )
1387
 
1388
+ tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
1389
  len_node_datas = len(node_datas)
1390
  node_datas = truncate_list_by_token_size(
1391
  node_datas,
1392
  key=lambda x: x["description"] if x["description"] is not None else "",
1393
  max_token_size=query_param.max_token_for_local_context,
1394
+ tokenizer=tokenizer,
1395
  )
1396
  logger.debug(
1397
  f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
 
1566
  logger.warning("No valid text units found")
1567
  return []
1568
 
1569
+ tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
1570
  all_text_units = sorted(
1571
  all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
1572
  )
 
1573
  all_text_units = truncate_list_by_token_size(
1574
  all_text_units,
1575
  key=lambda x: x["data"]["content"],
1576
  max_token_size=query_param.max_token_for_text_unit,
1577
+ tokenizer=tokenizer,
1578
  )
1579
 
1580
  logger.debug(
 
1628
  }
1629
  all_edges_data.append(combined)
1630
 
1631
+ tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
1632
  all_edges_data = sorted(
1633
  all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
1634
  )
 
1636
  all_edges_data,
1637
  key=lambda x: x["description"] if x["description"] is not None else "",
1638
  max_token_size=query_param.max_token_for_global_context,
1639
+ tokenizer=tokenizer,
1640
  )
1641
 
1642
  logger.debug(
 
1692
  }
1693
  edge_datas.append(combined)
1694
 
1695
+ tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
1696
  edge_datas = sorted(
1697
  edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
1698
  )
 
1700
  edge_datas,
1701
  key=lambda x: x["description"] if x["description"] is not None else "",
1702
  max_token_size=query_param.max_token_for_global_context,
1703
+ tokenizer=tokenizer,
1704
  )
1705
  use_entities, use_text_units = await asyncio.gather(
1706
  _find_most_related_entities_from_relationships(
1707
+ edge_datas,
1708
+ query_param,
1709
+ knowledge_graph_inst,
1710
  ),
1711
  _find_related_text_unit_from_relationships(
1712
+ edge_datas,
1713
+ query_param,
1714
+ text_chunks_db,
1715
+ knowledge_graph_inst,
1716
  ),
1717
  )
1718
  logger.info(
 
1822
  combined = {**node, "entity_name": entity_name, "rank": degree}
1823
  node_datas.append(combined)
1824
 
1825
+ tokenizer: Tokenizer = knowledge_graph_inst.global_config.get("tokenizer")
1826
  len_node_datas = len(node_datas)
1827
  node_datas = truncate_list_by_token_size(
1828
  node_datas,
1829
  key=lambda x: x["description"] if x["description"] is not None else "",
1830
  max_token_size=query_param.max_token_for_local_context,
1831
+ tokenizer=tokenizer,
1832
  )
1833
  logger.debug(
1834
  f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
 
1883
  logger.warning("No valid text chunks after filtering")
1884
  return []
1885
 
1886
+ tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
1887
  truncated_text_units = truncate_list_by_token_size(
1888
  valid_text_units,
1889
  key=lambda x: x["data"]["content"],
1890
  max_token_size=query_param.max_token_for_text_unit,
1891
+ tokenizer=tokenizer,
1892
  )
1893
 
1894
  logger.debug(
 
1959
  logger.warning("No valid chunks found after filtering")
1960
  return PROMPTS["fail_response"]
1961
 
1962
+ tokenizer: Tokenizer = global_config["tokenizer"]
1963
  maybe_trun_chunks = truncate_list_by_token_size(
1964
  valid_chunks,
1965
  key=lambda x: x["content"],
1966
  max_token_size=query_param.max_token_for_text_unit,
1967
+ tokenizer=tokenizer,
1968
  )
1969
 
1970
  if not maybe_trun_chunks:
 
2002
  if query_param.only_need_prompt:
2003
  return sys_prompt
2004
 
2005
+ len_of_prompts = len(tokenizer.encode(query + sys_prompt))
2006
  logger.debug(f"[naive_query]Prompt Tokens: {len_of_prompts}")
2007
 
2008
  response = await use_model_func(
 
2149
  if query_param.only_need_prompt:
2150
  return sys_prompt
2151
 
2152
+ tokenizer: Tokenizer = global_config["tokenizer"]
2153
+ len_of_prompts = len(tokenizer.encode(query + sys_prompt))
2154
  logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
2155
 
2156
  # 6. Generate response
lightrag/utils.py CHANGED
@@ -12,10 +12,9 @@ import re
12
  from dataclasses import dataclass
13
  from functools import wraps
14
  from hashlib import md5
15
- from typing import Any, Callable, TYPE_CHECKING
16
  import xml.etree.ElementTree as ET
17
  import numpy as np
18
- import tiktoken
19
  from lightrag.prompt import PROMPTS
20
  from dotenv import load_dotenv
21
 
@@ -193,9 +192,6 @@ class UnlimitedSemaphore:
193
  pass
194
 
195
 
196
- ENCODER = None
197
-
198
-
199
  @dataclass
200
  class EmbeddingFunc:
201
  embedding_dim: int
@@ -311,20 +307,89 @@ def write_json(json_obj, file_name):
311
  json.dump(json_obj, f, indent=2, ensure_ascii=False)
312
 
313
 
314
- def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
315
- global ENCODER
316
- if ENCODER is None:
317
- ENCODER = tiktoken.encoding_for_model(model_name)
318
- tokens = ENCODER.encode(content)
319
- return tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
 
 
 
 
321
 
322
- def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
323
- global ENCODER
324
- if ENCODER is None:
325
- ENCODER = tiktoken.encoding_for_model(model_name)
326
- content = ENCODER.decode(tokens)
327
- return content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
 
330
  def pack_user_ass_to_openai_messages(*args: str):
@@ -361,14 +426,17 @@ def is_float_regex(value: str) -> bool:
361
 
362
 
363
  def truncate_list_by_token_size(
364
- list_data: list[Any], key: Callable[[Any], str], max_token_size: int
 
 
 
365
  ) -> list[int]:
366
  """Truncate a list of data by token size"""
367
  if max_token_size <= 0:
368
  return []
369
  tokens = 0
370
  for i, data in enumerate(list_data):
371
- tokens += len(encode_string_by_tiktoken(key(data)))
372
  if tokens > max_token_size:
373
  return list_data[:i]
374
  return list_data
 
12
  from dataclasses import dataclass
13
  from functools import wraps
14
  from hashlib import md5
15
+ from typing import Any, Protocol, Callable, TYPE_CHECKING, List
16
  import xml.etree.ElementTree as ET
17
  import numpy as np
 
18
  from lightrag.prompt import PROMPTS
19
  from dotenv import load_dotenv
20
 
 
192
  pass
193
 
194
 
 
 
 
195
  @dataclass
196
  class EmbeddingFunc:
197
  embedding_dim: int
 
307
  json.dump(json_obj, f, indent=2, ensure_ascii=False)
308
 
309
 
310
+ class TokenizerInterface(Protocol):
311
+ """
312
+ Defines the interface for a tokenizer, requiring encode and decode methods.
313
+ """
314
+
315
+ def encode(self, content: str) -> List[int]:
316
+ """Encodes a string into a list of tokens."""
317
+ ...
318
+
319
+ def decode(self, tokens: List[int]) -> str:
320
+ """Decodes a list of tokens into a string."""
321
+ ...
322
+
323
+
324
+ class Tokenizer:
325
+ """
326
+ A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
327
+ """
328
+
329
+ def __init__(self, model_name: str, tokenizer: TokenizerInterface):
330
+ """
331
+ Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
332
+
333
+ Args:
334
+ model_name: The associated model name for the tokenizer.
335
+ tokenizer: An instance of a class implementing the TokenizerInterface.
336
+ """
337
+ self.model_name: str = model_name
338
+ self.tokenizer: TokenizerInterface = tokenizer
339
+
340
+ def encode(self, content: str) -> List[int]:
341
+ """
342
+ Encodes a string into a list of tokens using the underlying tokenizer.
343
+
344
+ Args:
345
+ content: The string to encode.
346
 
347
+ Returns:
348
+ A list of integer tokens.
349
+ """
350
+ return self.tokenizer.encode(content)
351
 
352
+ def decode(self, tokens: List[int]) -> str:
353
+ """
354
+ Decodes a list of tokens into a string using the underlying tokenizer.
355
+
356
+ Args:
357
+ tokens: A list of integer tokens to decode.
358
+
359
+ Returns:
360
+ The decoded string.
361
+ """
362
+ return self.tokenizer.decode(tokens)
363
+
364
+
365
+ class TiktokenTokenizer(Tokenizer):
366
+ """
367
+ A Tokenizer implementation using the tiktoken library.
368
+ """
369
+
370
+ def __init__(self, model_name: str = "gpt-4o-mini"):
371
+ """
372
+ Initializes the TiktokenTokenizer with a specified model name.
373
+
374
+ Args:
375
+ model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini".
376
+
377
+ Raises:
378
+ ImportError: If tiktoken is not installed.
379
+ ValueError: If the model_name is invalid.
380
+ """
381
+ try:
382
+ import tiktoken
383
+ except ImportError:
384
+ raise ImportError(
385
+ "tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`."
386
+ )
387
+
388
+ try:
389
+ tokenizer = tiktoken.encoding_for_model(model_name)
390
+ super().__init__(model_name=model_name, tokenizer=tokenizer)
391
+ except KeyError:
392
+ raise ValueError(f"Invalid model_name: {model_name}.")
393
 
394
 
395
  def pack_user_ass_to_openai_messages(*args: str):
 
426
 
427
 
428
  def truncate_list_by_token_size(
429
+ list_data: list[Any],
430
+ key: Callable[[Any], str],
431
+ max_token_size: int,
432
+ tokenizer: Tokenizer,
433
  ) -> list[int]:
434
  """Truncate a list of data by token size"""
435
  if max_token_size <= 0:
436
  return []
437
  tokens = 0
438
  for i, data in enumerate(list_data):
439
+ tokens += len(tokenizer.encode(key(data)))
440
  if tokens > max_token_size:
441
  return list_data[:i]
442
  return list_data