yangdx
commited on
Commit
·
dcd4769
1
Parent(s):
4a5fb8c
Fix linting, remove redundant commentsr and clean up code for better readability
Browse files- lightrag/api/lightrag_server.py +38 -52
- test_lightrag_ollama_chat.py +38 -35
lightrag/api/lightrag_server.py
CHANGED
@@ -476,6 +476,7 @@ class OllamaChatResponse(BaseModel):
|
|
476 |
message: OllamaMessage
|
477 |
done: bool
|
478 |
|
|
|
479 |
class OllamaGenerateRequest(BaseModel):
|
480 |
model: str = LIGHTRAG_MODEL
|
481 |
prompt: str
|
@@ -483,6 +484,7 @@ class OllamaGenerateRequest(BaseModel):
|
|
483 |
stream: bool = False
|
484 |
options: Optional[Dict[str, Any]] = None
|
485 |
|
|
|
486 |
class OllamaGenerateResponse(BaseModel):
|
487 |
model: str
|
488 |
created_at: str
|
@@ -490,12 +492,13 @@ class OllamaGenerateResponse(BaseModel):
|
|
490 |
done: bool
|
491 |
context: Optional[List[int]]
|
492 |
total_duration: Optional[int]
|
493 |
-
load_duration: Optional[int]
|
494 |
prompt_eval_count: Optional[int]
|
495 |
prompt_eval_duration: Optional[int]
|
496 |
eval_count: Optional[int]
|
497 |
eval_duration: Optional[int]
|
498 |
|
|
|
499 |
class OllamaVersionResponse(BaseModel):
|
500 |
version: str
|
501 |
|
@@ -1262,52 +1265,45 @@ def create_app(args):
|
|
1262 |
"""Handle generate completion requests"""
|
1263 |
try:
|
1264 |
query = request.prompt
|
1265 |
-
|
1266 |
-
# 开始计时
|
1267 |
start_time = time.time_ns()
|
1268 |
-
|
1269 |
-
# 计算输入token数量
|
1270 |
prompt_tokens = estimate_tokens(query)
|
1271 |
-
|
1272 |
-
# 直接使用 llm_model_func 进行查询
|
1273 |
if request.system:
|
1274 |
rag.llm_model_kwargs["system_prompt"] = request.system
|
1275 |
-
|
1276 |
if request.stream:
|
1277 |
from fastapi.responses import StreamingResponse
|
1278 |
-
|
1279 |
response = await rag.llm_model_func(
|
1280 |
-
query,
|
1281 |
-
stream=True,
|
1282 |
-
**rag.llm_model_kwargs
|
1283 |
)
|
1284 |
-
|
1285 |
async def stream_generator():
|
1286 |
try:
|
1287 |
first_chunk_time = None
|
1288 |
last_chunk_time = None
|
1289 |
total_response = ""
|
1290 |
-
|
1291 |
-
#
|
1292 |
if isinstance(response, str):
|
1293 |
-
#
|
1294 |
first_chunk_time = time.time_ns()
|
1295 |
last_chunk_time = first_chunk_time
|
1296 |
total_response = response
|
1297 |
-
|
1298 |
data = {
|
1299 |
"model": LIGHTRAG_MODEL,
|
1300 |
"created_at": LIGHTRAG_CREATED_AT,
|
1301 |
"response": response,
|
1302 |
-
"done": False
|
1303 |
}
|
1304 |
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1305 |
-
|
1306 |
completion_tokens = estimate_tokens(total_response)
|
1307 |
total_time = last_chunk_time - start_time
|
1308 |
prompt_eval_time = first_chunk_time - start_time
|
1309 |
eval_time = last_chunk_time - first_chunk_time
|
1310 |
-
|
1311 |
data = {
|
1312 |
"model": LIGHTRAG_MODEL,
|
1313 |
"created_at": LIGHTRAG_CREATED_AT,
|
@@ -1317,7 +1313,7 @@ def create_app(args):
|
|
1317 |
"prompt_eval_count": prompt_tokens,
|
1318 |
"prompt_eval_duration": prompt_eval_time,
|
1319 |
"eval_count": completion_tokens,
|
1320 |
-
"eval_duration": eval_time
|
1321 |
}
|
1322 |
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1323 |
else:
|
@@ -1325,23 +1321,23 @@ def create_app(args):
|
|
1325 |
if chunk:
|
1326 |
if first_chunk_time is None:
|
1327 |
first_chunk_time = time.time_ns()
|
1328 |
-
|
1329 |
last_chunk_time = time.time_ns()
|
1330 |
-
|
1331 |
total_response += chunk
|
1332 |
data = {
|
1333 |
"model": LIGHTRAG_MODEL,
|
1334 |
"created_at": LIGHTRAG_CREATED_AT,
|
1335 |
"response": chunk,
|
1336 |
-
"done": False
|
1337 |
}
|
1338 |
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1339 |
-
|
1340 |
completion_tokens = estimate_tokens(total_response)
|
1341 |
total_time = last_chunk_time - start_time
|
1342 |
prompt_eval_time = first_chunk_time - start_time
|
1343 |
eval_time = last_chunk_time - first_chunk_time
|
1344 |
-
|
1345 |
data = {
|
1346 |
"model": LIGHTRAG_MODEL,
|
1347 |
"created_at": LIGHTRAG_CREATED_AT,
|
@@ -1351,15 +1347,15 @@ def create_app(args):
|
|
1351 |
"prompt_eval_count": prompt_tokens,
|
1352 |
"prompt_eval_duration": prompt_eval_time,
|
1353 |
"eval_count": completion_tokens,
|
1354 |
-
"eval_duration": eval_time
|
1355 |
}
|
1356 |
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1357 |
return
|
1358 |
-
|
1359 |
except Exception as e:
|
1360 |
logging.error(f"Error in stream_generator: {str(e)}")
|
1361 |
raise
|
1362 |
-
|
1363 |
return StreamingResponse(
|
1364 |
stream_generator(),
|
1365 |
media_type="application/x-ndjson",
|
@@ -1375,20 +1371,18 @@ def create_app(args):
|
|
1375 |
else:
|
1376 |
first_chunk_time = time.time_ns()
|
1377 |
response_text = await rag.llm_model_func(
|
1378 |
-
query,
|
1379 |
-
stream=False,
|
1380 |
-
**rag.llm_model_kwargs
|
1381 |
)
|
1382 |
last_chunk_time = time.time_ns()
|
1383 |
-
|
1384 |
if not response_text:
|
1385 |
response_text = "No response generated"
|
1386 |
-
|
1387 |
completion_tokens = estimate_tokens(str(response_text))
|
1388 |
total_time = last_chunk_time - start_time
|
1389 |
prompt_eval_time = first_chunk_time - start_time
|
1390 |
eval_time = last_chunk_time - first_chunk_time
|
1391 |
-
|
1392 |
return {
|
1393 |
"model": LIGHTRAG_MODEL,
|
1394 |
"created_at": LIGHTRAG_CREATED_AT,
|
@@ -1399,7 +1393,7 @@ def create_app(args):
|
|
1399 |
"prompt_eval_count": prompt_tokens,
|
1400 |
"prompt_eval_duration": prompt_eval_time,
|
1401 |
"eval_count": completion_tokens,
|
1402 |
-
"eval_duration": eval_time
|
1403 |
}
|
1404 |
except Exception as e:
|
1405 |
trace_exception(e)
|
@@ -1417,16 +1411,12 @@ def create_app(args):
|
|
1417 |
# Get the last message as query
|
1418 |
query = messages[-1].content
|
1419 |
|
1420 |
-
#
|
1421 |
cleaned_query, mode = parse_query_mode(query)
|
1422 |
|
1423 |
-
# 开始计时
|
1424 |
start_time = time.time_ns()
|
1425 |
-
|
1426 |
-
# 计算输入token数量
|
1427 |
prompt_tokens = estimate_tokens(cleaned_query)
|
1428 |
|
1429 |
-
# 调用RAG进行查询
|
1430 |
query_param = QueryParam(
|
1431 |
mode=mode, stream=request.stream, only_need_context=False
|
1432 |
)
|
@@ -1537,25 +1527,21 @@ def create_app(args):
|
|
1537 |
)
|
1538 |
else:
|
1539 |
first_chunk_time = time.time_ns()
|
1540 |
-
|
1541 |
-
# 判断是否包含特定字符串,使用正则表达式进行匹配
|
1542 |
-
logging.info(f"Cleaned query content: {cleaned_query}")
|
1543 |
-
match_result = re.search(r'\n<chat_history>\nUSER:', cleaned_query, re.MULTILINE)
|
1544 |
-
logging.info(f"Regex match result: {bool(match_result)}")
|
1545 |
-
|
1546 |
-
if match_result:
|
1547 |
|
|
|
|
|
|
|
|
|
|
|
1548 |
if request.system:
|
1549 |
rag.llm_model_kwargs["system_prompt"] = request.system
|
1550 |
|
1551 |
response_text = await rag.llm_model_func(
|
1552 |
-
cleaned_query,
|
1553 |
-
stream=False,
|
1554 |
-
**rag.llm_model_kwargs
|
1555 |
)
|
1556 |
else:
|
1557 |
response_text = await rag.aquery(cleaned_query, param=query_param)
|
1558 |
-
|
1559 |
last_chunk_time = time.time_ns()
|
1560 |
|
1561 |
if not response_text:
|
|
|
476 |
message: OllamaMessage
|
477 |
done: bool
|
478 |
|
479 |
+
|
480 |
class OllamaGenerateRequest(BaseModel):
|
481 |
model: str = LIGHTRAG_MODEL
|
482 |
prompt: str
|
|
|
484 |
stream: bool = False
|
485 |
options: Optional[Dict[str, Any]] = None
|
486 |
|
487 |
+
|
488 |
class OllamaGenerateResponse(BaseModel):
|
489 |
model: str
|
490 |
created_at: str
|
|
|
492 |
done: bool
|
493 |
context: Optional[List[int]]
|
494 |
total_duration: Optional[int]
|
495 |
+
load_duration: Optional[int]
|
496 |
prompt_eval_count: Optional[int]
|
497 |
prompt_eval_duration: Optional[int]
|
498 |
eval_count: Optional[int]
|
499 |
eval_duration: Optional[int]
|
500 |
|
501 |
+
|
502 |
class OllamaVersionResponse(BaseModel):
|
503 |
version: str
|
504 |
|
|
|
1265 |
"""Handle generate completion requests"""
|
1266 |
try:
|
1267 |
query = request.prompt
|
|
|
|
|
1268 |
start_time = time.time_ns()
|
|
|
|
|
1269 |
prompt_tokens = estimate_tokens(query)
|
1270 |
+
|
|
|
1271 |
if request.system:
|
1272 |
rag.llm_model_kwargs["system_prompt"] = request.system
|
1273 |
+
|
1274 |
if request.stream:
|
1275 |
from fastapi.responses import StreamingResponse
|
1276 |
+
|
1277 |
response = await rag.llm_model_func(
|
1278 |
+
query, stream=True, **rag.llm_model_kwargs
|
|
|
|
|
1279 |
)
|
1280 |
+
|
1281 |
async def stream_generator():
|
1282 |
try:
|
1283 |
first_chunk_time = None
|
1284 |
last_chunk_time = None
|
1285 |
total_response = ""
|
1286 |
+
|
1287 |
+
# Ensure response is an async generator
|
1288 |
if isinstance(response, str):
|
1289 |
+
# If it's a string, send in two parts
|
1290 |
first_chunk_time = time.time_ns()
|
1291 |
last_chunk_time = first_chunk_time
|
1292 |
total_response = response
|
1293 |
+
|
1294 |
data = {
|
1295 |
"model": LIGHTRAG_MODEL,
|
1296 |
"created_at": LIGHTRAG_CREATED_AT,
|
1297 |
"response": response,
|
1298 |
+
"done": False,
|
1299 |
}
|
1300 |
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1301 |
+
|
1302 |
completion_tokens = estimate_tokens(total_response)
|
1303 |
total_time = last_chunk_time - start_time
|
1304 |
prompt_eval_time = first_chunk_time - start_time
|
1305 |
eval_time = last_chunk_time - first_chunk_time
|
1306 |
+
|
1307 |
data = {
|
1308 |
"model": LIGHTRAG_MODEL,
|
1309 |
"created_at": LIGHTRAG_CREATED_AT,
|
|
|
1313 |
"prompt_eval_count": prompt_tokens,
|
1314 |
"prompt_eval_duration": prompt_eval_time,
|
1315 |
"eval_count": completion_tokens,
|
1316 |
+
"eval_duration": eval_time,
|
1317 |
}
|
1318 |
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1319 |
else:
|
|
|
1321 |
if chunk:
|
1322 |
if first_chunk_time is None:
|
1323 |
first_chunk_time = time.time_ns()
|
1324 |
+
|
1325 |
last_chunk_time = time.time_ns()
|
1326 |
+
|
1327 |
total_response += chunk
|
1328 |
data = {
|
1329 |
"model": LIGHTRAG_MODEL,
|
1330 |
"created_at": LIGHTRAG_CREATED_AT,
|
1331 |
"response": chunk,
|
1332 |
+
"done": False,
|
1333 |
}
|
1334 |
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1335 |
+
|
1336 |
completion_tokens = estimate_tokens(total_response)
|
1337 |
total_time = last_chunk_time - start_time
|
1338 |
prompt_eval_time = first_chunk_time - start_time
|
1339 |
eval_time = last_chunk_time - first_chunk_time
|
1340 |
+
|
1341 |
data = {
|
1342 |
"model": LIGHTRAG_MODEL,
|
1343 |
"created_at": LIGHTRAG_CREATED_AT,
|
|
|
1347 |
"prompt_eval_count": prompt_tokens,
|
1348 |
"prompt_eval_duration": prompt_eval_time,
|
1349 |
"eval_count": completion_tokens,
|
1350 |
+
"eval_duration": eval_time,
|
1351 |
}
|
1352 |
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
1353 |
return
|
1354 |
+
|
1355 |
except Exception as e:
|
1356 |
logging.error(f"Error in stream_generator: {str(e)}")
|
1357 |
raise
|
1358 |
+
|
1359 |
return StreamingResponse(
|
1360 |
stream_generator(),
|
1361 |
media_type="application/x-ndjson",
|
|
|
1371 |
else:
|
1372 |
first_chunk_time = time.time_ns()
|
1373 |
response_text = await rag.llm_model_func(
|
1374 |
+
query, stream=False, **rag.llm_model_kwargs
|
|
|
|
|
1375 |
)
|
1376 |
last_chunk_time = time.time_ns()
|
1377 |
+
|
1378 |
if not response_text:
|
1379 |
response_text = "No response generated"
|
1380 |
+
|
1381 |
completion_tokens = estimate_tokens(str(response_text))
|
1382 |
total_time = last_chunk_time - start_time
|
1383 |
prompt_eval_time = first_chunk_time - start_time
|
1384 |
eval_time = last_chunk_time - first_chunk_time
|
1385 |
+
|
1386 |
return {
|
1387 |
"model": LIGHTRAG_MODEL,
|
1388 |
"created_at": LIGHTRAG_CREATED_AT,
|
|
|
1393 |
"prompt_eval_count": prompt_tokens,
|
1394 |
"prompt_eval_duration": prompt_eval_time,
|
1395 |
"eval_count": completion_tokens,
|
1396 |
+
"eval_duration": eval_time,
|
1397 |
}
|
1398 |
except Exception as e:
|
1399 |
trace_exception(e)
|
|
|
1411 |
# Get the last message as query
|
1412 |
query = messages[-1].content
|
1413 |
|
1414 |
+
# Check for query prefix
|
1415 |
cleaned_query, mode = parse_query_mode(query)
|
1416 |
|
|
|
1417 |
start_time = time.time_ns()
|
|
|
|
|
1418 |
prompt_tokens = estimate_tokens(cleaned_query)
|
1419 |
|
|
|
1420 |
query_param = QueryParam(
|
1421 |
mode=mode, stream=request.stream, only_need_context=False
|
1422 |
)
|
|
|
1527 |
)
|
1528 |
else:
|
1529 |
first_chunk_time = time.time_ns()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1530 |
|
1531 |
+
# Determine if the request is from Open WebUI's session title and session keyword generation task
|
1532 |
+
match_result = re.search(
|
1533 |
+
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
1534 |
+
)
|
1535 |
+
if match_result:
|
1536 |
if request.system:
|
1537 |
rag.llm_model_kwargs["system_prompt"] = request.system
|
1538 |
|
1539 |
response_text = await rag.llm_model_func(
|
1540 |
+
cleaned_query, stream=False, **rag.llm_model_kwargs
|
|
|
|
|
1541 |
)
|
1542 |
else:
|
1543 |
response_text = await rag.aquery(cleaned_query, param=query_param)
|
1544 |
+
|
1545 |
last_chunk_time = time.time_ns()
|
1546 |
|
1547 |
if not response_text:
|
test_lightrag_ollama_chat.py
CHANGED
@@ -110,7 +110,7 @@ DEFAULT_CONFIG = {
|
|
110 |
},
|
111 |
"test_cases": {
|
112 |
"basic": {"query": "唐僧有几个徒弟"},
|
113 |
-
"generate": {"query": "电视剧西游记导演是谁"}
|
114 |
},
|
115 |
}
|
116 |
|
@@ -205,12 +205,13 @@ def create_chat_request_data(
|
|
205 |
"stream": stream,
|
206 |
}
|
207 |
|
|
|
208 |
def create_generate_request_data(
|
209 |
-
prompt: str,
|
210 |
system: str = None,
|
211 |
-
stream: bool = False,
|
212 |
model: str = None,
|
213 |
-
options: Dict[str, Any] = None
|
214 |
) -> Dict[str, Any]:
|
215 |
"""Create generate request data
|
216 |
Args:
|
@@ -225,7 +226,7 @@ def create_generate_request_data(
|
|
225 |
data = {
|
226 |
"model": model or CONFIG["server"]["model"],
|
227 |
"prompt": prompt,
|
228 |
-
"stream": stream
|
229 |
}
|
230 |
if system:
|
231 |
data["system"] = system
|
@@ -258,7 +259,9 @@ def run_test(func: Callable, name: str) -> None:
|
|
258 |
def test_non_stream_chat() -> None:
|
259 |
"""Test non-streaming call to /api/chat endpoint"""
|
260 |
url = get_base_url()
|
261 |
-
data = create_chat_request_data(
|
|
|
|
|
262 |
|
263 |
# Send request
|
264 |
response = make_request(url, data)
|
@@ -487,8 +490,7 @@ def test_non_stream_generate() -> None:
|
|
487 |
"""Test non-streaming call to /api/generate endpoint"""
|
488 |
url = get_base_url("generate")
|
489 |
data = create_generate_request_data(
|
490 |
-
CONFIG["test_cases"]["generate"]["query"],
|
491 |
-
stream=False
|
492 |
)
|
493 |
|
494 |
# Send request
|
@@ -504,17 +506,17 @@ def test_non_stream_generate() -> None:
|
|
504 |
{
|
505 |
"model": response_json["model"],
|
506 |
"response": response_json["response"],
|
507 |
-
"done": response_json["done"]
|
508 |
},
|
509 |
-
"Response content"
|
510 |
)
|
511 |
|
|
|
512 |
def test_stream_generate() -> None:
|
513 |
"""Test streaming call to /api/generate endpoint"""
|
514 |
url = get_base_url("generate")
|
515 |
data = create_generate_request_data(
|
516 |
-
CONFIG["test_cases"]["generate"]["query"],
|
517 |
-
stream=True
|
518 |
)
|
519 |
|
520 |
# Send request and get streaming response
|
@@ -530,13 +532,17 @@ def test_stream_generate() -> None:
|
|
530 |
# Decode and parse JSON
|
531 |
data = json.loads(line.decode("utf-8"))
|
532 |
if data.get("done", True): # If it's the completion marker
|
533 |
-
if
|
|
|
|
|
534 |
break
|
535 |
else: # Normal content message
|
536 |
content = data.get("response", "")
|
537 |
if content: # Only collect non-empty content
|
538 |
output_buffer.append(content)
|
539 |
-
print(
|
|
|
|
|
540 |
except json.JSONDecodeError:
|
541 |
print("Error decoding JSON from response line")
|
542 |
finally:
|
@@ -545,13 +551,14 @@ def test_stream_generate() -> None:
|
|
545 |
# Print a newline
|
546 |
print()
|
547 |
|
|
|
548 |
def test_generate_with_system() -> None:
|
549 |
"""Test generate with system prompt"""
|
550 |
url = get_base_url("generate")
|
551 |
data = create_generate_request_data(
|
552 |
CONFIG["test_cases"]["generate"]["query"],
|
553 |
system="你是一个知识渊博的助手",
|
554 |
-
stream=False
|
555 |
)
|
556 |
|
557 |
# Send request
|
@@ -567,15 +574,16 @@ def test_generate_with_system() -> None:
|
|
567 |
{
|
568 |
"model": response_json["model"],
|
569 |
"response": response_json["response"],
|
570 |
-
"done": response_json["done"]
|
571 |
},
|
572 |
-
"Response content"
|
573 |
)
|
574 |
|
|
|
575 |
def test_generate_error_handling() -> None:
|
576 |
"""Test error handling for generate endpoint"""
|
577 |
url = get_base_url("generate")
|
578 |
-
|
579 |
# Test empty prompt
|
580 |
if OutputControl.is_verbose():
|
581 |
print("\n=== Testing empty prompt ===")
|
@@ -583,14 +591,14 @@ def test_generate_error_handling() -> None:
|
|
583 |
response = make_request(url, data)
|
584 |
print(f"Status code: {response.status_code}")
|
585 |
print_json_response(response.json(), "Error message")
|
586 |
-
|
587 |
# Test invalid options
|
588 |
if OutputControl.is_verbose():
|
589 |
print("\n=== Testing invalid options ===")
|
590 |
data = create_generate_request_data(
|
591 |
CONFIG["test_cases"]["basic"]["query"],
|
592 |
options={"invalid_option": "value"},
|
593 |
-
stream=False
|
594 |
)
|
595 |
response = make_request(url, data)
|
596 |
print(f"Status code: {response.status_code}")
|
@@ -602,12 +610,12 @@ def test_generate_concurrent() -> None:
|
|
602 |
import asyncio
|
603 |
import aiohttp
|
604 |
from contextlib import asynccontextmanager
|
605 |
-
|
606 |
@asynccontextmanager
|
607 |
async def get_session():
|
608 |
async with aiohttp.ClientSession() as session:
|
609 |
yield session
|
610 |
-
|
611 |
async def make_request(session, prompt: str):
|
612 |
url = get_base_url("generate")
|
613 |
data = create_generate_request_data(prompt, stream=False)
|
@@ -616,32 +624,27 @@ def test_generate_concurrent() -> None:
|
|
616 |
return await response.json()
|
617 |
except Exception as e:
|
618 |
return {"error": str(e)}
|
619 |
-
|
620 |
async def run_concurrent_requests():
|
621 |
-
prompts = [
|
622 |
-
|
623 |
-
"第二个问题",
|
624 |
-
"第三个问题",
|
625 |
-
"第四个问题",
|
626 |
-
"第五个问题"
|
627 |
-
]
|
628 |
-
|
629 |
async with get_session() as session:
|
630 |
tasks = [make_request(session, prompt) for prompt in prompts]
|
631 |
results = await asyncio.gather(*tasks)
|
632 |
return results
|
633 |
-
|
634 |
if OutputControl.is_verbose():
|
635 |
print("\n=== Testing concurrent generate requests ===")
|
636 |
-
|
637 |
# Run concurrent requests
|
638 |
results = asyncio.run(run_concurrent_requests())
|
639 |
-
|
640 |
# Print results
|
641 |
for i, result in enumerate(results, 1):
|
642 |
print(f"\nRequest {i} result:")
|
643 |
print_json_response(result)
|
644 |
|
|
|
645 |
def get_test_cases() -> Dict[str, Callable]:
|
646 |
"""Get all available test cases
|
647 |
Returns:
|
@@ -657,7 +660,7 @@ def get_test_cases() -> Dict[str, Callable]:
|
|
657 |
"stream_generate": test_stream_generate,
|
658 |
"generate_with_system": test_generate_with_system,
|
659 |
"generate_errors": test_generate_error_handling,
|
660 |
-
"generate_concurrent": test_generate_concurrent
|
661 |
}
|
662 |
|
663 |
|
|
|
110 |
},
|
111 |
"test_cases": {
|
112 |
"basic": {"query": "唐僧有几个徒弟"},
|
113 |
+
"generate": {"query": "电视剧西游记导演是谁"},
|
114 |
},
|
115 |
}
|
116 |
|
|
|
205 |
"stream": stream,
|
206 |
}
|
207 |
|
208 |
+
|
209 |
def create_generate_request_data(
|
210 |
+
prompt: str,
|
211 |
system: str = None,
|
212 |
+
stream: bool = False,
|
213 |
model: str = None,
|
214 |
+
options: Dict[str, Any] = None,
|
215 |
) -> Dict[str, Any]:
|
216 |
"""Create generate request data
|
217 |
Args:
|
|
|
226 |
data = {
|
227 |
"model": model or CONFIG["server"]["model"],
|
228 |
"prompt": prompt,
|
229 |
+
"stream": stream,
|
230 |
}
|
231 |
if system:
|
232 |
data["system"] = system
|
|
|
259 |
def test_non_stream_chat() -> None:
|
260 |
"""Test non-streaming call to /api/chat endpoint"""
|
261 |
url = get_base_url()
|
262 |
+
data = create_chat_request_data(
|
263 |
+
CONFIG["test_cases"]["basic"]["query"], stream=False
|
264 |
+
)
|
265 |
|
266 |
# Send request
|
267 |
response = make_request(url, data)
|
|
|
490 |
"""Test non-streaming call to /api/generate endpoint"""
|
491 |
url = get_base_url("generate")
|
492 |
data = create_generate_request_data(
|
493 |
+
CONFIG["test_cases"]["generate"]["query"], stream=False
|
|
|
494 |
)
|
495 |
|
496 |
# Send request
|
|
|
506 |
{
|
507 |
"model": response_json["model"],
|
508 |
"response": response_json["response"],
|
509 |
+
"done": response_json["done"],
|
510 |
},
|
511 |
+
"Response content",
|
512 |
)
|
513 |
|
514 |
+
|
515 |
def test_stream_generate() -> None:
|
516 |
"""Test streaming call to /api/generate endpoint"""
|
517 |
url = get_base_url("generate")
|
518 |
data = create_generate_request_data(
|
519 |
+
CONFIG["test_cases"]["generate"]["query"], stream=True
|
|
|
520 |
)
|
521 |
|
522 |
# Send request and get streaming response
|
|
|
532 |
# Decode and parse JSON
|
533 |
data = json.loads(line.decode("utf-8"))
|
534 |
if data.get("done", True): # If it's the completion marker
|
535 |
+
if (
|
536 |
+
"total_duration" in data
|
537 |
+
): # Final performance statistics message
|
538 |
break
|
539 |
else: # Normal content message
|
540 |
content = data.get("response", "")
|
541 |
if content: # Only collect non-empty content
|
542 |
output_buffer.append(content)
|
543 |
+
print(
|
544 |
+
content, end="", flush=True
|
545 |
+
) # Print content in real-time
|
546 |
except json.JSONDecodeError:
|
547 |
print("Error decoding JSON from response line")
|
548 |
finally:
|
|
|
551 |
# Print a newline
|
552 |
print()
|
553 |
|
554 |
+
|
555 |
def test_generate_with_system() -> None:
|
556 |
"""Test generate with system prompt"""
|
557 |
url = get_base_url("generate")
|
558 |
data = create_generate_request_data(
|
559 |
CONFIG["test_cases"]["generate"]["query"],
|
560 |
system="你是一个知识渊博的助手",
|
561 |
+
stream=False,
|
562 |
)
|
563 |
|
564 |
# Send request
|
|
|
574 |
{
|
575 |
"model": response_json["model"],
|
576 |
"response": response_json["response"],
|
577 |
+
"done": response_json["done"],
|
578 |
},
|
579 |
+
"Response content",
|
580 |
)
|
581 |
|
582 |
+
|
583 |
def test_generate_error_handling() -> None:
|
584 |
"""Test error handling for generate endpoint"""
|
585 |
url = get_base_url("generate")
|
586 |
+
|
587 |
# Test empty prompt
|
588 |
if OutputControl.is_verbose():
|
589 |
print("\n=== Testing empty prompt ===")
|
|
|
591 |
response = make_request(url, data)
|
592 |
print(f"Status code: {response.status_code}")
|
593 |
print_json_response(response.json(), "Error message")
|
594 |
+
|
595 |
# Test invalid options
|
596 |
if OutputControl.is_verbose():
|
597 |
print("\n=== Testing invalid options ===")
|
598 |
data = create_generate_request_data(
|
599 |
CONFIG["test_cases"]["basic"]["query"],
|
600 |
options={"invalid_option": "value"},
|
601 |
+
stream=False,
|
602 |
)
|
603 |
response = make_request(url, data)
|
604 |
print(f"Status code: {response.status_code}")
|
|
|
610 |
import asyncio
|
611 |
import aiohttp
|
612 |
from contextlib import asynccontextmanager
|
613 |
+
|
614 |
@asynccontextmanager
|
615 |
async def get_session():
|
616 |
async with aiohttp.ClientSession() as session:
|
617 |
yield session
|
618 |
+
|
619 |
async def make_request(session, prompt: str):
|
620 |
url = get_base_url("generate")
|
621 |
data = create_generate_request_data(prompt, stream=False)
|
|
|
624 |
return await response.json()
|
625 |
except Exception as e:
|
626 |
return {"error": str(e)}
|
627 |
+
|
628 |
async def run_concurrent_requests():
|
629 |
+
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
|
630 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
async with get_session() as session:
|
632 |
tasks = [make_request(session, prompt) for prompt in prompts]
|
633 |
results = await asyncio.gather(*tasks)
|
634 |
return results
|
635 |
+
|
636 |
if OutputControl.is_verbose():
|
637 |
print("\n=== Testing concurrent generate requests ===")
|
638 |
+
|
639 |
# Run concurrent requests
|
640 |
results = asyncio.run(run_concurrent_requests())
|
641 |
+
|
642 |
# Print results
|
643 |
for i, result in enumerate(results, 1):
|
644 |
print(f"\nRequest {i} result:")
|
645 |
print_json_response(result)
|
646 |
|
647 |
+
|
648 |
def get_test_cases() -> Dict[str, Callable]:
|
649 |
"""Get all available test cases
|
650 |
Returns:
|
|
|
660 |
"stream_generate": test_stream_generate,
|
661 |
"generate_with_system": test_generate_with_system,
|
662 |
"generate_errors": test_generate_error_handling,
|
663 |
+
"generate_concurrent": test_generate_concurrent,
|
664 |
}
|
665 |
|
666 |
|