choizhang commited on
Commit
417da19
·
1 Parent(s): 0943277

feat: Add TokenTracker to track token usage for LLM calls

Browse files
examples/lightrag_gemini_track_token_demo.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install -q -U google-genai to use gemini as a client
2
+
3
+ import os
4
+ import asyncio
5
+ import numpy as np
6
+ import nest_asyncio
7
+ from google import genai
8
+ from google.genai import types
9
+ from dotenv import load_dotenv
10
+ from lightrag.utils import EmbeddingFunc
11
+ from lightrag import LightRAG, QueryParam
12
+ from lightrag.kg.shared_storage import initialize_pipeline_status
13
+ from lightrag.llm.siliconcloud import siliconcloud_embedding
14
+ from lightrag.utils import setup_logger
15
+ from lightrag.utils import TokenTracker
16
+
17
+ setup_logger("lightrag", level="DEBUG")
18
+
19
+ # Apply nest_asyncio to solve event loop issues
20
+ nest_asyncio.apply()
21
+
22
+ load_dotenv()
23
+ gemini_api_key = os.getenv("GEMINI_API_KEY")
24
+ siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
25
+
26
+ WORKING_DIR = "./dickens"
27
+
28
+ if not os.path.exists(WORKING_DIR):
29
+ os.mkdir(WORKING_DIR)
30
+
31
+ token_tracker = TokenTracker()
32
+
33
+
34
+ async def llm_model_func(
35
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
36
+ ) -> str:
37
+ # 1. Initialize the GenAI Client with your Gemini API Key
38
+ client = genai.Client(api_key=gemini_api_key)
39
+
40
+ # 2. Combine prompts: system prompt, history, and user prompt
41
+ if history_messages is None:
42
+ history_messages = []
43
+
44
+ combined_prompt = ""
45
+ if system_prompt:
46
+ combined_prompt += f"{system_prompt}\n"
47
+
48
+ for msg in history_messages:
49
+ # Each msg is expected to be a dict: {"role": "...", "content": "..."}
50
+ combined_prompt += f"{msg['role']}: {msg['content']}\n"
51
+
52
+ # Finally, add the new user prompt
53
+ combined_prompt += f"user: {prompt}"
54
+
55
+ # 3. Call the Gemini model
56
+ response = client.models.generate_content(
57
+ model="gemini-2.0-flash",
58
+ contents=[combined_prompt],
59
+ config=types.GenerateContentConfig(
60
+ max_output_tokens=5000, temperature=0, top_k=10
61
+ ),
62
+ )
63
+
64
+ # 4. Get token counts with null safety
65
+ usage = getattr(response, "usage_metadata", None)
66
+ prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
67
+ completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
68
+ total_tokens = getattr(usage, "total_token_count", 0) or (
69
+ prompt_tokens + completion_tokens
70
+ )
71
+
72
+ token_counts = {
73
+ "prompt_tokens": prompt_tokens,
74
+ "completion_tokens": completion_tokens,
75
+ "total_tokens": total_tokens,
76
+ }
77
+
78
+ token_tracker.add_usage(token_counts)
79
+
80
+ # 5. Return the response text
81
+ return response.text
82
+
83
+
84
+ async def embedding_func(texts: list[str]) -> np.ndarray:
85
+ return await siliconcloud_embedding(
86
+ texts,
87
+ model="BAAI/bge-m3",
88
+ api_key=siliconflow_api_key,
89
+ max_token_size=512,
90
+ )
91
+
92
+
93
+ async def initialize_rag():
94
+ rag = LightRAG(
95
+ working_dir=WORKING_DIR,
96
+ entity_extract_max_gleaning=1,
97
+ enable_llm_cache=True,
98
+ enable_llm_cache_for_entity_extract=True,
99
+ embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
100
+ llm_model_func=llm_model_func,
101
+ embedding_func=EmbeddingFunc(
102
+ embedding_dim=1024,
103
+ max_token_size=8192,
104
+ func=embedding_func,
105
+ ),
106
+ )
107
+
108
+ await rag.initialize_storages()
109
+ await initialize_pipeline_status()
110
+
111
+ return rag
112
+
113
+
114
+ def main():
115
+ # Initialize RAG instance
116
+ rag = asyncio.run(initialize_rag())
117
+
118
+ # Reset tracker before processing queries
119
+ token_tracker.reset()
120
+
121
+ with open("./book.txt", "r", encoding="utf-8") as f:
122
+ rag.insert(f.read())
123
+
124
+ print(
125
+ rag.query(
126
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
127
+ )
128
+ )
129
+
130
+ print(
131
+ rag.query(
132
+ "What are the top themes in this story?", param=QueryParam(mode="local")
133
+ )
134
+ )
135
+
136
+ print(
137
+ rag.query(
138
+ "What are the top themes in this story?", param=QueryParam(mode="global")
139
+ )
140
+ )
141
+
142
+ print(
143
+ rag.query(
144
+ "What are the top themes in this story?", param=QueryParam(mode="hybrid")
145
+ )
146
+ )
147
+
148
+ # Display final token usage after main query
149
+ print("Token usage:", token_tracker.get_usage())
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
examples/lightrag_siliconcloud_track_token_demo.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import openai_complete_if_cache
5
+ from lightrag.llm.siliconcloud import siliconcloud_embedding
6
+ from lightrag.utils import EmbeddingFunc
7
+ from lightrag.utils import TokenTracker
8
+ import numpy as np
9
+ from lightrag.kg.shared_storage import initialize_pipeline_status
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+
14
+ token_tracker = TokenTracker()
15
+ WORKING_DIR = "./dickens"
16
+
17
+ if not os.path.exists(WORKING_DIR):
18
+ os.mkdir(WORKING_DIR)
19
+
20
+
21
+ async def llm_model_func(
22
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
23
+ ) -> str:
24
+ return await openai_complete_if_cache(
25
+ "Qwen/Qwen2.5-7B-Instruct",
26
+ prompt,
27
+ system_prompt=system_prompt,
28
+ history_messages=history_messages,
29
+ api_key=os.getenv("SILICONFLOW_API_KEY"),
30
+ base_url="https://api.siliconflow.cn/v1/",
31
+ token_tracker=token_tracker,
32
+ **kwargs,
33
+ )
34
+
35
+
36
+ async def embedding_func(texts: list[str]) -> np.ndarray:
37
+ return await siliconcloud_embedding(
38
+ texts,
39
+ model="BAAI/bge-m3",
40
+ api_key=os.getenv("SILICONFLOW_API_KEY"),
41
+ max_token_size=512,
42
+ )
43
+
44
+
45
+ # function test
46
+ async def test_funcs():
47
+ # Reset tracker before processing queries
48
+ token_tracker.reset()
49
+
50
+ result = await llm_model_func("How are you?")
51
+ print("llm_model_func: ", result)
52
+
53
+ # Display final token usage after main query
54
+ print("Token usage:", token_tracker.get_usage())
55
+
56
+
57
+ asyncio.run(test_funcs())
58
+
59
+
60
+ async def initialize_rag():
61
+ rag = LightRAG(
62
+ working_dir=WORKING_DIR,
63
+ llm_model_func=llm_model_func,
64
+ embedding_func=EmbeddingFunc(
65
+ embedding_dim=1024, max_token_size=512, func=embedding_func
66
+ ),
67
+ )
68
+
69
+ await rag.initialize_storages()
70
+ await initialize_pipeline_status()
71
+
72
+ return rag
73
+
74
+
75
+ def main():
76
+ # Initialize RAG instance
77
+ rag = asyncio.run(initialize_rag())
78
+
79
+ # Reset tracker before processing queries
80
+ token_tracker.reset()
81
+
82
+ with open("./book.txt", "r", encoding="utf-8") as f:
83
+ rag.insert(f.read())
84
+
85
+ print(
86
+ rag.query(
87
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
88
+ )
89
+ )
90
+
91
+ print(
92
+ rag.query(
93
+ "What are the top themes in this story?", param=QueryParam(mode="local")
94
+ )
95
+ )
96
+
97
+ print(
98
+ rag.query(
99
+ "What are the top themes in this story?", param=QueryParam(mode="global")
100
+ )
101
+ )
102
+
103
+ print(
104
+ rag.query(
105
+ "What are the top themes in this story?", param=QueryParam(mode="hybrid")
106
+ )
107
+ )
108
+
109
+ # Display final token usage after main query
110
+ print("Token usage:", token_tracker.get_usage())
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
lightrag/llm/openai.py CHANGED
@@ -58,6 +58,7 @@ async def openai_complete_if_cache(
58
  history_messages: list[dict[str, Any]] | None = None,
59
  base_url: str | None = None,
60
  api_key: str | None = None,
 
61
  **kwargs: Any,
62
  ) -> str:
