# Gradio UI not currenlty working. import gradio as gr from fastapi import FastAPI from langserve import add_routes from langgraph.graph import StateGraph, START, END from typing import Optional, Dict, Any from typing_extensions import TypedDict from pydantic import BaseModel from gradio_client import Client import uvicorn import os from datetime import datetime import logging from contextlib import asynccontextmanager import threading from langchain_core.runnables import RunnableLambda from utils import getconfig config = getconfig("params.cfg") RETRIEVER = config.get("retriever", "RETRIEVER") GENERATOR = config.get("generator", "GENERATOR") logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Models class GraphState(TypedDict): query: str context: str result: str reports_filter: str sources_filter: str subtype_filter: str year_filter: str metadata: Optional[Dict[str, Any]] class ChatFedInput(TypedDict): query: str reports_filter: Optional[str] sources_filter: Optional[str] subtype_filter: Optional[str] year_filter: Optional[str] session_id: Optional[str] user_id: Optional[str] class ChatFedOutput(TypedDict): result: str metadata: Dict[str, Any] class ChatUIInput(BaseModel): text: str # Module functions def retrieve_node(state: GraphState) -> GraphState: start_time = datetime.now() logger.info(f"Retrieval: {state['query'][:50]}...") try: client = Client(RETRIEVER) context = client.predict( query=state["query"], reports_filter=state.get("reports_filter", ""), sources_filter=state.get("sources_filter", ""), subtype_filter=state.get("subtype_filter", ""), year_filter=state.get("year_filter", ""), api_name="/retrieve" ) duration = (datetime.now() - start_time).total_seconds() metadata = state.get("metadata", {}) metadata.update({ "retrieval_duration": duration, "context_length": len(context) if context else 0, "retrieval_success": True }) return {"context": context, "metadata": metadata} except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Retrieval failed: {str(e)}") metadata = state.get("metadata", {}) metadata.update({ "retrieval_duration": duration, "retrieval_success": False, "retrieval_error": str(e) }) return {"context": "", "metadata": metadata} def generate_node(state: GraphState) -> GraphState: start_time = datetime.now() logger.info(f"Generation: {state['query'][:50]}...") try: client = Client(GENERATOR) result = client.predict( query=state["query"], context=state["context"], api_name="/generate" ) duration = (datetime.now() - start_time).total_seconds() metadata = state.get("metadata", {}) metadata.update({ "generation_duration": duration, "result_length": len(result) if result else 0, "generation_success": True }) return {"result": result, "metadata": metadata} except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Generation failed: {str(e)}") metadata = state.get("metadata", {}) metadata.update({ "generation_duration": duration, "generation_success": False, "generation_error": str(e) }) return {"result": f"Error: {str(e)}", "metadata": metadata} # start the graph workflow = StateGraph(GraphState) workflow.add_node("retrieve", retrieve_node) workflow.add_node("generate", generate_node) workflow.add_edge(START, "retrieve") workflow.add_edge("retrieve", "generate") workflow.add_edge("generate", END) compiled_graph = workflow.compile() def process_query_core( query: str, reports_filter: str = "", sources_filter: str = "", subtype_filter: str = "", year_filter: str = "", session_id: Optional[str] = None, user_id: Optional[str] = None, return_metadata: bool = False ): start_time = datetime.now() if not session_id: session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}" try: initial_state = { "query": query, "context": "", "result": "", "reports_filter": reports_filter or "", "sources_filter": sources_filter or "", "subtype_filter": subtype_filter or "", "year_filter": year_filter or "", "metadata": { "session_id": session_id, "user_id": user_id, "start_time": start_time.isoformat() } } final_state = compiled_graph.invoke(initial_state) total_duration = (datetime.now() - start_time).total_seconds() final_metadata = final_state.get("metadata", {}) final_metadata.update({ "total_duration": total_duration, "end_time": datetime.now().isoformat(), "pipeline_success": True }) if return_metadata: return {"result": final_state["result"], "metadata": final_metadata} else: return final_state["result"] except Exception as e: total_duration = (datetime.now() - start_time).total_seconds() logger.error(f"Pipeline failed: {str(e)}") if return_metadata: error_metadata = { "session_id": session_id, "total_duration": total_duration, "pipeline_success": False, "error": str(e) } return {"result": f"Error: {str(e)}", "metadata": error_metadata} else: return f"Error: {str(e)}" def process_query_gradio(query: str, reports_filter: str = "", sources_filter: str = "", subtype_filter: str = "", year_filter: str = "") -> str: return process_query_core( query=query, reports_filter=reports_filter, sources_filter=sources_filter, subtype_filter=subtype_filter, year_filter=year_filter, session_id=f"gradio_{datetime.now().strftime('%Y%m%d_%H%M%S')}", return_metadata=False ) def chatui_adapter(data) -> str: try: # Handle both dict and Pydantic model input if hasattr(data, 'text'): text = data.text elif isinstance(data, dict) and 'text' in data: text = data['text'] else: logger.error(f"Unexpected input structure: {data}") return "Error: Invalid input format. Expected 'text' field." result = process_query_core( query=text, session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}", return_metadata=False ) return result except Exception as e: logger.error(f"ChatUI error: {str(e)}") return f"Error: {str(e)}" def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput: result = process_query_core( query=input_data["query"], reports_filter=input_data.get("reports_filter", ""), sources_filter=input_data.get("sources_filter", ""), subtype_filter=input_data.get("subtype_filter", ""), year_filter=input_data.get("year_filter", ""), session_id=input_data.get("session_id"), user_id=input_data.get("user_id"), return_metadata=True ) return ChatFedOutput(result=result["result"], metadata=result["metadata"]) # This is not working currently... Problematic because HF doesn't allow > 1 port open at the same time def create_gradio_interface(): with gr.Blocks(title="ChatFed Orchestrator") as demo: gr.Markdown("# ChatFed Orchestrator") gr.Markdown("MCP endpoints available at `/gradio_api/mcp/sse`") with gr.Row(): with gr.Column(): query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...") reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports") sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal") subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial") year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024") submit_btn = gr.Button("Submit", variant="primary") with gr.Column(): output = gr.Textbox(label="Response", lines=10) submit_btn.click( fn=process_query_gradio, inputs=[query_input, reports_filter_input, sources_filter_input, subtype_filter_input, year_filter_input], outputs=output ) return demo @asynccontextmanager async def lifespan(app: FastAPI): logger.info("ChatFed Orchestrator starting up...") yield logger.info("Orchestrator shutting down...") app = FastAPI( title="ChatFed Orchestrator", version="1.0.0", lifespan=lifespan, docs_url=None, redoc_url=None ) @app.get("/health") async def health_check(): return {"status": "healthy"} @app.get("/") async def root(): return { "message": "ChatFed Orchestrator API", "endpoints": { "health": "/health", "chatfed": "/chatfed", "chatfed-ui-stream": "/chatfed-ui-stream" } } # LangServe routes (these are the main endpoints) add_routes( app, RunnableLambda(process_query_langserve), path="/chatfed", input_type=ChatFedInput, output_type=ChatFedOutput ) add_routes( app, RunnableLambda(chatui_adapter), path="/chatfed-ui-stream", input_type=ChatUIInput, output_type=str, enable_feedback_endpoint=True, enable_public_trace_link_endpoint=True, ) def run_gradio_server(): demo = create_gradio_interface() demo.launch( server_name="0.0.0.0", server_port=7861, mcp_server=True, show_error=True, share=False, quiet=True ) if __name__ == "__main__": gradio_thread = threading.Thread(target=run_gradio_server, daemon=True) gradio_thread.start() logger.info("Gradio MCP server started on port 7861") host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "7860")) logger.info(f"Starting FastAPI server on {host}:{port}") uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)