Spaces:
Running
Running
# Gradio UI not currenlty working. | |
import gradio as gr | |
from fastapi import FastAPI | |
from langserve import add_routes | |
from langgraph.graph import StateGraph, START, END | |
from typing import Optional, Dict, Any | |
from typing_extensions import TypedDict | |
from pydantic import BaseModel | |
from gradio_client import Client | |
import uvicorn | |
import os | |
from datetime import datetime | |
import logging | |
from contextlib import asynccontextmanager | |
import threading | |
from langchain_core.runnables import RunnableLambda | |
from utils import getconfig | |
config = getconfig("params.cfg") | |
RETRIEVER = config.get("retriever", "RETRIEVER") | |
GENERATOR = config.get("generator", "GENERATOR") | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Models | |
class GraphState(TypedDict): | |
query: str | |
context: str | |
result: str | |
reports_filter: str | |
sources_filter: str | |
subtype_filter: str | |
year_filter: str | |
metadata: Optional[Dict[str, Any]] | |
class ChatFedInput(TypedDict): | |
query: str | |
reports_filter: Optional[str] | |
sources_filter: Optional[str] | |
subtype_filter: Optional[str] | |
year_filter: Optional[str] | |
session_id: Optional[str] | |
user_id: Optional[str] | |
class ChatFedOutput(TypedDict): | |
result: str | |
metadata: Dict[str, Any] | |
class ChatUIInput(BaseModel): | |
text: str | |
# Module functions | |
def retrieve_node(state: GraphState) -> GraphState: | |
start_time = datetime.now() | |
logger.info(f"Retrieval: {state['query'][:50]}...") | |
try: | |
client = Client(RETRIEVER) | |
context = client.predict( | |
query=state["query"], | |
reports_filter=state.get("reports_filter", ""), | |
sources_filter=state.get("sources_filter", ""), | |
subtype_filter=state.get("subtype_filter", ""), | |
year_filter=state.get("year_filter", ""), | |
api_name="/retrieve" | |
) | |
duration = (datetime.now() - start_time).total_seconds() | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"retrieval_duration": duration, | |
"context_length": len(context) if context else 0, | |
"retrieval_success": True | |
}) | |
return {"context": context, "metadata": metadata} | |
except Exception as e: | |
duration = (datetime.now() - start_time).total_seconds() | |
logger.error(f"Retrieval failed: {str(e)}") | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"retrieval_duration": duration, | |
"retrieval_success": False, | |
"retrieval_error": str(e) | |
}) | |
return {"context": "", "metadata": metadata} | |
def generate_node(state: GraphState) -> GraphState: | |
start_time = datetime.now() | |
logger.info(f"Generation: {state['query'][:50]}...") | |
try: | |
client = Client(GENERATOR) | |
result = client.predict( | |
query=state["query"], | |
context=state["context"], | |
api_name="/generate" | |
) | |
duration = (datetime.now() - start_time).total_seconds() | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"generation_duration": duration, | |
"result_length": len(result) if result else 0, | |
"generation_success": True | |
}) | |
return {"result": result, "metadata": metadata} | |
except Exception as e: | |
duration = (datetime.now() - start_time).total_seconds() | |
logger.error(f"Generation failed: {str(e)}") | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"generation_duration": duration, | |
"generation_success": False, | |
"generation_error": str(e) | |
}) | |
return {"result": f"Error: {str(e)}", "metadata": metadata} | |
# start the graph | |
workflow = StateGraph(GraphState) | |
workflow.add_node("retrieve", retrieve_node) | |
workflow.add_node("generate", generate_node) | |
workflow.add_edge(START, "retrieve") | |
workflow.add_edge("retrieve", "generate") | |
workflow.add_edge("generate", END) | |
compiled_graph = workflow.compile() | |
def process_query_core( | |
query: str, | |
reports_filter: str = "", | |
sources_filter: str = "", | |
subtype_filter: str = "", | |
year_filter: str = "", | |
session_id: Optional[str] = None, | |
user_id: Optional[str] = None, | |
return_metadata: bool = False | |
): | |
start_time = datetime.now() | |
if not session_id: | |
session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}" | |
try: | |
initial_state = { | |
"query": query, | |
"context": "", | |
"result": "", | |
"reports_filter": reports_filter or "", | |
"sources_filter": sources_filter or "", | |
"subtype_filter": subtype_filter or "", | |
"year_filter": year_filter or "", | |
"metadata": { | |
"session_id": session_id, | |
"user_id": user_id, | |
"start_time": start_time.isoformat() | |
} | |
} | |
final_state = compiled_graph.invoke(initial_state) | |
total_duration = (datetime.now() - start_time).total_seconds() | |
final_metadata = final_state.get("metadata", {}) | |
final_metadata.update({ | |
"total_duration": total_duration, | |
"end_time": datetime.now().isoformat(), | |
"pipeline_success": True | |
}) | |
if return_metadata: | |
return {"result": final_state["result"], "metadata": final_metadata} | |
else: | |
return final_state["result"] | |
except Exception as e: | |
total_duration = (datetime.now() - start_time).total_seconds() | |
logger.error(f"Pipeline failed: {str(e)}") | |
if return_metadata: | |
error_metadata = { | |
"session_id": session_id, | |
"total_duration": total_duration, | |
"pipeline_success": False, | |
"error": str(e) | |
} | |
return {"result": f"Error: {str(e)}", "metadata": error_metadata} | |
else: | |
return f"Error: {str(e)}" | |
def process_query_gradio(query: str, reports_filter: str = "", sources_filter: str = "", | |
subtype_filter: str = "", year_filter: str = "") -> str: | |
return process_query_core( | |
query=query, | |
reports_filter=reports_filter, | |
sources_filter=sources_filter, | |
subtype_filter=subtype_filter, | |
year_filter=year_filter, | |
session_id=f"gradio_{datetime.now().strftime('%Y%m%d_%H%M%S')}", | |
return_metadata=False | |
) | |
def chatui_adapter(data) -> str: | |
try: | |
# Handle both dict and Pydantic model input | |
if hasattr(data, 'text'): | |
text = data.text | |
elif isinstance(data, dict) and 'text' in data: | |
text = data['text'] | |
else: | |
logger.error(f"Unexpected input structure: {data}") | |
return "Error: Invalid input format. Expected 'text' field." | |
result = process_query_core( | |
query=text, | |
session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}", | |
return_metadata=False | |
) | |
return result | |
except Exception as e: | |
logger.error(f"ChatUI error: {str(e)}") | |
return f"Error: {str(e)}" | |
def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput: | |
result = process_query_core( | |
query=input_data["query"], | |
reports_filter=input_data.get("reports_filter", ""), | |
sources_filter=input_data.get("sources_filter", ""), | |
subtype_filter=input_data.get("subtype_filter", ""), | |
year_filter=input_data.get("year_filter", ""), | |
session_id=input_data.get("session_id"), | |
user_id=input_data.get("user_id"), | |
return_metadata=True | |
) | |
return ChatFedOutput(result=result["result"], metadata=result["metadata"]) | |
# This is not working currently... Problematic because HF doesn't allow > 1 port open at the same time | |
def create_gradio_interface(): | |
with gr.Blocks(title="ChatFed Orchestrator") as demo: | |
gr.Markdown("# ChatFed Orchestrator") | |
gr.Markdown("MCP endpoints available at `/gradio_api/mcp/sse`") | |
with gr.Row(): | |
with gr.Column(): | |
query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...") | |
reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports") | |
sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal") | |
subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial") | |
year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024") | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
output = gr.Textbox(label="Response", lines=10) | |
submit_btn.click( | |
fn=process_query_gradio, | |
inputs=[query_input, reports_filter_input, sources_filter_input, subtype_filter_input, year_filter_input], | |
outputs=output | |
) | |
return demo | |
async def lifespan(app: FastAPI): | |
logger.info("ChatFed Orchestrator starting up...") | |
yield | |
logger.info("Orchestrator shutting down...") | |
app = FastAPI( | |
title="ChatFed Orchestrator", | |
version="1.0.0", | |
lifespan=lifespan, | |
docs_url=None, | |
redoc_url=None | |
) | |
async def health_check(): | |
return {"status": "healthy"} | |
async def root(): | |
return { | |
"message": "ChatFed Orchestrator API", | |
"endpoints": { | |
"health": "/health", | |
"chatfed": "/chatfed", | |
"chatfed-ui-stream": "/chatfed-ui-stream" | |
} | |
} | |
# LangServe routes (these are the main endpoints) | |
add_routes( | |
app, | |
RunnableLambda(process_query_langserve), | |
path="/chatfed", | |
input_type=ChatFedInput, | |
output_type=ChatFedOutput | |
) | |
add_routes( | |
app, | |
RunnableLambda(chatui_adapter), | |
path="/chatfed-ui-stream", | |
input_type=ChatUIInput, | |
output_type=str, | |
enable_feedback_endpoint=True, | |
enable_public_trace_link_endpoint=True, | |
) | |
def run_gradio_server(): | |
demo = create_gradio_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7861, | |
mcp_server=True, | |
show_error=True, | |
share=False, | |
quiet=True | |
) | |
if __name__ == "__main__": | |
gradio_thread = threading.Thread(target=run_gradio_server, daemon=True) | |
gradio_thread.start() | |
logger.info("Gradio MCP server started on port 7861") | |
host = os.getenv("HOST", "0.0.0.0") | |
port = int(os.getenv("PORT", "7860")) | |
logger.info(f"Starting FastAPI server on {host}:{port}") | |
uvicorn.run(app, host=host, port=port, log_level="info", access_log=True) |