import asyncio import contextlib import os import threading import time import typing import fastapi import httpx from agent_server.helpers import sse_headers from agent_server.sanitizing_think_tags import scrub_think_tags from agent_server.std_tee import QueueWriter, _serialize_step async def run_agent_stream(task: str, agent_obj: typing.Optional[typing.Any] = None): """ Start the agent in a worker thread. Stream THREE sources of incremental data into the async generator: (1) live stdout/stderr lines, (2) newly appended memory steps (polled), (3) any iterable the agent may yield (if supported). Finally emit a __final__ item with the last answer. """ loop = asyncio.get_running_loop() q: asyncio.Queue = asyncio.Queue() agent_to_use = agent_obj stop_evt = threading.Event() # 1) stdout/stderr live tee qwriter = QueueWriter(q) # 2) memory poller def poll_memory(): last_len = 0 while not stop_evt.is_set(): try: steps = [] try: # Common API: agent.memory.get_full_steps() steps = agent_to_use.memory.get_full_steps() # type: ignore[attr-defined] except Exception: # Fallbacks: different names across versions steps = ( getattr(agent_to_use, "steps", []) or getattr(agent_to_use, "memory", []) or [] ) if steps is None: steps = [] curr_len = len(steps) if curr_len > last_len: new = steps[last_len:curr_len] last_len = curr_len for s in new: s_text = _serialize_step(s) if s_text: try: q.put_nowait({"__step__": s_text}) except Exception: pass except Exception: pass time.sleep(0.10) # 100 ms cadence # 3) agent runner (may or may not yield) def run_agent(): final_result = None try: with contextlib.redirect_stdout(qwriter), contextlib.redirect_stderr( qwriter ): used_iterable = False if hasattr(agent_to_use, "run") and callable( getattr(agent_to_use, "run") ): try: res = agent_to_use.run(task, stream=True) if hasattr(res, "__iter__") and not isinstance( res, (str, bytes) ): used_iterable = True for it in res: try: q.put_nowait(it) except Exception: pass final_result = ( None # iterable may already contain the answer ) else: final_result = res except TypeError: # run(stream=True) not supported -> fall back pass if final_result is None and not used_iterable: # Try other common streaming signatures for name in ( "run_stream", "stream", "stream_run", "run_with_callback", ): fn = getattr(agent_to_use, name, None) if callable(fn): try: res = fn(task) if hasattr(res, "__iter__") and not isinstance( res, (str, bytes) ): for it in res: q.put_nowait(it) final_result = None else: final_result = res break except TypeError: # maybe callback signature def cb(item): try: q.put_nowait(item) except Exception: pass try: fn(task, cb) final_result = None break except Exception: continue if final_result is None and not used_iterable: pass # (typo guard removed below) if final_result is None and not used_iterable: # Last resort: synchronous run()/generate()/callable if hasattr(agent_to_use, "run") and callable( getattr(agent_to_use, "run") ): final_result = agent_to_use.run(task) elif hasattr(agent_to_use, "generate") and callable( getattr(agent_to_use, "generate") ): final_result = agent_to_use.generate(task) elif callable(agent_to_use): final_result = agent_to_use(task) except Exception as e: try: qwriter.flush() except Exception: pass try: q.put_nowait({"__error__": str(e)}) except Exception: pass finally: try: qwriter.flush() except Exception: pass try: q.put_nowait({"__final__": final_result}) except Exception: pass stop_evt.set() # Kick off threads mem_thread = threading.Thread(target=poll_memory, daemon=True) run_thread = threading.Thread(target=run_agent, daemon=True) mem_thread.start() run_thread.start() # Async consumer while True: item = await q.get() yield item if isinstance(item, dict) and "__final__" in item: break def _recursively_scrub(obj): if isinstance(obj, str): return scrub_think_tags(obj) if isinstance(obj, dict): return {k: _recursively_scrub(v) for k, v in obj.items()} if isinstance(obj, list): return [_recursively_scrub(v) for v in obj] return obj async def proxy_upstream_chat_completions( body: dict, stream: bool, scrub_think: bool = False ): HF_TOKEN = os.getenv("OPENAI_API_KEY") headers = { "Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else "", "Content-Type": "application/json", } UPSTREAM_BASE = os.getenv("UPSTREAM_OPENAI_BASE", "").rstrip("/") url = f"{UPSTREAM_BASE}/chat/completions" if stream: async def proxy_stream(): async with httpx.AsyncClient(timeout=None) as client: async with client.stream( "POST", url, headers=headers, json=body ) as resp: resp.raise_for_status() if scrub_think: # Pull text segments, scrub tags, and yield bytes async for txt in resp.aiter_text(): try: cleaned = scrub_think_tags(txt) yield cleaned.encode("utf-8") except Exception: yield txt.encode("utf-8") else: async for chunk in resp.aiter_bytes(): yield chunk return fastapi.responses.StreamingResponse( proxy_stream(), media_type="text/event-stream", headers=sse_headers() ) else: async with httpx.AsyncClient(timeout=None) as client: r = await client.post(url, headers=headers, json=body) try: payload = r.json() except Exception: payload = {"status_code": r.status_code, "text": r.text} if scrub_think: try: payload = _recursively_scrub(payload) except Exception: pass return fastapi.responses.JSONResponse( status_code=r.status_code, content=payload )