zrguo commited on
Commit
9d0622c
·
unverified ·
2 Parent(s): a374cd8 68834c7

Merge pull request #592 from danielaskdd/yangdx

Browse files
README.md CHANGED
@@ -716,7 +716,7 @@ Output the results in the following structure:
716
  ```
717
  </details>
718
 
719
- ### Batch Eval
720
  To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
721
 
722
  <details>
@@ -767,6 +767,7 @@ Output your evaluation in the following JSON format:
767
  </details>
768
 
769
  ### Overall Performance Table
 
770
  | | **Agriculture** | | **CS** | | **Legal** | | **Mix** | |
771
  |----------------------|-------------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|
772
  | | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** |
 
716
  ```
717
  </details>
718
 
719
+ ### Batch Eval
720
  To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
721
 
722
  <details>
 
767
  </details>
768
 
769
  ### Overall Performance Table
770
+
771
  | | **Agriculture** | | **CS** | | **Legal** | | **Mix** | |
772
  |----------------------|-------------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|
773
  | | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** |
lightrag/api/lightrag_ollama.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request
2
+ from pydantic import BaseModel
3
+ import logging
4
+ import argparse
5
+ import json
6
+ import time
7
+ import re
8
+ from typing import List, Dict, Any, Optional
9
+ from lightrag import LightRAG, QueryParam
10
+ from lightrag.llm import openai_complete_if_cache, ollama_embedding
11
+
12
+ from lightrag.utils import EmbeddingFunc
13
+ from enum import Enum
14
+ from pathlib import Path
15
+ import shutil
16
+ import aiofiles
17
+ from ascii_colors import trace_exception
18
+ import os
19
+
20
+ from fastapi import Depends, Security
21
+ from fastapi.security import APIKeyHeader
22
+ from fastapi.middleware.cors import CORSMiddleware
23
+
24
+ from starlette.status import HTTP_403_FORBIDDEN
25
+
26
+ from dotenv import load_dotenv
27
+
28
+ load_dotenv()
29
+
30
+
31
+ def estimate_tokens(text: str) -> int:
32
+ """Estimate the number of tokens in text
33
+ Chinese characters: approximately 1.5 tokens per character
34
+ English characters: approximately 0.25 tokens per character
35
+ """
36
+ # Use regex to match Chinese and non-Chinese characters separately
37
+ chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
38
+ non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
39
+
40
+ # Calculate estimated token count
41
+ tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
42
+
43
+ return int(tokens)
44
+
45
+
46
+ # Constants for model information
47
+ LIGHTRAG_NAME = "lightrag"
48
+ LIGHTRAG_TAG = "latest"
49
+ LIGHTRAG_MODEL = "lightrag:latest"
50
+ LIGHTRAG_SIZE = 7365960935
51
+ LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
52
+ LIGHTRAG_DIGEST = "sha256:lightrag"
53
+
54
+
55
+ async def llm_model_func(
56
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
57
+ ) -> str:
58
+ return await openai_complete_if_cache(
59
+ "deepseek-chat",
60
+ prompt,
61
+ system_prompt=system_prompt,
62
+ history_messages=history_messages,
63
+ api_key=os.getenv("DEEPSEEK_API_KEY"),
64
+ base_url=os.getenv("DEEPSEEK_ENDPOINT"),
65
+ **kwargs,
66
+ )
67
+
68
+
69
+ def get_default_host(binding_type: str) -> str:
70
+ default_hosts = {
71
+ "ollama": "http://m4.lan.znipower.com:11434",
72
+ "lollms": "http://localhost:9600",
73
+ "azure_openai": "https://api.openai.com/v1",
74
+ "openai": os.getenv("DEEPSEEK_ENDPOINT"),
75
+ }
76
+ return default_hosts.get(
77
+ binding_type, "http://localhost:11434"
78
+ ) # fallback to ollama if unknown
79
+
80
+
81
+ def parse_args():
82
+ parser = argparse.ArgumentParser(
83
+ description="LightRAG FastAPI Server with separate working and input directories"
84
+ )
85
+
86
+ # Start by the bindings
87
+ parser.add_argument(
88
+ "--llm-binding",
89
+ default="ollama",
90
+ help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)",
91
+ )
92
+ parser.add_argument(
93
+ "--embedding-binding",
94
+ default="ollama",
95
+ help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)",
96
+ )
97
+
98
+ # Parse just these arguments first
99
+ temp_args, _ = parser.parse_known_args()
100
+
101
+ # Add remaining arguments with dynamic defaults for hosts
102
+ # Server configuration
103
+ parser.add_argument(
104
+ "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
105
+ )
106
+ parser.add_argument(
107
+ "--port", type=int, default=9621, help="Server port (default: 9621)"
108
+ )
109
+
110
+ # Directory configuration
111
+ parser.add_argument(
112
+ "--working-dir",
113
+ default="./rag_storage",
114
+ help="Working directory for RAG storage (default: ./rag_storage)",
115
+ )
116
+ parser.add_argument(
117
+ "--input-dir",
118
+ default="./inputs",
119
+ help="Directory containing input documents (default: ./inputs)",
120
+ )
121
+
122
+ # LLM Model configuration
123
+ default_llm_host = get_default_host(temp_args.llm_binding)
124
+ parser.add_argument(
125
+ "--llm-binding-host",
126
+ default=default_llm_host,
127
+ help=f"llm server host URL (default: {default_llm_host})",
128
+ )
129
+
130
+ parser.add_argument(
131
+ "--llm-model",
132
+ default="mistral-nemo:latest",
133
+ help="LLM model name (default: mistral-nemo:latest)",
134
+ )
135
+
136
+ # Embedding model configuration
137
+ default_embedding_host = get_default_host(temp_args.embedding_binding)
138
+ parser.add_argument(
139
+ "--embedding-binding-host",
140
+ default=default_embedding_host,
141
+ help=f"embedding server host URL (default: {default_embedding_host})",
142
+ )
143
+
144
+ parser.add_argument(
145
+ "--embedding-model",
146
+ default="bge-m3:latest",
147
+ help="Embedding model name (default: bge-m3:latest)",
148
+ )
149
+
150
+ def timeout_type(value):
151
+ if value is None or value == "None":
152
+ return None
153
+ return int(value)
154
+
155
+ parser.add_argument(
156
+ "--timeout",
157
+ default=None,
158
+ type=timeout_type,
159
+ help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
160
+ )
161
+ # RAG configuration
162
+ parser.add_argument(
163
+ "--max-async", type=int, default=4, help="Maximum async operations (default: 4)"
164
+ )
165
+ parser.add_argument(
166
+ "--max-tokens",
167
+ type=int,
168
+ default=32768,
169
+ help="Maximum token size (default: 32768)",
170
+ )
171
+ parser.add_argument(
172
+ "--embedding-dim",
173
+ type=int,
174
+ default=1024,
175
+ help="Embedding dimensions (default: 1024)",
176
+ )
177
+ parser.add_argument(
178
+ "--max-embed-tokens",
179
+ type=int,
180
+ default=8192,
181
+ help="Maximum embedding token size (default: 8192)",
182
+ )
183
+
184
+ # Logging configuration
185
+ parser.add_argument(
186
+ "--log-level",
187
+ default="INFO",
188
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
189
+ help="Logging level (default: INFO)",
190
+ )
191
+
192
+ parser.add_argument(
193
+ "--key",
194
+ type=str,
195
+ help="API key for authentication. This protects lightrag server against unauthorized access",
196
+ default=None,
197
+ )
198
+
199
+ # Optional https parameters
200
+ parser.add_argument(
201
+ "--ssl", action="store_true", help="Enable HTTPS (default: False)"
202
+ )
203
+ parser.add_argument(
204
+ "--ssl-certfile",
205
+ default=None,
206
+ help="Path to SSL certificate file (required if --ssl is enabled)",
207
+ )
208
+ parser.add_argument(
209
+ "--ssl-keyfile",
210
+ default=None,
211
+ help="Path to SSL private key file (required if --ssl is enabled)",
212
+ )
213
+ return parser.parse_args()
214
+
215
+
216
+ class DocumentManager:
217
+ """Handles document operations and tracking"""
218
+
219
+ def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")):
220
+ self.input_dir = Path(input_dir)
221
+ self.supported_extensions = supported_extensions
222
+ self.indexed_files = set()
223
+
224
+ # Create input directory if it doesn't exist
225
+ self.input_dir.mkdir(parents=True, exist_ok=True)
226
+
227
+ def scan_directory(self) -> List[Path]:
228
+ """Scan input directory for new files"""
229
+ new_files = []
230
+ for ext in self.supported_extensions:
231
+ for file_path in self.input_dir.rglob(f"*{ext}"):
232
+ if file_path not in self.indexed_files:
233
+ new_files.append(file_path)
234
+ return new_files
235
+
236
+ def mark_as_indexed(self, file_path: Path):
237
+ """Mark a file as indexed"""
238
+ self.indexed_files.add(file_path)
239
+
240
+ def is_supported_file(self, filename: str) -> bool:
241
+ """Check if file type is supported"""
242
+ return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
243
+
244
+
245
+ # Pydantic models
246
+ class SearchMode(str, Enum):
247
+ naive = "naive"
248
+ local = "local"
249
+ global_ = "global" # Using global_ because global is a Python reserved keyword, but enum value will be converted to string "global"
250
+ hybrid = "hybrid"
251
+ mix = "mix"
252
+
253
+
254
+ # Ollama API compatible models
255
+ class OllamaMessage(BaseModel):
256
+ role: str
257
+ content: str
258
+ images: Optional[List[str]] = None
259
+
260
+
261
+ class OllamaChatRequest(BaseModel):
262
+ model: str = LIGHTRAG_MODEL
263
+ messages: List[OllamaMessage]
264
+ stream: bool = True # Default to streaming mode
265
+ options: Optional[Dict[str, Any]] = None
266
+
267
+
268
+ class OllamaChatResponse(BaseModel):
269
+ model: str
270
+ created_at: str
271
+ message: OllamaMessage
272
+ done: bool
273
+
274
+
275
+ class OllamaVersionResponse(BaseModel):
276
+ version: str
277
+
278
+
279
+ class OllamaModelDetails(BaseModel):
280
+ parent_model: str
281
+ format: str
282
+ family: str
283
+ families: List[str]
284
+ parameter_size: str
285
+ quantization_level: str
286
+
287
+
288
+ class OllamaModel(BaseModel):
289
+ name: str
290
+ model: str
291
+ size: int
292
+ digest: str
293
+ modified_at: str
294
+ details: OllamaModelDetails
295
+
296
+
297
+ class OllamaTagResponse(BaseModel):
298
+ models: List[OllamaModel]
299
+
300
+
301
+ # Original LightRAG models
302
+ class QueryRequest(BaseModel):
303
+ query: str
304
+ mode: SearchMode = SearchMode.hybrid
305
+ stream: bool = False
306
+ only_need_context: bool = False
307
+
308
+
309
+ class QueryResponse(BaseModel):
310
+ response: str
311
+
312
+
313
+ class InsertTextRequest(BaseModel):
314
+ text: str
315
+ description: Optional[str] = None
316
+
317
+
318
+ class InsertResponse(BaseModel):
319
+ status: str
320
+ message: str
321
+ document_count: int
322
+
323
+
324
+ def get_api_key_dependency(api_key: Optional[str]):
325
+ if not api_key:
326
+ # If no API key is configured, return a dummy dependency that always succeeds
327
+ async def no_auth():
328
+ return None
329
+
330
+ return no_auth
331
+
332
+ # If API key is configured, use proper authentication
333
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
334
+
335
+ async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
336
+ if not api_key_header_value:
337
+ raise HTTPException(
338
+ status_code=HTTP_403_FORBIDDEN, detail="API Key required"
339
+ )
340
+ if api_key_header_value != api_key:
341
+ raise HTTPException(
342
+ status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
343
+ )
344
+ return api_key_header_value
345
+
346
+ return api_key_auth
347
+
348
+
349
+ def create_app(args):
350
+ # Verify that bindings arer correctly setup
351
+ if args.llm_binding not in ["lollms", "ollama", "openai"]:
352
+ raise Exception("llm binding not supported")
353
+
354
+ if args.embedding_binding not in ["lollms", "ollama", "openai"]:
355
+ raise Exception("embedding binding not supported")
356
+
357
+ # Add SSL validation
358
+ if args.ssl:
359
+ if not args.ssl_certfile or not args.ssl_keyfile:
360
+ raise Exception(
361
+ "SSL certificate and key files must be provided when SSL is enabled"
362
+ )
363
+ if not os.path.exists(args.ssl_certfile):
364
+ raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
365
+ if not os.path.exists(args.ssl_keyfile):
366
+ raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
367
+
368
+ # Setup logging
369
+ logging.basicConfig(
370
+ format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
371
+ )
372
+
373
+ # Check if API key is provided either through env var or args
374
+ api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
375
+
376
+ # Initialize FastAPI
377
+ app = FastAPI(
378
+ title="LightRAG API",
379
+ description="API for querying text using LightRAG with separate storage and input directories"
380
+ + "(With authentication)"
381
+ if api_key
382
+ else "",
383
+ version="1.0.1",
384
+ openapi_tags=[{"name": "api"}],
385
+ )
386
+
387
+ # Add CORS middleware
388
+ app.add_middleware(
389
+ CORSMiddleware,
390
+ allow_origins=["*"],
391
+ allow_credentials=True,
392
+ allow_methods=["*"],
393
+ allow_headers=["*"],
394
+ )
395
+
396
+ # Create the optional API key dependency
397
+ optional_api_key = get_api_key_dependency(api_key)
398
+
399
+ # Create working directory if it doesn't exist
400
+ Path(args.working_dir).mkdir(parents=True, exist_ok=True)
401
+
402
+ # Initialize document manager
403
+ doc_manager = DocumentManager(args.input_dir)
404
+
405
+ # Initialize RAG
406
+ rag = LightRAG(
407
+ working_dir=args.working_dir,
408
+ llm_model_func=llm_model_func,
409
+ embedding_func=EmbeddingFunc(
410
+ embedding_dim=1024,
411
+ max_token_size=8192,
412
+ func=lambda texts: ollama_embedding(
413
+ texts,
414
+ embed_model="bge-m3:latest",
415
+ host="http://m4.lan.znipower.com:11434",
416
+ ),
417
+ ),
418
+ )
419
+
420
+ @app.on_event("startup")
421
+ async def startup_event():
422
+ """Index all files in input directory during startup"""
423
+ try:
424
+ new_files = doc_manager.scan_directory()
425
+ for file_path in new_files:
426
+ try:
427
+ # Use async file reading
428
+ async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
429
+ content = await f.read()
430
+ # Use the async version of insert directly
431
+ await rag.ainsert(content)
432
+ doc_manager.mark_as_indexed(file_path)
433
+ logging.info(f"Indexed file: {file_path}")
434
+ except Exception as e:
435
+ trace_exception(e)
436
+ logging.error(f"Error indexing file {file_path}: {str(e)}")
437
+
438
+ logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
439
+
440
+ except Exception as e:
441
+ logging.error(f"Error during startup indexing: {str(e)}")
442
+
443
+ @app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
444
+ async def scan_for_new_documents():
445
+ """Manually trigger scanning for new documents"""
446
+ try:
447
+ new_files = doc_manager.scan_directory()
448
+ indexed_count = 0
449
+
450
+ for file_path in new_files:
451
+ try:
452
+ with open(file_path, "r", encoding="utf-8") as f:
453
+ content = f.read()
454
+ await rag.ainsert(content)
455
+ doc_manager.mark_as_indexed(file_path)
456
+ indexed_count += 1
457
+ except Exception as e:
458
+ logging.error(f"Error indexing file {file_path}: {str(e)}")
459
+
460
+ return {
461
+ "status": "success",
462
+ "indexed_count": indexed_count,
463
+ "total_documents": len(doc_manager.indexed_files),
464
+ }
465
+ except Exception as e:
466
+ raise HTTPException(status_code=500, detail=str(e))
467
+
468
+ @app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
469
+ async def upload_to_input_dir(file: UploadFile = File(...)):
470
+ """Upload a file to the input directory"""
471
+ try:
472
+ if not doc_manager.is_supported_file(file.filename):
473
+ raise HTTPException(
474
+ status_code=400,
475
+ detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
476
+ )
477
+
478
+ file_path = doc_manager.input_dir / file.filename
479
+ with open(file_path, "wb") as buffer:
480
+ shutil.copyfileobj(file.file, buffer)
481
+
482
+ # Immediately index the uploaded file
483
+ with open(file_path, "r", encoding="utf-8") as f:
484
+ content = f.read()
485
+ await rag.ainsert(content)
486
+ doc_manager.mark_as_indexed(file_path)
487
+
488
+ return {
489
+ "status": "success",
490
+ "message": f"File uploaded and indexed: {file.filename}",
491
+ "total_documents": len(doc_manager.indexed_files),
492
+ }
493
+ except Exception as e:
494
+ raise HTTPException(status_code=500, detail=str(e))
495
+
496
+ @app.post(
497
+ "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
498
+ )
499
+ async def query_text(request: QueryRequest):
500
+ try:
501
+ response = await rag.aquery(
502
+ request.query,
503
+ param=QueryParam(
504
+ mode=request.mode,
505
+ stream=request.stream,
506
+ only_need_context=request.only_need_context,
507
+ ),
508
+ )
509
+
510
+ # If response is a string (e.g. cache hit), return directly
511
+ if isinstance(response, str):
512
+ return QueryResponse(response=response)
513
+
514
+ # If it's an async generator, decide whether to stream based on stream parameter
515
+ if request.stream:
516
+ result = ""
517
+ async for chunk in response:
518
+ result += chunk
519
+ return QueryResponse(response=result)
520
+ else:
521
+ result = ""
522
+ async for chunk in response:
523
+ result += chunk
524
+ return QueryResponse(response=result)
525
+ except Exception as e:
526
+ raise HTTPException(status_code=500, detail=str(e))
527
+
528
+ @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
529
+ async def query_text_stream(request: QueryRequest):
530
+ try:
531
+ response = await rag.aquery( # Use aquery instead of query, and add await
532
+ request.query,
533
+ param=QueryParam(
534
+ mode=request.mode,
535
+ stream=True,
536
+ only_need_context=request.only_need_context,
537
+ ),
538
+ )
539
+
540
+ from fastapi.responses import StreamingResponse
541
+
542
+ async def stream_generator():
543
+ if isinstance(response, str):
544
+ # If it's a string, send it all at once
545
+ yield f"{json.dumps({'response': response})}\n"
546
+ else:
547
+ # If it's an async generator, send chunks one by one
548
+ try:
549
+ async for chunk in response:
550
+ if chunk: # Only send non-empty content
551
+ yield f"{json.dumps({'response': chunk})}\n"
552
+ except Exception as e:
553
+ logging.error(f"Streaming error: {str(e)}")
554
+ yield f"{json.dumps({'error': str(e)})}\n"
555
+
556
+ return StreamingResponse(
557
+ stream_generator(),
558
+ media_type="application/x-ndjson",
559
+ headers={
560
+ "Cache-Control": "no-cache",
561
+ "Connection": "keep-alive",
562
+ "Content-Type": "application/x-ndjson",
563
+ "Access-Control-Allow-Origin": "*",
564
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
565
+ "Access-Control-Allow-Headers": "Content-Type",
566
+ "X-Accel-Buffering": "no", # Disable Nginx buffering
567
+ },
568
+ )
569
+ except Exception as e:
570
+ raise HTTPException(status_code=500, detail=str(e))
571
+
572
+ @app.post(
573
+ "/documents/text",
574
+ response_model=InsertResponse,
575
+ dependencies=[Depends(optional_api_key)],
576
+ )
577
+ async def insert_text(request: InsertTextRequest):
578
+ try:
579
+ await rag.ainsert(request.text)
580
+ return InsertResponse(
581
+ status="success",
582
+ message="Text successfully inserted",
583
+ document_count=1,
584
+ )
585
+ except Exception as e:
586
+ raise HTTPException(status_code=500, detail=str(e))
587
+
588
+ @app.post(
589
+ "/documents/file",
590
+ response_model=InsertResponse,
591
+ dependencies=[Depends(optional_api_key)],
592
+ )
593
+ async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
594
+ try:
595
+ content = await file.read()
596
+
597
+ if file.filename.endswith((".txt", ".md")):
598
+ text = content.decode("utf-8")
599
+ await rag.ainsert(text)
600
+ else:
601
+ raise HTTPException(
602
+ status_code=400,
603
+ detail="Unsupported file type. Only .txt and .md files are supported",
604
+ )
605
+
606
+ return InsertResponse(
607
+ status="success",
608
+ message=f"File '{file.filename}' successfully inserted",
609
+ document_count=1,
610
+ )
611
+ except UnicodeDecodeError:
612
+ raise HTTPException(status_code=400, detail="File encoding not supported")
613
+ except Exception as e:
614
+ raise HTTPException(status_code=500, detail=str(e))
615
+
616
+ @app.post(
617
+ "/documents/batch",
618
+ response_model=InsertResponse,
619
+ dependencies=[Depends(optional_api_key)],
620
+ )
621
+ async def insert_batch(files: List[UploadFile] = File(...)):
622
+ try:
623
+ inserted_count = 0
624
+ failed_files = []
625
+
626
+ for file in files:
627
+ try:
628
+ content = await file.read()
629
+ if file.filename.endswith((".txt", ".md")):
630
+ text = content.decode("utf-8")
631
+ await rag.ainsert(text)
632
+ inserted_count += 1
633
+ else:
634
+ failed_files.append(f"{file.filename} (unsupported type)")
635
+ except Exception as e:
636
+ failed_files.append(f"{file.filename} ({str(e)})")
637
+
638
+ status_message = f"Successfully inserted {inserted_count} documents"
639
+ if failed_files:
640
+ status_message += f". Failed files: {', '.join(failed_files)}"
641
+
642
+ return InsertResponse(
643
+ status="success" if inserted_count > 0 else "partial_success",
644
+ message=status_message,
645
+ document_count=len(files),
646
+ )
647
+ except Exception as e:
648
+ raise HTTPException(status_code=500, detail=str(e))
649
+
650
+ @app.delete(
651
+ "/documents",
652
+ response_model=InsertResponse,
653
+ dependencies=[Depends(optional_api_key)],
654
+ )
655
+ async def clear_documents():
656
+ try:
657
+ rag.text_chunks = []
658
+ rag.entities_vdb = None
659
+ rag.relationships_vdb = None
660
+ return InsertResponse(
661
+ status="success",
662
+ message="All documents cleared successfully",
663
+ document_count=0,
664
+ )
665
+ except Exception as e:
666
+ raise HTTPException(status_code=500, detail=str(e))
667
+
668
+ # Ollama compatible API endpoints
669
+ @app.get("/api/version")
670
+ async def get_version():
671
+ """Get Ollama version information"""
672
+ return OllamaVersionResponse(version="0.5.4")
673
+
674
+ @app.get("/api/tags")
675
+ async def get_tags():
676
+ """Get available models"""
677
+ return OllamaTagResponse(
678
+ models=[
679
+ {
680
+ "name": LIGHTRAG_MODEL,
681
+ "model": LIGHTRAG_MODEL,
682
+ "size": LIGHTRAG_SIZE,
683
+ "digest": LIGHTRAG_DIGEST,
684
+ "modified_at": LIGHTRAG_CREATED_AT,
685
+ "details": {
686
+ "parent_model": "",
687
+ "format": "gguf",
688
+ "family": LIGHTRAG_NAME,
689
+ "families": [LIGHTRAG_NAME],
690
+ "parameter_size": "13B",
691
+ "quantization_level": "Q4_0",
692
+ },
693
+ }
694
+ ]
695
+ )
696
+
697
+ def parse_query_mode(query: str) -> tuple[str, SearchMode]:
698
+ """Parse query prefix to determine search mode
699
+ Returns tuple of (cleaned_query, search_mode)
700
+ """
701
+ mode_map = {
702
+ "/local ": SearchMode.local,
703
+ "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
704
+ "/naive ": SearchMode.naive,
705
+ "/hybrid ": SearchMode.hybrid,
706
+ "/mix ": SearchMode.mix,
707
+ }
708
+
709
+ for prefix, mode in mode_map.items():
710
+ if query.startswith(prefix):
711
+ # After removing prefix an leading spaces
712
+ cleaned_query = query[len(prefix) :].lstrip()
713
+ return cleaned_query, mode
714
+
715
+ return query, SearchMode.hybrid
716
+
717
+ @app.post("/api/chat")
718
+ async def chat(raw_request: Request, request: OllamaChatRequest):
719
+ """Handle chat completion requests"""
720
+ try:
721
+ # Get all messages
722
+ messages = request.messages
723
+ if not messages:
724
+ raise HTTPException(status_code=400, detail="No messages provided")
725
+
726
+ # Get the last message as query
727
+ query = messages[-1].content
728
+
729
+ # 解析查询模式
730
+ cleaned_query, mode = parse_query_mode(query)
731
+
732
+ # 开始计时
733
+ start_time = time.time_ns()
734
+
735
+ # 计算输入token数量
736
+ prompt_tokens = estimate_tokens(cleaned_query)
737
+
738
+ # 调用RAG进行查询
739
+ query_param = QueryParam(
740
+ mode=mode, stream=request.stream, only_need_context=False
741
+ )
742
+
743
+ if request.stream:
744
+ from fastapi.responses import StreamingResponse
745
+
746
+ response = await rag.aquery( # Need await to get async generator
747
+ cleaned_query, param=query_param
748
+ )
749
+
750
+ async def stream_generator():
751
+ try:
752
+ first_chunk_time = None
753
+ last_chunk_time = None
754
+ total_response = ""
755
+
756
+ # Ensure response is an async generator
757
+ if isinstance(response, str):
758
+ # If it's a string, send in two parts
759
+ first_chunk_time = time.time_ns()
760
+ last_chunk_time = first_chunk_time
761
+ total_response = response
762
+
763
+ data = {
764
+ "model": LIGHTRAG_MODEL,
765
+ "created_at": LIGHTRAG_CREATED_AT,
766
+ "message": {
767
+ "role": "assistant",
768
+ "content": response,
769
+ "images": None,
770
+ },
771
+ "done": False,
772
+ }
773
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
774
+
775
+ completion_tokens = estimate_tokens(total_response)
776
+ total_time = last_chunk_time - start_time
777
+ prompt_eval_time = first_chunk_time - start_time
778
+ eval_time = last_chunk_time - first_chunk_time
779
+
780
+ data = {
781
+ "model": LIGHTRAG_MODEL,
782
+ "created_at": LIGHTRAG_CREATED_AT,
783
+ "done": True,
784
+ "total_duration": total_time,
785
+ "load_duration": 0,
786
+ "prompt_eval_count": prompt_tokens,
787
+ "prompt_eval_duration": prompt_eval_time,
788
+ "eval_count": completion_tokens,
789
+ "eval_duration": eval_time,
790
+ }
791
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
792
+ else:
793
+ async for chunk in response:
794
+ if chunk:
795
+ if first_chunk_time is None:
796
+ first_chunk_time = time.time_ns()
797
+
798
+ last_chunk_time = time.time_ns()
799
+
800
+ total_response += chunk
801
+ data = {
802
+ "model": LIGHTRAG_MODEL,
803
+ "created_at": LIGHTRAG_CREATED_AT,
804
+ "message": {
805
+ "role": "assistant",
806
+ "content": chunk,
807
+ "images": None,
808
+ },
809
+ "done": False,
810
+ }
811
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
812
+
813
+ completion_tokens = estimate_tokens(total_response)
814
+ total_time = last_chunk_time - start_time
815
+ prompt_eval_time = first_chunk_time - start_time
816
+ eval_time = last_chunk_time - first_chunk_time
817
+
818
+ data = {
819
+ "model": LIGHTRAG_MODEL,
820
+ "created_at": LIGHTRAG_CREATED_AT,
821
+ "done": True,
822
+ "total_duration": total_time,
823
+ "load_duration": 0,
824
+ "prompt_eval_count": prompt_tokens,
825
+ "prompt_eval_duration": prompt_eval_time,
826
+ "eval_count": completion_tokens,
827
+ "eval_duration": eval_time,
828
+ }
829
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
830
+ return # Ensure the generator ends immediately after sending the completion marker
831
+ except Exception as e:
832
+ logging.error(f"Error in stream_generator: {str(e)}")
833
+ raise
834
+
835
+ return StreamingResponse(
836
+ stream_generator(),
837
+ media_type="application/x-ndjson",
838
+ headers={
839
+ "Cache-Control": "no-cache",
840
+ "Connection": "keep-alive",
841
+ "Content-Type": "application/x-ndjson",
842
+ "Access-Control-Allow-Origin": "*",
843
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
844
+ "Access-Control-Allow-Headers": "Content-Type",
845
+ },
846
+ )
847
+ else:
848
+ first_chunk_time = time.time_ns()
849
+ response_text = await rag.aquery(cleaned_query, param=query_param)
850
+ last_chunk_time = time.time_ns()
851
+
852
+ if not response_text:
853
+ response_text = "No response generated"
854
+
855
+ completion_tokens = estimate_tokens(str(response_text))
856
+ total_time = last_chunk_time - start_time
857
+ prompt_eval_time = first_chunk_time - start_time
858
+ eval_time = last_chunk_time - first_chunk_time
859
+
860
+ return {
861
+ "model": LIGHTRAG_MODEL,
862
+ "created_at": LIGHTRAG_CREATED_AT,
863
+ "message": {
864
+ "role": "assistant",
865
+ "content": str(response_text),
866
+ "images": None,
867
+ },
868
+ "done": True,
869
+ "total_duration": total_time,
870
+ "load_duration": 0,
871
+ "prompt_eval_count": prompt_tokens,
872
+ "prompt_eval_duration": prompt_eval_time,
873
+ "eval_count": completion_tokens,
874
+ "eval_duration": eval_time,
875
+ }
876
+ except Exception as e:
877
+ raise HTTPException(status_code=500, detail=str(e))
878
+
879
+ @app.get("/health", dependencies=[Depends(optional_api_key)])
880
+ async def get_status():
881
+ """Get current system status"""
882
+ return {
883
+ "status": "healthy",
884
+ "working_directory": str(args.working_dir),
885
+ "input_directory": str(args.input_dir),
886
+ "indexed_files": len(doc_manager.indexed_files),
887
+ "configuration": {
888
+ # LLM configuration binding/host address (if applicable)/model (if applicable)
889
+ "llm_binding": args.llm_binding,
890
+ "llm_binding_host": args.llm_binding_host,
891
+ "llm_model": args.llm_model,
892
+ # embedding model configuration binding/host address (if applicable)/model (if applicable)
893
+ "embedding_binding": args.embedding_binding,
894
+ "embedding_binding_host": args.embedding_binding_host,
895
+ "embedding_model": args.embedding_model,
896
+ "max_tokens": args.max_tokens,
897
+ },
898
+ }
899
+
900
+ return app
901
+
902
+
903
+ def main():
904
+ args = parse_args()
905
+ import uvicorn
906
+
907
+ app = create_app(args)
908
+ uvicorn_config = {
909
+ "app": app,
910
+ "host": args.host,
911
+ "port": args.port,
912
+ }
913
+ if args.ssl:
914
+ uvicorn_config.update(
915
+ {
916
+ "ssl_certfile": args.ssl_certfile,
917
+ "ssl_keyfile": args.ssl_keyfile,
918
+ }
919
+ )
920
+ uvicorn.run(**uvicorn_config)
921
+
922
+
923
+ if __name__ == "__main__":
924
+ main()
lightrag/api/requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  aioboto3
2
  ascii_colors
