Magicyuan commited on
Commit
d979af0
·
2 Parent(s): 55d34bc c52ae01

Merge remote-tracking branch 'origin/main'

Browse files

# Conflicts:
# lightrag/llm.py
# lightrag/operate.py

examples/lightrag_api_open_webui_demo.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timezone
2
+ from fastapi import FastAPI
3
+ from fastapi.responses import StreamingResponse
4
+ import inspect
5
+ import json
6
+ from pydantic import BaseModel
7
+ from typing import Optional
8
+
9
+ import os
10
+ import logging
11
+ from lightrag import LightRAG, QueryParam
12
+ from lightrag.llm import ollama_model_complete, ollama_embed
13
+ from lightrag.utils import EmbeddingFunc
14
+
15
+ import nest_asyncio
16
+
17
+ WORKING_DIR = "./dickens"
18
+
19
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
20
+
21
+ if not os.path.exists(WORKING_DIR):
22
+ os.mkdir(WORKING_DIR)
23
+
24
+ rag = LightRAG(
25
+ working_dir=WORKING_DIR,
26
+ llm_model_func=ollama_model_complete,
27
+ llm_model_name="qwen2.5:latest",
28
+ llm_model_max_async=4,
29
+ llm_model_max_token_size=32768,
30
+ llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
31
+ embedding_func=EmbeddingFunc(
32
+ embedding_dim=1024,
33
+ max_token_size=8192,
34
+ func=lambda texts: ollama_embed(
35
+ texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
36
+ ),
37
+ ),
38
+ )
39
+
40
+ with open("./book.txt", "r", encoding="utf-8") as f:
41
+ rag.insert(f.read())
42
+
43
+ # Apply nest_asyncio to solve event loop issues
44
+ nest_asyncio.apply()
45
+
46
+ app = FastAPI(title="LightRAG", description="LightRAG API open-webui")
47
+
48
+
49
+ # Data models
50
+ MODEL_NAME = "LightRAG:latest"
51
+
52
+
53
+ class Message(BaseModel):
54
+ role: Optional[str] = None
55
+ content: str
56
+
57
+
58
+ class OpenWebUIRequest(BaseModel):
59
+ stream: Optional[bool] = None
60
+ model: Optional[str] = None
61
+ messages: list[Message]
62
+
63
+
64
+ # API routes
65
+
66
+
67
+ @app.get("/")
68
+ async def index():
69
+ return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"
70
+
71
+
72
+ @app.get("/ollama/api/version")
73
+ async def ollama_version():
74
+ return {"version": "0.4.7"}
75
+
76
+
77
+ @app.get("/ollama/api/tags")
78
+ async def ollama_tags():
79
+ return {
80
+ "models": [
81
+ {
82
+ "name": MODEL_NAME,
83
+ "model": MODEL_NAME,
84
+ "modified_at": "2024-11-12T20:22:37.561463923+08:00",
85
+ "size": 4683087332,
86
+ "digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
87
+ "details": {
88
+ "parent_model": "",
89
+ "format": "gguf",
90
+ "family": "qwen2",
91
+ "families": ["qwen2"],
92
+ "parameter_size": "7.6B",
93
+ "quantization_level": "Q4_K_M",
94
+ },
95
+ }
96
+ ]
97
+ }
98
+
99
+
100
+ @app.post("/ollama/api/chat")
101
+ async def ollama_chat(request: OpenWebUIRequest):
102
+ resp = rag.query(
103
+ request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
104
+ )
105
+ if inspect.isasyncgen(resp):
106
+
107
+ async def ollama_resp(chunks):
108
+ async for chunk in chunks:
109
+ yield (
110
+ json.dumps(
111
+ {
112
+ "model": MODEL_NAME,
113
+ "created_at": datetime.now(timezone.utc).strftime(
114
+ "%Y-%m-%dT%H:%M:%S.%fZ"
115
+ ),
116
+ "message": {
117
+ "role": "assistant",
118
+ "content": chunk,
119
+ },
120
+ "done": False,
121
+ },
122
+ ensure_ascii=False,
123
+ ).encode("utf-8")
124
+ + b"\n"
125
+ ) # the b"\n" is important
126
+
127
+ return StreamingResponse(ollama_resp(resp), media_type="application/json")
128
+ else:
129
+ return resp
130
+
131
+
132
+ @app.get("/health")
133
+ async def health_check():
134
+ return {"status": "healthy"}
135
+
136
+
137
+ if __name__ == "__main__":
138
+ import uvicorn
139
+
140
+ uvicorn.run(app, host="0.0.0.0", port=8020)
examples/lightrag_ollama_demo.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import os
 
2
  import logging
3
  from lightrag import LightRAG, QueryParam
