File size: 5,348 Bytes
5a3ba21
 
 
a25342b
 
0553d6a
a25342b
5a3ba21
 
 
a25342b
5a3ba21
 
 
 
 
 
 
a25342b
5a3ba21
a25342b
5a3ba21
 
 
 
a25342b
 
5a3ba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a25342b
 
5a3ba21
a25342b
 
5a3ba21
0553d6a
5a3ba21
a25342b
5a3ba21
 
 
a25342b
 
 
5a3ba21
 
a25342b
5a3ba21
0553d6a
5a3ba21
a25342b
5a3ba21
 
 
a25342b
 
 
5a3ba21
 
a25342b
 
5a3ba21
 
a25342b
5a3ba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a25342b
5a3ba21
 
 
a25342b
5a3ba21
 
a25342b
 
 
5a3ba21
 
 
a25342b
 
5a3ba21
 
 
a25342b
5a3ba21
 
 
a25342b
5a3ba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import (
    openai_complete_if_cache,
    nvidia_openai_embed,
)
from lightrag.utils import EmbeddingFunc
import numpy as np

# for custom llm_model_func
from lightrag.utils import locate_json_string_body_from_string

WORKING_DIR = "./dickens"

if not os.path.exists(WORKING_DIR):
    os.mkdir(WORKING_DIR)

# some method to use your API key (choose one)
# NVIDIA_OPENAI_API_KEY = os.getenv("NVIDIA_OPENAI_API_KEY")
NVIDIA_OPENAI_API_KEY = "nvapi-xxxx"  # your api key

# using pre-defined function for nvidia LLM API. OpenAI compatible
# llm_model_func = nvidia_openai_complete


# If you trying to make custom llm_model_func to use llm model on NVIDIA API like other example:
async def llm_model_func(
    prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
    result = await openai_complete_if_cache(
        "nvidia/llama-3.1-nemotron-70b-instruct",
        prompt,
        system_prompt=system_prompt,
        history_messages=history_messages,
        api_key=NVIDIA_OPENAI_API_KEY,
        base_url="https://integrate.api.nvidia.com/v1",
        **kwargs,
    )
    if keyword_extraction:
        return locate_json_string_body_from_string(result)
    return result


# custom embedding
nvidia_embed_model = "nvidia/nv-embedqa-e5-v5"


async def indexing_embedding_func(texts: list[str]) -> np.ndarray:
    return await nvidia_openai_embed(
        texts,
        model=nvidia_embed_model,  # maximum 512 token
        # model="nvidia/llama-3.2-nv-embedqa-1b-v1",
        api_key=NVIDIA_OPENAI_API_KEY,
        base_url="https://integrate.api.nvidia.com/v1",
        input_type="passage",
        trunc="END",  # handling on server side if input token is longer than maximum token
        encode="float",
    )


async def query_embedding_func(texts: list[str]) -> np.ndarray:
    return await nvidia_openai_embed(
        texts,
        model=nvidia_embed_model,  # maximum 512 token
        # model="nvidia/llama-3.2-nv-embedqa-1b-v1",
        api_key=NVIDIA_OPENAI_API_KEY,
        base_url="https://integrate.api.nvidia.com/v1",
        input_type="query",
        trunc="END",  # handling on server side if input token is longer than maximum token
        encode="float",
    )


# dimension are same
async def get_embedding_dim():
    test_text = ["This is a test sentence."]
    embedding = await indexing_embedding_func(test_text)
    embedding_dim = embedding.shape[1]
    return embedding_dim


# function test
async def test_funcs():
    result = await llm_model_func("How are you?")
    print("llm_model_func: ", result)

    result = await indexing_embedding_func(["How are you?"])
    print("embedding_func: ", result)


# asyncio.run(test_funcs())


async def main():
    try:
        embedding_dimension = await get_embedding_dim()
        print(f"Detected embedding dimension: {embedding_dimension}")

        # lightRAG class during indexing
        rag = LightRAG(
            working_dir=WORKING_DIR,
            llm_model_func=llm_model_func,
            # llm_model_name="meta/llama3-70b-instruct", #un comment if
            embedding_func=EmbeddingFunc(
                embedding_dim=embedding_dimension,
                max_token_size=512,  # maximum token size, somehow it's still exceed maximum number of token
                # so truncate (trunc) parameter on embedding_func will handle it and try to examine the tokenizer used in LightRAG
                # so you can adjust to be able to fit the NVIDIA model (future work)
                func=indexing_embedding_func,
            ),
        )

        # reading file
        with open("./book.txt", "r", encoding="utf-8") as f:
            await rag.ainsert(f.read())

        # redefine rag to change embedding into query type
        rag = LightRAG(
            working_dir=WORKING_DIR,
            llm_model_func=llm_model_func,
            # llm_model_name="meta/llama3-70b-instruct", #un comment if
            embedding_func=EmbeddingFunc(
                embedding_dim=embedding_dimension,
                max_token_size=512,
                func=query_embedding_func,
            ),
        )

        # Perform naive search
        print("==============Naive===============")
        print(
            await rag.aquery(
                "What are the top themes in this story?", param=QueryParam(mode="naive")
            )
        )

        # Perform local search
        print("==============local===============")
        print(
            await rag.aquery(
                "What are the top themes in this story?", param=QueryParam(mode="local")
            )
        )

        # Perform global search
        print("==============global===============")
        print(
            await rag.aquery(
                "What are the top themes in this story?",
                param=QueryParam(mode="global"),
            )
        )

        # Perform hybrid search
        print("==============hybrid===============")
        print(
            await rag.aquery(
                "What are the top themes in this story?",
                param=QueryParam(mode="hybrid"),
            )
        )
    except Exception as e:
        print(f"An error occurred: {e}")


if __name__ == "__main__":
    asyncio.run(main())