File size: 9,004 Bytes
2fabb0d
 
b0394f8
2fabb0d
7e34bee
 
2fabb0d
 
 
 
b0394f8
2fabb0d
7e34bee
2fabb0d
a2b9043
7e34bee
 
 
 
 
 
 
 
2fabb0d
7e34bee
2fabb0d
 
cdca445
7e34bee
 
 
2fabb0d
 
7e34bee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fabb0d
 
 
7e34bee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fabb0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e34bee
2fabb0d
7e34bee
 
 
 
2fabb0d
 
7e34bee
2fabb0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e34bee
 
 
b0394f8
2fabb0d
 
 
 
 
 
 
 
 
 
 
7e34bee
 
 
2fabb0d
 
7e34bee
2fabb0d
 
 
 
 
 
 
 
 
 
 
b0394f8
2fabb0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e34bee
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import asyncio
import contextlib
import os
import threading
import time
import typing

import fastapi
import httpx

from agent_server.helpers import sse_headers
from agent_server.sanitizing_think_tags import scrub_think_tags
from agent_server.std_tee import QueueWriter, _serialize_step


async def run_agent_stream(task: str, agent_obj: typing.Optional[typing.Any] = None):
    """
    Start the agent in a worker thread.
    Stream THREE sources of incremental data into the async generator:
      (1) live stdout/stderr lines,
      (2) newly appended memory steps (polled),
      (3) any iterable the agent may yield (if supported).
    Finally emit a __final__ item with the last answer.
    """
    loop = asyncio.get_running_loop()
    q: asyncio.Queue = asyncio.Queue()
    agent_to_use = agent_obj

    stop_evt = threading.Event()

    # 1) stdout/stderr live tee
    qwriter = QueueWriter(q)

    # 2) memory poller
    def poll_memory():
        last_len = 0
        while not stop_evt.is_set():
            try:
                steps = []
                try:
                    # Common API: agent.memory.get_full_steps()
                    steps = agent_to_use.memory.get_full_steps()  # type: ignore[attr-defined]
                except Exception:
                    # Fallbacks: different names across versions
                    steps = (
                        getattr(agent_to_use, "steps", [])
                        or getattr(agent_to_use, "memory", [])
                        or []
                    )
                if steps is None:
                    steps = []
                curr_len = len(steps)
                if curr_len > last_len:
                    new = steps[last_len:curr_len]
                    last_len = curr_len
                    for s in new:
                        s_text = _serialize_step(s)
                        if s_text:
                            try:
                                q.put_nowait({"__step__": s_text})
                            except Exception:
                                pass
            except Exception:
                pass
            time.sleep(0.10)  # 100 ms cadence

    # 3) agent runner (may or may not yield)
    def run_agent():
        final_result = None
        try:
            with contextlib.redirect_stdout(qwriter), contextlib.redirect_stderr(
                qwriter
            ):
                used_iterable = False
                if hasattr(agent_to_use, "run") and callable(
                    getattr(agent_to_use, "run")
                ):
                    try:
                        res = agent_to_use.run(task, stream=True)
                        if hasattr(res, "__iter__") and not isinstance(
                            res, (str, bytes)
                        ):
                            used_iterable = True
                            for it in res:
                                try:
                                    q.put_nowait(it)
                                except Exception:
                                    pass
                            final_result = (
                                None  # iterable may already contain the answer
                            )
                        else:
                            final_result = res
                    except TypeError:
                        # run(stream=True) not supported -> fall back
                        pass

                if final_result is None and not used_iterable:
                    # Try other common streaming signatures
                    for name in (
                        "run_stream",
                        "stream",
                        "stream_run",
                        "run_with_callback",
                    ):
                        fn = getattr(agent_to_use, name, None)
                        if callable(fn):
                            try:
                                res = fn(task)
                                if hasattr(res, "__iter__") and not isinstance(
                                    res, (str, bytes)
                                ):
                                    for it in res:
                                        q.put_nowait(it)
                                    final_result = None
                                else:
                                    final_result = res
                                break
                            except TypeError:
                                # maybe callback signature
                                def cb(item):
                                    try:
                                        q.put_nowait(item)
                                    except Exception:
                                        pass

                                try:
                                    fn(task, cb)
                                    final_result = None
                                    break
                                except Exception:
                                    continue

                if final_result is None and not used_iterable:
                    pass  # (typo guard removed below)

                if final_result is None and not used_iterable:
                    # Last resort: synchronous run()/generate()/callable
                    if hasattr(agent_to_use, "run") and callable(
                        getattr(agent_to_use, "run")
                    ):
                        final_result = agent_to_use.run(task)
                    elif hasattr(agent_to_use, "generate") and callable(
                        getattr(agent_to_use, "generate")
                    ):
                        final_result = agent_to_use.generate(task)
                    elif callable(agent_to_use):
                        final_result = agent_to_use(task)

        except Exception as e:
            try:
                qwriter.flush()
            except Exception:
                pass
            try:
                q.put_nowait({"__error__": str(e)})
            except Exception:
                pass
        finally:
            try:
                qwriter.flush()
            except Exception:
                pass
            try:
                q.put_nowait({"__final__": final_result})
            except Exception:
                pass
            stop_evt.set()

    # Kick off threads
    mem_thread = threading.Thread(target=poll_memory, daemon=True)
    run_thread = threading.Thread(target=run_agent, daemon=True)
    mem_thread.start()
    run_thread.start()

    # Async consumer
    while True:
        item = await q.get()
        yield item
        if isinstance(item, dict) and "__final__" in item:
            break


def _recursively_scrub(obj):
    if isinstance(obj, str):
        return scrub_think_tags(obj)
    if isinstance(obj, dict):
        return {k: _recursively_scrub(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [_recursively_scrub(v) for v in obj]
    return obj


async def proxy_upstream_chat_completions(
    body: dict, stream: bool, scrub_think: bool = False
):
    HF_TOKEN = os.getenv("OPENAI_API_KEY")
    headers = {
        "Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else "",
        "Content-Type": "application/json",
    }
    UPSTREAM_BASE = os.getenv("UPSTREAM_OPENAI_BASE", "").rstrip("/")
    url = f"{UPSTREAM_BASE}/chat/completions"

    if stream:

        async def proxy_stream():
            async with httpx.AsyncClient(timeout=None) as client:
                async with client.stream(
                    "POST", url, headers=headers, json=body
                ) as resp:
                    resp.raise_for_status()
                    if scrub_think:
                        # Pull text segments, scrub tags, and yield bytes
                        async for txt in resp.aiter_text():
                            try:
                                cleaned = scrub_think_tags(txt)
                                yield cleaned.encode("utf-8")
                            except Exception:
                                yield txt.encode("utf-8")
                    else:
                        async for chunk in resp.aiter_bytes():
                            yield chunk

        return fastapi.responses.StreamingResponse(
            proxy_stream(), media_type="text/event-stream", headers=sse_headers()
        )
    else:
        async with httpx.AsyncClient(timeout=None) as client:
            r = await client.post(url, headers=headers, json=body)
            try:
                payload = r.json()
            except Exception:
                payload = {"status_code": r.status_code, "text": r.text}

            if scrub_think:
                try:
                    payload = _recursively_scrub(payload)
                except Exception:
                    pass

            return fastapi.responses.JSONResponse(
                status_code=r.status_code, content=payload
            )