yangdx
commited on
Commit
·
ead5d24
1
Parent(s):
a42342a
临时保存
Browse files- lightrag/api/lightrag_ollama.py +98 -28
lightrag/api/lightrag_ollama.py
CHANGED
@@ -472,10 +472,25 @@ def create_app(args):
|
|
472 |
)
|
473 |
|
474 |
if request.stream:
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
else:
|
480 |
return QueryResponse(response=response)
|
481 |
except Exception as e:
|
@@ -484,7 +499,7 @@ def create_app(args):
|
|
484 |
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
485 |
async def query_text_stream(request: QueryRequest):
|
486 |
try:
|
487 |
-
response = rag.query
|
488 |
request.query,
|
489 |
param=QueryParam(
|
490 |
mode=request.mode,
|
@@ -493,11 +508,24 @@ def create_app(args):
|
|
493 |
),
|
494 |
)
|
495 |
|
|
|
|
|
496 |
async def stream_generator():
|
497 |
async for chunk in response:
|
498 |
-
yield chunk
|
499 |
-
|
500 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
except Exception as e:
|
502 |
raise HTTPException(status_code=500, detail=str(e))
|
503 |
|
@@ -659,20 +687,48 @@ def create_app(args):
|
|
659 |
cleaned_query, mode = parse_query_mode(query)
|
660 |
|
661 |
# 调用RAG进行查询
|
|
|
|
|
|
|
|
|
|
|
|
|
662 |
if request.stream:
|
663 |
-
|
|
|
|
|
|
|
664 |
cleaned_query,
|
665 |
-
param=
|
666 |
-
mode=mode,
|
667 |
-
stream=True,
|
668 |
-
only_need_context=False
|
669 |
-
),
|
670 |
)
|
671 |
|
672 |
async def stream_generator():
|
673 |
try:
|
674 |
-
|
675 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
"model": LIGHTRAG_MODEL,
|
677 |
"created_at": LIGHTRAG_CREATED_AT,
|
678 |
"message": {
|
@@ -681,7 +737,10 @@ def create_app(args):
|
|
681 |
},
|
682 |
"done": False
|
683 |
}
|
684 |
-
|
|
|
|
|
|
|
685 |
"model": LIGHTRAG_MODEL,
|
686 |
"created_at": LIGHTRAG_CREATED_AT,
|
687 |
"message": {
|
@@ -690,30 +749,41 @@ def create_app(args):
|
|
690 |
},
|
691 |
"done": True
|
692 |
}
|
|
|
693 |
except Exception as e:
|
694 |
logging.error(f"Error in stream_generator: {str(e)}")
|
695 |
raise
|
696 |
-
|
697 |
-
import json
|
698 |
return StreamingResponse(
|
699 |
-
|
700 |
-
media_type="text/event-stream"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
701 |
)
|
702 |
else:
|
703 |
-
|
|
|
704 |
cleaned_query,
|
705 |
-
param=
|
706 |
-
mode=mode,
|
707 |
-
stream=False,
|
708 |
-
only_need_context=False
|
709 |
-
),
|
710 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
711 |
return OllamaChatResponse(
|
712 |
model=LIGHTRAG_MODEL,
|
713 |
created_at=LIGHTRAG_CREATED_AT,
|
714 |
message=OllamaMessage(
|
715 |
role="assistant",
|
716 |
-
content=
|
717 |
),
|
718 |
done=True
|
719 |
)
|
|
|
472 |
)
|
473 |
|
474 |
if request.stream:
|
475 |
+
from fastapi.responses import StreamingResponse
|
476 |
+
import json
|
477 |
+
|
478 |
+
async def stream_generator():
|
479 |
+
async for chunk in response:
|
480 |
+
yield f"data: {json.dumps({'response': chunk})}\n\n"
|
481 |
+
|
482 |
+
return StreamingResponse(
|
483 |
+
stream_generator(),
|
484 |
+
media_type="text/event-stream",
|
485 |
+
headers={
|
486 |
+
"Cache-Control": "no-cache",
|
487 |
+
"Connection": "keep-alive",
|
488 |
+
"Content-Type": "text/event-stream",
|
489 |
+
"Access-Control-Allow-Origin": "*",
|
490 |
+
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
491 |
+
"Access-Control-Allow-Headers": "Content-Type"
|
492 |
+
}
|
493 |
+
)
|
494 |
else:
|
495 |
return QueryResponse(response=response)
|
496 |
except Exception as e:
|
|
|
499 |
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
500 |
async def query_text_stream(request: QueryRequest):
|
501 |
try:
|
502 |
+
response = await rag.aquery( # 使用 aquery 而不是 query,并添加 await
|
503 |
request.query,
|
504 |
param=QueryParam(
|
505 |
mode=request.mode,
|
|
|
508 |
),
|
509 |
)
|
510 |
|
511 |
+
from fastapi.responses import StreamingResponse
|
512 |
+
|
513 |
async def stream_generator():
|
514 |
async for chunk in response:
|
515 |
+
yield f"data: {chunk}\n\n"
|
516 |
+
|
517 |
+
return StreamingResponse(
|
518 |
+
stream_generator(),
|
519 |
+
media_type="text/event-stream",
|
520 |
+
headers={
|
521 |
+
"Cache-Control": "no-cache",
|
522 |
+
"Connection": "keep-alive",
|
523 |
+
"Content-Type": "text/event-stream",
|
524 |
+
"Access-Control-Allow-Origin": "*",
|
525 |
+
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
526 |
+
"Access-Control-Allow-Headers": "Content-Type"
|
527 |
+
}
|
528 |
+
)
|
529 |
except Exception as e:
|
530 |
raise HTTPException(status_code=500, detail=str(e))
|
531 |
|
|
|
687 |
cleaned_query, mode = parse_query_mode(query)
|
688 |
|
689 |
# 调用RAG进行查询
|
690 |
+
query_param = QueryParam(
|
691 |
+
mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid
|
692 |
+
stream=request.stream,
|
693 |
+
only_need_context=False
|
694 |
+
)
|
695 |
+
|
696 |
if request.stream:
|
697 |
+
from fastapi.responses import StreamingResponse
|
698 |
+
import json
|
699 |
+
|
700 |
+
response = await rag.aquery( # 需要 await 来获取异步生成器
|
701 |
cleaned_query,
|
702 |
+
param=query_param
|
|
|
|
|
|
|
|
|
703 |
)
|
704 |
|
705 |
async def stream_generator():
|
706 |
try:
|
707 |
+
# 确保 response 是异步生成器
|
708 |
+
if isinstance(response, str):
|
709 |
+
data = {
|
710 |
+
'model': LIGHTRAG_MODEL,
|
711 |
+
'created_at': LIGHTRAG_CREATED_AT,
|
712 |
+
'message': {
|
713 |
+
'role': 'assistant',
|
714 |
+
'content': response
|
715 |
+
},
|
716 |
+
'done': True
|
717 |
+
}
|
718 |
+
yield f"data: {json.dumps(data)}\n\n"
|
719 |
+
else:
|
720 |
+
async for chunk in response:
|
721 |
+
data = {
|
722 |
+
"model": LIGHTRAG_MODEL,
|
723 |
+
"created_at": LIGHTRAG_CREATED_AT,
|
724 |
+
"message": {
|
725 |
+
"role": "assistant",
|
726 |
+
"content": chunk
|
727 |
+
},
|
728 |
+
"done": False
|
729 |
+
}
|
730 |
+
yield f"data: {json.dumps(data)}\n\n"
|
731 |
+
data = {
|
732 |
"model": LIGHTRAG_MODEL,
|
733 |
"created_at": LIGHTRAG_CREATED_AT,
|
734 |
"message": {
|
|
|
737 |
},
|
738 |
"done": False
|
739 |
}
|
740 |
+
yield f"data: {json.dumps(data)}\n\n"
|
741 |
+
|
742 |
+
# 发送完成标记
|
743 |
+
data = {
|
744 |
"model": LIGHTRAG_MODEL,
|
745 |
"created_at": LIGHTRAG_CREATED_AT,
|
746 |
"message": {
|
|
|
749 |
},
|
750 |
"done": True
|
751 |
}
|
752 |
+
yield f"data: {json.dumps(data)}\n\n"
|
753 |
except Exception as e:
|
754 |
logging.error(f"Error in stream_generator: {str(e)}")
|
755 |
raise
|
756 |
+
|
|
|
757 |
return StreamingResponse(
|
758 |
+
stream_generator(),
|
759 |
+
media_type="text/event-stream",
|
760 |
+
headers={
|
761 |
+
"Cache-Control": "no-cache",
|
762 |
+
"Connection": "keep-alive",
|
763 |
+
"Content-Type": "text/event-stream",
|
764 |
+
"Access-Control-Allow-Origin": "*",
|
765 |
+
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
766 |
+
"Access-Control-Allow-Headers": "Content-Type"
|
767 |
+
}
|
768 |
)
|
769 |
else:
|
770 |
+
# 非流式响应
|
771 |
+
response_text = await rag.aquery(
|
772 |
cleaned_query,
|
773 |
+
param=query_param
|
|
|
|
|
|
|
|
|
774 |
)
|
775 |
+
|
776 |
+
# 确保响应不为空
|
777 |
+
if not response_text:
|
778 |
+
response_text = "No response generated"
|
779 |
+
|
780 |
+
# 构造并返回响应
|
781 |
return OllamaChatResponse(
|
782 |
model=LIGHTRAG_MODEL,
|
783 |
created_at=LIGHTRAG_CREATED_AT,
|
784 |
message=OllamaMessage(
|
785 |
role="assistant",
|
786 |
+
content=str(response_text) # 确保转换为字符串
|
787 |
),
|
788 |
done=True
|
789 |
)
|