yangdx commited on
Commit
cfd1740
·
2 Parent(s): b212444 602ff41

Merge branch 'main' into upload-error

Browse files
examples/lightrag_gemini_track_token_demo.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install -q -U google-genai to use gemini as a client
2
+
3
+ import os
4
+ import asyncio
5
+ import numpy as np
6
+ import nest_asyncio
7
+ from google import genai
8
+ from google.genai import types
9
+ from dotenv import load_dotenv
10
+ from lightrag.utils import EmbeddingFunc
11
+ from lightrag import LightRAG, QueryParam
12
+ from lightrag.kg.shared_storage import initialize_pipeline_status
13
+ from lightrag.llm.siliconcloud import siliconcloud_embedding
14
+ from lightrag.utils import setup_logger
15
+ from lightrag.utils import TokenTracker
16
+
17
+ setup_logger("lightrag", level="DEBUG")
18
+
19
+ # Apply nest_asyncio to solve event loop issues
20
+ nest_asyncio.apply()
21
+
22
+ load_dotenv()
23
+ gemini_api_key = os.getenv("GEMINI_API_KEY")
24
+ siliconflow_api_key = os.getenv("SILICONFLOW_API_KEY")
25
+
26
+ WORKING_DIR = "./dickens"
27
+
28
+ if not os.path.exists(WORKING_DIR):
29
+ os.mkdir(WORKING_DIR)
30
+
31
+ token_tracker = TokenTracker()
32
+
33
+
34
+ async def llm_model_func(
35
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
36
+ ) -> str:
37
+ # 1. Initialize the GenAI Client with your Gemini API Key
38
+ client = genai.Client(api_key=gemini_api_key)
39
+
40
+ # 2. Combine prompts: system prompt, history, and user prompt
41
+ if history_messages is None:
42
+ history_messages = []
43
+
44
+ combined_prompt = ""
45
+ if system_prompt:
46
+ combined_prompt += f"{system_prompt}\n"
47
+
48
+ for msg in history_messages:
49
+ # Each msg is expected to be a dict: {"role": "...", "content": "..."}
50
+ combined_prompt += f"{msg['role']}: {msg['content']}\n"
51
+
52
+ # Finally, add the new user prompt
53
+ combined_prompt += f"user: {prompt}"
54
+
55
+ # 3. Call the Gemini model
56
+ response = client.models.generate_content(
57
+ model="gemini-2.0-flash",
58
+ contents=[combined_prompt],
59
+ config=types.GenerateContentConfig(
60
+ max_output_tokens=5000, temperature=0, top_k=10
61
+ ),
62
+ )
63
+
64
+ # 4. Get token counts with null safety
65
+ usage = getattr(response, "usage_metadata", None)
66
+ prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
67
+ completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
68
+ total_tokens = getattr(usage, "total_token_count", 0) or (
69
+ prompt_tokens + completion_tokens
70
+ )
71
+
72
+ token_counts = {
73
+ "prompt_tokens": prompt_tokens,
74
+ "completion_tokens": completion_tokens,
75
+ "total_tokens": total_tokens,
76
+ }
77
+
78
+ token_tracker.add_usage(token_counts)
79
+
80
+ # 5. Return the response text
81
+ return response.text
82
+
83
+
84
+ async def embedding_func(texts: list[str]) -> np.ndarray:
85
+ return await siliconcloud_embedding(
86
+ texts,
87
+ model="BAAI/bge-m3",
88
+ api_key=siliconflow_api_key,
89
+ max_token_size=512,
90
+ )
91
+
92
+
93
+ async def initialize_rag():
94
+ rag = LightRAG(
95
+ working_dir=WORKING_DIR,
96
+ entity_extract_max_gleaning=1,
97
+ enable_llm_cache=True,
98
+ enable_llm_cache_for_entity_extract=True,
99
+ embedding_cache_config={"enabled": True, "similarity_threshold": 0.90},
100
+ llm_model_func=llm_model_func,
101
+ embedding_func=EmbeddingFunc(
102
+ embedding_dim=1024,
103
+ max_token_size=8192,
104
+ func=embedding_func,
105
+ ),
106
+ )
107
+
108
+ await rag.initialize_storages()
109
+ await initialize_pipeline_status()
110
+
111
+ return rag
112
+
113
+
114
+ def main():
115
+ # Initialize RAG instance
116
+ rag = asyncio.run(initialize_rag())
117
+
118
+ # Reset tracker before processing queries
119
+ token_tracker.reset()
120
+
121
+ with open("./book.txt", "r", encoding="utf-8") as f:
122
+ rag.insert(f.read())
123
+
124
+ print(
125
+ rag.query(
126
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
127
+ )
128
+ )
129
+
130
+ print(
131
+ rag.query(
132
+ "What are the top themes in this story?", param=QueryParam(mode="local")
133
+ )
134
+ )
135
+
136
+ print(
137
+ rag.query(
138
+ "What are the top themes in this story?", param=QueryParam(mode="global")
139
+ )
140
+ )
141
+
142
+ print(
143
+ rag.query(
144
+ "What are the top themes in this story?", param=QueryParam(mode="hybrid")
145
+ )
146
+ )
147
+
148
+ # Display final token usage after main query
149
+ print("Token usage:", token_tracker.get_usage())
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
examples/lightrag_siliconcloud_track_token_demo.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import openai_complete_if_cache
5
+ from lightrag.llm.siliconcloud import siliconcloud_embedding
6
+ from lightrag.utils import EmbeddingFunc
7
+ from lightrag.utils import TokenTracker
8
+ import numpy as np
9
+ from lightrag.kg.shared_storage import initialize_pipeline_status
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+
14
+ token_tracker = TokenTracker()
15
+ WORKING_DIR = "./dickens"
16
+
17
+ if not os.path.exists(WORKING_DIR):
18
+ os.mkdir(WORKING_DIR)
19
+
20
+
21
+ async def llm_model_func(
22
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
23
+ ) -> str:
24
+ return await openai_complete_if_cache(
25
+ "Qwen/Qwen2.5-7B-Instruct",
26
+ prompt,
27
+ system_prompt=system_prompt,
28
+ history_messages=history_messages,
29
+ api_key=os.getenv("SILICONFLOW_API_KEY"),
30
+ base_url="https://api.siliconflow.cn/v1/",
31
+ token_tracker=token_tracker,
32
+ **kwargs,
33
+ )
34
+
35
+
36
+ async def embedding_func(texts: list[str]) -> np.ndarray:
37
+ return await siliconcloud_embedding(
38
+ texts,
39
+ model="BAAI/bge-m3",
40
+ api_key=os.getenv("SILICONFLOW_API_KEY"),
41
+ max_token_size=512,
42
+ )
43
+
44
+
45
+ # function test
46
+ async def test_funcs():
47
+ # Reset tracker before processing queries
48
+ token_tracker.reset()
49
+
50
+ result = await llm_model_func("How are you?")
51
+ print("llm_model_func: ", result)
52
+
53
+ # Display final token usage after main query
54
+ print("Token usage:", token_tracker.get_usage())
55
+
56
+
57
+ asyncio.run(test_funcs())
58
+
59
+
60
+ async def initialize_rag():
61
+ rag = LightRAG(
62
+ working_dir=WORKING_DIR,
63
+ llm_model_func=llm_model_func,
64
+ embedding_func=EmbeddingFunc(
65
+ embedding_dim=1024, max_token_size=512, func=embedding_func
66
+ ),
67
+ )
68
+
69
+ await rag.initialize_storages()
70
+ await initialize_pipeline_status()
71
+
72
+ return rag
73
+
74
+
75
+ def main():
76
+ # Initialize RAG instance
77
+ rag = asyncio.run(initialize_rag())
78
+
79
+ # Reset tracker before processing queries
80
+ token_tracker.reset()
81
+
82
+ with open("./book.txt", "r", encoding="utf-8") as f:
83
+ rag.insert(f.read())
84
+
85
+ print(
86
+ rag.query(
87
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
88
+ )
89
+ )
90
+
91
+ print(
92
+ rag.query(
93
+ "What are the top themes in this story?", param=QueryParam(mode="local")
94
+ )
95
+ )
96
+
97
+ print(
98
+ rag.query(
99
+ "What are the top themes in this story?", param=QueryParam(mode="global")
100
+ )
101
+ )
102
+
103
+ print(
104
+ rag.query(
105
+ "What are the top themes in this story?", param=QueryParam(mode="hybrid")
106
+ )
107
+ )
108
+
109
+ # Display final token usage after main query
110
+ print("Token usage:", token_tracker.get_usage())
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
lightrag/llm/openai.py CHANGED
@@ -58,6 +58,7 @@ async def openai_complete_if_cache(
58
  history_messages: list[dict[str, Any]] | None = None,
59
  base_url: str | None = None,
60
  api_key: str | None = None,
 
61
  **kwargs: Any,
62
  ) -> str:
63
  if history_messages is None:
@@ -154,6 +155,15 @@ async def openai_complete_if_cache(
154
 
155
  if r"\u" in content:
156
  content = safe_unicode_decode(content.encode("utf-8"))
 
 
 
 
 
 
 
 
 
157
  return content
158
 
159
 
 
58
  history_messages: list[dict[str, Any]] | None = None,
59
  base_url: str | None = None,
60
  api_key: str | None = None,
61
+ token_tracker: Any | None = None,
62
  **kwargs: Any,
63
  ) -> str:
64
  if history_messages is None:
 
155
 
156
  if r"\u" in content:
157
  content = safe_unicode_decode(content.encode("utf-8"))
158
+
159
+ if token_tracker and hasattr(response, "usage"):
160
+ token_counts = {
161
+ "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
162
+ "completion_tokens": getattr(response.usage, "completion_tokens", 0),
163
+ "total_tokens": getattr(response.usage, "total_tokens", 0),
164
+ }
165
+ token_tracker.add_usage(token_counts)
166
+
167
  return content
168
 
169
 
lightrag/operate.py CHANGED
@@ -1038,7 +1038,7 @@ async def mix_kg_vector_query(
1038
  # Include time information in content
1039
  formatted_chunks = []
1040
  for c in maybe_trun_chunks:
1041
- chunk_text = c["content"]
1042
  if c["created_at"]:
1043
  chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
1044
  formatted_chunks.append(chunk_text)
@@ -1334,9 +1334,9 @@ async def _get_node_data(
1334
  )
1335
  relations_context = list_of_list_to_csv(relations_section_list)
1336
 
1337
- text_units_section_list = [["id", "content"]]
1338
  for i, t in enumerate(use_text_units):
1339
- text_units_section_list.append([i, t["content"]])
1340
  text_units_context = list_of_list_to_csv(text_units_section_list)
1341
  return entities_context, relations_context, text_units_context
1342
 
@@ -1597,9 +1597,9 @@ async def _get_edge_data(
1597
  )
1598
  entities_context = list_of_list_to_csv(entites_section_list)
1599
 
1600
- text_units_section_list = [["id", "content"]]
1601
  for i, t in enumerate(use_text_units):
1602
- text_units_section_list.append([i, t["content"]])
1603
  text_units_context = list_of_list_to_csv(text_units_section_list)
1604
  return entities_context, relations_context, text_units_context
1605
 
@@ -1785,7 +1785,12 @@ async def naive_query(
1785
  f"Truncate chunks from {len(chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
1786
  )
1787
 
1788
- section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
 
 
 
 
 
1789
 
1790
  if query_param.only_need_context:
1791
  return section
 
1038
  # Include time information in content
1039
  formatted_chunks = []
1040
  for c in maybe_trun_chunks:
1041
+ chunk_text = "File path: " + c["file_path"] + "\n" + c["content"]
1042
  if c["created_at"]:
1043
  chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
1044
  formatted_chunks.append(chunk_text)
 
1334
  )
