zrguo commited on
Commit
c52ae01
·
2 Parent(s): e15ba5c 1673f79

Merge pull request #407 from partoneplay/main

Browse files

Add support for Ollama streaming output and integrate Open-WebUI as the chat UI demo

examples/lightrag_api_open_webui_demo.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timezone
2
+ from fastapi import FastAPI
3
+ from fastapi.responses import StreamingResponse
4
+ import inspect
5
+ import json
6
+ from pydantic import BaseModel
7
+ from typing import Optional
8
+
9
+ import os
10
+ import logging
11
+ from lightrag import LightRAG, QueryParam
12
+ from lightrag.llm import ollama_model_complete, ollama_embed
13
+ from lightrag.utils import EmbeddingFunc
14
+
15
+ import nest_asyncio
16
+
17
+ WORKING_DIR = "./dickens"
18
+
19
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
20
+
21
+ if not os.path.exists(WORKING_DIR):
22
+ os.mkdir(WORKING_DIR)
23
+
24
+ rag = LightRAG(
25
+ working_dir=WORKING_DIR,
26
+ llm_model_func=ollama_model_complete,
27
+ llm_model_name="qwen2.5:latest",
28
+ llm_model_max_async=4,
29
+ llm_model_max_token_size=32768,
30
+ llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
31
+ embedding_func=EmbeddingFunc(
32
+ embedding_dim=1024,
33
+ max_token_size=8192,
34
+ func=lambda texts: ollama_embed(
35
+ texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
36
+ ),
37
+ ),
38
+ )
39
+
40
+ with open("./book.txt", "r", encoding="utf-8") as f:
41
+ rag.insert(f.read())
42
+
43
+ # Apply nest_asyncio to solve event loop issues
44
+ nest_asyncio.apply()
45
+
46
+ app = FastAPI(title="LightRAG", description="LightRAG API open-webui")
47
+
48
+
49
+ # Data models
50
+ MODEL_NAME = "LightRAG:latest"
51
+
52
+
53
+ class Message(BaseModel):
54
+ role: Optional[str] = None
55
+ content: str
56
+
57
+
58
+ class OpenWebUIRequest(BaseModel):
59
+ stream: Optional[bool] = None
60
+ model: Optional[str] = None
61
+ messages: list[Message]
62
+
63
+
64
+ # API routes
65
+
66
+
67
+ @app.get("/")
68
+ async def index():
69
+ return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"
70
+
71
+
72
+ @app.get("/ollama/api/version")
73
+ async def ollama_version():
74
+ return {"version": "0.4.7"}
75
+
76
+
77
+ @app.get("/ollama/api/tags")
78
+ async def ollama_tags():
79
+ return {
80
+ "models": [
81
+ {
82
+ "name": MODEL_NAME,
83
+ "model": MODEL_NAME,
84
+ "modified_at": "2024-11-12T20:22:37.561463923+08:00",
85
+ "size": 4683087332,
86
+ "digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
87
+ "details": {
88
+ "parent_model": "",
89
+ "format": "gguf",
90
+ "family": "qwen2",
91
+ "families": ["qwen2"],
92
+ "parameter_size": "7.6B",
93
+ "quantization_level": "Q4_K_M",
94
+ },
95
+ }
96
+ ]
97
+ }
98
+
99
+
100
+ @app.post("/ollama/api/chat")
101
+ async def ollama_chat(request: OpenWebUIRequest):
102
+ resp = rag.query(
103
+ request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
104
+ )
105
+ if inspect.isasyncgen(resp):
106
+
107
+ async def ollama_resp(chunks):
108
+ async for chunk in chunks:
109
+ yield (
110
+ json.dumps(
111
+ {
112
+ "model": MODEL_NAME,
113
+ "created_at": datetime.now(timezone.utc).strftime(
114
+ "%Y-%m-%dT%H:%M:%S.%fZ"
115
+ ),
116
+ "message": {
117
+ "role": "assistant",
118
+ "content": chunk,
119
+ },
120
+ "done": False,
121
+ },
122
+ ensure_ascii=False,
123
+ ).encode("utf-8")
124
+ + b"\n"
125
+ ) # the b"\n" is important
126
+
127
+ return StreamingResponse(ollama_resp(resp), media_type="application/json")
128
+ else:
129
+ return resp
130
+
131
+
132
+ @app.get("/health")
133
+ async def health_check():
134
+ return {"status": "healthy"}
135
+
136
+
137
+ if __name__ == "__main__":
138
+ import uvicorn
139
+
140
+ uvicorn.run(app, host="0.0.0.0", port=8020)
examples/lightrag_ollama_demo.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import os
 
2
  import logging
3
  from lightrag import LightRAG, QueryParam
