Merge pull request #644 from danielaskdd/Add-Ollama-generate-API-support
Browse files- lightrag/api/README.md +1 -0
- lightrag/api/lightrag_server.py +178 -6
- test_lightrag_ollama_chat.py +230 -25
lightrag/api/README.md
CHANGED
@@ -94,6 +94,7 @@ For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode q
|
|
94 |
|
95 |
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
|
96 |
|
|
|
97 |
|
98 |
## Configuration
|
99 |
|
|
|
94 |
|
95 |
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
|
96 |
|
97 |
+
To prevent Open WebUI from using LightRAG when generating conversation titles, go to Admin Panel > Interface > Set Task Model and change both Local Models and External Models to any option except "Current Model".
|
98 |
|
99 |
## Configuration
|
100 |
|
lightrag/api/lightrag_server.py
CHANGED
@@ -533,6 +533,7 @@ class OllamaChatRequest(BaseModel):
|
|
533 |
messages: List[OllamaMessage]
|
534 |
stream: bool = True # Default to streaming mode
|
535 |
options: Optional[Dict[str, Any]] = None
|
|
|
536 |
|
537 |
|
538 |
class OllamaChatResponse(BaseModel):
|
@@ -542,6 +543,28 @@ class OllamaChatResponse(BaseModel):
|
|
542 |
done: bool
|
543 |
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
class OllamaVersionResponse(BaseModel):
|
546 |
version: str
|
547 |
|
@@ -1417,6 +1440,145 @@ def create_app(args):
|
|
1417 |
|
1418 |
return query, SearchMode.hybrid
|
1419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1420 |
@app.post("/api/chat")
|
1421 |
async def chat(raw_request: Request, request: OllamaChatRequest):
|
1422 |
"""Handle chat completion requests"""
|
@@ -1429,16 +1591,12 @@ def create_app(args):
|
|
1429 |
# Get the last message as query
|
1430 |
query = messages[-1].content
|
1431 |
|
1432 |
-
#
|
1433 |
cleaned_query, mode = parse_query_mode(query)
|
1434 |
|
1435 |
-
# 开始计时
|
1436 |
start_time = time.time_ns()
|
1437 |
-
|
1438 |
-
# 计算输入token数量
|
1439 |
prompt_tokens = estimate_tokens(cleaned_query)
|
1440 |
|
1441 |
-
# 调用RAG进行查询
|
1442 |
query_param = QueryParam(
|
1443 |
mode=mode, stream=request.stream, only_need_context=False
|
1444 |
)
|
@@ -1549,7 +1707,21 @@ def create_app(args):
|
|
1549 |
)
|
1550 |
else:
|
1551 |
first_chunk_time = time.time_ns()
|
1552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1553 |
last_chunk_time = time.time_ns()
|
1554 |
|
1555 |
if not response_text:
|
|
|
533 |
messages: List[OllamaMessage]
|
534 |
stream: bool = True # Default to streaming mode
|
535 |
options: Optional[Dict[str, Any]] = None
|
536 |
+
system: Optional[str] = None
|
537 |
|
538 |
|
539 |
class OllamaChatResponse(BaseModel):
|
|
|
543 |
done: bool
|
544 |
|
545 |
|
546 |
+
class OllamaGenerateRequest(BaseModel):
|
547 |
+
model: str = LIGHTRAG_MODEL
|
548 |
+
prompt: str
|
549 |
+
system: Optional[str] = None
|
550 |
+
stream: bool = False
|
551 |
+
options: Optional[Dict[str, Any]] = None
|
552 |
+
|
553 |
+
|
554 |
+
class OllamaGenerateResponse(BaseModel):
|
555 |
+
model: str
|
556 |
+
created_at: str
|
557 |
+
response: str
|
558 |
+
done: bool
|
559 |
+
context: Optional[List[int]]
|
560 |
+
total_duration: Optional[int]
|
561 |
+
load_duration: Optional[int]
|
562 |
+
prompt_eval_count: Optional[int]
|
563 |
+
prompt_eval_duration: Optional[int]
|
564 |
+
eval_count: Optional[int]
|
565 |
+
eval_duration: Optional[int]
|
566 |
+
|
567 |
+
|
568 |
class OllamaVersionResponse(BaseModel):
|
569 |
version: str
|
570 |
|
|
|
1440 |
|
1441 |
return query, SearchMode.hybrid
|
1442 |
|
1443 |
+
@app.post("/api/generate")
|
1444 |
+
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
1445 |
+
"""Handle generate completion requests"""
|
1446 |
+
try:
|
1447 |
+
query = request.prompt
|
1448 |
+
start_time = time.time_ns()
|
1449 |
+
prompt_tokens = estimate_tokens(query)
|
1450 |
+
|
1451 |
+
if request.system:
|
1452 |
+
rag.llm_model_kwargs["system_prompt"] = request.system
|
1453 |
+
|
1454 |
+
if request.stream:
|
1455 |
+
from fastapi.responses import StreamingResponse
|
1456 |
+
|
1457 |
+
response = await rag.llm_model_func(
|
1458 |
+
query, stream=True, **rag.llm_model_kwargs
|
1459 |
+
)
|
1460 |
+
|
1461 |
+
async def stream_generator():
|
1462 |
+
try:
|
1463 |
+
first_chunk_time = None
|
1464 |
+
last_chunk_time = None
|
1465 |
+
total_response = ""
|
1466 |
+
|
1467 |
+
# Ensure response is an async generator
|
1468 |
+
if isinstance(response, str):
|
1469 |
+
# If it's a string, send in two parts
|
1470 |
+
first_chunk_time = time.time_ns()
|
1471 |
+
last_chunk_time = first_chunk_time
|
1472 |
+
total_response = response
|
1473 |
+
|
1474 |
+
data = {
|
1475 |
+
"model": LIGHTRAG_MODEL,
|
1476 |
+
"created_at": LIGHTRAG_CREATED_AT,
|
1477 |
+
"response": response,
|
1478 |
+
"done": False,
|
1479 |
+
}
|
1480 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1481 |
+
|
1482 |
+
completion_tokens = estimate_tokens(total_response)
|
1483 |
+
total_time = last_chunk_time - start_time
|
1484 |
+
prompt_eval_time = first_chunk_time - start_time
|
1485 |
+
eval_time = last_chunk_time - first_chunk_time
|
1486 |
+
|
1487 |
+
data = {
|
1488 |
+
"model": LIGHTRAG_MODEL,
|
1489 |
+
"created_at": LIGHTRAG_CREATED_AT,
|
1490 |
+
"done": True,
|
1491 |
+
"total_duration": total_time,
|
1492 |
+
"load_duration": 0,
|
1493 |
+
"prompt_eval_count": prompt_tokens,
|
1494 |
+
"prompt_eval_duration": prompt_eval_time,
|
1495 |
+
"eval_count": completion_tokens,
|
1496 |
+
"eval_duration": eval_time,
|
1497 |
+
}
|
1498 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1499 |
+
else:
|
1500 |
+
async for chunk in response:
|
1501 |
+
if chunk:
|
1502 |
+
if first_chunk_time is None:
|
1503 |
+
first_chunk_time = time.time_ns()
|
1504 |
+
|
1505 |
+
last_chunk_time = time.time_ns()
|
1506 |
+
|
1507 |
+
total_response += chunk
|
1508 |
+
data = {
|
1509 |
+
"model": LIGHTRAG_MODEL,
|
1510 |
+
"created_at": LIGHTRAG_CREATED_AT,
|
1511 |
+
"response": chunk,
|
1512 |
+
"done": False,
|
1513 |
+
}
|
1514 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1515 |
+
|
1516 |
+
completion_tokens = estimate_tokens(total_response)
|
1517 |
+
total_time = last_chunk_time - start_time
|
1518 |
+
prompt_eval_time = first_chunk_time - start_time
|
1519 |
+
eval_time = last_chunk_time - first_chunk_time
|
1520 |
+
|
1521 |
+
data = {
|
1522 |
+
"model": LIGHTRAG_MODEL,
|
1523 |
+
"created_at": LIGHTRAG_CREATED_AT,
|
1524 |
+
"done": True,
|
1525 |
+
"total_duration": total_time,
|
1526 |
+
"load_duration": 0,
|
1527 |
+
"prompt_eval_count": prompt_tokens,
|
1528 |
+
"prompt_eval_duration": prompt_eval_time,
|
1529 |
+
"eval_count": completion_tokens,
|
1530 |
+
"eval_duration": eval_time,
|
1531 |
+
}
|
1532 |
+
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1533 |
+
return
|
1534 |
+
|
1535 |
+
except Exception as e:
|
1536 |
+
logging.error(f"Error in stream_generator: {str(e)}")
|
1537 |
+
raise
|
1538 |
+
|
1539 |
+
return StreamingResponse(
|
1540 |
+
stream_generator(),
|
1541 |
+
media_type="application/x-ndjson",
|
1542 |
+
headers={
|
1543 |
+
"Cache-Control": "no-cache",
|
1544 |
+
"Connection": "keep-alive",
|
1545 |
+
"Content-Type": "application/x-ndjson",
|
1546 |
+
"Access-Control-Allow-Origin": "*",
|
1547 |
+
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
1548 |
+
"Access-Control-Allow-Headers": "Content-Type",
|
1549 |
+
},
|
1550 |
+
)
|
1551 |
+
else:
|
1552 |
+
first_chunk_time = time.time_ns()
|
1553 |
+
response_text = await rag.llm_model_func(
|
1554 |
+
query, stream=False, **rag.llm_model_kwargs
|
1555 |
+
)
|
1556 |
+
last_chunk_time = time.time_ns()
|
1557 |
+
|
1558 |
+
if not response_text:
|
1559 |
+
response_text = "No response generated"
|
1560 |
+
|
1561 |
+
completion_tokens = estimate_tokens(str(response_text))
|
1562 |
+
total_time = last_chunk_time - start_time
|
1563 |
+
prompt_eval_time = first_chunk_time - start_time
|
1564 |
+
eval_time = last_chunk_time - first_chunk_time
|
1565 |
+
|
1566 |
+
return {
|
1567 |
+
"model": LIGHTRAG_MODEL,
|
1568 |
+
"created_at": LIGHTRAG_CREATED_AT,
|
1569 |
+
"response": str(response_text),
|
1570 |
+
"done": True,
|
1571 |
+
"total_duration": total_time,
|
1572 |
+
"load_duration": 0,
|
1573 |
+
"prompt_eval_count": prompt_tokens,
|
1574 |
+
"prompt_eval_duration": prompt_eval_time,
|
1575 |
+
"eval_count": completion_tokens,
|
1576 |
+
"eval_duration": eval_time,
|
1577 |
+
}
|
1578 |
+
except Exception as e:
|
1579 |
+
trace_exception(e)
|
1580 |
+
raise HTTPException(status_code=500, detail=str(e))
|
1581 |
+
|
1582 |
@app.post("/api/chat")
|
1583 |
async def chat(raw_request: Request, request: OllamaChatRequest):
|
1584 |
"""Handle chat completion requests"""
|
|
|
1591 |
# Get the last message as query
|
1592 |
query = messages[-1].content
|
1593 |
|
1594 |
+
# Check for query prefix
|
1595 |
cleaned_query, mode = parse_query_mode(query)
|
1596 |
|
|
|
1597 |
start_time = time.time_ns()
|
|
|
|
|
1598 |
prompt_tokens = estimate_tokens(cleaned_query)
|
1599 |
|
|
|
1600 |
query_param = QueryParam(
|
1601 |
mode=mode, stream=request.stream, only_need_context=False
|
1602 |
)
|
|
|
1707 |
)
|
1708 |
else:
|
1709 |
first_chunk_time = time.time_ns()
|
1710 |
+
|
1711 |
+
# Determine if the request is from Open WebUI's session title and session keyword generation task
|
1712 |
+
match_result = re.search(
|
1713 |
+
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
1714 |
+
)
|
1715 |
+
if match_result:
|
1716 |
+
if request.system:
|
1717 |
+
rag.llm_model_kwargs["system_prompt"] = request.system
|
1718 |
+
|
1719 |
+
response_text = await rag.llm_model_func(
|
1720 |
+
cleaned_query, stream=False, **rag.llm_model_kwargs
|
1721 |
+
)
|
1722 |
+
else:
|
1723 |
+
response_text = await rag.aquery(cleaned_query, param=query_param)
|
1724 |
+
|
1725 |
last_chunk_time = time.time_ns()
|
1726 |
|
1727 |
if not response_text:
|
test_lightrag_ollama_chat.py
CHANGED
@@ -108,7 +108,10 @@ DEFAULT_CONFIG = {
|
|
108 |
"max_retries": 3,
|
109 |
"retry_delay": 1,
|
110 |
},
|
111 |
-
"test_cases": {
|
|
|
|
|
|
|
112 |
}
|
113 |
|
114 |
|
@@ -174,22 +177,27 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2)
|
|
174 |
CONFIG = load_config()
|
175 |
|
176 |
|
177 |
-
def get_base_url() -> str:
|
178 |
-
"""Return the base URL
|
|
|
|
|
|
|
|
|
|
|
179 |
server = CONFIG["server"]
|
180 |
-
return f"http://{server['host']}:{server['port']}/api/
|
181 |
|
182 |
|
183 |
-
def
|
184 |
content: str, stream: bool = False, model: str = None
|
185 |
) -> Dict[str, Any]:
|
186 |
-
"""Create
|
187 |
Args:
|
188 |
content: User message content
|
189 |
stream: Whether to use streaming response
|
190 |
model: Model name
|
191 |
Returns:
|
192 |
-
Dictionary containing complete request data
|
193 |
"""
|
194 |
return {
|
195 |
"model": model or CONFIG["server"]["model"],
|
@@ -198,6 +206,35 @@ def create_request_data(
|
|
198 |
}
|
199 |
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
# Global test statistics
|
202 |
STATS = TestStats()
|
203 |
|
@@ -219,10 +256,12 @@ def run_test(func: Callable, name: str) -> None:
|
|
219 |
raise
|
220 |
|
221 |
|
222 |
-
def test_non_stream_chat():
|
223 |
"""Test non-streaming call to /api/chat endpoint"""
|
224 |
url = get_base_url()
|
225 |
-
data =
|
|
|
|
|
226 |
|
227 |
# Send request
|
228 |
response = make_request(url, data)
|
@@ -239,7 +278,7 @@ def test_non_stream_chat():
|
|
239 |
)
|
240 |
|
241 |
|
242 |
-
def test_stream_chat():
|
243 |
"""Test streaming call to /api/chat endpoint
|
244 |
|
245 |
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
|
@@ -258,7 +297,7 @@ def test_stream_chat():
|
|
258 |
The last message will contain performance statistics, with done set to true.
|
259 |
"""
|
260 |
url = get_base_url()
|
261 |
-
data =
|
262 |
|
263 |
# Send request and get streaming response
|
264 |
response = make_request(url, data, stream=True)
|
@@ -295,7 +334,7 @@ def test_stream_chat():
|
|
295 |
print()
|
296 |
|
297 |
|
298 |
-
def test_query_modes():
|
299 |
"""Test different query mode prefixes
|
300 |
|
301 |
Supported query modes:
|
@@ -313,7 +352,7 @@ def test_query_modes():
|
|
313 |
for mode in modes:
|
314 |
if OutputControl.is_verbose():
|
315 |
print(f"\n=== Testing /{mode} mode ===")
|
316 |
-
data =
|
317 |
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
|
318 |
)
|
319 |
|
@@ -354,7 +393,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
|
354 |
return error_data.get(error_type, error_data["empty_messages"])
|
355 |
|
356 |
|
357 |
-
def test_stream_error_handling():
|
358 |
"""Test error handling for streaming responses
|
359 |
|
360 |
Test scenarios:
|
@@ -400,7 +439,7 @@ def test_stream_error_handling():
|
|
400 |
response.close()
|
401 |
|
402 |
|
403 |
-
def test_error_handling():
|
404 |
"""Test error handling for non-streaming responses
|
405 |
|
406 |
Test scenarios:
|
@@ -447,6 +486,165 @@ def test_error_handling():
|
|
447 |
print_json_response(response.json(), "Error message")
|
448 |
|
449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
450 |
def get_test_cases() -> Dict[str, Callable]:
|
451 |
"""Get all available test cases
|
452 |
Returns:
|
@@ -458,6 +656,11 @@ def get_test_cases() -> Dict[str, Callable]:
|
|
458 |
"modes": test_query_modes,
|
459 |
"errors": test_error_handling,
|
460 |
"stream_errors": test_stream_error_handling,
|
|
|
|
|
|
|
|
|
|
|
461 |
}
|
462 |
|
463 |
|
@@ -544,18 +747,20 @@ if __name__ == "__main__":
|
|
544 |
if "all" in args.tests:
|
545 |
# Run all tests
|
546 |
if OutputControl.is_verbose():
|
547 |
-
print("\n【
|
548 |
-
run_test(test_non_stream_chat, "Non-streaming
|
549 |
-
run_test(test_stream_chat, "Streaming
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
run_test(test_query_modes, "Query Mode Test")
|
554 |
|
555 |
if OutputControl.is_verbose():
|
556 |
-
print("\n【
|
557 |
-
run_test(
|
558 |
-
run_test(
|
|
|
|
|
|
|
559 |
else:
|
560 |
# Run specified tests
|
561 |
for test_name in args.tests:
|
|
|
108 |
"max_retries": 3,
|
109 |
"retry_delay": 1,
|
110 |
},
|
111 |
+
"test_cases": {
|
112 |
+
"basic": {"query": "唐僧有几个徒弟"},
|
113 |
+
"generate": {"query": "电视剧西游记导演是谁"},
|
114 |
+
},
|
115 |
}
|
116 |
|
117 |
|
|
|
177 |
CONFIG = load_config()
|
178 |
|
179 |
|
180 |
+
def get_base_url(endpoint: str = "chat") -> str:
|
181 |
+
"""Return the base URL for specified endpoint
|
182 |
+
Args:
|
183 |
+
endpoint: API endpoint name (chat or generate)
|
184 |
+
Returns:
|
185 |
+
Complete URL for the endpoint
|
186 |
+
"""
|
187 |
server = CONFIG["server"]
|
188 |
+
return f"http://{server['host']}:{server['port']}/api/{endpoint}"
|
189 |
|
190 |
|
191 |
+
def create_chat_request_data(
|
192 |
content: str, stream: bool = False, model: str = None
|
193 |
) -> Dict[str, Any]:
|
194 |
+
"""Create chat request data
|
195 |
Args:
|
196 |
content: User message content
|
197 |
stream: Whether to use streaming response
|
198 |
model: Model name
|
199 |
Returns:
|
200 |
+
Dictionary containing complete chat request data
|
201 |
"""
|
202 |
return {
|
203 |
"model": model or CONFIG["server"]["model"],
|
|
|
206 |
}
|
207 |
|
208 |
|
209 |
+
def create_generate_request_data(
|
210 |
+
prompt: str,
|
211 |
+
system: str = None,
|
212 |
+
stream: bool = False,
|
213 |
+
model: str = None,
|
214 |
+
options: Dict[str, Any] = None,
|
215 |
+
) -> Dict[str, Any]:
|
216 |
+
"""Create generate request data
|
217 |
+
Args:
|
218 |
+
prompt: Generation prompt
|
219 |
+
system: System prompt
|
220 |
+
stream: Whether to use streaming response
|
221 |
+
model: Model name
|
222 |
+
options: Additional options
|
223 |
+
Returns:
|
224 |
+
Dictionary containing complete generate request data
|
225 |
+
"""
|
226 |
+
data = {
|
227 |
+
"model": model or CONFIG["server"]["model"],
|
228 |
+
"prompt": prompt,
|
229 |
+
"stream": stream,
|
230 |
+
}
|
231 |
+
if system:
|
232 |
+
data["system"] = system
|
233 |
+
if options:
|
234 |
+
data["options"] = options
|
235 |
+
return data
|
236 |
+
|
237 |
+
|
238 |
# Global test statistics
|
239 |
STATS = TestStats()
|
240 |
|
|
|
256 |
raise
|
257 |
|
258 |
|
259 |
+
def test_non_stream_chat() -> None:
|
260 |
"""Test non-streaming call to /api/chat endpoint"""
|
261 |
url = get_base_url()
|
262 |
+
data = create_chat_request_data(
|
263 |
+
CONFIG["test_cases"]["basic"]["query"], stream=False
|
264 |
+
)
|
265 |
|
266 |
# Send request
|
267 |
response = make_request(url, data)
|
|
|
278 |
)
|
279 |
|
280 |
|
281 |
+
def test_stream_chat() -> None:
|
282 |
"""Test streaming call to /api/chat endpoint
|
283 |
|
284 |
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
|
|
|
297 |
The last message will contain performance statistics, with done set to true.
|
298 |
"""
|
299 |
url = get_base_url()
|
300 |
+
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
|
301 |
|
302 |
# Send request and get streaming response
|
303 |
response = make_request(url, data, stream=True)
|
|
|
334 |
print()
|
335 |
|
336 |
|
337 |
+
def test_query_modes() -> None:
|
338 |
"""Test different query mode prefixes
|
339 |
|
340 |
Supported query modes:
|
|
|
352 |
for mode in modes:
|
353 |
if OutputControl.is_verbose():
|
354 |
print(f"\n=== Testing /{mode} mode ===")
|
355 |
+
data = create_chat_request_data(
|
356 |
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
|
357 |
)
|
358 |
|
|
|
393 |
return error_data.get(error_type, error_data["empty_messages"])
|
394 |
|
395 |
|
396 |
+
def test_stream_error_handling() -> None:
|
397 |
"""Test error handling for streaming responses
|
398 |
|
399 |
Test scenarios:
|
|
|
439 |
response.close()
|
440 |
|
441 |
|
442 |
+
def test_error_handling() -> None:
|
443 |
"""Test error handling for non-streaming responses
|
444 |
|
445 |
Test scenarios:
|
|
|
486 |
print_json_response(response.json(), "Error message")
|
487 |
|
488 |
|
489 |
+
def test_non_stream_generate() -> None:
|
490 |
+
"""Test non-streaming call to /api/generate endpoint"""
|
491 |
+
url = get_base_url("generate")
|
492 |
+
data = create_generate_request_data(
|
493 |
+
CONFIG["test_cases"]["generate"]["query"], stream=False
|
494 |
+
)
|
495 |
+
|
496 |
+
# Send request
|
497 |
+
response = make_request(url, data)
|
498 |
+
|
499 |
+
# Print response
|
500 |
+
if OutputControl.is_verbose():
|
501 |
+
print("\n=== Non-streaming generate response ===")
|
502 |
+
response_json = response.json()
|
503 |
+
|
504 |
+
# Print response content
|
505 |
+
print_json_response(
|
506 |
+
{
|
507 |
+
"model": response_json["model"],
|
508 |
+
"response": response_json["response"],
|
509 |
+
"done": response_json["done"],
|
510 |
+
},
|
511 |
+
"Response content",
|
512 |
+
)
|
513 |
+
|
514 |
+
|
515 |
+
def test_stream_generate() -> None:
|
516 |
+
"""Test streaming call to /api/generate endpoint"""
|
517 |
+
url = get_base_url("generate")
|
518 |
+
data = create_generate_request_data(
|
519 |
+
CONFIG["test_cases"]["generate"]["query"], stream=True
|
520 |
+
)
|
521 |
+
|
522 |
+
# Send request and get streaming response
|
523 |
+
response = make_request(url, data, stream=True)
|
524 |
+
|
525 |
+
if OutputControl.is_verbose():
|
526 |
+
print("\n=== Streaming generate response ===")
|
527 |
+
output_buffer = []
|
528 |
+
try:
|
529 |
+
for line in response.iter_lines():
|
530 |
+
if line: # Skip empty lines
|
531 |
+
try:
|
532 |
+
# Decode and parse JSON
|
533 |
+
data = json.loads(line.decode("utf-8"))
|
534 |
+
if data.get("done", True): # If it's the completion marker
|
535 |
+
if (
|
536 |
+
"total_duration" in data
|
537 |
+
): # Final performance statistics message
|
538 |
+
break
|
539 |
+
else: # Normal content message
|
540 |
+
content = data.get("response", "")
|
541 |
+
if content: # Only collect non-empty content
|
542 |
+
output_buffer.append(content)
|
543 |
+
print(
|
544 |
+
content, end="", flush=True
|
545 |
+
) # Print content in real-time
|
546 |
+
except json.JSONDecodeError:
|
547 |
+
print("Error decoding JSON from response line")
|
548 |
+
finally:
|
549 |
+
response.close() # Ensure the response connection is closed
|
550 |
+
|
551 |
+
# Print a newline
|
552 |
+
print()
|
553 |
+
|
554 |
+
|
555 |
+
def test_generate_with_system() -> None:
|
556 |
+
"""Test generate with system prompt"""
|
557 |
+
url = get_base_url("generate")
|
558 |
+
data = create_generate_request_data(
|
559 |
+
CONFIG["test_cases"]["generate"]["query"],
|
560 |
+
system="你是一个知识渊博的助手",
|
561 |
+
stream=False,
|
562 |
+
)
|
563 |
+
|
564 |
+
# Send request
|
565 |
+
response = make_request(url, data)
|
566 |
+
|
567 |
+
# Print response
|
568 |
+
if OutputControl.is_verbose():
|
569 |
+
print("\n=== Generate with system prompt response ===")
|
570 |
+
response_json = response.json()
|
571 |
+
|
572 |
+
# Print response content
|
573 |
+
print_json_response(
|
574 |
+
{
|
575 |
+
"model": response_json["model"],
|
576 |
+
"response": response_json["response"],
|
577 |
+
"done": response_json["done"],
|
578 |
+
},
|
579 |
+
"Response content",
|
580 |
+
)
|
581 |
+
|
582 |
+
|
583 |
+
def test_generate_error_handling() -> None:
|
584 |
+
"""Test error handling for generate endpoint"""
|
585 |
+
url = get_base_url("generate")
|
586 |
+
|
587 |
+
# Test empty prompt
|
588 |
+
if OutputControl.is_verbose():
|
589 |
+
print("\n=== Testing empty prompt ===")
|
590 |
+
data = create_generate_request_data("", stream=False)
|
591 |
+
response = make_request(url, data)
|
592 |
+
print(f"Status code: {response.status_code}")
|
593 |
+
print_json_response(response.json(), "Error message")
|
594 |
+
|
595 |
+
# Test invalid options
|
596 |
+
if OutputControl.is_verbose():
|
597 |
+
print("\n=== Testing invalid options ===")
|
598 |
+
data = create_generate_request_data(
|
599 |
+
CONFIG["test_cases"]["basic"]["query"],
|
600 |
+
options={"invalid_option": "value"},
|
601 |
+
stream=False,
|
602 |
+
)
|
603 |
+
response = make_request(url, data)
|
604 |
+
print(f"Status code: {response.status_code}")
|
605 |
+
print_json_response(response.json(), "Error message")
|
606 |
+
|
607 |
+
|
608 |
+
def test_generate_concurrent() -> None:
|
609 |
+
"""Test concurrent generate requests"""
|
610 |
+
import asyncio
|
611 |
+
import aiohttp
|
612 |
+
from contextlib import asynccontextmanager
|
613 |
+
|
614 |
+
@asynccontextmanager
|
615 |
+
async def get_session():
|
616 |
+
async with aiohttp.ClientSession() as session:
|
617 |
+
yield session
|
618 |
+
|
619 |
+
async def make_request(session, prompt: str):
|
620 |
+
url = get_base_url("generate")
|
621 |
+
data = create_generate_request_data(prompt, stream=False)
|
622 |
+
try:
|
623 |
+
async with session.post(url, json=data) as response:
|
624 |
+
return await response.json()
|
625 |
+
except Exception as e:
|
626 |
+
return {"error": str(e)}
|
627 |
+
|
628 |
+
async def run_concurrent_requests():
|
629 |
+
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
|
630 |
+
|
631 |
+
async with get_session() as session:
|
632 |
+
tasks = [make_request(session, prompt) for prompt in prompts]
|
633 |
+
results = await asyncio.gather(*tasks)
|
634 |
+
return results
|
635 |
+
|
636 |
+
if OutputControl.is_verbose():
|
637 |
+
print("\n=== Testing concurrent generate requests ===")
|
638 |
+
|
639 |
+
# Run concurrent requests
|
640 |
+
results = asyncio.run(run_concurrent_requests())
|
641 |
+
|
642 |
+
# Print results
|
643 |
+
for i, result in enumerate(results, 1):
|
644 |
+
print(f"\nRequest {i} result:")
|
645 |
+
print_json_response(result)
|
646 |
+
|
647 |
+
|
648 |
def get_test_cases() -> Dict[str, Callable]:
|
649 |
"""Get all available test cases
|
650 |
Returns:
|
|
|
656 |
"modes": test_query_modes,
|
657 |
"errors": test_error_handling,
|
658 |
"stream_errors": test_stream_error_handling,
|
659 |
+
"non_stream_generate": test_non_stream_generate,
|
660 |
+
"stream_generate": test_stream_generate,
|
661 |
+
"generate_with_system": test_generate_with_system,
|
662 |
+
"generate_errors": test_generate_error_handling,
|
663 |
+
"generate_concurrent": test_generate_concurrent,
|
664 |
}
|
665 |
|
666 |
|
|
|
747 |
if "all" in args.tests:
|
748 |
# Run all tests
|
749 |
if OutputControl.is_verbose():
|
750 |
+
print("\n【Chat API Tests】")
|
751 |
+
run_test(test_non_stream_chat, "Non-streaming Chat Test")
|
752 |
+
run_test(test_stream_chat, "Streaming Chat Test")
|
753 |
+
run_test(test_query_modes, "Chat Query Mode Test")
|
754 |
+
run_test(test_error_handling, "Chat Error Handling Test")
|
755 |
+
run_test(test_stream_error_handling, "Chat Streaming Error Test")
|
|
|
756 |
|
757 |
if OutputControl.is_verbose():
|
758 |
+
print("\n【Generate API Tests】")
|
759 |
+
run_test(test_non_stream_generate, "Non-streaming Generate Test")
|
760 |
+
run_test(test_stream_generate, "Streaming Generate Test")
|
761 |
+
run_test(test_generate_with_system, "Generate with System Prompt Test")
|
762 |
+
run_test(test_generate_error_handling, "Generate Error Handling Test")
|
763 |
+
run_test(test_generate_concurrent, "Generate Concurrent Test")
|
764 |
else:
|
765 |
# Run specified tests
|
766 |
for test_name in args.tests:
|