partoneplay commited on
Commit
59bd4fe
·
1 Parent(s): 0114598

Add support for Ollama streaming output and integrate Open-WebUI as the chat UI demo

Browse files
examples/lightrag_api_open_webui_demo.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timezone
2
+ from fastapi import FastAPI
3
+ from fastapi.responses import StreamingResponse
4
+ import inspect
5
+ import json
6
+ from pydantic import BaseModel
7
+ from typing import Optional
8
+
9
+ import os
10
+ import logging
11
+ from lightrag import LightRAG, QueryParam
12
+ from lightrag.llm import ollama_model_complete, ollama_embed
13
+ from lightrag.utils import EmbeddingFunc
14
+
15
+ import nest_asyncio
16
+
17
+ WORKING_DIR = "./dickens"
18
+
19
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
20
+
21
+ if not os.path.exists(WORKING_DIR):
22
+ os.mkdir(WORKING_DIR)
23
+
24
+ rag = LightRAG(
25
+ working_dir=WORKING_DIR,
26
+ llm_model_func=ollama_model_complete,
27
+ llm_model_name="qwen2.5:latest",
28
+ llm_model_max_async=4,
29
+ llm_model_max_token_size=32768,
30
+ llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
31
+ embedding_func=EmbeddingFunc(
32
+ embedding_dim=1024,
33
+ max_token_size=8192,
34
+ func=lambda texts: ollama_embed(
35
+ texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
36
+ ),
37
+ ),
38
+ )
39
+
40
+ with open("./book.txt", "r", encoding="utf-8") as f:
41
+ rag.insert(f.read())
42
+
43
+ # Apply nest_asyncio to solve event loop issues
44
+ nest_asyncio.apply()
45
+
46
+ app = FastAPI(title="LightRAG", description="LightRAG API open-webui")
47
+
48
+
49
+ # Data models
50
+ MODEL_NAME = "LightRAG:latest"
51
+
52
+
53
+ class Message(BaseModel):
54
+ role: Optional[str] = None
55
+ content: str
56
+
57
+
58
+ class OpenWebUIRequest(BaseModel):
59
+ stream: Optional[bool] = None
60
+ model: Optional[str] = None
61
+ messages: list[Message]
62
+
63
+
64
+ # API routes
65
+
66
+
67
+ @app.get("/")
68
+ async def index():
69
+ return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"
70
+
71
+
72
+ @app.get("/ollama/api/version")
73
+ async def ollama_version():
74
+ return {"version": "0.4.7"}
75
+
76
+
77
+ @app.get("/ollama/api/tags")
78
+ async def ollama_tags():
79
+ return {
80
+ "models": [
81
+ {
82
+ "name": MODEL_NAME,
83
+ "model": MODEL_NAME,
84
+ "modified_at": "2024-11-12T20:22:37.561463923+08:00",
85
+ "size": 4683087332,
86
+ "digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
87
+ "details": {
88
+ "parent_model": "",
89
+ "format": "gguf",
90
+ "family": "qwen2",
91
+ "families": ["qwen2"],
92
+ "parameter_size": "7.6B",
93
+ "quantization_level": "Q4_K_M",
94
+ },
95
+ }
96
+ ]
97
+ }
98
+
99
+
100
+ @app.post("/ollama/api/chat")
101
+ async def ollama_chat(request: OpenWebUIRequest):
102
+ resp = rag.query(
103
+ request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
104
+ )
105
+ if inspect.isasyncgen(resp):
106
+
107
+ async def ollama_resp(chunks):
108
+ async for chunk in chunks:
109
+ yield (
110
+ json.dumps(
111
+ {
112
+ "model": MODEL_NAME,
113
+ "created_at": datetime.now(timezone.utc).strftime(
114
+ "%Y-%m-%dT%H:%M:%S.%fZ"
115
+ ),
116
+ "message": {
117
+ "role": "assistant",
118
+ "content": chunk,
119
+ },
120
+ "done": False,
121
+ },
122
+ ensure_ascii=False,
123
+ ).encode("utf-8")
124
+ + b"\n"
125
+ ) # the b"\n" is important
126
+
127
+ return StreamingResponse(ollama_resp(resp), media_type="application/json")
128
+ else:
129
+ return resp
130
+
131
+
132
+ @app.get("/health")
133
+ async def health_check():
134
+ return {"status": "healthy"}
135
+
136
+
137
+ if __name__ == "__main__":
138
+ import uvicorn
139
+
140
+ uvicorn.run(app, host="0.0.0.0", port=8020)
examples/lightrag_ollama_demo.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import os
 
2
  import logging
3
  from lightrag import LightRAG, QueryParam
