File size: 4,380 Bytes
5dcb28f 8d66028 0553d6a 8d66028 8b3b01c 8d66028 bf0dfbd 8d66028 0fc2402 8d66028 c832152 5dcb28f 8d66028 94cd4d3 8d66028 0553d6a 8d66028 275e33e 8b3b01c 8d66028 275e33e 8d66028 8b3b01c 9f4950c 8d66028 121ced9 18a40e5 121ced9 18a40e5 797f88c 8d66028 18a40e5 121ced9 8d66028 5dcb28f 8d66028 5dcb28f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import sys
import os
from pathlib import Path
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.shared_storage import initialize_pipeline_status
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
sys.path.append(os.path.abspath(script_directory))
WORKING_DIR = "./dickens"
# We use OpenAI compatible API to call LLM on Oracle Cloud
# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
APIKEY = "ocigenerativeai"
CHATMODEL = "cohere.command-r-plus"
EMBEDMODEL = "cohere.embed-multilingual-v3.0"
CHUNK_TOKEN_SIZE = 1024
MAX_TOKENS = 4000
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
os.environ["ORACLE_USER"] = "username"
os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx"
os.environ["ORACLE_DSN"] = "xxxxxxx_medium"
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
os.environ["ORACLE_WORKSPACE"] = "company"
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
CHATMODEL,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=APIKEY,
base_url=BASE_URL,
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model=EMBEDMODEL,
api_key=APIKEY,
base_url=BASE_URL,
)
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
return embedding_dim
async def initialize_rag():
# Detect embedding dimension
embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG(
# log_level="DEBUG",
working_dir=WORKING_DIR,
entity_extract_max_gleaning=1,
enable_llm_cache=True,
enable_llm_cache_for_entity_extract=True,
embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
chunk_token_size=CHUNK_TOKEN_SIZE,
llm_model_max_token_size=MAX_TOKENS,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=500,
func=embedding_func,
),
graph_storage="OracleGraphStorage",
kv_storage="OracleKVStorage",
vector_storage="OracleVectorDBStorage",
addon_params={
"example_number": 1,
"language": "Simplfied Chinese",
"entity_types": ["organization", "person", "geo", "event"],
"insert_batch_size": 2,
},
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def main():
try:
# Initialize RAG instance
rag = await initialize_rag()
# Extract and Insert into LightRAG storage
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
all_text = f.read()
texts = [x for x in all_text.split("\n") if x]
# New mode use pipeline
await rag.apipeline_enqueue_documents(texts)
await rag.apipeline_process_enqueue_documents()
# Old method use ainsert
# await rag.ainsert(texts)
# Perform search in different modes
modes = ["naive", "local", "global", "hybrid"]
for mode in modes:
print("=" * 20, mode, "=" * 20)
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode=mode),
)
)
print("-" * 100, "\n")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
asyncio.run(main())
|