mtyrrell commited on
Commit
d357a83
Β·
1 Parent(s): 944dc09

refactored approach

Browse files
Files changed (1) hide show
  1. app/main.py +128 -84
app/main.py CHANGED
@@ -1,6 +1,3 @@
1
- """
2
- Complete ChatFed Orchestrator with flexible input handling for ChatUI compatibility
3
- """
4
  from fastapi import FastAPI, Request
5
  from fastapi.responses import StreamingResponse
6
  import json, uuid
@@ -11,6 +8,7 @@ from typing import List, Literal, Optional, Dict, Any
11
  import gradio as gr
12
  from datetime import datetime
13
  import logging
 
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
@@ -131,6 +129,43 @@ def extract_user_query_fallback(data: Any) -> str:
131
  logger.error(f"Error extracting query from {data}: {e}")
132
  return ""
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  # ─────────────────────────────────────────
135
  # Handlers - Multiple for Different Use Cases
136
  # ─────────────────────────────────────────
@@ -329,8 +364,6 @@ def create_gradio_interface():
329
 
330
  return demo
331
 
332
-
333
-
334
  # ─────────────────────────────────────────
335
  # App Startup
336
  # ─────────────────────────────────────────
@@ -341,30 +374,30 @@ app = FastAPI(
341
  version="1.0.0"
342
  )
343
 
344
- # # Add request logging middleware for debugging
345
- # @app.middleware("http")
346
- # async def log_requests(request: Request, call_next):
347
- # """Log incoming requests for debugging"""
348
- # if request.url.path.startswith("/chatfed") or request.url.path.startswith("/debug"):
349
- # try:
350
- # body = await request.body()
351
- # logger.info(f"=== REQUEST DEBUG ===")
352
- # logger.info(f"Path: {request.url.path}")
353
- # logger.info(f"Method: {request.method}")
354
- # logger.info(f"Headers: {dict(request.headers)}")
355
- # logger.info(f"Body: {body.decode('utf-8') if body else 'Empty'}")
356
 
357
- # # Recreate request for next handler
358
- # async def receive():
359
- # return {"type": "http.request", "body": body}
360
 
361
- # request._receive = receive
362
 
363
- # except Exception as e:
364
- # logger.error(f"Error logging request: {e}")
365
 
366
- # response = await call_next(request)
367
- # return response
368
 
369
  # ─────────────────────────────────────────
370
  # LangServe Routes - Flexible input handling
@@ -375,7 +408,6 @@ add_routes(
375
  app,
376
  RunnableLambda(flexible_handler),
377
  path="/chatfed",
378
- # Remove strict input type to allow both dicts and Pydantic models
379
  output_type=ChatFedOutput
380
  )
381
 
@@ -384,7 +416,6 @@ add_routes(
384
  app,
385
  RunnableLambda(chatui_handler),
386
  path="/chatfed-chatui",
387
- # Remove strict input type to allow both dicts and Pydantic models
388
  output_type=ChatFedOutput
389
  )
390
 
@@ -393,16 +424,78 @@ add_routes(
393
  app,
394
  RunnableLambda(legacy_langserve_handler),
395
  path="/chatfed-strict",
396
- # Remove strict input type to allow both dicts and Pydantic models
397
  output_type=ChatFedOutput
398
  )
399
 
400
- # ChatUI-compatible streaming route (yields tokens via SSE)
401
- add_routes(
402
- app,
403
- RunnableLambda(flexible_handler),
404
- path="/chatfed-ui-stream"
405
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
  # ─────────────────────────────────────────
408
  # Additional Endpoints
@@ -423,6 +516,7 @@ async def root():
423
  "primary": "/chatfed (flexible input - use this for ChatUI)",
424
  "chatui": "/chatfed-chatui",
425
  "legacy": "/chatfed-strict (requires 'query' field)",
 
426
  "openai": "/v1/chat/completions",
427
  "simple": "/simple-chat",
428
  "gradio_ui": "/ui",
@@ -546,56 +640,6 @@ async def debug_input_endpoint(request: Request):
546
  "error": str(e),
547
  "raw_body": raw_body.decode('utf-8') if 'raw_body' in locals() else "failed_to_read"
548
  }
