Spaces:
Sleeping
Sleeping
| """ | |
| OpenAI-compatible FastAPI proxy that wraps a smolagents CodeAgent | |
| Refactored for readability and modularity (single-file). | |
| """ | |
| import logging # For logging | |
| import os # For dealing with env vars | |
| import typing # For type annotations | |
| import fastapi | |
| import fastapi.responses | |
| # Upstream pass-through + local helpers | |
| from agent_server.agent_streaming import ( | |
| proxy_upstream_chat_completions, | |
| ) | |
| from agent_server.chat_completions import ( | |
| normalize_model_name, | |
| is_upstream_passthrough, | |
| is_upstream_passthrough_nothink, | |
| apply_nothink_to_body, | |
| agent_for_model, | |
| make_sse_generator, | |
| run_non_streaming, | |
| ) | |
| from agent_server.helpers import ( | |
| messages_to_task, | |
| openai_response, | |
| sse_headers, | |
| ) | |
| from agent_server.models import models_payload | |
| from agent_server.openai_schemas import ChatMessage, ChatCompletionRequest | |
| # Local agent factories | |
| # -------------------------------------------------------------------------------------- | |
| # Logging / Config | |
| # -------------------------------------------------------------------------------------- | |
| logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper()) | |
| log = logging.getLogger(__name__) | |
| # -------------------------------------------------------------------------------------- | |
| # FastAPI app | |
| # -------------------------------------------------------------------------------------- | |
| app = fastapi.FastAPI() | |
| # -------------------------------------------------------------------------------------- | |
| # HTTP Handlers (thin wrappers around helpers) | |
| # -------------------------------------------------------------------------------------- | |
| async def healthz(): | |
| return {"ok": True} | |
| async def list_models(): | |
| return models_payload() | |
| async def chat_completions(req: fastapi.Request): | |
| # ---------------- Parse & basic validation ---------------- | |
| try: | |
| body: ChatCompletionRequest = typing.cast( | |
| ChatCompletionRequest, await req.json() | |
| ) | |
| except Exception as e: | |
| return fastapi.responses.JSONResponse( | |
| {"error": {"message": f"Invalid JSON: {e}"}}, status_code=400 | |
| ) | |
| messages: typing.List[ChatMessage] = typing.cast( | |
| typing.List[ChatMessage], body.get("messages") or [] | |
| ) | |
| stream: bool = bool(body.get("stream", False)) | |
| model_name: str = normalize_model_name(body.get("model")) | |
| try: | |
| # ---------------- Upstream pass-through modes ---------------- | |
| if is_upstream_passthrough(model_name): | |
| # Raw pass-through to upstream | |
| return await proxy_upstream_chat_completions(dict(body), stream) | |
| if is_upstream_passthrough_nothink(model_name): | |
| # Modify body for /nothink and forward to upstream | |
| return await proxy_upstream_chat_completions( | |
| apply_nothink_to_body(body, messages), stream, scrub_think=True | |
| ) | |
| # ---------------- Local agent execution ---------------- | |
| # Convert OpenAI messages -> internal "task" | |
| task: str = messages_to_task(messages) | |
| # Create agent impl for the requested local model | |
| agent_for_request = agent_for_model(model_name) | |
| if stream: | |
| # Streaming: return SSE response | |
| gen = make_sse_generator(task, agent_for_request, model_name) | |
| return fastapi.responses.StreamingResponse( | |
| gen(), media_type="text/event-stream", headers=sse_headers() | |
| ) | |
| else: | |
| # Non-streaming: materialize final text and wrap in OpenAI shape | |
| result_text = await run_non_streaming(task, agent_for_request) | |
| return fastapi.responses.JSONResponse( | |
| openai_response(result_text, model_name) | |
| ) | |
| except ValueError as ve: | |
| # Unknown model or other parameter validation errors | |
| log.error("Invalid request: %s", ve) | |
| return fastapi.responses.JSONResponse( | |
| status_code=400, | |
| content={"error": {"message": str(ve), "type": "invalid_request_error"}}, | |
| ) | |
| except Exception as e: | |
| # Operational / agent runtime errors | |
| msg = str(e) | |
| status = 503 if "503" in msg or "Service Unavailable" in msg else 500 | |
| log.error("Agent error (%s): %s", status, msg) | |
| return fastapi.responses.JSONResponse( | |
| status_code=status, | |
| content={ | |
| "error": {"message": f"Agent error: {msg}", "type": "agent_error"} | |
| }, | |
| ) | |
| # -------------------------------------------------------------------------------------- | |
| # Local dev entrypoint | |
| # -------------------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")), reload=False | |
| ) | |