File size: 3,972 Bytes
5f72ddb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
import numpy as np
from dotenv import load_dotenv
import logging
from openai import OpenAI
from lightrag.kg.shared_storage import initialize_pipeline_status
logging.basicConfig(level=logging.INFO)
load_dotenv()
LLM_MODEL = os.environ.get("LLM_MODEL", "qwen-turbo-latest")
LLM_BINDING_HOST = "https://dashscope.aliyuncs.com/compatible-mode/v1"
LLM_BINDING_API_KEY = os.getenv("LLM_BINDING_API_KEY")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-v3")
EMBEDDING_BINDING_HOST = os.getenv("EMBEDDING_BINDING_HOST", LLM_BINDING_HOST)
EMBEDDING_BINDING_API_KEY = os.getenv("EMBEDDING_BINDING_API_KEY", LLM_BINDING_API_KEY)
EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", 1024))
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
EMBEDDING_MAX_BATCH_SIZE = int(os.environ.get("EMBEDDING_MAX_BATCH_SIZE", 10))
print(f"LLM_MODEL: {LLM_MODEL}")
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
client = OpenAI(
api_key=LLM_BINDING_API_KEY,
base_url=LLM_BINDING_HOST,
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
chat_completion = client.chat.completions.create(
model=LLM_MODEL,
messages=messages,
temperature=kwargs.get("temperature", 0),
top_p=kwargs.get("top_p", 1),
n=kwargs.get("n", 1),
extra_body={"enable_thinking": False},
)
return chat_completion.choices[0].message.content
async def embedding_func(texts: list[str]) -> np.ndarray:
client = OpenAI(
api_key=EMBEDDING_BINDING_API_KEY,
base_url=EMBEDDING_BINDING_HOST,
)
print("##### embedding: texts: %d #####" % len(texts))
max_batch_size = EMBEDDING_MAX_BATCH_SIZE
embeddings = []
for i in range(0, len(texts), max_batch_size):
batch = texts[i : i + max_batch_size]
embedding = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)
embeddings += [item.embedding for item in embedding.data]
return np.array(embeddings)
async def test_funcs():
result = await llm_model_func("How are you?")
print("Resposta do llm_model_func: ", result)
result = await embedding_func(["How are you?"])
print("Resultado do embedding_func: ", result.shape)
print("Dimensão da embedding: ", result.shape[1])
asyncio.run(test_funcs())
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=EMBEDDING_DIM,
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
rag = asyncio.run(initialize_rag())
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
query_text = "What are the main themes?"
print("Result (Naive):")
print(rag.query(query_text, param=QueryParam(mode="naive")))
print("\nResult (Local):")
print(rag.query(query_text, param=QueryParam(mode="local")))
print("\nResult (Global):")
print(rag.query(query_text, param=QueryParam(mode="global")))
print("\nResult (Hybrid):")
print(rag.query(query_text, param=QueryParam(mode="hybrid")))
print("\nResult (mix):")
print(rag.query(query_text, param=QueryParam(mode="mix")))
if __name__ == "__main__":
main()
|