Commit
·
59bd4fe
1
Parent(s):
0114598
Add support for Ollama streaming output and integrate Open-WebUI as the chat UI demo
Browse files- examples/lightrag_api_open_webui_demo.py +140 -0
- examples/lightrag_ollama_demo.py +19 -0
- lightrag/base.py +1 -0
- lightrag/llm.py +41 -22
- lightrag/operate.py +2 -1
examples/lightrag_api_open_webui_demo.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime, timezone
|
2 |
+
from fastapi import FastAPI
|
3 |
+
from fastapi.responses import StreamingResponse
|
4 |
+
import inspect
|
5 |
+
import json
|
6 |
+
from pydantic import BaseModel
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import os
|
10 |
+
import logging
|
11 |
+
from lightrag import LightRAG, QueryParam
|
12 |
+
from lightrag.llm import ollama_model_complete, ollama_embed
|
13 |
+
from lightrag.utils import EmbeddingFunc
|
14 |
+
|
15 |
+
import nest_asyncio
|
16 |
+
|
17 |
+
WORKING_DIR = "./dickens"
|
18 |
+
|
19 |
+
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
|
20 |
+
|
21 |
+
if not os.path.exists(WORKING_DIR):
|
22 |
+
os.mkdir(WORKING_DIR)
|
23 |
+
|
24 |
+
rag = LightRAG(
|
25 |
+
working_dir=WORKING_DIR,
|
26 |
+
llm_model_func=ollama_model_complete,
|
27 |
+
llm_model_name="qwen2.5:latest",
|
28 |
+
llm_model_max_async=4,
|
29 |
+
llm_model_max_token_size=32768,
|
30 |
+
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
|
31 |
+
embedding_func=EmbeddingFunc(
|
32 |
+
embedding_dim=1024,
|
33 |
+
max_token_size=8192,
|
34 |
+
func=lambda texts: ollama_embed(
|
35 |
+
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
|
36 |
+
),
|
37 |
+
),
|
38 |
+
)
|
39 |
+
|
40 |
+
with open("./book.txt", "r", encoding="utf-8") as f:
|
41 |
+
rag.insert(f.read())
|
42 |
+
|
43 |
+
# Apply nest_asyncio to solve event loop issues
|
44 |
+
nest_asyncio.apply()
|
45 |
+
|
46 |
+
app = FastAPI(title="LightRAG", description="LightRAG API open-webui")
|
47 |
+
|
48 |
+
|
49 |
+
# Data models
|
50 |
+
MODEL_NAME = "LightRAG:latest"
|
51 |
+
|
52 |
+
|
53 |
+
class Message(BaseModel):
|
54 |
+
role: Optional[str] = None
|
55 |
+
content: str
|
56 |
+
|
57 |
+
|
58 |
+
class OpenWebUIRequest(BaseModel):
|
59 |
+
stream: Optional[bool] = None
|
60 |
+
model: Optional[str] = None
|
61 |
+
messages: list[Message]
|
62 |
+
|
63 |
+
|
64 |
+
# API routes
|
65 |
+
|
66 |
+
|
67 |
+
@app.get("/")
|
68 |
+
async def index():
|
69 |
+
return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"
|
70 |
+
|
71 |
+
|
72 |
+
@app.get("/ollama/api/version")
|
73 |
+
async def ollama_version():
|
74 |
+
return {"version": "0.4.7"}
|
75 |
+
|
76 |
+
|
77 |
+
@app.get("/ollama/api/tags")
|
78 |
+
async def ollama_tags():
|
79 |
+
return {
|
80 |
+
"models": [
|
81 |
+
{
|
82 |
+
"name": MODEL_NAME,
|
83 |
+
"model": MODEL_NAME,
|
84 |
+
"modified_at": "2024-11-12T20:22:37.561463923+08:00",
|
85 |
+
"size": 4683087332,
|
86 |
+
"digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
|
87 |
+
"details": {
|
88 |
+
"parent_model": "",
|
89 |
+
"format": "gguf",
|
90 |
+
"family": "qwen2",
|
91 |
+
"families": ["qwen2"],
|
92 |
+
"parameter_size": "7.6B",
|
93 |
+
"quantization_level": "Q4_K_M",
|
94 |
+
},
|
95 |
+
}
|
96 |
+
]
|
97 |
+
}
|
98 |
+
|
99 |
+
|
100 |
+
@app.post("/ollama/api/chat")
|
101 |
+
async def ollama_chat(request: OpenWebUIRequest):
|
102 |
+
resp = rag.query(
|
103 |
+
request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
|
104 |
+
)
|
105 |
+
if inspect.isasyncgen(resp):
|
106 |
+
|
107 |
+
async def ollama_resp(chunks):
|
108 |
+
async for chunk in chunks:
|
109 |
+
yield (
|
110 |
+
json.dumps(
|
111 |
+
{
|
112 |
+
"model": MODEL_NAME,
|
113 |
+
"created_at": datetime.now(timezone.utc).strftime(
|
114 |
+
"%Y-%m-%dT%H:%M:%S.%fZ"
|
115 |
+
),
|
116 |
+
"message": {
|
117 |
+
"role": "assistant",
|
118 |
+
"content": chunk,
|
119 |
+
},
|
120 |
+
"done": False,
|
121 |
+
},
|
122 |
+
ensure_ascii=False,
|
123 |
+
).encode("utf-8")
|
124 |
+
+ b"\n"
|
125 |
+
) # the b"\n" is important
|
126 |
+
|
127 |
+
return StreamingResponse(ollama_resp(resp), media_type="application/json")
|
128 |
+
else:
|
129 |
+
return resp
|
130 |
+
|
131 |
+
|
132 |
+
@app.get("/health")
|
133 |
+
async def health_check():
|
134 |
+
return {"status": "healthy"}
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
import uvicorn
|
139 |
+
|
140 |
+
uvicorn.run(app, host="0.0.0.0", port=8020)
|
examples/lightrag_ollama_demo.py
CHANGED
@@ -1,4 +1,6 @@
|
|
|
|
1 |
import os
|
|
|
2 |
import logging
|
3 |
from lightrag import LightRAG, QueryParam
|
4 |
from lightrag.llm import ollama_model_complete, ollama_embedding
|
@@ -49,3 +51,20 @@ print(
|
|
49 |
print(
|
50 |
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
import os
|
3 |
+
import inspect
|
4 |
import logging
|
5 |
from lightrag import LightRAG, QueryParam
|
6 |
from lightrag.llm import ollama_model_complete, ollama_embedding
|
|
|
51 |
print(
|
52 |
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
53 |
)
|
54 |
+
|
55 |
+
# stream response
|
56 |
+
resp = rag.query(
|
57 |
+
"What are the top themes in this story?",
|
58 |
+
param=QueryParam(mode="hybrid", stream=True),
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
async def print_stream(stream):
|
63 |
+
async for chunk in stream:
|
64 |
+
print(chunk, end="", flush=True)
|
65 |
+
|
66 |
+
|
67 |
+
if inspect.isasyncgen(resp):
|
68 |
+
asyncio.run(print_stream(resp))
|
69 |
+
else:
|
70 |
+
print(resp)
|
lightrag/base.py
CHANGED
@@ -19,6 +19,7 @@ class QueryParam:
|
|
19 |
only_need_context: bool = False
|
20 |
only_need_prompt: bool = False
|
21 |
response_type: str = "Multiple Paragraphs"
|
|
|
22 |
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
23 |
top_k: int = 60
|
24 |
# Number of document chunks to retrieve.
|
|
|
19 |
only_need_context: bool = False
|
20 |
only_need_prompt: bool = False
|
21 |
response_type: str = "Multiple Paragraphs"
|
22 |
+
stream: bool = False
|
23 |
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
24 |
top_k: int = 60
|
25 |
# Number of document chunks to retrieve.
|
lightrag/llm.py
CHANGED
@@ -27,7 +27,7 @@ from tenacity import (
|
|
27 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
28 |
import torch
|
29 |
from pydantic import BaseModel, Field
|
30 |
-
from typing import List, Dict, Callable, Any
|
31 |
from .base import BaseKVStorage
|
32 |
from .utils import (
|
33 |
compute_args_hash,
|
@@ -37,6 +37,13 @@ from .utils import (
|
|
37 |
get_best_cached_response,
|
38 |
)
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
41 |
|
42 |
|
@@ -454,7 +461,8 @@ async def ollama_model_if_cache(
|
|
454 |
system_prompt=None,
|
455 |
history_messages=[],
|
456 |
**kwargs,
|
457 |
-
) -> str:
|
|
|
458 |
kwargs.pop("max_tokens", None)
|
459 |
# kwargs.pop("response_format", None) # allow json
|
460 |
host = kwargs.pop("host", None)
|
@@ -494,28 +502,39 @@ async def ollama_model_if_cache(
|
|
494 |
return if_cache_return["return"]
|
495 |
|
496 |
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
|
|
|
|
497 |
|
498 |
-
|
|
|
|
|
499 |
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
}
|
516 |
-
|
517 |
-
|
518 |
-
return result
|
519 |
|
520 |
|
521 |
@lru_cache(maxsize=1)
|
@@ -785,7 +804,7 @@ async def hf_model_complete(
|
|
785 |
|
786 |
async def ollama_model_complete(
|
787 |
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
788 |
-
) -> str:
|
789 |
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
790 |
if keyword_extraction:
|
791 |
kwargs["format"] = "json"
|
|
|
27 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
28 |
import torch
|
29 |
from pydantic import BaseModel, Field
|
30 |
+
from typing import List, Dict, Callable, Any, Union
|
31 |
from .base import BaseKVStorage
|
32 |
from .utils import (
|
33 |
compute_args_hash,
|
|
|
37 |
get_best_cached_response,
|
38 |
)
|
39 |
|
40 |
+
import sys
|
41 |
+
|
42 |
+
if sys.version_info < (3, 9):
|
43 |
+
from typing import AsyncIterator
|
44 |
+
else:
|
45 |
+
from collections.abc import AsyncIterator
|
46 |
+
|
47 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
48 |
|
49 |
|
|
|
461 |
system_prompt=None,
|
462 |
history_messages=[],
|
463 |
**kwargs,
|
464 |
+
) -> Union[str, AsyncIterator[str]]:
|
465 |
+
stream = True if kwargs.get("stream") else False
|
466 |
kwargs.pop("max_tokens", None)
|
467 |
# kwargs.pop("response_format", None) # allow json
|
468 |
host = kwargs.pop("host", None)
|
|
|
502 |
return if_cache_return["return"]
|
503 |
|
504 |
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
505 |
+
if stream:
|
506 |
+
""" cannot cache stream response """
|
507 |
|
508 |
+
async def inner():
|
509 |
+
async for chunk in response:
|
510 |
+
yield chunk["message"]["content"]
|
511 |
|
512 |
+
return inner()
|
513 |
+
else:
|
514 |
+
result = response["message"]["content"]
|
515 |
+
if hashing_kv is not None:
|
516 |
+
await hashing_kv.upsert(
|
517 |
+
{
|
518 |
+
args_hash: {
|
519 |
+
"return": result,
|
520 |
+
"model": model,
|
521 |
+
"embedding": quantized.tobytes().hex()
|
522 |
+
if is_embedding_cache_enabled
|
523 |
+
else None,
|
524 |
+
"embedding_shape": quantized.shape
|
525 |
+
if is_embedding_cache_enabled
|
526 |
+
else None,
|
527 |
+
"embedding_min": min_val
|
528 |
+
if is_embedding_cache_enabled
|
529 |
+
else None,
|
530 |
+
"embedding_max": max_val
|
531 |
+
if is_embedding_cache_enabled
|
532 |
+
else None,
|
533 |
+
"original_prompt": prompt,
|
534 |
+
}
|
535 |
}
|
536 |
+
)
|
537 |
+
return result
|
|
|
538 |
|
539 |
|
540 |
@lru_cache(maxsize=1)
|
|
|
804 |
|
805 |
async def ollama_model_complete(
|
806 |
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
807 |
+
) -> Union[str, AsyncIterator[str]]:
|
808 |
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
809 |
if keyword_extraction:
|
810 |
kwargs["format"] = "json"
|
lightrag/operate.py
CHANGED
@@ -534,8 +534,9 @@ async def kg_query(
|
|
534 |
response = await use_model_func(
|
535 |
query,
|
536 |
system_prompt=sys_prompt,
|
|
|
537 |
)
|
538 |
-
if len(response) > len(sys_prompt):
|
539 |
response = (
|
540 |
response.replace(sys_prompt, "")
|
541 |
.replace("user", "")
|
|
|
534 |
response = await use_model_func(
|
535 |
query,
|
536 |
system_prompt=sys_prompt,
|
537 |
+
stream=query_param.stream,
|
538 |
)
|
539 |
+
if isinstance(response, str) and len(response) > len(sys_prompt):
|
540 |
response = (
|
541 |
response.replace(sys_prompt, "")
|
542 |
.replace("user", "")
|