549
-
550
- # @app.post("/chatfed-ui-stream/stream")
551
- # async def ui_stream_sse(request: Request) -> StreamingResponse:
552
- # """
553
- # Stream-friendly handler for ChatUI's langserve integration.
554
- # Emits proper SSE: metadata β†’ data (one char/token per event) β†’ end.
555
- # """
556
- # try:
557
- # payload = await request.json()
558
- # # 1) Unwrap ChatUI envelope if needed
559
- # if isinstance(payload.get("input"), dict):
560
- # payload = payload["input"]
561
-
562
- # # 2) Extract query (flexible or fallback)
563
- # try:
564
- # input_data = FlexibleChatInput(**payload)
565
- # query = input_data.extract_query()
566
- # except Exception:
567
- # query = extract_user_query_fallback(payload)
568
-
569
- # if not query.strip():
570
- # # send a single β€œend” event with an error
571
- # async def no_q():
572
- # yield 'event: end\ndata: ["error","No valid query found"]\n\n'
573
- # return StreamingResponse(no_q(), media_type="text/event-stream")
574
-
575
- # # 3) Generate full text
576
- # full_text = process_chatfed_query_core(query).result
577
-
578
- # # 4) Build SSE generator
579
- # async def event_gen():
580
- # # metadata
581
- # meta = {"run_id": str(uuid.uuid4())}
582
- # yield f"event: metadata\ndata: {json.dumps(meta)}\n\n"
583
-
584
- # # one data event per character/token
585
- # for ch in full_text:
586
- # yield f"event: data\ndata: {json.dumps(ch)}\n\n"
587
-
588
- # # end
589
- # # yield "event: end\ndata: [DONE]\n\n"
590
- # yield f"event: end\ndata: {json.dumps('[DONE]')}\n\n"
591
-
592
- # return StreamingResponse(event_gen(), media_type="text/event-stream")
593
-
594
- # except Exception as e:
595
- # logger.exception(f"ui_stream_sse error: {e}")
596
- # async def crash():
597
- # yield f'event: end\ndata: ["error","{str(e)}"]\n\n'
598
- # return StreamingResponse(crash(), media_type="text/event-stream")
599
 
600
  # Mount Gradio at a specific path
601
  demo = create_gradio_interface()
 
 
 
 
1
  from fastapi import FastAPI, Request
2
  from fastapi.responses import StreamingResponse
3
  import json, uuid
 
8
  import gradio as gr
9
  from datetime import datetime
10
  import logging
11
+ import asyncio
12
 
13
  # Set up logging
14
  logging.basicConfig(level=logging.INFO)
 
129
  logger.error(f"Error extracting query from {data}: {e}")
130
  return ""
131
 
132
+ # ─────────────────────────────────────────
133
+ # Streaming Generators
134
+ # ─────────────────────────────────────────
135
+
136
+ async def generate_streaming_response(query: str):
137
+ """Generate streaming response for ChatUI compatibility"""
138
+ try:
139
+ # Process the query
140
+ result = process_chatfed_query_core(query)
141
+ full_text = result.result
142
+
143
+ # Emit metadata event
144
+ metadata = {
145
+ "run_id": str(uuid.uuid4()),
146
+ "timestamp": datetime.now().isoformat(),
147
+ **result.metadata
148
+ }
149
+ yield f"event: metadata\ndata: {json.dumps(metadata)}\n\n"
150
+
151
+ # Stream tokens one by one
152
+ for i, char in enumerate(full_text):
153
+ # Ensure each token is a string
154
+ token_data = str(char)
155
+ yield f"event: data\ndata: {json.dumps(token_data)}\n\n"
156
+
157
+ # Small delay to simulate realistic streaming
158
+ await asyncio.sleep(0.01)
159
+
160
+ # End event
161
+ yield f"event: end\ndata: {json.dumps('[DONE]')}\n\n"
162
+
163
+ except Exception as e:
164
+ logger.error(f"Error in streaming response: {e}")
165
+ error_data = {"error": str(e)}
166
+ yield f"event: error\ndata: {json.dumps(error_data)}\n\n"
167
+ yield f"event: end\ndata: {json.dumps('[ERROR]')}\n\n"
168
+
169
  # ─────────────────────────────────────────
170
  # Handlers - Multiple for Different Use Cases
171
  # ─────────────────────────────────────────
 
364
 
365
  return demo
366
 
 
 
367
  # ─────────────────────────────────────────
368
  # App Startup
369
  # ─────────────────────────────────────────
 
374
  version="1.0.0"
375
  )
376
 
377
+ # Add request logging middleware for debugging
378
+ @app.middleware("http")
379
+ async def log_requests(request: Request, call_next):
380
+ """Log incoming requests for debugging"""
381
+ if request.url.path.startswith("/chatfed") or request.url.path.startswith("/debug"):
382
+ try:
383
+ body = await request.body()
384
+ logger.info(f"=== REQUEST DEBUG ===")
385
+ logger.info(f"Path: {request.url.path}")
386
+ logger.info(f"Method: {request.method}")
387
+ logger.info(f"Headers: {dict(request.headers)}")
388
+ logger.info(f"Body: {body.decode('utf-8') if body else 'Empty'}")
389
 
