ccm's picture
Cleaning up imports
b0394f8
raw
history blame
9 kB
import asyncio
import contextlib
import os
import threading
import time
import typing
import fastapi
import httpx
from agent_server.helpers import sse_headers
from agent_server.sanitizing_think_tags import scrub_think_tags
from agent_server.std_tee import QueueWriter, _serialize_step
async def run_agent_stream(task: str, agent_obj: typing.Optional[typing.Any] = None):
"""
Start the agent in a worker thread.
Stream THREE sources of incremental data into the async generator:
(1) live stdout/stderr lines,
(2) newly appended memory steps (polled),
(3) any iterable the agent may yield (if supported).
Finally emit a __final__ item with the last answer.
"""
loop = asyncio.get_running_loop()
q: asyncio.Queue = asyncio.Queue()
agent_to_use = agent_obj
stop_evt = threading.Event()
# 1) stdout/stderr live tee
qwriter = QueueWriter(q)
# 2) memory poller
def poll_memory():
last_len = 0
while not stop_evt.is_set():
try:
steps = []
try:
# Common API: agent.memory.get_full_steps()
steps = agent_to_use.memory.get_full_steps() # type: ignore[attr-defined]
except Exception:
# Fallbacks: different names across versions
steps = (
getattr(agent_to_use, "steps", [])
or getattr(agent_to_use, "memory", [])
or []
)
if steps is None:
steps = []
curr_len = len(steps)
if curr_len > last_len:
new = steps[last_len:curr_len]
last_len = curr_len
for s in new:
s_text = _serialize_step(s)
if s_text:
try:
q.put_nowait({"__step__": s_text})
except Exception:
pass
except Exception:
pass
time.sleep(0.10) # 100 ms cadence
# 3) agent runner (may or may not yield)
def run_agent():
final_result = None
try:
with contextlib.redirect_stdout(qwriter), contextlib.redirect_stderr(
qwriter
):
used_iterable = False
if hasattr(agent_to_use, "run") and callable(
getattr(agent_to_use, "run")
):
try:
res = agent_to_use.run(task, stream=True)
if hasattr(res, "__iter__") and not isinstance(
res, (str, bytes)
):
used_iterable = True
for it in res:
try:
q.put_nowait(it)
except Exception:
pass
final_result = (
None # iterable may already contain the answer
)
else:
final_result = res
except TypeError:
# run(stream=True) not supported -> fall back
pass
if final_result is None and not used_iterable:
# Try other common streaming signatures
for name in (
"run_stream",
"stream",
"stream_run",
"run_with_callback",
):
fn = getattr(agent_to_use, name, None)
if callable(fn):
try:
res = fn(task)
if hasattr(res, "__iter__") and not isinstance(
res, (str, bytes)
):
for it in res:
q.put_nowait(it)
final_result = None
else:
final_result = res
break
except TypeError:
# maybe callback signature
def cb(item):
try:
q.put_nowait(item)
except Exception:
pass
try:
fn(task, cb)
final_result = None
break
except Exception:
continue
if final_result is None and not used_iterable:
pass # (typo guard removed below)
if final_result is None and not used_iterable:
# Last resort: synchronous run()/generate()/callable
if hasattr(agent_to_use, "run") and callable(
getattr(agent_to_use, "run")
):
final_result = agent_to_use.run(task)
elif hasattr(agent_to_use, "generate") and callable(
getattr(agent_to_use, "generate")
):
final_result = agent_to_use.generate(task)
elif callable(agent_to_use):
final_result = agent_to_use(task)
except Exception as e:
try:
qwriter.flush()
except Exception:
pass
try:
q.put_nowait({"__error__": str(e)})
except Exception:
pass
finally:
try:
qwriter.flush()
except Exception:
pass
try:
q.put_nowait({"__final__": final_result})
except Exception:
pass
stop_evt.set()
# Kick off threads
mem_thread = threading.Thread(target=poll_memory, daemon=True)
run_thread = threading.Thread(target=run_agent, daemon=True)
mem_thread.start()
run_thread.start()
# Async consumer
while True:
item = await q.get()
yield item
if isinstance(item, dict) and "__final__" in item:
break
def _recursively_scrub(obj):
if isinstance(obj, str):
return scrub_think_tags(obj)
if isinstance(obj, dict):
return {k: _recursively_scrub(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_recursively_scrub(v) for v in obj]
return obj
async def proxy_upstream_chat_completions(
body: dict, stream: bool, scrub_think: bool = False
):
HF_TOKEN = os.getenv("OPENAI_API_KEY")
headers = {
"Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else "",
"Content-Type": "application/json",
}
UPSTREAM_BASE = os.getenv("UPSTREAM_OPENAI_BASE", "").rstrip("/")
url = f"{UPSTREAM_BASE}/chat/completions"
if stream:
async def proxy_stream():
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST", url, headers=headers, json=body
) as resp:
resp.raise_for_status()
if scrub_think:
# Pull text segments, scrub tags, and yield bytes
async for txt in resp.aiter_text():
try:
cleaned = scrub_think_tags(txt)
yield cleaned.encode("utf-8")
except Exception:
yield txt.encode("utf-8")
else:
async for chunk in resp.aiter_bytes():
yield chunk
return fastapi.responses.StreamingResponse(
proxy_stream(), media_type="text/event-stream", headers=sse_headers()
)
else:
async with httpx.AsyncClient(timeout=None) as client:
r = await client.post(url, headers=headers, json=body)
try:
payload = r.json()
except Exception:
payload = {"status_code": r.status_code, "text": r.text}
if scrub_think:
try:
payload = _recursively_scrub(payload)
except Exception:
pass
return fastapi.responses.JSONResponse(
status_code=r.status_code, content=payload
)