mtyrrell's picture
port of test repo
3d98931
# Gradio UI not currenlty working.
import gradio as gr
from fastapi import FastAPI
from langserve import add_routes
from langgraph.graph import StateGraph, START, END
from typing import Optional, Dict, Any
from typing_extensions import TypedDict
from pydantic import BaseModel
from gradio_client import Client
import uvicorn
import os
from datetime import datetime
import logging
from contextlib import asynccontextmanager
import threading
from langchain_core.runnables import RunnableLambda
from utils import getconfig
config = getconfig("params.cfg")
RETRIEVER = config.get("retriever", "RETRIEVER")
GENERATOR = config.get("generator", "GENERATOR")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Models
class GraphState(TypedDict):
query: str
context: str
result: str
reports_filter: str
sources_filter: str
subtype_filter: str
year_filter: str
metadata: Optional[Dict[str, Any]]
class ChatFedInput(TypedDict):
query: str
reports_filter: Optional[str]
sources_filter: Optional[str]
subtype_filter: Optional[str]
year_filter: Optional[str]
session_id: Optional[str]
user_id: Optional[str]
class ChatFedOutput(TypedDict):
result: str
metadata: Dict[str, Any]
class ChatUIInput(BaseModel):
text: str
# Module functions
def retrieve_node(state: GraphState) -> GraphState:
start_time = datetime.now()
logger.info(f"Retrieval: {state['query'][:50]}...")
try:
client = Client(RETRIEVER)
context = client.predict(
query=state["query"],
reports_filter=state.get("reports_filter", ""),
sources_filter=state.get("sources_filter", ""),
subtype_filter=state.get("subtype_filter", ""),
year_filter=state.get("year_filter", ""),
api_name="/retrieve"
)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"retrieval_duration": duration,
"context_length": len(context) if context else 0,
"retrieval_success": True
})
return {"context": context, "metadata": metadata}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Retrieval failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"retrieval_duration": duration,
"retrieval_success": False,
"retrieval_error": str(e)
})
return {"context": "", "metadata": metadata}
def generate_node(state: GraphState) -> GraphState:
start_time = datetime.now()
logger.info(f"Generation: {state['query'][:50]}...")
try:
client = Client(GENERATOR)
result = client.predict(
query=state["query"],
context=state["context"],
api_name="/generate"
)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"generation_duration": duration,
"result_length": len(result) if result else 0,
"generation_success": True
})
return {"result": result, "metadata": metadata}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Generation failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"generation_duration": duration,
"generation_success": False,
"generation_error": str(e)
})
return {"result": f"Error: {str(e)}", "metadata": metadata}
# start the graph
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("generate", generate_node)
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
compiled_graph = workflow.compile()
def process_query_core(
query: str,
reports_filter: str = "",
sources_filter: str = "",
subtype_filter: str = "",
year_filter: str = "",
session_id: Optional[str] = None,
user_id: Optional[str] = None,
return_metadata: bool = False
):
start_time = datetime.now()
if not session_id:
session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}"
try:
initial_state = {
"query": query,
"context": "",
"result": "",
"reports_filter": reports_filter or "",
"sources_filter": sources_filter or "",
"subtype_filter": subtype_filter or "",
"year_filter": year_filter or "",
"metadata": {
"session_id": session_id,
"user_id": user_id,
"start_time": start_time.isoformat()
}
}
final_state = compiled_graph.invoke(initial_state)
total_duration = (datetime.now() - start_time).total_seconds()
final_metadata = final_state.get("metadata", {})
final_metadata.update({
"total_duration": total_duration,
"end_time": datetime.now().isoformat(),
"pipeline_success": True
})
if return_metadata:
return {"result": final_state["result"], "metadata": final_metadata}
else:
return final_state["result"]
except Exception as e:
total_duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Pipeline failed: {str(e)}")
if return_metadata:
error_metadata = {
"session_id": session_id,
"total_duration": total_duration,
"pipeline_success": False,
"error": str(e)
}
return {"result": f"Error: {str(e)}", "metadata": error_metadata}
else:
return f"Error: {str(e)}"
def process_query_gradio(query: str, reports_filter: str = "", sources_filter: str = "",
subtype_filter: str = "", year_filter: str = "") -> str:
return process_query_core(
query=query,
reports_filter=reports_filter,
sources_filter=sources_filter,
subtype_filter=subtype_filter,
year_filter=year_filter,
session_id=f"gradio_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
return_metadata=False
)
def chatui_adapter(data) -> str:
try:
# Handle both dict and Pydantic model input
if hasattr(data, 'text'):
text = data.text
elif isinstance(data, dict) and 'text' in data:
text = data['text']
else:
logger.error(f"Unexpected input structure: {data}")
return "Error: Invalid input format. Expected 'text' field."
result = process_query_core(
query=text,
session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
return_metadata=False
)
return result
except Exception as e:
logger.error(f"ChatUI error: {str(e)}")
return f"Error: {str(e)}"
def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
result = process_query_core(
query=input_data["query"],
reports_filter=input_data.get("reports_filter", ""),
sources_filter=input_data.get("sources_filter", ""),
subtype_filter=input_data.get("subtype_filter", ""),
year_filter=input_data.get("year_filter", ""),
session_id=input_data.get("session_id"),
user_id=input_data.get("user_id"),
return_metadata=True
)
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
# This is not working currently... Problematic because HF doesn't allow > 1 port open at the same time
def create_gradio_interface():
with gr.Blocks(title="ChatFed Orchestrator") as demo:
gr.Markdown("# ChatFed Orchestrator")
gr.Markdown("MCP endpoints available at `/gradio_api/mcp/sse`")
with gr.Row():
with gr.Column():
query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
output = gr.Textbox(label="Response", lines=10)
submit_btn.click(
fn=process_query_gradio,
inputs=[query_input, reports_filter_input, sources_filter_input, subtype_filter_input, year_filter_input],
outputs=output
)
return demo
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("ChatFed Orchestrator starting up...")
yield
logger.info("Orchestrator shutting down...")
app = FastAPI(
title="ChatFed Orchestrator",
version="1.0.0",
lifespan=lifespan,
docs_url=None,
redoc_url=None
)
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.get("/")
async def root():
return {
"message": "ChatFed Orchestrator API",
"endpoints": {
"health": "/health",
"chatfed": "/chatfed",
"chatfed-ui-stream": "/chatfed-ui-stream"
}
}
# LangServe routes (these are the main endpoints)
add_routes(
app,
RunnableLambda(process_query_langserve),
path="/chatfed",
input_type=ChatFedInput,
output_type=ChatFedOutput
)
add_routes(
app,
RunnableLambda(chatui_adapter),
path="/chatfed-ui-stream",
input_type=ChatUIInput,
output_type=str,
enable_feedback_endpoint=True,
enable_public_trace_link_endpoint=True,
)
def run_gradio_server():
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7861,
mcp_server=True,
show_error=True,
share=False,
quiet=True
)
if __name__ == "__main__":
gradio_thread = threading.Thread(target=run_gradio_server, daemon=True)
gradio_thread.start()
logger.info("Gradio MCP server started on port 7861")
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "7860"))
logger.info(f"Starting FastAPI server on {host}:{port}")
uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)