ccm's picture
Cleaning up imports
b0394f8
"""
OpenAI-compatible FastAPI proxy that wraps a smolagents CodeAgent
Refactored for readability and modularity (single-file).
"""
import logging # For logging
import os # For dealing with env vars
import typing # For type annotations
import fastapi
import fastapi.responses
# Upstream pass-through + local helpers
from agent_server.agent_streaming import (
proxy_upstream_chat_completions,
)
from agent_server.chat_completions import (
normalize_model_name,
is_upstream_passthrough,
is_upstream_passthrough_nothink,
apply_nothink_to_body,
agent_for_model,
make_sse_generator,
run_non_streaming,
)
from agent_server.helpers import (
messages_to_task,
openai_response,
sse_headers,
)
from agent_server.models import models_payload
from agent_server.openai_schemas import ChatMessage, ChatCompletionRequest
# Local agent factories
# --------------------------------------------------------------------------------------
# Logging / Config
# --------------------------------------------------------------------------------------
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
log = logging.getLogger(__name__)
# --------------------------------------------------------------------------------------
# FastAPI app
# --------------------------------------------------------------------------------------
app = fastapi.FastAPI()
# --------------------------------------------------------------------------------------
# HTTP Handlers (thin wrappers around helpers)
# --------------------------------------------------------------------------------------
@app.get("/healthz")
async def healthz():
return {"ok": True}
@app.get("/v1/models")
async def list_models():
return models_payload()
@app.post("/v1/chat/completions")
async def chat_completions(req: fastapi.Request):
# ---------------- Parse & basic validation ----------------
try:
body: ChatCompletionRequest = typing.cast(
ChatCompletionRequest, await req.json()
)
except Exception as e:
return fastapi.responses.JSONResponse(
{"error": {"message": f"Invalid JSON: {e}"}}, status_code=400
)
messages: typing.List[ChatMessage] = typing.cast(
typing.List[ChatMessage], body.get("messages") or []
)
stream: bool = bool(body.get("stream", False))
model_name: str = normalize_model_name(body.get("model"))
try:
# ---------------- Upstream pass-through modes ----------------
if is_upstream_passthrough(model_name):
# Raw pass-through to upstream
return await proxy_upstream_chat_completions(dict(body), stream)
if is_upstream_passthrough_nothink(model_name):
# Modify body for /nothink and forward to upstream
return await proxy_upstream_chat_completions(
apply_nothink_to_body(body, messages), stream, scrub_think=True
)
# ---------------- Local agent execution ----------------
# Convert OpenAI messages -> internal "task"
task: str = messages_to_task(messages)
# Create agent impl for the requested local model
agent_for_request = agent_for_model(model_name)
if stream:
# Streaming: return SSE response
gen = make_sse_generator(task, agent_for_request, model_name)
return fastapi.responses.StreamingResponse(
gen(), media_type="text/event-stream", headers=sse_headers()
)
else:
# Non-streaming: materialize final text and wrap in OpenAI shape
result_text = await run_non_streaming(task, agent_for_request)
return fastapi.responses.JSONResponse(
openai_response(result_text, model_name)
)
except ValueError as ve:
# Unknown model or other parameter validation errors
log.error("Invalid request: %s", ve)
return fastapi.responses.JSONResponse(
status_code=400,
content={"error": {"message": str(ve), "type": "invalid_request_error"}},
)
except Exception as e:
# Operational / agent runtime errors
msg = str(e)
status = 503 if "503" in msg or "Service Unavailable" in msg else 500
log.error("Agent error (%s): %s", status, msg)
return fastapi.responses.JSONResponse(
status_code=status,
content={
"error": {"message": f"Agent error: {msg}", "type": "agent_error"}
},
)
# --------------------------------------------------------------------------------------
# Local dev entrypoint
# --------------------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")), reload=False
)