|
|
|
|
|
import os |
|
import numpy as np |
|
from google import genai |
|
from google.genai import types |
|
from dotenv import load_dotenv |
|
from lightrag.utils import EmbeddingFunc |
|
from lightrag import LightRAG, QueryParam |
|
from sentence_transformers import SentenceTransformer |
|
from lightrag.kg.shared_storage import initialize_pipeline_status |
|
|
|
import asyncio |
|
import nest_asyncio |
|
|
|
|
|
nest_asyncio.apply() |
|
|
|
load_dotenv() |
|
gemini_api_key = os.getenv("GEMINI_API_KEY") |
|
|
|
WORKING_DIR = "./dickens" |
|
|
|
if os.path.exists(WORKING_DIR): |
|
import shutil |
|
|
|
shutil.rmtree(WORKING_DIR) |
|
|
|
os.mkdir(WORKING_DIR) |
|
|
|
|
|
async def llm_model_func( |
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs |
|
) -> str: |
|
|
|
client = genai.Client(api_key=gemini_api_key) |
|
|
|
|
|
if history_messages is None: |
|
history_messages = [] |
|
|
|
combined_prompt = "" |
|
if system_prompt: |
|
combined_prompt += f"{system_prompt}\n" |
|
|
|
for msg in history_messages: |
|
|
|
combined_prompt += f"{msg['role']}: {msg['content']}\n" |
|
|
|
|
|
combined_prompt += f"user: {prompt}" |
|
|
|
|
|
response = client.models.generate_content( |
|
model="gemini-1.5-flash", |
|
contents=[combined_prompt], |
|
config=types.GenerateContentConfig(max_output_tokens=500, temperature=0.1), |
|
) |
|
|
|
|
|
return response.text |
|
|
|
|
|
async def embedding_func(texts: list[str]) -> np.ndarray: |
|
model = SentenceTransformer("all-MiniLM-L6-v2") |
|
embeddings = model.encode(texts, convert_to_numpy=True) |
|
return embeddings |
|
|
|
|
|
async def initialize_rag(): |
|
rag = LightRAG( |
|
working_dir=WORKING_DIR, |
|
llm_model_func=llm_model_func, |
|
embedding_func=EmbeddingFunc( |
|
embedding_dim=384, |
|
max_token_size=8192, |
|
func=embedding_func, |
|
), |
|
) |
|
|
|
await rag.initialize_storages() |
|
await initialize_pipeline_status() |
|
|
|
return rag |
|
|
|
|
|
def main(): |
|
|
|
rag = asyncio.run(initialize_rag()) |
|
file_path = "story.txt" |
|
with open(file_path, "r") as file: |
|
text = file.read() |
|
|
|
rag.insert(text) |
|
|
|
response = rag.query( |
|
query="What is the main theme of the story?", |
|
param=QueryParam(mode="hybrid", top_k=5, response_type="single line"), |
|
) |
|
|
|
print(response) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|