Spaces:
Sleeping
Sleeping
| import asyncio | |
| import contextlib | |
| import os | |
| import threading | |
| import time | |
| import typing | |
| import fastapi | |
| import httpx | |
| from agent_server.helpers import sse_headers | |
| from agent_server.sanitizing_think_tags import scrub_think_tags | |
| from agent_server.std_tee import QueueWriter, _serialize_step | |
| async def run_agent_stream(task: str, agent_obj: typing.Optional[typing.Any] = None): | |
| """ | |
| Start the agent in a worker thread. | |
| Stream THREE sources of incremental data into the async generator: | |
| (1) live stdout/stderr lines, | |
| (2) newly appended memory steps (polled), | |
| (3) any iterable the agent may yield (if supported). | |
| Finally emit a __final__ item with the last answer. | |
| """ | |
| loop = asyncio.get_running_loop() | |
| q: asyncio.Queue = asyncio.Queue() | |
| agent_to_use = agent_obj | |
| stop_evt = threading.Event() | |
| # 1) stdout/stderr live tee | |
| qwriter = QueueWriter(q) | |
| # 2) memory poller | |
| def poll_memory(): | |
| last_len = 0 | |
| while not stop_evt.is_set(): | |
| try: | |
| steps = [] | |
| try: | |
| # Common API: agent.memory.get_full_steps() | |
| steps = agent_to_use.memory.get_full_steps() # type: ignore[attr-defined] | |
| except Exception: | |
| # Fallbacks: different names across versions | |
| steps = ( | |
| getattr(agent_to_use, "steps", []) | |
| or getattr(agent_to_use, "memory", []) | |
| or [] | |
| ) | |
| if steps is None: | |
| steps = [] | |
| curr_len = len(steps) | |
| if curr_len > last_len: | |
| new = steps[last_len:curr_len] | |
| last_len = curr_len | |
| for s in new: | |
| s_text = _serialize_step(s) | |
| if s_text: | |
| try: | |
| q.put_nowait({"__step__": s_text}) | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| time.sleep(0.10) # 100 ms cadence | |
| # 3) agent runner (may or may not yield) | |
| def run_agent(): | |
| final_result = None | |
| try: | |
| with contextlib.redirect_stdout(qwriter), contextlib.redirect_stderr( | |
| qwriter | |
| ): | |
| used_iterable = False | |
| if hasattr(agent_to_use, "run") and callable( | |
| getattr(agent_to_use, "run") | |
| ): | |
| try: | |
| res = agent_to_use.run(task, stream=True) | |
| if hasattr(res, "__iter__") and not isinstance( | |
| res, (str, bytes) | |
| ): | |
| used_iterable = True | |
| for it in res: | |
| try: | |
| q.put_nowait(it) | |
| except Exception: | |
| pass | |
| final_result = ( | |
| None # iterable may already contain the answer | |
| ) | |
| else: | |
| final_result = res | |
| except TypeError: | |
| # run(stream=True) not supported -> fall back | |
| pass | |
| if final_result is None and not used_iterable: | |
| # Try other common streaming signatures | |
| for name in ( | |
| "run_stream", | |
| "stream", | |
| "stream_run", | |
| "run_with_callback", | |
| ): | |
| fn = getattr(agent_to_use, name, None) | |
| if callable(fn): | |
| try: | |
| res = fn(task) | |
| if hasattr(res, "__iter__") and not isinstance( | |
| res, (str, bytes) | |
| ): | |
| for it in res: | |
| q.put_nowait(it) | |
| final_result = None | |
| else: | |
| final_result = res | |
| break | |
| except TypeError: | |
| # maybe callback signature | |
| def cb(item): | |
| try: | |
| q.put_nowait(item) | |
| except Exception: | |
| pass | |
| try: | |
| fn(task, cb) | |
| final_result = None | |
| break | |
| except Exception: | |
| continue | |
| if final_result is None and not used_iterable: | |
| pass # (typo guard removed below) | |
| if final_result is None and not used_iterable: | |
| # Last resort: synchronous run()/generate()/callable | |
| if hasattr(agent_to_use, "run") and callable( | |
| getattr(agent_to_use, "run") | |
| ): | |
| final_result = agent_to_use.run(task) | |
| elif hasattr(agent_to_use, "generate") and callable( | |
| getattr(agent_to_use, "generate") | |
| ): | |
| final_result = agent_to_use.generate(task) | |
| elif callable(agent_to_use): | |
| final_result = agent_to_use(task) | |
| except Exception as e: | |
| try: | |
| qwriter.flush() | |
| except Exception: | |
| pass | |
| try: | |
| q.put_nowait({"__error__": str(e)}) | |
| except Exception: | |
| pass | |
| finally: | |
| try: | |
| qwriter.flush() | |
| except Exception: | |
| pass | |
| try: | |
| q.put_nowait({"__final__": final_result}) | |
| except Exception: | |
| pass | |
| stop_evt.set() | |
| # Kick off threads | |
| mem_thread = threading.Thread(target=poll_memory, daemon=True) | |
| run_thread = threading.Thread(target=run_agent, daemon=True) | |
| mem_thread.start() | |
| run_thread.start() | |
| # Async consumer | |
| while True: | |
| item = await q.get() | |
| yield item | |
| if isinstance(item, dict) and "__final__" in item: | |
| break | |
| def _recursively_scrub(obj): | |
| if isinstance(obj, str): | |
| return scrub_think_tags(obj) | |
| if isinstance(obj, dict): | |
| return {k: _recursively_scrub(v) for k, v in obj.items()} | |
| if isinstance(obj, list): | |
| return [_recursively_scrub(v) for v in obj] | |
| return obj | |
| async def proxy_upstream_chat_completions( | |
| body: dict, stream: bool, scrub_think: bool = False | |
| ): | |
| HF_TOKEN = os.getenv("OPENAI_API_KEY") | |
| headers = { | |
| "Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else "", | |
| "Content-Type": "application/json", | |
| } | |
| UPSTREAM_BASE = os.getenv("UPSTREAM_OPENAI_BASE", "").rstrip("/") | |
| url = f"{UPSTREAM_BASE}/chat/completions" | |
| if stream: | |
| async def proxy_stream(): | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| async with client.stream( | |
| "POST", url, headers=headers, json=body | |
| ) as resp: | |
| resp.raise_for_status() | |
| if scrub_think: | |
| # Pull text segments, scrub tags, and yield bytes | |
| async for txt in resp.aiter_text(): | |
| try: | |
| cleaned = scrub_think_tags(txt) | |
| yield cleaned.encode("utf-8") | |
| except Exception: | |
| yield txt.encode("utf-8") | |
| else: | |
| async for chunk in resp.aiter_bytes(): | |
| yield chunk | |
| return fastapi.responses.StreamingResponse( | |
| proxy_stream(), media_type="text/event-stream", headers=sse_headers() | |
| ) | |
| else: | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| r = await client.post(url, headers=headers, json=body) | |
| try: | |
| payload = r.json() | |
| except Exception: | |
| payload = {"status_code": r.status_code, "text": r.text} | |
| if scrub_think: | |
| try: | |
| payload = _recursively_scrub(payload) | |
| except Exception: | |
| pass | |
| return fastapi.responses.JSONResponse( | |
| status_code=r.status_code, content=payload | |
| ) | |