1335
  relations_context = list_of_list_to_csv(relations_section_list)
1336
 
1337
+ text_units_section_list = [["id", "content", "file_path"]]
1338
  for i, t in enumerate(use_text_units):
1339
+ text_units_section_list.append([i, t["content"], t["file_path"]])
1340
  text_units_context = list_of_list_to_csv(text_units_section_list)
1341
  return entities_context, relations_context, text_units_context
1342
 
 
1597
  )
1598
  entities_context = list_of_list_to_csv(entites_section_list)
1599
 
1600
+ text_units_section_list = [["id", "content", "file_path"]]
1601
  for i, t in enumerate(use_text_units):
1602
+ text_units_section_list.append([i, t["content"], t["file_path"]])
1603
  text_units_context = list_of_list_to_csv(text_units_section_list)
1604
  return entities_context, relations_context, text_units_context
1605
 
 
1785
  f"Truncate chunks from {len(chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
1786
  )
1787
 
1788
+ section = "\n--New Chunk--\n".join(
1789
+ [
1790
+ "File path: " + c["file_path"] + "\n" + c["content"]
1791
+ for c in maybe_trun_chunks
1792
+ ]
1793
+ )
1794
 
1795
  if query_param.only_need_context:
1796
  return section
lightrag/prompt.py CHANGED
@@ -222,7 +222,7 @@ When handling relationships with timestamps:
222
  - Use markdown formatting with appropriate section headings