63
  if history_messages is None:
@@ -154,6 +155,15 @@ async def openai_complete_if_cache(
154
 
155
  if r"\u" in content:
156
  content = safe_unicode_decode(content.encode("utf-8"))
 
 
 
 
 
 
 
 
 
157
  return content
158
 
159
 
 
58
  history_messages: list[dict[str, Any]] | None = None,
59
  base_url: str | None = None,
60
  api_key: str | None = None,
61
+ token_tracker: Any | None = None,
62
  **kwargs: Any,
63
  ) -> str:
64
  if history_messages is None:
 
155
 
156
  if r"\u" in content:
157
  content = safe_unicode_decode(content.encode("utf-8"))
158
+
159
+ if token_tracker and hasattr(response, "usage"):
160
+ token_counts = {
161
+ "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
162
+ "completion_tokens": getattr(response.usage, "completion_tokens", 0),
163
+ "total_tokens": getattr(response.usage, "total_tokens", 0),
164
+ }
165
+ token_tracker.add_usage(token_counts)
166
+
167
  return content
168
 
169
 
lightrag/utils.py CHANGED
@@ -953,3 +953,53 @@ def check_storage_env_vars(storage_name: str) -> None:
953
  f"Storage implementation '{storage_name}' requires the following "
954
  f"environment variables: {', '.join(missing_vars)}"
