drahnreb commited on
Commit
0228302
·
1 Parent(s): f420d5a

add: GemmaTokenizer example

Browse files
examples/lightrag_gemini_demo_no_tiktoken.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install -q -U google-genai to use gemini as a client
2
+
3
+ import os
4
+ from typing import Optional
5
+ import dataclasses
6
+ from pathlib import Path
7
+ import hashlib
8
+ import numpy as np
9
+ from google import genai
10
+ from google.genai import types
11
+ from dotenv import load_dotenv
12
+ from lightrag.utils import EmbeddingFunc, Tokenizer
13
+ from lightrag import LightRAG, QueryParam
14
+ from sentence_transformers import SentenceTransformer
15
+ from lightrag.kg.shared_storage import initialize_pipeline_status
16
+ import sentencepiece as spm
17
+ import requests
18
+
19
+ import asyncio
20
+ import nest_asyncio
21
+
22
+ # Apply nest_asyncio to solve event loop issues
23
+ nest_asyncio.apply()
24
+
25
+ load_dotenv()
26
+ gemini_api_key = os.getenv("GEMINI_API_KEY")
27
+
28
+ WORKING_DIR = "./dickens"
29
+
30
+ if os.path.exists(WORKING_DIR):
31
+ import shutil
32
+
33
+ shutil.rmtree(WORKING_DIR)
34
+
35
+ os.mkdir(WORKING_DIR)
36
+
37
+
38
+ class GemmaTokenizer(Tokenizer):
39
+ # adapted from google-cloud-aiplatform[tokenization]
40
+
41
+ @dataclasses.dataclass(frozen=True)
42
+ class _TokenizerConfig:
43
+ tokenizer_model_url: str
44
+ tokenizer_model_hash: str
45
+
46
+ _TOKENIZERS = {
47
+ "google/gemma2": _TokenizerConfig(
48
+ tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
49
+ tokenizer_model_hash="61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
50
+ ),
51
+ "google/gemma3": _TokenizerConfig(
52
+ tokenizer_model_url="https://raw.githubusercontent.com/google/gemma_pytorch/cb7c0152a369e43908e769eb09e1ce6043afe084/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
53
+ tokenizer_model_hash="1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
54
+ )
55
+ }
56
+
57
+ def __init__(self, model_name: str = "gemini-2.0-flash", tokenizer_dir: Optional[str] = None):
58
+ # https://github.com/google/gemma_pytorch/tree/main/tokenizer
59
+ if "1.5" in model_name or "1.0" in model_name:
60
+ # up to gemini 1.5 gemma2 is a comparable local tokenizer
61
+ # https://github.com/googleapis/python-aiplatform/blob/main/vertexai/tokenization/_tokenizer_loading.py
62
+ tokenizer_name = "google/gemma2"
63
+ else:
64
+ # for gemini > 2.0 gemma3 was used
65
+ tokenizer_name = "google/gemma3"
66
+
67
+ file_url = self._TOKENIZERS[tokenizer_name].tokenizer_model_url
68
+ tokenizer_model_name = file_url.rsplit("/", 1)[1]
69
+ expected_hash = self._TOKENIZERS[tokenizer_name].tokenizer_model_hash
70
+
71
+ tokenizer_dir = Path(tokenizer_dir)
72
+ if tokenizer_dir.is_dir():
73
+ file_path = tokenizer_dir / tokenizer_model_name
74
+ model_data = self._maybe_load_from_cache(
75
+ file_path=file_path, expected_hash=expected_hash
76
+ )
77
+ else:
78
+ model_data = None
79
+ if not model_data:
80
+ model_data = self._load_from_url(file_url=file_url, expected_hash=expected_hash)
81
+ self.save_tokenizer_to_cache(cache_path=file_path, model_data=model_data)
82
+
83
+ tokenizer = spm.SentencePieceProcessor()
84
+ tokenizer.LoadFromSerializedProto(model_data)
85
+ super().__init__(model_name=model_name, tokenizer=tokenizer)
86
+
87
+ def _is_valid_model(self, model_data: bytes, expected_hash: str) -> bool:
88
+ """Returns true if the content is valid by checking the hash."""
89
+ return hashlib.sha256(model_data).hexdigest() == expected_hash
90
+
91
+ def _maybe_load_from_cache(self, file_path: Path, expected_hash: str) -> bytes:
92
+ """Loads the model data from the cache path."""
93
+ if not file_path.is_file():
94
+ return
95
+ with open(file_path, "rb") as f:
96
+ content = f.read()
97
+ if self._is_valid_model(model_data=content, expected_hash=expected_hash):
98
+ return content
99
+
100
+ # Cached file corrupted.
101
+ self._maybe_remove_file(file_path)
102
+
103
+ def _load_from_url(self, file_url: str, expected_hash: str) -> bytes:
104
+ """Loads model bytes from the given file url."""
105
+ resp = requests.get(file_url)
106
+ resp.raise_for_status()
107
+ content = resp.content
108
+
109
+ if not self._is_valid_model(model_data=content, expected_hash=expected_hash):
110
+ actual_hash = hashlib.sha256(content).hexdigest()
111
+ raise ValueError(
112
+ f"Downloaded model file is corrupted."
113
+ f" Expected hash {expected_hash}. Got file hash {actual_hash}."
114
+ )
115
+ return content
116
+
117
+ @staticmethod
118
+ def save_tokenizer_to_cache(cache_path: Path, model_data: bytes) -> None:
119
+ """Saves the model data to the cache path."""
120
+ try:
121
+ if not cache_path.is_file():
122
+ cache_dir = cache_path.parent
123
+ cache_dir.mkdir(parents=True, exist_ok=True)
124
+ with open(cache_path, "wb") as f:
125
+ f.write(model_data)
126
+ except OSError:
127
+ # Don't raise if we cannot write file.
128
+ pass
129
+
130
+ @staticmethod
131
+ def _maybe_remove_file(file_path: Path) -> None:
132
+ """Removes the file if exists."""
133
+ if not file_path.is_file():
134
+ return
135
+ try:
136
+ file_path.unlink()
137
+ except OSError:
138
+ # Don't raise if we cannot remove file.
139
+ pass
140
+
141
+ # def encode(self, content: str) -> list[int]:
142
+ # return self.tokenizer.encode(content)
143
+
144
+ # def decode(self, tokens: list[int]) -> str:
145
+ # return self.tokenizer.decode(tokens)
146
+
147
+
148
+ async def llm_model_func(
149
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
150
+ ) -> str:
151
+ # 1. Initialize the GenAI Client with your Gemini API Key
152
+ client = genai.Client(api_key=gemini_api_key)
153
+
154
+ # 2. Combine prompts: system prompt, history, and user prompt
155
+ if history_messages is None:
156
+ history_messages = []
157
+
158
+ combined_prompt = ""
159
+ if system_prompt:
160
+ combined_prompt += f"{system_prompt}\n"
161
+
162
+ for msg in history_messages:
163
+ # Each msg is expected to be a dict: {"role": "...", "content": "..."}
164
+ combined_prompt += f"{msg['role']}: {msg['content']}\n"
165
+
166
+ # Finally, add the new user prompt
167
+ combined_prompt += f"user: {prompt}"
168
+
169
+ # 3. Call the Gemini model
170
+ response = client.models.generate_content(
171
+ model="gemini-1.5-flash",
172
+ contents=[combined_prompt],
173
+ config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1),
174
+ )
175
+
176
+ # 4. Return the response text
177
+ return response.text
178
+
179
+
180
+ async def embedding_func(texts: list[str]) -> np.ndarray:
181
+ model = SentenceTransformer("all-MiniLM-L6-v2")
182
+ embeddings = model.encode(texts, convert_to_numpy=True)
183
+ return embeddings
184
+
185
+
186
+ async def initialize_rag():
187
+ rag = LightRAG(
188
+ working_dir=WORKING_DIR,
189
+ # tiktoken_model_name="gpt-4o-mini",
190
+ tokenizer=GemmaTokenizer(tokenizer_dir=(Path(WORKING_DIR) / "vertexai_tokenizer_model"), model_name="gemini-2.0-flash"),
191
+ llm_model_func=llm_model_func,
192
+ embedding_func=EmbeddingFunc(
193
+ embedding_dim=384,
194
+ max_token_size=8192,
195
+ func=embedding_func,
196
+ ),
197
+ )
198
+
199
+ await rag.initialize_storages()
200
+ await initialize_pipeline_status()
201
+
202
+ return rag
203
+
204
+
205
+ def main():
206
+ # Initialize RAG instance
207
+ rag = asyncio.run(initialize_rag())
208
+ file_path = "story.txt"
209
+ with open(file_path, "r") as file:
210
+ text = file.read()
211
+
212
+ rag.insert(text)
213
+
214
+ response = rag.query(
215
+ query="What is the main theme of the story?",
216
+ param=QueryParam(mode="hybrid", top_k=5, response_type="single line"),
217
+ )
218
+
219
+ print(response)
220
+
221
+
222
+ if __name__ == "__main__":
223
+ main()