4
  from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -49,3 +51,20 @@ print(
49
  print(
50
  rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  import os
3
+ import inspect
4
  import logging
5
  from lightrag import LightRAG, QueryParam
6
  from lightrag.llm import ollama_model_complete, ollama_embedding
 
51
  print(
52
  rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
53
  )
54
+
55
+ # stream response
56
+ resp = rag.query(
57
+ "What are the top themes in this story?",
58
+ param=QueryParam(mode="hybrid", stream=True),
59
+ )
60
+
61
+
62
+ async def print_stream(stream):
63
+ async for chunk in stream:
64
+ print(chunk, end="", flush=True)
65
+
66
+
67
+ if inspect.isasyncgen(resp):
68
+ asyncio.run(print_stream(resp))
69
+ else:
70
+ print(resp)
lightrag/base.py CHANGED
@@ -19,6 +19,7 @@ class QueryParam:
19
  only_need_context: bool = False
20
  only_need_prompt: bool = False
21
  response_type: str = "Multiple Paragraphs"
 
22
  # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
23
  top_k: int = 60
24
  # Number of document chunks to retrieve.
 
19
  only_need_context: bool = False
20
  only_need_prompt: bool = False
21
  response_type: str = "Multiple Paragraphs"
22
+ stream: bool = False
23
  # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
24
  top_k: int = 60
25
  # Number of document chunks to retrieve.
lightrag/llm.py CHANGED
@@ -4,7 +4,7 @@ import json
4
  import os
5
  import struct
6
  from functools import lru_cache
7
- from typing import List, Dict, Callable, Any
8
 
9
  import aioboto3
10
  import aiohttp
@@ -36,6 +36,13 @@ from .utils import (
36
  get_best_cached_response,
37
  )
38
 
 
 
 
 
 
 
 
39
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
40
 
41
 
@@ -474,7 +481,8 @@ async def ollama_model_if_cache(
474
  system_prompt=None,
475
  history_messages=[],
476
  **kwargs,
477
- ) -> str:
 
478
  kwargs.pop("max_tokens", None)
479
  # kwargs.pop("response_format", None) # allow json
480
  host = kwargs.pop("host", None)
@@ -517,28 +525,39 @@ async def ollama_model_if_cache(
517
  return if_cache_return["return"]
518
 
519
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
 
 
520
 
521
- result = response["message"]["content"]
 
 
522
 
523
- if hashing_kv is not None:
524
- await hashing_kv.upsert(
525
- {
526
- args_hash: {
527
- "return": result,
528
- "model": model,
529
- "embedding": quantized.tobytes().hex()
530
- if is_embedding_cache_enabled
531
- else None,
532
- "embedding_shape": quantized.shape
533
- if is_embedding_cache_enabled
534
- else None,
535
- "embedding_min": min_val if is_embedding_cache_enabled else None,
536
- "embedding_max": max_val if is_embedding_cache_enabled else None,
537
- "original_prompt": prompt,
 
 
 
 
 
 
 
 
538
  }
539
- }
540
- )
541
- return result
542
 
543
 
544
  @lru_cache(maxsize=1)
@@ -816,7 +835,7 @@ async def hf_model_complete(
816
 
817
  async def ollama_model_complete(
818
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
819
- ) -> str:
820
  keyword_extraction = kwargs.pop("keyword_extraction", None)
821
  if keyword_extraction:
822
  kwargs["format"] = "json"
 
4
  import os
5
  import struct
6
  from functools import lru_cache
7
+ from typing import List, Dict, Callable, Any, Union
8
 
9
  import aioboto3
10
  import aiohttp
 
36
  get_best_cached_response,
37
  )
38
 
39
+ import sys
40
+
41
+ if sys.version_info < (3, 9):
42
+ from typing import AsyncIterator
43
+ else:
44
+ from collections.abc import AsyncIterator
45
+
46
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
47
 
48
 
 
481
  system_prompt=None,
482
  history_messages=[],
483
  **kwargs,
484
+ ) -> Union[str, AsyncIterator[str]]:
485
+ stream = True if kwargs.get("stream") else False
486
  kwargs.pop("max_tokens", None)
487
  # kwargs.pop("response_format", None) # allow json
488
  host = kwargs.pop("host", None)
 
525
  return if_cache_return["return"]
526
 
527
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
528
+ if stream:
529
+ """ cannot cache stream response """
530
 
531
+ async def inner():
532
+ async for chunk in response:
533
+ yield chunk["message"]["content"]
534
 
535
+ return inner()
536
+ else:
537
+ result = response["message"]["content"]
538
+ if hashing_kv is not None:
539
+ await hashing_kv.upsert(
540
+ {
541
+ args_hash: {
542
+ "return": result,
543
+ "model": model,
544
+ "embedding": quantized.tobytes().hex()
545
+ if is_embedding_cache_enabled
546
+ else None,
547
+ "embedding_shape": quantized.shape
548
+ if is_embedding_cache_enabled
549
+ else None,
550
+ "embedding_min": min_val
551
+ if is_embedding_cache_enabled
552
+ else None,
553
+ "embedding_max": max_val
554
+ if is_embedding_cache_enabled
555
+ else None,
556
+ "original_prompt": prompt,
557
+ }
558
  }
559
+ )
560
+ return result
 
561
 
562
 
563
  @lru_cache(maxsize=1)
 
835
 
836
  async def ollama_model_complete(
837
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
838
+ ) -> Union[str, AsyncIterator[str]]:
839
  keyword_extraction = kwargs.pop("keyword_extraction", None)
840
  if keyword_extraction:
841
  kwargs["format"] = "json"
lightrag/operate.py CHANGED
@@ -534,8 +534,9 @@ async def kg_query(
534
  response = await use_model_func(
535
  query,
536
  system_prompt=sys_prompt,
 
537
  )
538
- if len(response) > len(sys_prompt):
539
  response = (
540
  response.replace(sys_prompt, "")
541
  .replace("user", "")
 
534
  response = await use_model_func(
535
  query,
536
  system_prompt=sys_prompt,
537
+ stream=query_param.stream,
538
  )
539
+ if isinstance(response, str) and len(response) > len(sys_prompt):
540
  response = (
541
  response.replace(sys_prompt, "")
542
  .replace("user", "")