223
  - Please respond in the same language as the user's question.
224
  - Ensure the response maintains continuity with the conversation history.
225
- - List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
226
  - If you don't know the answer, just say so.
227
  - Do not make anything up. Do not include information not provided by the Knowledge Base."""
228
 
@@ -320,7 +320,7 @@ When handling content with timestamps:
320
  - Use markdown formatting with appropriate section headings
321
  - Please respond in the same language as the user's question.
322
  - Ensure the response maintains continuity with the conversation history.
323
- - List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
324
  - If you don't know the answer, just say so.
325
  - Do not include information not provided by the Document Chunks."""
326
 
@@ -382,6 +382,6 @@ When handling information with timestamps:
382
  - Ensure the response maintains continuity with the conversation history.
383
  - Organize answer in sections focusing on one main point or aspect of the answer
384
  - Use clear and descriptive section titles that reflect the content
385
- - List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] Source content (File: file_path)
386
  - If you don't know the answer, just say so. Do not make anything up.
387
  - Do not include information not provided by the Data Sources."""
 
222
  - Use markdown formatting with appropriate section headings
223
  - Please respond in the same language as the user's question.
224
  - Ensure the response maintains continuity with the conversation history.
225
+ - List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path
226
  - If you don't know the answer, just say so.
227
  - Do not make anything up. Do not include information not provided by the Knowledge Base."""
228
 
 
320
  - Use markdown formatting with appropriate section headings
321
  - Please respond in the same language as the user's question.
322
  - Ensure the response maintains continuity with the conversation history.
323
+ - List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path
324
  - If you don't know the answer, just say so.
325
  - Do not include information not provided by the Document Chunks."""
326
 
 
382
  - Ensure the response maintains continuity with the conversation history.
383
  - Organize answer in sections focusing on one main point or aspect of the answer
384
  - Use clear and descriptive section titles that reflect the content
385
+ - List up to 5 most important reference sources at the end under "References" section. Clearly indicating whether each source is from Knowledge Graph (KG) or Vector Data (DC), and include the file path if available, in the following format: [KG/DC] file_path
386
  - If you don't know the answer, just say so. Do not make anything up.
387
  - Do not include information not provided by the Data Sources."""
lightrag/utils.py CHANGED
@@ -953,3 +953,53 @@ def check_storage_env_vars(storage_name: str) -> None:
953
  f"Storage implementation '{storage_name}' requires the following "
954
  f"environment variables: {', '.join(missing_vars)}"
955
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
953
  f"Storage implementation '{storage_name}' requires the following "
954
  f"environment variables: {', '.join(missing_vars)}"
955
  )
956
+
957
+
958
+ class TokenTracker:
959
+ """Track token usage for LLM calls."""
960
+
961
+ def __init__(self):
962
+ self.reset()
963
+
964
+ def reset(self):
965
+ self.prompt_tokens = 0
966
+ self.completion_tokens = 0
967
+ self.total_tokens = 0
968
+ self.call_count = 0
969
+
970
+ def add_usage(self, token_counts):
971
+ """Add token usage from one LLM call.
972
+
973
+ Args:
974
+ token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
975
+ """
976
+ self.prompt_tokens += token_counts.get("prompt_tokens", 0)
977
+ self.completion_tokens += token_counts.get("completion_tokens", 0)
978
+
979
+ # If total_tokens is provided, use it directly; otherwise calculate the sum
980
+ if "total_tokens" in token_counts:
981
+ self.total_tokens += token_counts["total_tokens"]
982
+ else:
983
+ self.total_tokens += token_counts.get(
984
+ "prompt_tokens", 0
985
+ ) + token_counts.get("completion_tokens", 0)
986
+
987
+ self.call_count += 1
988
+
989
+ def get_usage(self):
990
+ """Get current usage statistics."""
991
+ return {
992
+ "prompt_tokens": self.prompt_tokens,
993
+ "completion_tokens": self.completion_tokens,
994
+ "total_tokens": self.total_tokens,
995
+ "call_count": self.call_count,
996
+ }
997
+
998
+ def __str__(self):
999
+ usage = self.get_usage()
1000
+ return (
1001
+ f"LLM call count: {usage['call_count']}, "
1002
+ f"Prompt tokens: {usage['prompt_tokens']}, "
1003
+ f"Completion tokens: {usage['completion_tokens']}, "
1004
+ f"Total tokens: {usage['total_tokens']}"
1005
+ )