390
+ # Recreate request for next handler
391
+ async def receive():
392
+ return {"type": "http.request", "body": body}
393
 
394
+ request._receive = receive
395
 
396
+ except Exception as e:
397
+ logger.error(f"Error logging request: {e}")
398
 
399
+ response = await call_next(request)
400
+ return response
401
 
402
  # ─────────────────────────────────────────
403
  # LangServe Routes - Flexible input handling
 
408
  app,
409
  RunnableLambda(flexible_handler),
410
  path="/chatfed",
 
411
  output_type=ChatFedOutput
412
  )
413
 
 
416
  app,
417
  RunnableLambda(chatui_handler),
418
  path="/chatfed-chatui",
 
419
  output_type=ChatFedOutput
420
  )
421
 
 
424
  app,
425
  RunnableLambda(legacy_langserve_handler),
426
  path="/chatfed-strict",
 
427
  output_type=ChatFedOutput
428
  )
429
 
430
+ # ─────────────────────────────────────────
431
+ # Custom Streaming Endpoint for ChatUI
432
+ # ─────────────────────────────────────────
433
+
434
+ @app.post("/chatfed-ui-stream/stream")
435
+ async def chatui_stream_endpoint(request: Request) -> StreamingResponse:
436
+ """
437
+ Proper streaming endpoint for ChatUI's langserve integration.
438
+ Returns Server-Sent Events with individual tokens.
439
+ """
440
+ try:
441
+ # Get the request payload
442
+ payload = await request.json()
443
+ logger.info(f"Stream endpoint received: {payload}")
444
+
445
+ # Handle ChatUI's envelope format
446
+ if isinstance(payload.get("input"), dict):
447
+ input_data = payload["input"]
448
+ else:
449
+ input_data = payload
450
+
451
+ # Extract query using flexible approach
452
+ try:
453
+ flexible_input = FlexibleChatInput(**input_data)
454
+ query = flexible_input.extract_query()
455
+ except Exception as e:
456
+ logger.warning(f"Failed to parse as FlexibleChatInput: {e}")
457
+ query = extract_user_query_fallback(input_data)
458
+
459
+ if not query.strip():
460
+ # Return error stream
461
+ async def error_stream():
462
+ yield f"event: error\ndata: {json.dumps({'error': 'No valid query found'})}\n\n"
463
+ yield f"event: end\ndata: {json.dumps('[ERROR]')}\n\n"
464
+
465
+ return StreamingResponse(
466
+ error_stream(),
467
+ media_type="text/event-stream",
468
+ headers={
469
+ "Cache-Control": "no-cache",
470
+ "Connection": "keep-alive",
471
+ }
472
+ )
473
+
474
+ # Return successful stream
475
+ return StreamingResponse(
476
+ generate_streaming_response(query),
477
+ media_type="text/event-stream",
478
+ headers={
479
+ "Cache-Control": "no-cache",
480
+ "Connection": "keep-alive",
481
+ }
482
+ )
483
+
484
+ except Exception as e:
485
+ logger.error(f"Error in stream endpoint: {e}")
486
+
487
+ async def error_stream():
488
+ yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
489
+ yield f"event: end\ndata: {json.dumps('[ERROR]')}\n\n"
490
+
491
+ return StreamingResponse(
492
+ error_stream(),
493
+ media_type="text/event-stream",
494
+ headers={
495
+ "Cache-Control": "no-cache",
496
+ "Connection": "keep-alive",
497
+ }
498
+ )
499
 
500
  # ─────────────────────────────────────────
501
  # Additional Endpoints
 
516
  "primary": "/chatfed (flexible input - use this for ChatUI)",
517
  "chatui": "/chatfed-chatui",
518
  "legacy": "/chatfed-strict (requires 'query' field)",
519
+ "streaming": "/chatfed-ui-stream/stream (proper SSE streaming for ChatUI)",
520
  "openai": "/v1/chat/completions",
521
  "simple": "/simple-chat",
522
  "gradio_ui": "/ui",
 
640
  "error": str(e),
641
  "raw_body": raw_body.decode('utf-8') if 'raw_body' in locals() else "failed_to_read"
642
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
 
644
  # Mount Gradio at a specific path
645
  demo = create_gradio_interface()