4
  from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -49,3 +51,20 @@ print(
49
  print(
50
  rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  import os
3
+ import inspect
4
  import logging
5
  from lightrag import LightRAG, QueryParam
6
  from lightrag.llm import ollama_model_complete, ollama_embedding
 
51
  print(
52
  rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
53
  )
54
+
55
+ # stream response
56
+ resp = rag.query(
57
+ "What are the top themes in this story?",
58
+ param=QueryParam(mode="hybrid", stream=True),
59
+ )
60
+
61
+
62
+ async def print_stream(stream):
63
+ async for chunk in stream:
64
+ print(chunk, end="", flush=True)
65
+
66
+
67
+ if inspect.isasyncgen(resp):
68
+ asyncio.run(print_stream(resp))
69
+ else:
70
+ print(resp)
lightrag/base.py CHANGED
@@ -19,6 +19,7 @@ class QueryParam:
19
  only_need_context: bool = False
20
  only_need_prompt: bool = False
21
  response_type: str = "Multiple Paragraphs"
 
22
  # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
23
  top_k: int = 60
24
  # Number of document chunks to retrieve.
 
19
  only_need_context: bool = False
20
  only_need_prompt: bool = False
21
  response_type: str = "Multiple Paragraphs"
22
+ stream: bool = False
23
  # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
24
  top_k: int = 60
25
  # Number of document chunks to retrieve.
lightrag/llm.py CHANGED
@@ -27,7 +27,7 @@ from tenacity import (
27
  from transformers import AutoTokenizer, AutoModelForCausalLM
28
  import torch
29
  from pydantic import BaseModel, Field
30
- from typing import List, Dict, Callable, Any
31
  from .base import BaseKVStorage
32
  from .utils import (
33
  compute_args_hash,
@@ -37,6 +37,13 @@ from .utils import (
37
  get_best_cached_response,
38
  )
39
 
 
 
 
 
 
 
 
40
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
 
42
 
@@ -454,7 +461,8 @@ async def ollama_model_if_cache(
454
  system_prompt=None,
455
  history_messages=[],
456
  **kwargs,
457
- ) -> str:
 
458
  kwargs.pop("max_tokens", None)
459
  # kwargs.pop("response_format", None) # allow json
460
  host = kwargs.pop("host", None)
@@ -494,28 +502,39 @@ async def ollama_model_if_cache(
494
  return if_cache_return["return"]
495
 
496
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
 
 
497
 
498
- result = response["message"]["content"]
 
 
499
 
500
- if hashing_kv is not None:
501
- await hashing_kv.upsert(
502
- {
503
- args_hash: {
504
- "return": result,
505
- "model": model,
506
- "embedding": quantized.tobytes().hex()
507
- if is_embedding_cache_enabled
508
- else None,
509
- "embedding_shape": quantized.shape
510
- if is_embedding_cache_enabled
511
- else None,
512
- "embedding_min": min_val if is_embedding_cache_enabled else None,
513
- "embedding_max": max_val if is_embedding_cache_enabled else None,
514
- "original_prompt": prompt,
 
 
 
 
 
 
 
 
515
  }
516
- }
517
- )
518
- return result
519
 
520
 
521
  @lru_cache(maxsize=1)
@@ -785,7 +804,7 @@ async def hf_model_complete(
785
 
786
  async def ollama_model_complete(
787
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
788
- ) -> str:
789
  keyword_extraction = kwargs.pop("keyword_extraction", None)
790
  if keyword_extraction:
791
  kwargs["format"] = "json"
 
27
  from transformers import AutoTokenizer, AutoModelForCausalLM
28
  import torch
29
  from pydantic import BaseModel, Field
30
+ from typing import List, Dict, Callable, Any, Union
31
  from .base import BaseKVStorage
32
  from .utils import (
33
  compute_args_hash,
 
37
  get_best_cached_response,
38
  )
39
 
40
+ import sys
41
+
42
+ if sys.version_info < (3, 9):
43
+ from typing import AsyncIterator
44
+ else:
45
+ from collections.abc import AsyncIterator
46
+
47
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
48
 
49
 
 
461
  system_prompt=None,
462
  history_messages=[],
463
  **kwargs,
464
+ ) -> Union[str, AsyncIterator[str]]:
465
+ stream = True if kwargs.get("stream") else False
466
  kwargs.pop("max_tokens", None)
467
  # kwargs.pop("response_format", None) # allow json
468
  host = kwargs.pop("host", None)
 
502
  return if_cache_return["return"]
503
 
504
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
505
+ if stream:
506
+ """ cannot cache stream response """
507
 
508
+ async def inner():
509
+ async for chunk in response:
510
+ yield chunk["message"]["content"]
511
 
512
+ return inner()
513
+ else:
514
+ result = response["message"]["content"]
515
+ if hashing_kv is not None:
516
+ await hashing_kv.upsert(
517
+ {
518
+ args_hash: {
519
+ "return": result,
520
+ "model": model,
521
+ "embedding": quantized.tobytes().hex()
522
+ if is_embedding_cache_enabled
523
+ else None,
524
+ "embedding_shape": quantized.shape
525
+ if is_embedding_cache_enabled
526
+ else None,
527
+ "embedding_min": min_val
528
+ if is_embedding_cache_enabled
529
+ else None,
530
+ "embedding_max": max_val
531
+ if is_embedding_cache_enabled
532
+ else None,
533
+ "original_prompt": prompt,
534
+ }
535
  }
536
+ )
537
+ return result
 
538
 
539
 
540
  @lru_cache(maxsize=1)
 
804
 
805
  async def ollama_model_complete(
806
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
807
+ ) -> Union[str, AsyncIterator[str]]:
808
  keyword_extraction = kwargs.pop("keyword_extraction", None)
809
  if keyword_extraction:
810
  kwargs["format"] = "json"
lightrag/operate.py CHANGED
@@ -534,8 +534,9 @@ async def kg_query(
534
  response = await use_model_func(
535
  query,
536
  system_prompt=sys_prompt,
 
537
  )
538
- if len(response) > len(sys_prompt):
539
  response = (
540
  response.replace(sys_prompt, "")
541
  .replace("user", "")
 
534
  response = await use_model_func(
535
  query,
536
  system_prompt=sys_prompt,
537
+ stream=query_param.stream,
538
  )
539
+ if isinstance(response, str) and len(response) > len(sys_prompt):
540
  response = (
541
  response.replace(sys_prompt, "")
542
  .replace("user", "")