""" OpenAI-compatible FastAPI proxy that wraps a smolagents CodeAgent """ import os # For dealing with env vars import re # For tag stripping import json # For JSON handling import time # For timestamps and sleeps import asyncio # For async operations import typing # For type annotations import logging # For logging import threading # For threading operations import fastapi import fastapi.responses import io import contextlib # Upstream pass-through import httpx from agents.code_writing_agent import create_code_writing_agent # Logging setup logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper()) log = logging.getLogger(__name__) # Config from env vars UPSTREAM_BASE = os.getenv("UPSTREAM_OPENAI_BASE", "").rstrip("/") HF_TOKEN = ( os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("API_TOKEN") or "" ) AGENT_MODEL = os.getenv("AGENT_MODEL", "Qwen/Qwen3-1.7B") if not UPSTREAM_BASE: log.warning( "UPSTREAM_OPENAI_BASE is empty; OpenAI-compatible upstream calls will fail." ) if not HF_TOKEN: log.warning("HF_TOKEN is empty; upstream may 401/403 if it requires auth.") # ================== FastAPI ================== app = fastapi.FastAPI() @app.get("/healthz") async def healthz(): return {"ok": True} # ---------- OpenAI-compatible minimal schemas ---------- class ChatMessage(typing.TypedDict, total=False): role: str content: typing.Any # str or multimodal list class ChatCompletionRequest(typing.TypedDict, total=False): model: typing.Optional[str] messages: typing.List[ChatMessage] temperature: typing.Optional[float] stream: typing.Optional[bool] max_tokens: typing.Optional[int] # ---------- Helpers ---------- def normalize_content_to_text(content: typing.Any) -> str: if isinstance(content, str): return content if isinstance(content, (bytes, bytearray)): try: return content.decode("utf-8", errors="ignore") except Exception: return str(content) if isinstance(content, list): parts = [] for item in content: if ( isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str) ): parts.append(item["text"]) else: try: parts.append(json.dumps(item, ensure_ascii=False)) except Exception: parts.append(str(item)) return "\n".join(parts) if isinstance(content, dict): try: return json.dumps(content, ensure_ascii=False) except Exception: return str(content) return str(content) def _messages_to_task(messages: typing.List[ChatMessage]) -> str: system_parts = [ normalize_content_to_text(m.get("content", "")) for m in messages if m.get("role") == "system" ] user_parts = [ normalize_content_to_text(m.get("content", "")) for m in messages if m.get("role") == "user" ] assistant_parts = [ normalize_content_to_text(m.get("content", "")) for m in messages if m.get("role") == "assistant" ] sys_txt = "\n".join([s for s in system_parts if s]).strip() history = "" if assistant_parts: history = "\n\nPrevious assistant replies (for context):\n" + "\n---\n".join( assistant_parts ) last_user = user_parts[-1] if user_parts else "" prefix = ( "You are a very small agent with only a Python REPL tool.\n" "Prefer short, correct answers. If Python is unnecessary, just answer plainly.\n" "If you do use Python, print only final results—no extra logs.\n" ) if sys_txt: prefix = f"{sys_txt}\n\n{prefix}" return f"{prefix}\nTask:\n{last_user}\n{history}".strip() def _openai_response( message_text: str, model_name: str ) -> typing.Dict[str, typing.Any]: now = int(time.time()) return { "id": f"chatcmpl-smol-{now}", "object": "chat.completion", "created": now, "model": model_name, "choices": [ { "index": 0, "message": {"role": "assistant", "content": message_text}, "finish_reason": "stop", } ], "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, } def _sse_headers() -> dict: return { "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", "X-Accel-Buffering": "no", } # ---------- Sanitizer: remove think/thank tags from LLM-originated text ---------- _THINK_TAG_RE = re.compile(r"]*>", flags=re.IGNORECASE) _THANK_TAG_RE = re.compile(r"]*>", flags=re.IGNORECASE) # typo safety _ESC_THINK_TAG_RE = re.compile(r"</?\s*think\b[^&]*>", flags=re.IGNORECASE) _ESC_THANK_TAG_RE = re.compile(r"</?\s*thank\b[^&]*>", flags=re.IGNORECASE) def scrub_think_tags(text: typing.Any) -> str: """ Remove literal and HTML-escaped / (and variants) tags. Content inside the tags is preserved; only the tags are stripped. """ if not isinstance(text, str): try: text = str(text) except Exception: return "" t = _THINK_TAG_RE.sub("", text) t = _THANK_TAG_RE.sub("", t) t = _ESC_THINK_TAG_RE.sub("", t) t = _ESC_THANK_TAG_RE.sub("", t) return t # ---------- Reasoning formatting for Chat-UI ---------- def _format_reasoning_chunk(text: str, tag: str, idx: int) -> str: """ Lightweight formatter for reasoning stream. Avoid huge code fences; make it readable and incremental. Also filters out ASCII/box-drawing noise. """ text = scrub_think_tags(text).rstrip("\n") if not text: return "" noisy_prefixes = ( "OpenAIServerModel", "Output message of the LLM", "─ Executing parsed code", "New run", "╭", "╰", "│", "━", "─", ) stripped = text.strip() if not stripped: return "" # Lines made mostly of box drawing/separators if all(ch in " ─━╭╮╰╯│═·—-_=+•" for ch in stripped): return "" if any(stripped.startswith(p) for p in noisy_prefixes): return "" # Excessively long lines with little signal (no alphanumerics) if len(stripped) > 240 and not re.search(r"[A-Za-z0-9]{3,}", stripped): return "" # No tag/idx prefix; add a trailing blank line for readability in markdown return f"{stripped}\n\n" def _extract_final_text(item: typing.Any) -> typing.Optional[str]: if isinstance(item, dict) and ("__stdout__" in item or "__step__" in item): return None if isinstance(item, (bytes, bytearray)): try: item = item.decode("utf-8", errors="ignore") except Exception: item = str(item) if isinstance(item, str): s = scrub_think_tags(item.strip()) return s or None # If it's a step-like object with an 'output' attribute, use that try: if not isinstance(item, (dict, list, bytes, bytearray)): out = getattr(item, "output", None) if out is not None: s = scrub_think_tags(str(out)).strip() if s: return s except Exception: pass if isinstance(item, dict): for key in ("content", "text", "message", "output", "final", "answer"): if key in item: val = item[key] if isinstance(val, (dict, list)): try: return scrub_think_tags(json.dumps(val, ensure_ascii=False)) except Exception: return scrub_think_tags(str(val)) if isinstance(val, (bytes, bytearray)): try: val = val.decode("utf-8", errors="ignore") except Exception: val = str(val) s = scrub_think_tags(str(val).strip()) return s or None try: return scrub_think_tags(json.dumps(item, ensure_ascii=False)) except Exception: return scrub_think_tags(str(item)) try: return scrub_think_tags(str(item)) except Exception: return None # Helper to parse explicit "Final answer:" from stdout lines _FINAL_RE = re.compile(r"(?:^|\\b)Final\\s+answer:\\s*(.+)$", flags=re.IGNORECASE) def _maybe_parse_final_from_stdout(line: str) -> typing.Optional[str]: if not isinstance(line, str): return None m = _FINAL_RE.search(line.strip()) if not m: return None return scrub_think_tags(m.group(1)).strip() or None # ---------- Live stdout/stderr tee ---------- class QueueWriter(io.TextIOBase): """ File-like object that pushes each write to an asyncio.Queue immediately. """ def __init__(self, q: "asyncio.Queue"): self.q = q self._lock = threading.Lock() self._buf = [] # accumulate until newline to reduce spam def write(self, s: str): if not s: return 0 with self._lock: self._buf.append(s) # flush on newline to keep granularity reasonable if "\n" in s: chunk = "".join(self._buf) self._buf.clear() try: self.q.put_nowait({"__stdout__": chunk}) except Exception: pass return len(s) def flush(self): with self._lock: if self._buf: chunk = "".join(self._buf) self._buf.clear() try: self.q.put_nowait({"__stdout__": chunk}) except Exception: pass def _serialize_step(step) -> str: """ Best-effort pretty string for a smolagents MemoryStep / ActionStep. Works even if attributes are missing on some versions. """ parts = [] sn = getattr(step, "step_number", None) if sn is not None: parts.append(f"Step {sn}") thought_val = getattr(step, "thought", None) if thought_val: parts.append(f"Thought: {scrub_think_tags(str(thought_val))}") tool_val = getattr(step, "tool", None) if tool_val: parts.append(f"Tool: {scrub_think_tags(str(tool_val))}") code_val = getattr(step, "code", None) if code_val: code_str = scrub_think_tags(str(code_val)).strip() parts.append("```python\n" + code_str + "\n```") args = getattr(step, "args", None) if args: try: parts.append( "Args: " + scrub_think_tags(json.dumps(args, ensure_ascii=False)) ) except Exception: parts.append("Args: " + scrub_think_tags(str(args))) error = getattr(step, "error", None) if error: parts.append(f"Error: {scrub_think_tags(str(error))}") obs = getattr(step, "observations", None) if obs is not None: if isinstance(obs, (list, tuple)): obs_str = "\n".join(map(str, obs)) else: obs_str = str(obs) parts.append("Observation:\n" + scrub_think_tags(obs_str).strip()) # If this looks like a FinalAnswer step object, surface a clean final answer try: tname = type(step).__name__ except Exception: tname = "" if tname.lower().startswith("finalanswer"): out = getattr(step, "output", None) if out is not None: return f"Final answer: {scrub_think_tags(str(out)).strip()}" # Fallback: try to parse from string repr "FinalAnswerStep(output=...)" s = scrub_think_tags(str(step)) m = re.search(r"FinalAnswer[^()]*\(\s*output\s*=\s*([^,)]+)", s) if m: return f"Final answer: {m.group(1).strip()}" # If the only content would be an object repr like FinalAnswerStep(...), drop it; # a cleaner "Final answer: ..." will come from the rule above or stdout. joined = "\n".join(parts).strip() if re.match(r"^FinalAnswer[^\n]+\)$", joined): return "" return joined or scrub_think_tags(str(step)) # ---------- Agent streaming bridge (truly live) ---------- async def run_agent_stream(task: str, agent_obj: typing.Optional[typing.Any] = None): """ Start the agent in a worker thread. Stream THREE sources of incremental data into the async generator: (1) live stdout/stderr lines, (2) newly appended memory steps (polled), (3) any iterable the agent may yield (if supported). Finally emit a __final__ item with the last answer. """ loop = asyncio.get_running_loop() q: asyncio.Queue = asyncio.Queue() agent_to_use = agent_obj or create_code_writing_agent stop_evt = threading.Event() # 1) stdout/stderr live tee qwriter = QueueWriter(q) # 2) memory poller def poll_memory(): last_len = 0 while not stop_evt.is_set(): try: steps = [] try: # Common API: agent.memory.get_full_steps() steps = agent_to_use.memory.get_full_steps() # type: ignore[attr-defined] except Exception: # Fallbacks: different names across versions steps = ( getattr(agent_to_use, "steps", []) or getattr(agent_to_use, "memory", []) or [] ) if steps is None: steps = [] curr_len = len(steps) if curr_len > last_len: new = steps[last_len:curr_len] last_len = curr_len for s in new: s_text = _serialize_step(s) if s_text: try: q.put_nowait({"__step__": s_text}) except Exception: pass except Exception: pass time.sleep(0.10) # 100 ms cadence # 3) agent runner (may or may not yield) def run_agent(): final_result = None try: with contextlib.redirect_stdout(qwriter), contextlib.redirect_stderr( qwriter ): used_iterable = False if hasattr(agent_to_use, "run") and callable( getattr(agent_to_use, "run") ): try: res = agent_to_use.run(task, stream=True) if hasattr(res, "__iter__") and not isinstance( res, (str, bytes) ): used_iterable = True for it in res: try: q.put_nowait(it) except Exception: pass final_result = ( None # iterable may already contain the answer ) else: final_result = res except TypeError: # run(stream=True) not supported -> fall back pass if final_result is None and not used_iterable: # Try other common streaming signatures for name in ( "run_stream", "stream", "stream_run", "run_with_callback", ): fn = getattr(agent_to_use, name, None) if callable(fn): try: res = fn(task) if hasattr(res, "__iter__") and not isinstance( res, (str, bytes) ): for it in res: q.put_nowait(it) final_result = None else: final_result = res break except TypeError: # maybe callback signature def cb(item): try: q.put_nowait(item) except Exception: pass try: fn(task, cb) final_result = None break except Exception: continue if final_result is None and not used_iterable: pass # (typo guard removed below) if final_result is None and not used_iterable: # Last resort: synchronous run()/generate()/callable if hasattr(agent_to_use, "run") and callable( getattr(agent_to_use, "run") ): final_result = agent_to_use.run(task) elif hasattr(agent_to_use, "generate") and callable( getattr(agent_to_use, "generate") ): final_result = agent_to_use.generate(task) elif callable(agent_to_use): final_result = agent_to_use(task) except Exception as e: try: qwriter.flush() except Exception: pass try: q.put_nowait({"__error__": str(e)}) except Exception: pass finally: try: qwriter.flush() except Exception: pass try: q.put_nowait({"__final__": final_result}) except Exception: pass stop_evt.set() # Kick off threads mem_thread = threading.Thread(target=poll_memory, daemon=True) run_thread = threading.Thread(target=run_agent, daemon=True) mem_thread.start() run_thread.start() # Async consumer while True: item = await q.get() yield item if isinstance(item, dict) and "__final__" in item: break def _recursively_scrub(obj): if isinstance(obj, str): return scrub_think_tags(obj) if isinstance(obj, dict): return {k: _recursively_scrub(v) for k, v in obj.items()} if isinstance(obj, list): return [_recursively_scrub(v) for v in obj] return obj async def _proxy_upstream_chat_completions( body: dict, stream: bool, scrub_think: bool = False ): if not UPSTREAM_BASE: return fastapi.responses.JSONResponse( {"error": {"message": "UPSTREAM_OPENAI_BASE not configured"}}, status_code=500, ) headers = { "Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else "", "Content-Type": "application/json", } url = f"{UPSTREAM_BASE}/chat/completions" if stream: async def proxy_stream(): async with httpx.AsyncClient(timeout=None) as client: async with client.stream( "POST", url, headers=headers, json=body ) as resp: resp.raise_for_status() if scrub_think: # Pull text segments, scrub tags, and yield bytes async for txt in resp.aiter_text(): try: cleaned = scrub_think_tags(txt) yield cleaned.encode("utf-8") except Exception: yield txt.encode("utf-8") else: async for chunk in resp.aiter_bytes(): yield chunk return fastapi.responses.StreamingResponse( proxy_stream(), media_type="text/event-stream", headers=_sse_headers() ) else: async with httpx.AsyncClient(timeout=None) as client: r = await client.post(url, headers=headers, json=body) try: payload = r.json() except Exception: payload = {"status_code": r.status_code, "text": r.text} if scrub_think: try: payload = _recursively_scrub(payload) except Exception: pass return fastapi.responses.JSONResponse( status_code=r.status_code, content=payload ) # ---------- Endpoints ---------- @app.get("/v1/models") async def list_models(): now = int(time.time()) return { "object": "list", "data": [ { "id": "code-writing-agent", "object": "model", "created": now, "owned_by": "you", }, { "id": AGENT_MODEL, "object": "model", "created": now, "owned_by": "upstream", }, { "id": AGENT_MODEL + "-nothink", "object": "model", "created": now, "owned_by": "upstream", }, ], } @app.post("/v1/chat/completions") async def chat_completions(req: fastapi.Request): try: body: ChatCompletionRequest = typing.cast( ChatCompletionRequest, await req.json() ) except Exception as e: return fastapi.responses.JSONResponse( {"error": {"message": f"Invalid JSON: {e}"}}, status_code=400 ) messages = body.get("messages") or [] stream = bool(body.get("stream", False)) raw_model = body.get("model") model_name = ( raw_model.get("id") if isinstance(raw_model, dict) else (raw_model or "code-writing-agent") ) # Pure pass-through if the user selects the upstream model id if model_name == AGENT_MODEL: return await _proxy_upstream_chat_completions(dict(body), stream) if model_name == AGENT_MODEL + "-nothink": # Remove "-nothink" from the model name in body body["model"] = AGENT_MODEL # Add /nothink to the end of the message contents to disable think tags new_messages = [] for msg in messages: if msg.get("role") == "user": content = normalize_content_to_text(msg.get("content", "")) content += "\n/nothink" new_msg: ChatMessage = { "role": "user", "content": content, } new_messages.append(new_msg) else: new_messages.append(msg) body["messages"] = new_messages return await _proxy_upstream_chat_completions( dict(body), stream, scrub_think=True ) # Otherwise, reasoning-aware wrapper task = _messages_to_task(messages) # Per-request agent override if a custom model id was provided (different from defaults) agent_for_request = None if model_name not in ( "code-writing-agent", AGENT_MODEL, AGENT_MODEL + "-nothink", ) and isinstance(model_name, str): try: agent_for_request = create_code_writing_agent() except Exception: log.exception( "Failed to construct agent for model '%s'; using default", model_name ) agent_for_request = None try: if stream: async def sse_streamer(): base = { "id": f"chatcmpl-smol-{int(time.time())}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model_name, "choices": [ { "index": 0, "delta": {"role": "assistant"}, "finish_reason": None, } ], } yield f"data: {json.dumps(base)}\n\n" reasoning_idx = 0 final_candidate: typing.Optional[str] = None async for item in run_agent_stream(task, agent_for_request): # Error short-circuit if isinstance(item, dict) and "__error__" in item: error_chunk = { **base, "choices": [ {"index": 0, "delta": {}, "finish_reason": "error"} ], } yield f"data: {json.dumps(error_chunk)}\n\n" yield f"data: {json.dumps({'error': item['__error__']})}\n\n" break # Explicit final result from the agent if isinstance(item, dict) and "__final__" in item: val = item["__final__"] cand = _extract_final_text(val) # Only update if the agent actually provided a non-empty answer if cand and cand.strip().lower() != "none": final_candidate = cand # do not emit anything yet; we'll send a single final chunk below continue # Live stdout -> reasoning_content if ( isinstance(item, dict) and "__stdout__" in item and isinstance(item["__stdout__"], str) ): for line in item["__stdout__"].splitlines(): parsed = _maybe_parse_final_from_stdout(line) if parsed: final_candidate = parsed rt = _format_reasoning_chunk( line, "stdout", reasoning_idx := reasoning_idx + 1 ) if rt: r_chunk = { **base, "choices": [ {"index": 0, "delta": {"reasoning_content": rt}} ], } yield f"data: {json.dumps(r_chunk, ensure_ascii=False)}\n\n" continue # Newly observed step -> reasoning_content if ( isinstance(item, dict) and "__step__" in item and isinstance(item["__step__"], str) ): for line in item["__step__"].splitlines(): parsed = _maybe_parse_final_from_stdout(line) if parsed: final_candidate = parsed rt = _format_reasoning_chunk( line, "step", reasoning_idx := reasoning_idx + 1 ) if rt: r_chunk = { **base, "choices": [ {"index": 0, "delta": {"reasoning_content": rt}} ], } yield f"data: {json.dumps(r_chunk, ensure_ascii=False)}\n\n" continue # Any iterable output from the agent (rare) — treat as candidate answer cand = _extract_final_text(item) if cand: final_candidate = cand await asyncio.sleep(0) # keep the loop fair # Emit the visible answer once at the end (scrub any stray tags) visible = scrub_think_tags(final_candidate or "") if not visible or visible.strip().lower() == "none": visible = "Done." final_chunk = { **base, "choices": [{"index": 0, "delta": {"content": visible}}], } yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" stop_chunk = { **base, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], } yield f"data: {json.dumps(stop_chunk)}\n\n" yield "data: [DONE]\n\n" return fastapi.responses.StreamingResponse( sse_streamer(), media_type="text/event-stream", headers=_sse_headers() ) else: # Non-streaming: collect into + final reasoning_lines: typing.List[str] = [] final_candidate: typing.Optional[str] = None async for item in run_agent_stream(task, agent_for_request): if isinstance(item, dict) and "__error__" in item: raise Exception(item["__error__"]) if isinstance(item, dict) and "__final__" in item: val = item["__final__"] cand = _extract_final_text(val) if cand and cand.strip().lower() != "none": final_candidate = cand continue if isinstance(item, dict) and "__stdout__" in item: lines = ( scrub_think_tags(item["__stdout__"]).rstrip("\n").splitlines() ) for line in lines: parsed = _maybe_parse_final_from_stdout(line) if parsed: final_candidate = parsed rt = _format_reasoning_chunk( line, "stdout", len(reasoning_lines) + 1 ) if rt: reasoning_lines.append(rt) continue if isinstance(item, dict) and "__step__" in item: lines = scrub_think_tags(item["__step__"]).rstrip("\n").splitlines() for line in lines: parsed = _maybe_parse_final_from_stdout(line) if parsed: final_candidate = parsed rt = _format_reasoning_chunk( line, "step", len(reasoning_lines) + 1 ) if rt: reasoning_lines.append(rt) continue cand = _extract_final_text(item) if cand: final_candidate = cand reasoning_blob = "\n".join(reasoning_lines).strip() if len(reasoning_blob) > 24000: reasoning_blob = reasoning_blob[:24000] + "\n… [truncated]" think_block = ( f"\n{reasoning_blob}\n\n" if reasoning_blob else "" ) final_text = scrub_think_tags(final_candidate or "") if not final_text or final_text.strip().lower() == "none": final_text = "Done." result_text = f"{think_block}{final_text}" except Exception as e: msg = str(e) status = 503 if "503" in msg or "Service Unavailable" in msg else 500 log.error("Agent error (%s): %s", status, msg) return fastapi.responses.JSONResponse( status_code=status, content={ "error": {"message": f"Agent error: {msg}", "type": "agent_error"} }, ) # Non-streaming response if result_text is None: result_text = "" if not isinstance(result_text, str): try: result_text = json.dumps(result_text, ensure_ascii=False) except Exception: result_text = str(result_text) return fastapi.responses.JSONResponse(_openai_response(result_text, model_name)) # Optional: local run if __name__ == "__main__": import uvicorn uvicorn.run( "app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")), reload=False )