955
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
953
  f"Storage implementation '{storage_name}' requires the following "
954
  f"environment variables: {', '.join(missing_vars)}"
955
  )
956
+
957
+
958
+ class TokenTracker:
959
+ """Track token usage for LLM calls."""
960
+
961
+ def __init__(self):
962
+ self.reset()
963
+
964
+ def reset(self):
965
+ self.prompt_tokens = 0
966
+ self.completion_tokens = 0
967
+ self.total_tokens = 0
968
+ self.call_count = 0
969
+
970
+ def add_usage(self, token_counts):
971
+ """Add token usage from one LLM call.
972
+
973
+ Args:
974
+ token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
975
+ """
976
+ self.prompt_tokens += token_counts.get("prompt_tokens", 0)
977
+ self.completion_tokens += token_counts.get("completion_tokens", 0)
978
+
979
+ # If total_tokens is provided, use it directly; otherwise calculate the sum
980
+ if "total_tokens" in token_counts:
981
+ self.total_tokens += token_counts["total_tokens"]
982
+ else:
983
+ self.total_tokens += token_counts.get(
984
+ "prompt_tokens", 0
985
+ ) + token_counts.get("completion_tokens", 0)
986
+
987
+ self.call_count += 1
988
+
989
+ def get_usage(self):
990
+ """Get current usage statistics."""
991
+ return {
992
+ "prompt_tokens": self.prompt_tokens,
993
+ "completion_tokens": self.completion_tokens,
994
+ "total_tokens": self.total_tokens,
995
+ "call_count": self.call_count,
996
+ }
997
+
998
+ def __str__(self):
999
+ usage = self.get_usage()
1000
+ return (
1001
+ f"LLM call count: {usage['call_count']}, "
1002
+ f"Prompt tokens: {usage['prompt_tokens']}, "
1003
+ f"Completion tokens: {usage['completion_tokens']}, "
1004
+ f"Total tokens: {usage['total_tokens']}"
1005
+ )