Merge remote-tracking branch 'origin/main'
Browse files# Conflicts:
# lightrag/llm.py
# lightrag/operate.py
- examples/lightrag_api_open_webui_demo.py +140 -0
- examples/lightrag_ollama_demo.py +19 -0
- lightrag/base.py +1 -0
- lightrag/kg/oracle_impl.py +1 -1
- lightrag/llm.py +35 -4
- lightrag/operate.py +2 -1
examples/lightrag_api_open_webui_demo.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime, timezone
|
2 |
+
from fastapi import FastAPI
|
3 |
+
from fastapi.responses import StreamingResponse
|
4 |
+
import inspect
|
5 |
+
import json
|
6 |
+
from pydantic import BaseModel
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import os
|
10 |
+
import logging
|
11 |
+
from lightrag import LightRAG, QueryParam
|
12 |
+
from lightrag.llm import ollama_model_complete, ollama_embed
|
13 |
+
from lightrag.utils import EmbeddingFunc
|
14 |
+
|
15 |
+
import nest_asyncio
|
16 |
+
|
17 |
+
WORKING_DIR = "./dickens"
|
18 |
+
|
19 |
+
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
|
20 |
+
|
21 |
+
if not os.path.exists(WORKING_DIR):
|
22 |
+
os.mkdir(WORKING_DIR)
|
23 |
+
|
24 |
+
rag = LightRAG(
|
25 |
+
working_dir=WORKING_DIR,
|
26 |
+
llm_model_func=ollama_model_complete,
|
27 |
+
llm_model_name="qwen2.5:latest",
|
28 |
+
llm_model_max_async=4,
|
29 |
+
llm_model_max_token_size=32768,
|
30 |
+
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
|
31 |
+
embedding_func=EmbeddingFunc(
|
32 |
+
embedding_dim=1024,
|
33 |
+
max_token_size=8192,
|
34 |
+
func=lambda texts: ollama_embed(
|
35 |
+
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
|
36 |
+
),
|
37 |
+
),
|
38 |
+
)
|
39 |
+
|
40 |
+
with open("./book.txt", "r", encoding="utf-8") as f:
|
41 |
+
rag.insert(f.read())
|
42 |
+
|
43 |
+
# Apply nest_asyncio to solve event loop issues
|
44 |
+
nest_asyncio.apply()
|
45 |
+
|
46 |
+
app = FastAPI(title="LightRAG", description="LightRAG API open-webui")
|
47 |
+
|
48 |
+
|
49 |
+
# Data models
|
50 |
+
MODEL_NAME = "LightRAG:latest"
|
51 |
+
|
52 |
+
|
53 |
+
class Message(BaseModel):
|
54 |
+
role: Optional[str] = None
|
55 |
+
content: str
|
56 |
+
|
57 |
+
|
58 |
+
class OpenWebUIRequest(BaseModel):
|
59 |
+
stream: Optional[bool] = None
|
60 |
+
model: Optional[str] = None
|
61 |
+
messages: list[Message]
|
62 |
+
|
63 |
+
|
64 |
+
# API routes
|
65 |
+
|
66 |
+
|
67 |
+
@app.get("/")
|
68 |
+
async def index():
|
69 |
+
return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"
|
70 |
+
|
71 |
+
|
72 |
+
@app.get("/ollama/api/version")
|
73 |
+
async def ollama_version():
|
74 |
+
return {"version": "0.4.7"}
|
75 |
+
|
76 |
+
|
77 |
+
@app.get("/ollama/api/tags")
|
78 |
+
async def ollama_tags():
|
79 |
+
return {
|
80 |
+
"models": [
|
81 |
+
{
|
82 |
+
"name": MODEL_NAME,
|
83 |
+
"model": MODEL_NAME,
|
84 |
+
"modified_at": "2024-11-12T20:22:37.561463923+08:00",
|
85 |
+
"size": 4683087332,
|
86 |
+
"digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
|
87 |
+
"details": {
|
88 |
+
"parent_model": "",
|
89 |
+
"format": "gguf",
|
90 |
+
"family": "qwen2",
|
91 |
+
"families": ["qwen2"],
|
92 |
+
"parameter_size": "7.6B",
|
93 |
+
"quantization_level": "Q4_K_M",
|
94 |
+
},
|
95 |
+
}
|
96 |
+
]
|
97 |
+
}
|
98 |
+
|
99 |
+
|
100 |
+
@app.post("/ollama/api/chat")
|
101 |
+
async def ollama_chat(request: OpenWebUIRequest):
|
102 |
+
resp = rag.query(
|
103 |
+
request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
|
104 |
+
)
|
105 |
+
if inspect.isasyncgen(resp):
|
106 |
+
|
107 |
+
async def ollama_resp(chunks):
|
108 |
+
async for chunk in chunks:
|
109 |
+
yield (
|
110 |
+
json.dumps(
|
111 |
+
{
|
112 |
+
"model": MODEL_NAME,
|
113 |
+
"created_at": datetime.now(timezone.utc).strftime(
|
114 |
+
"%Y-%m-%dT%H:%M:%S.%fZ"
|
115 |
+
),
|
116 |
+
"message": {
|
117 |
+
"role": "assistant",
|
118 |
+
"content": chunk,
|
119 |
+
},
|
120 |
+
"done": False,
|
121 |
+
},
|
122 |
+
ensure_ascii=False,
|
123 |
+
).encode("utf-8")
|
124 |
+
+ b"\n"
|
125 |
+
) # the b"\n" is important
|
126 |
+
|
127 |
+
return StreamingResponse(ollama_resp(resp), media_type="application/json")
|
128 |
+
else:
|
129 |
+
return resp
|
130 |
+
|
131 |
+
|
132 |
+
@app.get("/health")
|
133 |
+
async def health_check():
|
134 |
+
return {"status": "healthy"}
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
import uvicorn
|
139 |
+
|
140 |
+
uvicorn.run(app, host="0.0.0.0", port=8020)
|
examples/lightrag_ollama_demo.py
CHANGED
@@ -1,4 +1,6 @@
|
|
|
|
1 |
import os
|
|
|
2 |
import logging
|
3 |
from lightrag import LightRAG, QueryParam
|
4 |
from lightrag.llm import ollama_model_complete, ollama_embedding
|
@@ -49,3 +51,20 @@ print(
|
|
49 |
print(
|
50 |
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
import os
|
3 |
+
import inspect
|
4 |
import logging
|
5 |
from lightrag import LightRAG, QueryParam
|
6 |
from lightrag.llm import ollama_model_complete, ollama_embedding
|
|
|
51 |
print(
|
52 |
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
53 |
)
|
54 |
+
|
55 |
+
# stream response
|
56 |
+
resp = rag.query(
|
57 |
+
"What are the top themes in this story?",
|
58 |
+
param=QueryParam(mode="hybrid", stream=True),
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
async def print_stream(stream):
|
63 |
+
async for chunk in stream:
|
64 |
+
print(chunk, end="", flush=True)
|
65 |
+
|
66 |
+
|
67 |
+
if inspect.isasyncgen(resp):
|
68 |
+
asyncio.run(print_stream(resp))
|
69 |
+
else:
|
70 |
+
print(resp)
|
lightrag/base.py
CHANGED
@@ -19,6 +19,7 @@ class QueryParam:
|
|
19 |
only_need_context: bool = False
|
20 |
only_need_prompt: bool = False
|
21 |
response_type: str = "Multiple Paragraphs"
|
|
|
22 |
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
23 |
top_k: int = 60
|
24 |
# Number of document chunks to retrieve.
|
|
|
19 |
only_need_context: bool = False
|
20 |
only_need_prompt: bool = False
|
21 |
response_type: str = "Multiple Paragraphs"
|
22 |
+
stream: bool = False
|
23 |
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
24 |
top_k: int = 60
|
25 |
# Number of document chunks to retrieve.
|
lightrag/kg/oracle_impl.py
CHANGED
@@ -143,7 +143,7 @@ class OracleDB:
|
|
143 |
data = None
|
144 |
return data
|
145 |
|
146 |
-
async def execute(self, sql: str, data: list
|
147 |
# logger.info("go into OracleDB execute method")
|
148 |
try:
|
149 |
async with self.pool.acquire() as connection:
|
|
|
143 |
data = None
|
144 |
return data
|
145 |
|
146 |
+
async def execute(self, sql: str, data: Union[list, dict] = None):
|
147 |
# logger.info("go into OracleDB execute method")
|
148 |
try:
|
149 |
async with self.pool.acquire() as connection:
|
lightrag/llm.py
CHANGED
@@ -4,8 +4,7 @@ import json
|
|
4 |
import os
|
5 |
import struct
|
6 |
from functools import lru_cache
|
7 |
-
from typing import List, Dict, Callable, Any,
|
8 |
-
from dataclasses import dataclass
|
9 |
|
10 |
import aioboto3
|
11 |
import aiohttp
|
@@ -37,6 +36,13 @@ from .utils import (
|
|
37 |
get_best_cached_response,
|
38 |
)
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
41 |
|
42 |
|
@@ -397,7 +403,8 @@ async def ollama_model_if_cache(
|
|
397 |
system_prompt=None,
|
398 |
history_messages=[],
|
399 |
**kwargs,
|
400 |
-
) -> str:
|
|
|
401 |
kwargs.pop("max_tokens", None)
|
402 |
# kwargs.pop("response_format", None) # allow json
|
403 |
host = kwargs.pop("host", None)
|
@@ -422,7 +429,31 @@ async def ollama_model_if_cache(
|
|
422 |
return cached_response
|
423 |
|
424 |
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
|
|
|
|
425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
result = response["message"]["content"]
|
427 |
|
428 |
# Save to cache
|
@@ -697,7 +728,7 @@ async def hf_model_complete(
|
|
697 |
|
698 |
async def ollama_model_complete(
|
699 |
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
700 |
-
) -> str:
|
701 |
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
702 |
if keyword_extraction:
|
703 |
kwargs["format"] = "json"
|
|
|
4 |
import os
|
5 |
import struct
|
6 |
from functools import lru_cache
|
7 |
+
from typing import List, Dict, Callable, Any, Union
|
|
|
8 |
|
9 |
import aioboto3
|
10 |
import aiohttp
|
|
|
36 |
get_best_cached_response,
|
37 |
)
|
38 |
|
39 |
+
import sys
|
40 |
+
|
41 |
+
if sys.version_info < (3, 9):
|
42 |
+
from typing import AsyncIterator
|
43 |
+
else:
|
44 |
+
from collections.abc import AsyncIterator
|
45 |
+
|
46 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
47 |
|
48 |
|
|
|
403 |
system_prompt=None,
|
404 |
history_messages=[],
|
405 |
**kwargs,
|
406 |
+
) -> Union[str, AsyncIterator[str]]:
|
407 |
+
stream = True if kwargs.get("stream") else False
|
408 |
kwargs.pop("max_tokens", None)
|
409 |
# kwargs.pop("response_format", None) # allow json
|
410 |
host = kwargs.pop("host", None)
|
|
|
429 |
return cached_response
|
430 |
|
431 |
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
432 |
+
if stream:
|
433 |
+
""" cannot cache stream response """
|
434 |
|
435 |
+
async def inner():
|
436 |
+
async for chunk in response:
|
437 |
+
yield chunk["message"]["content"]
|
438 |
+
|
439 |
+
return inner()
|
440 |
+
else:
|
441 |
+
result = response["message"]["content"]
|
442 |
+
# Save to cache
|
443 |
+
await save_to_cache(
|
444 |
+
hashing_kv,
|
445 |
+
CacheData(
|
446 |
+
args_hash=args_hash,
|
447 |
+
content=result,
|
448 |
+
model=model,
|
449 |
+
prompt=prompt,
|
450 |
+
quantized=quantized,
|
451 |
+
min_val=min_val,
|
452 |
+
max_val=max_val,
|
453 |
+
mode=mode,
|
454 |
+
),
|
455 |
+
)
|
456 |
+
return result
|
457 |
result = response["message"]["content"]
|
458 |
|
459 |
# Save to cache
|
|
|
728 |
|
729 |
async def ollama_model_complete(
|
730 |
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
731 |
+
) -> Union[str, AsyncIterator[str]]:
|
732 |
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
733 |
if keyword_extraction:
|
734 |
kwargs["format"] = "json"
|
lightrag/operate.py
CHANGED
@@ -536,9 +536,10 @@ async def kg_query(
|
|
536 |
response = await use_model_func(
|
537 |
query,
|
538 |
system_prompt=sys_prompt,
|
|
|
539 |
mode=query_param.mode,
|
540 |
)
|
541 |
-
if len(response) > len(sys_prompt):
|
542 |
response = (
|
543 |
response.replace(sys_prompt, "")
|
544 |
.replace("user", "")
|
|
|
536 |
response = await use_model_func(
|
537 |
query,
|
538 |
system_prompt=sys_prompt,
|
539 |
+
stream=query_param.stream,
|
540 |
mode=query_param.mode,
|
541 |
)
|
542 |
+
if isinstance(response, str) and len(response) > len(sys_prompt):
|
543 |
response = (
|
544 |
response.replace(sys_prompt, "")
|
545 |
.replace("user", "")
|