3
  fastapi
4
- lightrag-hku
5
  nano_vectordb
6
  nest_asyncio
7
  numpy
 
1
  aioboto3
2
  ascii_colors
3
  fastapi
 
4
  nano_vectordb
5
  nest_asyncio
6
  numpy
setup.py CHANGED
@@ -101,6 +101,7 @@ setuptools.setup(
101
  entry_points={
102
  "console_scripts": [
103
  "lightrag-server=lightrag.api.lightrag_server:main [api]",
 
104
  ],
105
  },
106
  )
 
101
  entry_points={
102
  "console_scripts": [
103
  "lightrag-server=lightrag.api.lightrag_server:main [api]",
104
+ "lightrag-ollama=lightrag.api.lightrag_ollama:main [api]",
105
  ],
106
  },
107
  )
start-server.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ . venv/bin/activate
2
+
3
+ lightrag-ollama --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024
test_lightrag_ollama_chat.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LightRAG Ollama Compatibility Interface Test Script
3
+
4
+ This script tests the LightRAG's Ollama compatibility interface, including:
5
+ 1. Basic functionality tests (streaming and non-streaming responses)
6
+ 2. Query mode tests (local, global, naive, hybrid)
7
+ 3. Error handling tests (including streaming and non-streaming scenarios)
8
+
9
+ All responses use the JSON Lines format, complying with the Ollama API specification.
10
+ """
11
+
12
+ import requests
13
+ import json
14
+ import argparse
15
+ import time
16
+ from typing import Dict, Any, Optional, List, Callable
17
+ from dataclasses import dataclass, asdict
18
+ from datetime import datetime
19
+ from pathlib import Path
20
+
21
+
22
+ class OutputControl:
23
+ """Output control class, manages the verbosity of test output"""
24
+
25
+ _verbose: bool = False
26
+
27
+ @classmethod
28
+ def set_verbose(cls, verbose: bool) -> None:
29
+ cls._verbose = verbose
30
+
31
+ @classmethod
32
+ def is_verbose(cls) -> bool:
33
+ return cls._verbose
34
+
35
+
36
+ @dataclass
37
+ class TestResult:
38
+ """Test result data class"""
39
+
40
+ name: str
41
+ success: bool
42
+ duration: float
43
+ error: Optional[str] = None
44
+ timestamp: str = ""
45
+
46
+ def __post_init__(self):
47
+ if not self.timestamp:
48
+ self.timestamp = datetime.now().isoformat()
49
+
50
+
51
+ class TestStats:
52
+ """Test statistics"""
53
+
54
+ def __init__(self):
55
+ self.results: List[TestResult] = []
56
+ self.start_time = datetime.now()
57
+
58
+ def add_result(self, result: TestResult):
59
+ self.results.append(result)
60
+
61
+ def export_results(self, path: str = "test_results.json"):
62
+ """Export test results to a JSON file
63
+ Args:
64
+ path: Output file path
65
+ """
66
+ results_data = {
67
+ "start_time": self.start_time.isoformat(),
68
+ "end_time": datetime.now().isoformat(),
69
+ "results": [asdict(r) for r in self.results],
70
+ "summary": {
71
+ "total": len(self.results),
72
+ "passed": sum(1 for r in self.results if r.success),
73
+ "failed": sum(1 for r in self.results if not r.success),
74
+ "total_duration": sum(r.duration for r in self.results),
75
+ },
76
+ }
77
+
78
+ with open(path, "w", encoding="utf-8") as f:
79
+ json.dump(results_data, f, ensure_ascii=False, indent=2)
80
+ print(f"\nTest results saved to: {path}")
81
+
82
+ def print_summary(self):
83
+ total = len(self.results)
84
+ passed = sum(1 for r in self.results if r.success)
85
+ failed = total - passed
86
+ duration = sum(r.duration for r in self.results)
87
+
88
+ print("\n=== Test Summary ===")
89
+ print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
90
+ print(f"Total duration: {duration:.2f} seconds")
91
+ print(f"Total tests: {total}")
92
+ print(f"Passed: {passed}")
93
+ print(f"Failed: {failed}")
94
+
95
+ if failed > 0:
96
+ print("\nFailed tests:")
97
+ for result in self.results:
98
+ if not result.success:
99
+ print(f"- {result.name}: {result.error}")
100
+
101
+
102
+ DEFAULT_CONFIG = {
103
+ "server": {
104
+ "host": "localhost",
105
+ "port": 9621,
106
+ "model": "lightrag:latest",
107
+ "timeout": 30,
108
+ "max_retries": 3,
109
+ "retry_delay": 1,
110
+ },
111
+ "test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
112
+ }
113
+
114
+
115
+ def make_request(
116
+ url: str, data: Dict[str, Any], stream: bool = False
117
+ ) -> requests.Response:
118
+ """Send an HTTP request with retry mechanism
119
+ Args:
120
+ url: Request URL
121
+ data: Request data
122
+ stream: Whether to use streaming response
123
+ Returns:
124
+ requests.Response: Response object
125
+
126
+ Raises:
127
+ requests.exceptions.RequestException: Request failed after all retries
128
+ """
129
+ server_config = CONFIG["server"]
130
+ max_retries = server_config["max_retries"]
131
+ retry_delay = server_config["retry_delay"]
132
+ timeout = server_config["timeout"]
133
+
134
+ for attempt in range(max_retries):
135
+ try:
136
+ response = requests.post(url, json=data, stream=stream, timeout=timeout)
137
+ return response
138
+ except requests.exceptions.RequestException as e:
139
+ if attempt == max_retries - 1: # Last retry
140
+ raise
141
+ print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
142
+ time.sleep(retry_delay)
143
+
144
+
145
+ def load_config() -> Dict[str, Any]:
146
+ """Load configuration file
147
+
148
+ First try to load from config.json in the current directory,
149
+ if it doesn't exist, use the default configuration
150
+ Returns:
151
+ Configuration dictionary
152
+ """
153
+ config_path = Path("config.json")
154
+ if config_path.exists():
155
+ with open(config_path, "r", encoding="utf-8") as f:
156
+ return json.load(f)
157
+ return DEFAULT_CONFIG
158
+
159
+
160
+ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
161
+ """Format and print JSON response data
162
+ Args:
163
+ data: Data dictionary to print
164
+ title: Title to print
165
+ indent: Number of spaces for JSON indentation
166
+ """
167
+ if OutputControl.is_verbose():
168
+ if title:
169
+ print(f"\n=== {title} ===")
170
+ print(json.dumps(data, ensure_ascii=False, indent=indent))
171
+
172
+
173
+ # Global configuration
174
+ CONFIG = load_config()
175
+
176
+
177
+ def get_base_url() -> str:
178
+ """Return the base URL"""
179
+ server = CONFIG["server"]
180
+ return f"http://{server['host']}:{server['port']}/api/chat"
181
+
182
+
183
+ def create_request_data(
184
+ content: str, stream: bool = False, model: str = None
185
+ ) -> Dict[str, Any]:
186
+ """Create basic request data
187
+ Args:
188
+ content: User message content
189
+ stream: Whether to use streaming response
190
+ model: Model name
191
+ Returns:
192
+ Dictionary containing complete request data
193
+ """
194
+ return {
195
+ "model": model or CONFIG["server"]["model"],
196
+ "messages": [{"role": "user", "content": content}],
197
+ "stream": stream,
198
+ }
199
+
200
+
201
+ # Global test statistics
202
+ STATS = TestStats()
203
+
204
+
205
+ def run_test(func: Callable, name: str) -> None:
206
+ """Run a test and record the results
207
+ Args:
208
+ func: Test function
209
+ name: Test name
210
+ """
211
+ start_time = time.time()
212
+ try:
213
+ func()
214
+ duration = time.time() - start_time
215
+ STATS.add_result(TestResult(name, True, duration))
216
+ except Exception as e:
217
+ duration = time.time() - start_time
218
+ STATS.add_result(TestResult(name, False, duration, str(e)))
219
+ raise
220
+
221
+
222
+ def test_non_stream_chat():
223
+ """Test non-streaming call to /api/chat endpoint"""
224
+ url = get_base_url()
225
+ data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
226
+
227
+ # Send request
228
+ response = make_request(url, data)
229
+
230
+ # Print response
231
+ if OutputControl.is_verbose():
232
+ print("\n=== Non-streaming call response ===")
233
+ response_json = response.json()
234
+
235
+ # Print response content
236
+ print_json_response(
237
+ {"model": response_json["model"], "message": response_json["message"]},
238
+ "Response content",
239
+ )
240
+
241
+
242
+ def test_stream_chat():
243
+ """Test streaming call to /api/chat endpoint
244
+
245
+ Use JSON Lines format to process streaming responses, each line is a complete JSON object.
246
+ Response format:
247
+ {
248
+ "model": "lightrag:latest",
249
+ "created_at": "2024-01-15T00:00:00Z",
250
+ "message": {
251
+ "role": "assistant",
252
+ "content": "Partial response content",
253
+ "images": null
254
+ },
255
+ "done": false
256
+ }
257
+
258
+ The last message will contain performance statistics, with done set to true.
259
+ """
260
+ url = get_base_url()
261
+ data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
262
+
263
+ # Send request and get streaming response
264
+ response = make_request(url, data, stream=True)
265
+
266
+ if OutputControl.is_verbose():
267
+ print("\n=== Streaming call response ===")
268
+ output_buffer = []
269
+ try:
270
+ for line in response.iter_lines():
271
+ if line: # Skip empty lines
272
+ try:
273
+ # Decode and parse JSON
274
+ data = json.loads(line.decode("utf-8"))
275
+ if data.get("done", True): # If it's the completion marker
276
+ if (
277
+ "total_duration" in data
278
+ ): # Final performance statistics message
279
+ # print_json_response(data, "Performance statistics")
280
+ break
281
+ else: # Normal content message
282
+ message = data.get("message", {})
283
+ content = message.get("content", "")
284
+ if content: # Only collect non-empty content
285
+ output_buffer.append(content)
286
+ print(
287
+ content, end="", flush=True
288
+ ) # Print content in real-time
289
+ except json.JSONDecodeError:
290
+ print("Error decoding JSON from response line")
291
+ finally:
292
+ response.close() # Ensure the response connection is closed
293
+
294
+ # Print a newline
295
+ print()
296
+
297
+
298
+ def test_query_modes():
299
+ """Test different query mode prefixes
300
+
301
+ Supported query modes:
302
+ - /local: Local retrieval mode, searches only in highly relevant documents
303
+ - /global: Global retrieval mode, searches across all documents
304
+ - /naive: Naive mode, does not use any optimization strategies
305
+ - /hybrid: Hybrid mode (default), combines multiple strategies
306
+ - /mix: Mix mode
307
+
308
+ Each mode will return responses in the same format, but with different retrieval strategies.
309
+ """
310
+ url = get_base_url()
311
+ modes = ["local", "global", "naive", "hybrid", "mix"]
312
+
313
+ for mode in modes:
314
+ if OutputControl.is_verbose():
315
+ print(f"\n=== Testing /{mode} mode ===")
316
+ data = create_request_data(
317
+ f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
318
+ )
319
+
320
+ # Send request
321
+ response = make_request(url, data)
322
+ response_json = response.json()
323
+
324
+ # Print response content
325
+ print_json_response(
326
+ {"model": response_json["model"], "message": response_json["message"]}
327
+ )
328
+
329
+
330
+ def create_error_test_data(error_type: str) -> Dict[str, Any]:
331
+ """Create request data for error testing
332
+ Args:
333
+ error_type: Error type, supported:
334
+ - empty_messages: Empty message list
335
+ - invalid_role: Invalid role field
336
+ - missing_content: Missing content field
337
+
338
+ Returns:
339
+ Request dictionary containing error data
340
+ """
341
+ error_data = {
342
+ "empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True},
343
+ "invalid_role": {
344
+ "model": "lightrag:latest",
345
+ "messages": [{"invalid_role": "user", "content": "Test message"}],
346
+ "stream": True,
347
+ },
348
+ "missing_content": {
349
+ "model": "lightrag:latest",
350
+ "messages": [{"role": "user"}],
351
+ "stream": True,
352
+ },
353
+ }
354
+ return error_data.get(error_type, error_data["empty_messages"])
355
+
356
+
357
+ def test_stream_error_handling():
358
+ """Test error handling for streaming responses
359
+
360
+ Test scenarios:
361
+ 1. Empty message list
362
+ 2. Message format error (missing required fields)
363
+
364
+ Error responses should be returned immediately without establishing a streaming connection.
365
+ The status code should be 4xx, and detailed error information should be returned.
366
+ """
367
+ url = get_base_url()
368
+
369
+ if OutputControl.is_verbose():
370
+ print("\n=== Testing streaming response error handling ===")
371
+
372
+ # Test empty message list
373
+ if OutputControl.is_verbose():
374
+ print("\n--- Testing empty message list (streaming) ---")
375
+ data = create_error_test_data("empty_messages")
376
+ response = make_request(url, data, stream=True)
377
+ print(f"Status code: {response.status_code}")
378
+ if response.status_code != 200:
379
+ print_json_response(response.json(), "Error message")
380
+ response.close()
381
+
382
+ # Test invalid role field
383
+ if OutputControl.is_verbose():
384
+ print("\n--- Testing invalid role field (streaming) ---")
385
+ data = create_error_test_data("invalid_role")
386
+ response = make_request(url, data, stream=True)
387
+ print(f"Status code: {response.status_code}")
388
+ if response.status_code != 200:
389
+ print_json_response(response.json(), "Error message")
390
+ response.close()
391
+
392
+ # Test missing content field
393
+ if OutputControl.is_verbose():
394
+ print("\n--- Testing missing content field (streaming) ---")
395
+ data = create_error_test_data("missing_content")
396
+ response = make_request(url, data, stream=True)
397
+ print(f"Status code: {response.status_code}")
398
+ if response.status_code != 200:
399
+ print_json_response(response.json(), "Error message")
400
+ response.close()
401
+
402
+
403
+ def test_error_handling():
404
+ """Test error handling for non-streaming responses
405
+
406
+ Test scenarios:
407
+ 1. Empty message list
408
+ 2. Message format error (missing required fields)
409
+
410
+ Error response format:
411
+ {
412
+ "detail": "Error description"
413
+ }
414
+
415
+ All errors should return appropriate HTTP status codes and clear error messages.
416
+ """
417
+ url = get_base_url()
418
+
419
+ if OutputControl.is_verbose():
420
+ print("\n=== Testing error handling ===")
421
+
422
+ # Test empty message list
423
+ if OutputControl.is_verbose():
424
+ print("\n--- Testing empty message list ---")
425
+ data = create_error_test_data("empty_messages")
426
+ data["stream"] = False # Change to non-streaming mode
427
+ response = make_request(url, data)
428
+ print(f"Status code: {response.status_code}")
429
+ print_json_response(response.json(), "Error message")
430
+
431
+ # Test invalid role field
432
+ if OutputControl.is_verbose():
433
+ print("\n--- Testing invalid role field ---")
434
+ data = create_error_test_data("invalid_role")
435
+ data["stream"] = False # Change to non-streaming mode
436
+ response = make_request(url, data)
437
+ print(f"Status code: {response.status_code}")
438
+ print_json_response(response.json(), "Error message")
439
+
440
+ # Test missing content field
441
+ if OutputControl.is_verbose():
442
+ print("\n--- Testing missing content field ---")
443
+ data = create_error_test_data("missing_content")
444
+ data["stream"] = False # Change to non-streaming mode
445
+ response = make_request(url, data)
446
+ print(f"Status code: {response.status_code}")
447
+ print_json_response(response.json(), "Error message")
448
+
449
+
450
+ def get_test_cases() -> Dict[str, Callable]:
451
+ """Get all available test cases
452
+ Returns:
453
+ A dictionary mapping test names to test functions
454
+ """
455
+ return {
456
+ "non_stream": test_non_stream_chat,
457
+ "stream": test_stream_chat,
458
+ "modes": test_query_modes,
459
+ "errors": test_error_handling,
460
+ "stream_errors": test_stream_error_handling,
461
+ }
462
+
463
+
464
+ def create_default_config():
465
+ """Create a default configuration file"""
466
+ config_path = Path("config.json")
467
+ if not config_path.exists():
468
+ with open(config_path, "w", encoding="utf-8") as f:
469
+ json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
470
+ print(f"Default configuration file created: {config_path}")
471
+
472
+
473
+ def parse_args() -> argparse.Namespace:
474
+ """Parse command line arguments"""
475
+ parser = argparse.ArgumentParser(
476
+ description="LightRAG Ollama Compatibility Interface Testing",
477
+ formatter_class=argparse.RawDescriptionHelpFormatter,
478
+ epilog="""
479
+ Configuration file (config.json):
480
+ {
481
+ "server": {
482
+ "host": "localhost", # Server address
483
+ "port": 9621, # Server port
484
+ "model": "lightrag:latest" # Default model name
485
+ },
486
+ "test_cases": {
487
+ "basic": {
488
+ "query": "Test query", # Basic query text
489
+ "stream_query": "Stream query" # Stream query text
490
+ }
491
+ }
492
+ }
493
+ """,
494
+ )
495
+ parser.add_argument(
496
+ "-q",
497
+ "--quiet",
498
+ action="store_true",
499
+ help="Silent mode, only display test result summary",
500
+ )
501
+ parser.add_argument(
502
+ "-a",
503
+ "--ask",
504
+ type=str,
505
+ help="Specify query content, which will override the query settings in the configuration file",
506
+ )
507
+ parser.add_argument(
508
+ "--init-config", action="store_true", help="Create default configuration file"
509
+ )
510
+ parser.add_argument(
511
+ "--output",
512
+ type=str,
513
+ default="",
514
+ help="Test result output file path, default is not to output to a file",
515
+ )
516
+ parser.add_argument(
517
+ "--tests",
518
+ nargs="+",
519
+ choices=list(get_test_cases().keys()) + ["all"],
520
+ default=["all"],
521
+ help="Test cases to run, options: %(choices)s. Use 'all' to run all tests",
522
+ )
523
+ return parser.parse_args()
524
+
525
+
526
+ if __name__ == "__main__":
527
+ args = parse_args()
528
+
529
+ # Set output mode
530
+ OutputControl.set_verbose(not args.quiet)
531
+
532
+ # If query content is specified, update the configuration
533
+ if args.ask:
534
+ CONFIG["test_cases"]["basic"]["query"] = args.ask
535
+
536
+ # If specified to create a configuration file
537
+ if args.init_config:
538
+ create_default_config()
539
+ exit(0)
540
+
541
+ test_cases = get_test_cases()
542
+
543
+ try:
544
+ if "all" in args.tests:
545
+ # Run all tests
546
+ if OutputControl.is_verbose():
547
+ print("\n【Basic Functionality Tests】")
548
+ run_test(test_non_stream_chat, "Non-streaming Call Test")
549
+ run_test(test_stream_chat, "Streaming Call Test")
550
+
551
+ if OutputControl.is_verbose():
552
+ print("\n【Query Mode Tests】")
553
+ run_test(test_query_modes, "Query Mode Test")
554
+
555
+ if OutputControl.is_verbose():
556
+ print("\n【Error Handling Tests】")
557
+ run_test(test_error_handling, "Error Handling Test")
558
+ run_test(test_stream_error_handling, "Streaming Error Handling Test")
559
+ else:
560
+ # Run specified tests
561
+ for test_name in args.tests:
562
+ if OutputControl.is_verbose():
563
+ print(f"\n【Running Test: {test_name}】")
564
+ run_test(test_cases[test_name], test_name)
565
+ except Exception as e:
566
+ print(f"\nAn error occurred: {str(e)}")
567
+ finally:
568
+ # Print test statistics
569
+ STATS.print_summary()
570
+ # If an output file path is specified, export the results
571
+ if args.output:
572
+ STATS.export_results(args.output)