"""
OpenAI-compatible FastAPI proxy that wraps a smolagents CodeAgent
"""
import os # For dealing with env vars
import re # For tag stripping
import json # For JSON handling
import time # For timestamps and sleeps
import asyncio # For async operations
import typing # For type annotations
import logging # For logging
import threading # For threading operations
import fastapi
import fastapi.responses
import io
import contextlib
# Upstream pass-through
import httpx
from agents.code_writing_agent import create_code_writing_agent
# Logging setup
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
log = logging.getLogger(__name__)
# Config from env vars
UPSTREAM_BASE = os.getenv("UPSTREAM_OPENAI_BASE", "").rstrip("/")
HF_TOKEN = (
os.getenv("HF_TOKEN")
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
or os.getenv("API_TOKEN")
or ""
)
AGENT_MODEL = os.getenv("AGENT_MODEL", "Qwen/Qwen3-1.7B")
if not UPSTREAM_BASE:
log.warning(
"UPSTREAM_OPENAI_BASE is empty; OpenAI-compatible upstream calls will fail."
)
if not HF_TOKEN:
log.warning("HF_TOKEN is empty; upstream may 401/403 if it requires auth.")
# ================== FastAPI ==================
app = fastapi.FastAPI()
@app.get("/healthz")
async def healthz():
return {"ok": True}
# ---------- OpenAI-compatible minimal schemas ----------
class ChatMessage(typing.TypedDict, total=False):
role: str
content: typing.Any # str or multimodal list
class ChatCompletionRequest(typing.TypedDict, total=False):
model: typing.Optional[str]
messages: typing.List[ChatMessage]
temperature: typing.Optional[float]
stream: typing.Optional[bool]
max_tokens: typing.Optional[int]
# ---------- Helpers ----------
def normalize_content_to_text(content: typing.Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, (bytes, bytearray)):
try:
return content.decode("utf-8", errors="ignore")
except Exception:
return str(content)
if isinstance(content, list):
parts = []
for item in content:
if (
isinstance(item, dict)
and item.get("type") == "text"
and isinstance(item.get("text"), str)
):
parts.append(item["text"])
else:
try:
parts.append(json.dumps(item, ensure_ascii=False))
except Exception:
parts.append(str(item))
return "\n".join(parts)
if isinstance(content, dict):
try:
return json.dumps(content, ensure_ascii=False)
except Exception:
return str(content)
return str(content)
def _messages_to_task(messages: typing.List[ChatMessage]) -> str:
system_parts = [
normalize_content_to_text(m.get("content", ""))
for m in messages
if m.get("role") == "system"
]
user_parts = [
normalize_content_to_text(m.get("content", ""))
for m in messages
if m.get("role") == "user"
]
assistant_parts = [
normalize_content_to_text(m.get("content", ""))
for m in messages
if m.get("role") == "assistant"
]
sys_txt = "\n".join([s for s in system_parts if s]).strip()
history = ""
if assistant_parts:
history = "\n\nPrevious assistant replies (for context):\n" + "\n---\n".join(
assistant_parts
)
last_user = user_parts[-1] if user_parts else ""
prefix = (
"You are a very small agent with only a Python REPL tool.\n"
"Prefer short, correct answers. If Python is unnecessary, just answer plainly.\n"
"If you do use Python, print only final results—no extra logs.\n"
)
if sys_txt:
prefix = f"{sys_txt}\n\n{prefix}"
return f"{prefix}\nTask:\n{last_user}\n{history}".strip()
def _openai_response(
message_text: str, model_name: str
) -> typing.Dict[str, typing.Any]:
now = int(time.time())
return {
"id": f"chatcmpl-smol-{now}",
"object": "chat.completion",
"created": now,
"model": model_name,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": message_text},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
def _sse_headers() -> dict:
return {
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
# ---------- Sanitizer: remove think/thank tags from LLM-originated text ----------
_THINK_TAG_RE = re.compile(r"?\s*think\b[^>]*>", flags=re.IGNORECASE)
_THANK_TAG_RE = re.compile(r"?\s*thank\b[^>]*>", flags=re.IGNORECASE) # typo safety
_ESC_THINK_TAG_RE = re.compile(r"</?\s*think\b[^&]*>", flags=re.IGNORECASE)
_ESC_THANK_TAG_RE = re.compile(r"</?\s*thank\b[^&]*>", flags=re.IGNORECASE)
def scrub_think_tags(text: typing.Any) -> str:
"""
Remove literal and HTML-escaped / (and variants) tags.
Content inside the tags is preserved; only the tags are stripped.
"""
if not isinstance(text, str):
try:
text = str(text)
except Exception:
return ""
t = _THINK_TAG_RE.sub("", text)
t = _THANK_TAG_RE.sub("", t)
t = _ESC_THINK_TAG_RE.sub("", t)
t = _ESC_THANK_TAG_RE.sub("", t)
return t
# ---------- Reasoning formatting for Chat-UI ----------
def _format_reasoning_chunk(text: str, tag: str, idx: int) -> str:
"""
Lightweight formatter for reasoning stream. Avoid huge code fences;
make it readable and incremental. Also filters out ASCII/box-drawing noise.
"""
text = scrub_think_tags(text).rstrip("\n")
if not text:
return ""
noisy_prefixes = (
"OpenAIServerModel",
"Output message of the LLM",
"─ Executing parsed code",
"New run",
"╭",
"╰",
"│",
"━",
"─",
)
stripped = text.strip()
if not stripped:
return ""
# Lines made mostly of box drawing/separators
if all(ch in " ─━╭╮╰╯│═·—-_=+•" for ch in stripped):
return ""
if any(stripped.startswith(p) for p in noisy_prefixes):
return ""
# Excessively long lines with little signal (no alphanumerics)
if len(stripped) > 240 and not re.search(r"[A-Za-z0-9]{3,}", stripped):
return ""
# No tag/idx prefix; add a trailing blank line for readability in markdown
return f"{stripped}\n\n"
def _extract_final_text(item: typing.Any) -> typing.Optional[str]:
if isinstance(item, dict) and ("__stdout__" in item or "__step__" in item):
return None
if isinstance(item, (bytes, bytearray)):
try:
item = item.decode("utf-8", errors="ignore")
except Exception:
item = str(item)
if isinstance(item, str):
s = scrub_think_tags(item.strip())
return s or None
# If it's a step-like object with an 'output' attribute, use that
try:
if not isinstance(item, (dict, list, bytes, bytearray)):
out = getattr(item, "output", None)
if out is not None:
s = scrub_think_tags(str(out)).strip()
if s:
return s
except Exception:
pass
if isinstance(item, dict):
for key in ("content", "text", "message", "output", "final", "answer"):
if key in item:
val = item[key]
if isinstance(val, (dict, list)):
try:
return scrub_think_tags(json.dumps(val, ensure_ascii=False))
except Exception:
return scrub_think_tags(str(val))
if isinstance(val, (bytes, bytearray)):
try:
val = val.decode("utf-8", errors="ignore")
except Exception:
val = str(val)
s = scrub_think_tags(str(val).strip())
return s or None
try:
return scrub_think_tags(json.dumps(item, ensure_ascii=False))
except Exception:
return scrub_think_tags(str(item))
try:
return scrub_think_tags(str(item))
except Exception:
return None
# Helper to parse explicit "Final answer:" from stdout lines
_FINAL_RE = re.compile(r"(?:^|\\b)Final\\s+answer:\\s*(.+)$", flags=re.IGNORECASE)
def _maybe_parse_final_from_stdout(line: str) -> typing.Optional[str]:
if not isinstance(line, str):
return None
m = _FINAL_RE.search(line.strip())
if not m:
return None
return scrub_think_tags(m.group(1)).strip() or None
# ---------- Live stdout/stderr tee ----------
class QueueWriter(io.TextIOBase):
"""
File-like object that pushes each write to an asyncio.Queue immediately.
"""
def __init__(self, q: "asyncio.Queue"):
self.q = q
self._lock = threading.Lock()
self._buf = [] # accumulate until newline to reduce spam
def write(self, s: str):
if not s:
return 0
with self._lock:
self._buf.append(s)
# flush on newline to keep granularity reasonable
if "\n" in s:
chunk = "".join(self._buf)
self._buf.clear()
try:
self.q.put_nowait({"__stdout__": chunk})
except Exception:
pass
return len(s)
def flush(self):
with self._lock:
if self._buf:
chunk = "".join(self._buf)
self._buf.clear()
try:
self.q.put_nowait({"__stdout__": chunk})
except Exception:
pass
def _serialize_step(step) -> str:
"""
Best-effort pretty string for a smolagents MemoryStep / ActionStep.
Works even if attributes are missing on some versions.
"""
parts = []
sn = getattr(step, "step_number", None)
if sn is not None:
parts.append(f"Step {sn}")
thought_val = getattr(step, "thought", None)
if thought_val:
parts.append(f"Thought: {scrub_think_tags(str(thought_val))}")
tool_val = getattr(step, "tool", None)
if tool_val:
parts.append(f"Tool: {scrub_think_tags(str(tool_val))}")
code_val = getattr(step, "code", None)
if code_val:
code_str = scrub_think_tags(str(code_val)).strip()
parts.append("```python\n" + code_str + "\n```")
args = getattr(step, "args", None)
if args:
try:
parts.append(
"Args: " + scrub_think_tags(json.dumps(args, ensure_ascii=False))
)
except Exception:
parts.append("Args: " + scrub_think_tags(str(args)))
error = getattr(step, "error", None)
if error:
parts.append(f"Error: {scrub_think_tags(str(error))}")
obs = getattr(step, "observations", None)
if obs is not None:
if isinstance(obs, (list, tuple)):
obs_str = "\n".join(map(str, obs))
else:
obs_str = str(obs)
parts.append("Observation:\n" + scrub_think_tags(obs_str).strip())
# If this looks like a FinalAnswer step object, surface a clean final answer
try:
tname = type(step).__name__
except Exception:
tname = ""
if tname.lower().startswith("finalanswer"):
out = getattr(step, "output", None)
if out is not None:
return f"Final answer: {scrub_think_tags(str(out)).strip()}"
# Fallback: try to parse from string repr "FinalAnswerStep(output=...)"
s = scrub_think_tags(str(step))
m = re.search(r"FinalAnswer[^()]*\(\s*output\s*=\s*([^,)]+)", s)
if m:
return f"Final answer: {m.group(1).strip()}"
# If the only content would be an object repr like FinalAnswerStep(...), drop it;
# a cleaner "Final answer: ..." will come from the rule above or stdout.
joined = "\n".join(parts).strip()
if re.match(r"^FinalAnswer[^\n]+\)$", joined):
return ""
return joined or scrub_think_tags(str(step))
# ---------- Agent streaming bridge (truly live) ----------
async def run_agent_stream(task: str, agent_obj: typing.Optional[typing.Any] = None):
"""
Start the agent in a worker thread.
Stream THREE sources of incremental data into the async generator:
(1) live stdout/stderr lines,
(2) newly appended memory steps (polled),
(3) any iterable the agent may yield (if supported).
Finally emit a __final__ item with the last answer.
"""
loop = asyncio.get_running_loop()
q: asyncio.Queue = asyncio.Queue()
agent_to_use = agent_obj or create_code_writing_agent
stop_evt = threading.Event()
# 1) stdout/stderr live tee
qwriter = QueueWriter(q)
# 2) memory poller
def poll_memory():
last_len = 0
while not stop_evt.is_set():
try:
steps = []
try:
# Common API: agent.memory.get_full_steps()
steps = agent_to_use.memory.get_full_steps() # type: ignore[attr-defined]
except Exception:
# Fallbacks: different names across versions
steps = (
getattr(agent_to_use, "steps", [])
or getattr(agent_to_use, "memory", [])
or []
)
if steps is None:
steps = []
curr_len = len(steps)
if curr_len > last_len:
new = steps[last_len:curr_len]
last_len = curr_len
for s in new:
s_text = _serialize_step(s)
if s_text:
try:
q.put_nowait({"__step__": s_text})
except Exception:
pass
except Exception:
pass
time.sleep(0.10) # 100 ms cadence
# 3) agent runner (may or may not yield)
def run_agent():
final_result = None
try:
with contextlib.redirect_stdout(qwriter), contextlib.redirect_stderr(
qwriter
):
used_iterable = False
if hasattr(agent_to_use, "run") and callable(
getattr(agent_to_use, "run")
):
try:
res = agent_to_use.run(task, stream=True)
if hasattr(res, "__iter__") and not isinstance(
res, (str, bytes)
):
used_iterable = True
for it in res:
try:
q.put_nowait(it)
except Exception:
pass
final_result = (
None # iterable may already contain the answer
)
else:
final_result = res
except TypeError:
# run(stream=True) not supported -> fall back
pass
if final_result is None and not used_iterable:
# Try other common streaming signatures
for name in (
"run_stream",
"stream",
"stream_run",
"run_with_callback",
):
fn = getattr(agent_to_use, name, None)
if callable(fn):
try:
res = fn(task)
if hasattr(res, "__iter__") and not isinstance(
res, (str, bytes)
):
for it in res:
q.put_nowait(it)
final_result = None
else:
final_result = res
break
except TypeError:
# maybe callback signature
def cb(item):
try:
q.put_nowait(item)
except Exception:
pass
try:
fn(task, cb)
final_result = None
break
except Exception:
continue
if final_result is None and not used_iterable:
pass # (typo guard removed below)
if final_result is None and not used_iterable:
# Last resort: synchronous run()/generate()/callable
if hasattr(agent_to_use, "run") and callable(
getattr(agent_to_use, "run")
):
final_result = agent_to_use.run(task)
elif hasattr(agent_to_use, "generate") and callable(
getattr(agent_to_use, "generate")
):
final_result = agent_to_use.generate(task)
elif callable(agent_to_use):
final_result = agent_to_use(task)
except Exception as e:
try:
qwriter.flush()
except Exception:
pass
try:
q.put_nowait({"__error__": str(e)})
except Exception:
pass
finally:
try:
qwriter.flush()
except Exception:
pass
try:
q.put_nowait({"__final__": final_result})
except Exception:
pass
stop_evt.set()
# Kick off threads
mem_thread = threading.Thread(target=poll_memory, daemon=True)
run_thread = threading.Thread(target=run_agent, daemon=True)
mem_thread.start()
run_thread.start()
# Async consumer
while True:
item = await q.get()
yield item
if isinstance(item, dict) and "__final__" in item:
break
def _recursively_scrub(obj):
if isinstance(obj, str):
return scrub_think_tags(obj)
if isinstance(obj, dict):
return {k: _recursively_scrub(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_recursively_scrub(v) for v in obj]
return obj
async def _proxy_upstream_chat_completions(
body: dict, stream: bool, scrub_think: bool = False
):
if not UPSTREAM_BASE:
return fastapi.responses.JSONResponse(
{"error": {"message": "UPSTREAM_OPENAI_BASE not configured"}},
status_code=500,
)
headers = {
"Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else "",
"Content-Type": "application/json",
}
url = f"{UPSTREAM_BASE}/chat/completions"
if stream:
async def proxy_stream():
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST", url, headers=headers, json=body
) as resp:
resp.raise_for_status()
if scrub_think:
# Pull text segments, scrub tags, and yield bytes
async for txt in resp.aiter_text():
try:
cleaned = scrub_think_tags(txt)
yield cleaned.encode("utf-8")
except Exception:
yield txt.encode("utf-8")
else:
async for chunk in resp.aiter_bytes():
yield chunk
return fastapi.responses.StreamingResponse(
proxy_stream(), media_type="text/event-stream", headers=_sse_headers()
)
else:
async with httpx.AsyncClient(timeout=None) as client:
r = await client.post(url, headers=headers, json=body)
try:
payload = r.json()
except Exception:
payload = {"status_code": r.status_code, "text": r.text}
if scrub_think:
try:
payload = _recursively_scrub(payload)
except Exception:
pass
return fastapi.responses.JSONResponse(
status_code=r.status_code, content=payload
)
# ---------- Endpoints ----------
@app.get("/v1/models")
async def list_models():
now = int(time.time())
return {
"object": "list",
"data": [
{
"id": "code-writing-agent",
"object": "model",
"created": now,
"owned_by": "you",
},
{
"id": AGENT_MODEL,
"object": "model",
"created": now,
"owned_by": "upstream",
},
{
"id": AGENT_MODEL + "-nothink",
"object": "model",
"created": now,
"owned_by": "upstream",
},
],
}
@app.post("/v1/chat/completions")
async def chat_completions(req: fastapi.Request):
try:
body: ChatCompletionRequest = typing.cast(
ChatCompletionRequest, await req.json()
)
except Exception as e:
return fastapi.responses.JSONResponse(
{"error": {"message": f"Invalid JSON: {e}"}}, status_code=400
)
messages = body.get("messages") or []
stream = bool(body.get("stream", False))
raw_model = body.get("model")
model_name = (
raw_model.get("id")
if isinstance(raw_model, dict)
else (raw_model or "code-writing-agent")
)
# Pure pass-through if the user selects the upstream model id
if model_name == AGENT_MODEL:
return await _proxy_upstream_chat_completions(dict(body), stream)
if model_name == AGENT_MODEL + "-nothink":
# Remove "-nothink" from the model name in body
body["model"] = AGENT_MODEL
# Add /nothink to the end of the message contents to disable think tags
new_messages = []
for msg in messages:
if msg.get("role") == "user":
content = normalize_content_to_text(msg.get("content", ""))
content += "\n/nothink"
new_msg: ChatMessage = {
"role": "user",
"content": content,
}
new_messages.append(new_msg)
else:
new_messages.append(msg)
body["messages"] = new_messages
return await _proxy_upstream_chat_completions(
dict(body), stream, scrub_think=True
)
# Otherwise, reasoning-aware wrapper
task = _messages_to_task(messages)
# Per-request agent override if a custom model id was provided (different from defaults)
agent_for_request = None
if model_name not in (
"code-writing-agent",
AGENT_MODEL,
AGENT_MODEL + "-nothink",
) and isinstance(model_name, str):
try:
agent_for_request = create_code_writing_agent()
except Exception:
log.exception(
"Failed to construct agent for model '%s'; using default", model_name
)
agent_for_request = None
try:
if stream:
async def sse_streamer():
base = {
"id": f"chatcmpl-smol-{int(time.time())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(base)}\n\n"
reasoning_idx = 0
final_candidate: typing.Optional[str] = None
async for item in run_agent_stream(task, agent_for_request):
# Error short-circuit
if isinstance(item, dict) and "__error__" in item:
error_chunk = {
**base,
"choices": [
{"index": 0, "delta": {}, "finish_reason": "error"}
],
}
yield f"data: {json.dumps(error_chunk)}\n\n"
yield f"data: {json.dumps({'error': item['__error__']})}\n\n"
break
# Explicit final result from the agent
if isinstance(item, dict) and "__final__" in item:
val = item["__final__"]
cand = _extract_final_text(val)
# Only update if the agent actually provided a non-empty answer
if cand and cand.strip().lower() != "none":
final_candidate = cand
# do not emit anything yet; we'll send a single final chunk below
continue
# Live stdout -> reasoning_content
if (
isinstance(item, dict)
and "__stdout__" in item
and isinstance(item["__stdout__"], str)
):
for line in item["__stdout__"].splitlines():
parsed = _maybe_parse_final_from_stdout(line)
if parsed:
final_candidate = parsed
rt = _format_reasoning_chunk(
line, "stdout", reasoning_idx := reasoning_idx + 1
)
if rt:
r_chunk = {
**base,
"choices": [
{"index": 0, "delta": {"reasoning_content": rt}}
],
}
yield f"data: {json.dumps(r_chunk, ensure_ascii=False)}\n\n"
continue
# Newly observed step -> reasoning_content
if (
isinstance(item, dict)
and "__step__" in item
and isinstance(item["__step__"], str)
):
for line in item["__step__"].splitlines():
parsed = _maybe_parse_final_from_stdout(line)
if parsed:
final_candidate = parsed
rt = _format_reasoning_chunk(
line, "step", reasoning_idx := reasoning_idx + 1
)
if rt:
r_chunk = {
**base,
"choices": [
{"index": 0, "delta": {"reasoning_content": rt}}
],
}
yield f"data: {json.dumps(r_chunk, ensure_ascii=False)}\n\n"
continue
# Any iterable output from the agent (rare) — treat as candidate answer
cand = _extract_final_text(item)
if cand:
final_candidate = cand
await asyncio.sleep(0) # keep the loop fair
# Emit the visible answer once at the end (scrub any stray tags)
visible = scrub_think_tags(final_candidate or "")
if not visible or visible.strip().lower() == "none":
visible = "Done."
final_chunk = {
**base,
"choices": [{"index": 0, "delta": {"content": visible}}],
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
stop_chunk = {
**base,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(stop_chunk)}\n\n"
yield "data: [DONE]\n\n"
return fastapi.responses.StreamingResponse(
sse_streamer(), media_type="text/event-stream", headers=_sse_headers()
)
else:
# Non-streaming: collect into … + final
reasoning_lines: typing.List[str] = []
final_candidate: typing.Optional[str] = None
async for item in run_agent_stream(task, agent_for_request):
if isinstance(item, dict) and "__error__" in item:
raise Exception(item["__error__"])
if isinstance(item, dict) and "__final__" in item:
val = item["__final__"]
cand = _extract_final_text(val)
if cand and cand.strip().lower() != "none":
final_candidate = cand
continue
if isinstance(item, dict) and "__stdout__" in item:
lines = (
scrub_think_tags(item["__stdout__"]).rstrip("\n").splitlines()
)
for line in lines:
parsed = _maybe_parse_final_from_stdout(line)
if parsed:
final_candidate = parsed
rt = _format_reasoning_chunk(
line, "stdout", len(reasoning_lines) + 1
)
if rt:
reasoning_lines.append(rt)
continue
if isinstance(item, dict) and "__step__" in item:
lines = scrub_think_tags(item["__step__"]).rstrip("\n").splitlines()
for line in lines:
parsed = _maybe_parse_final_from_stdout(line)
if parsed:
final_candidate = parsed
rt = _format_reasoning_chunk(
line, "step", len(reasoning_lines) + 1
)
if rt:
reasoning_lines.append(rt)
continue
cand = _extract_final_text(item)
if cand:
final_candidate = cand
reasoning_blob = "\n".join(reasoning_lines).strip()
if len(reasoning_blob) > 24000:
reasoning_blob = reasoning_blob[:24000] + "\n… [truncated]"
think_block = (
f"\n{reasoning_blob}\n\n" if reasoning_blob else ""
)
final_text = scrub_think_tags(final_candidate or "")
if not final_text or final_text.strip().lower() == "none":
final_text = "Done."
result_text = f"{think_block}{final_text}"
except Exception as e:
msg = str(e)
status = 503 if "503" in msg or "Service Unavailable" in msg else 500
log.error("Agent error (%s): %s", status, msg)
return fastapi.responses.JSONResponse(
status_code=status,
content={
"error": {"message": f"Agent error: {msg}", "type": "agent_error"}
},
)
# Non-streaming response
if result_text is None:
result_text = ""
if not isinstance(result_text, str):
try:
result_text = json.dumps(result_text, ensure_ascii=False)
except Exception:
result_text = str(result_text)
return fastapi.responses.JSONResponse(_openai_response(result_text, model_name))
# Optional: local run
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")), reload=False
)