4
  from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -49,3 +51,20 @@ print(
49
  print(
50
  rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  import os
3
+ import inspect
4
  import logging
5
  from lightrag import LightRAG, QueryParam
6
  from lightrag.llm import ollama_model_complete, ollama_embedding
 
51
  print(
52
  rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
53
  )
54
+
55
+ # stream response
56
+ resp = rag.query(
57
+ "What are the top themes in this story?",
58
+ param=QueryParam(mode="hybrid", stream=True),
59
+ )
60
+
61
+
62
+ async def print_stream(stream):
63
+ async for chunk in stream:
64
+ print(chunk, end="", flush=True)
65
+
66
+
67
+ if inspect.isasyncgen(resp):
68
+ asyncio.run(print_stream(resp))
69
+ else:
70
+ print(resp)
lightrag/base.py CHANGED
@@ -19,6 +19,7 @@ class QueryParam:
19
  only_need_context: bool = False
20
  only_need_prompt: bool = False
21
  response_type: str = "Multiple Paragraphs"
 
22
  # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
23
  top_k: int = 60
24
  # Number of document chunks to retrieve.
 
19
  only_need_context: bool = False
20
  only_need_prompt: bool = False
21
  response_type: str = "Multiple Paragraphs"
22
+ stream: bool = False
23
  # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
24
  top_k: int = 60
25
  # Number of document chunks to retrieve.
lightrag/kg/oracle_impl.py CHANGED
@@ -143,7 +143,7 @@ class OracleDB:
143
  data = None
144
  return data
145
 
146
- async def execute(self, sql: str, data: list | dict = None):
147
  # logger.info("go into OracleDB execute method")
148
  try:
149
  async with self.pool.acquire() as connection:
 
143
  data = None
144
  return data
145
 
146
+ async def execute(self, sql: str, data: Union[list, dict] = None):
147
  # logger.info("go into OracleDB execute method")
148
  try:
149
  async with self.pool.acquire() as connection:
lightrag/llm.py CHANGED
@@ -4,8 +4,7 @@ import json
4
  import os
5
  import struct
6
  from functools import lru_cache
7
- from typing import List, Dict, Callable, Any, Optional
8
- from dataclasses import dataclass
9
 
10
  import aioboto3
11
  import aiohttp
@@ -37,6 +36,13 @@ from .utils import (
37
  get_best_cached_response,
38
  )
39
 
 
 
 
 
 
 
 
40
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
 
42
 
@@ -397,7 +403,8 @@ async def ollama_model_if_cache(
397
  system_prompt=None,
398
  history_messages=[],
399
  **kwargs,
400
- ) -> str:
 
401
  kwargs.pop("max_tokens", None)
402
  # kwargs.pop("response_format", None) # allow json
403
  host = kwargs.pop("host", None)
@@ -422,7 +429,31 @@ async def ollama_model_if_cache(
422
  return cached_response
423
 
424
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
 
 
425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  result = response["message"]["content"]
427
 
428
  # Save to cache
@@ -697,7 +728,7 @@ async def hf_model_complete(
697
 
698
  async def ollama_model_complete(
699
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
700
- ) -> str:
701
  keyword_extraction = kwargs.pop("keyword_extraction", None)
702
  if keyword_extraction:
703
  kwargs["format"] = "json"
 
4
  import os
5
  import struct
6
  from functools import lru_cache
7
+ from typing import List, Dict, Callable, Any, Union
 
8
 
9
  import aioboto3
10
  import aiohttp
 
36
  get_best_cached_response,
37
  )
38
 
39
+ import sys
40
+
41
+ if sys.version_info < (3, 9):
42
+ from typing import AsyncIterator
43
+ else:
44
+ from collections.abc import AsyncIterator
45
+
46
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
47
 
48
 
 
403
  system_prompt=None,
404
  history_messages=[],
405
  **kwargs,
406
+ ) -> Union[str, AsyncIterator[str]]:
407
+ stream = True if kwargs.get("stream") else False
408
  kwargs.pop("max_tokens", None)
409
  # kwargs.pop("response_format", None) # allow json
410
  host = kwargs.pop("host", None)
 
429
  return cached_response
430
 
431
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
432
+ if stream:
433
+ """ cannot cache stream response """
434
 
435
+ async def inner():
436
+ async for chunk in response:
437
+ yield chunk["message"]["content"]
438
+
439
+ return inner()
440
+ else:
441
+ result = response["message"]["content"]
442
+ # Save to cache
443
+ await save_to_cache(
444
+ hashing_kv,
445
+ CacheData(
446
+ args_hash=args_hash,
447
+ content=result,
448
+ model=model,
449
+ prompt=prompt,
450
+ quantized=quantized,
451
+ min_val=min_val,
452
+ max_val=max_val,
453
+ mode=mode,
454
+ ),
455
+ )
456
+ return result
457
  result = response["message"]["content"]
458
 
459
  # Save to cache
 
728
 
729
  async def ollama_model_complete(
730
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
731
+ ) -> Union[str, AsyncIterator[str]]:
732
  keyword_extraction = kwargs.pop("keyword_extraction", None)
733
  if keyword_extraction:
734
  kwargs["format"] = "json"
lightrag/operate.py CHANGED
@@ -536,9 +536,10 @@ async def kg_query(
536
  response = await use_model_func(
537
  query,
538
  system_prompt=sys_prompt,
 
539
  mode=query_param.mode,
540
  )
541
- if len(response) > len(sys_prompt):
542
  response = (
543
  response.replace(sys_prompt, "")
544
  .replace("user", "")
 
536
  response = await use_model_func(
537
  query,
538
  system_prompt=sys_prompt,
539
+ stream=query_param.stream,
540
  mode=query_param.mode,
541
  )
542
+ if isinstance(response, str) and len(response) > len(sys_prompt):
543
  response = (
544
  response.replace(sys_prompt, "")
545
  .replace("user", "")