cidadao.ai-backend / src /api /routes /websocket_chat.py
anderson-ufrj
refactor(performance): replace all json imports with json_utils
9730fbc
"""
WebSocket endpoints for real-time bidirectional chat communication.
This module provides WebSocket connections for:
- Real-time chat with agents
- Live investigation status updates
- Anomaly detection notifications
- Multi-user collaboration
"""
from typing import Dict, List, Set, Optional, Any
from datetime import datetime
from src.core import json_utils
import asyncio
from uuid import uuid4
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, HTTPException, Query
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from src.core import get_logger
from src.core.exceptions import AuthenticationError
from src.services.chat_service import ChatService
from src.api.dependencies import get_current_optional_user
logger = get_logger(__name__)
router = APIRouter()
# Security
security = HTTPBearer(auto_error=False)
class WebSocketMessage(BaseModel):
"""WebSocket message structure."""
type: str = Field(..., description="Message type")
data: Dict[str, Any] = Field(default_factory=dict, description="Message data")
timestamp: datetime = Field(default_factory=datetime.utcnow)
id: str = Field(default_factory=lambda: str(uuid4()))
class ConnectionManager:
"""Manages WebSocket connections and broadcasting."""
def __init__(self):
# Active connections by session_id
self.active_connections: Dict[str, List[WebSocket]] = {}
# User subscriptions to investigations
self.investigation_subscriptions: Dict[str, Set[str]] = {}
# Connection metadata
self.connection_metadata: Dict[str, Dict[str, Any]] = {}
async def connect(self, websocket: WebSocket, session_id: str, metadata: Dict[str, Any] = None):
"""Accept and register new WebSocket connection."""
await websocket.accept()
if session_id not in self.active_connections:
self.active_connections[session_id] = []
self.active_connections[session_id].append(websocket)
self.connection_metadata[id(websocket)] = metadata or {}
logger.info(f"WebSocket connected: session={session_id}, total_connections={len(self.active_connections[session_id])}")
# Send welcome message
await self.send_personal_message(
WebSocketMessage(
type="connection",
data={
"status": "connected",
"session_id": session_id,
"message": "Conectado ao Cidadão.AI em tempo real"
}
),
websocket
)
async def disconnect(self, websocket: WebSocket, session_id: str):
"""Remove WebSocket connection."""
if session_id in self.active_connections:
self.active_connections[session_id].remove(websocket)
if not self.active_connections[session_id]:
del self.active_connections[session_id]
# Clean up metadata
if id(websocket) in self.connection_metadata:
del self.connection_metadata[id(websocket)]
# Remove from investigation subscriptions
for investigation_id, subscribers in self.investigation_subscriptions.items():
if session_id in subscribers:
subscribers.remove(session_id)
logger.info(f"WebSocket disconnected: session={session_id}")
async def send_personal_message(self, message: WebSocketMessage, websocket: WebSocket):
"""Send message to specific WebSocket."""
try:
await websocket.send_json(message.model_dump(mode='json'))
except Exception as e:
logger.error(f"Error sending WebSocket message: {e}")
async def broadcast_to_session(self, message: WebSocketMessage, session_id: str):
"""Broadcast message to all connections in a session."""
if session_id in self.active_connections:
disconnected = []
for websocket in self.active_connections[session_id]:
try:
await websocket.send_json(message.model_dump(mode='json'))
except Exception as e:
logger.error(f"Error broadcasting to session {session_id}: {e}")
disconnected.append(websocket)
# Clean up disconnected websockets
for ws in disconnected:
await self.disconnect(ws, session_id)
async def broadcast_to_investigation(self, message: WebSocketMessage, investigation_id: str):
"""Broadcast message to all subscribers of an investigation."""
if investigation_id in self.investigation_subscriptions:
for session_id in self.investigation_subscriptions[investigation_id]:
await self.broadcast_to_session(message, session_id)
def subscribe_to_investigation(self, session_id: str, investigation_id: str):
"""Subscribe session to investigation updates."""
if investigation_id not in self.investigation_subscriptions:
self.investigation_subscriptions[investigation_id] = set()
self.investigation_subscriptions[investigation_id].add(session_id)
logger.info(f"Session {session_id} subscribed to investigation {investigation_id}")
def unsubscribe_from_investigation(self, session_id: str, investigation_id: str):
"""Unsubscribe session from investigation updates."""
if investigation_id in self.investigation_subscriptions:
self.investigation_subscriptions[investigation_id].discard(session_id)
if not self.investigation_subscriptions[investigation_id]:
del self.investigation_subscriptions[investigation_id]
# Global connection manager
manager = ConnectionManager()
@router.websocket("/ws/chat/{session_id}")
async def websocket_chat_endpoint(
websocket: WebSocket,
session_id: str,
token: Optional[str] = Query(None)
):
"""
WebSocket endpoint for real-time chat.
Features:
- Bidirectional communication with agents
- Real-time streaming responses
- Investigation status updates
- Anomaly notifications
Message Types:
- chat: User messages and agent responses
- status: Investigation status updates
- notification: Anomaly alerts and notifications
- subscribe/unsubscribe: Investigation subscriptions
- ping/pong: Keep-alive messages
"""
# Optional authentication
user = None
if token:
try:
# Validate token (simplified for example)
# In production, properly validate JWT token
user = {"id": "user-123", "name": "User"}
except Exception as e:
await websocket.close(code=1008, reason="Invalid authentication")
return
# Connect
await manager.connect(websocket, session_id, {"user": user})
chat_service = ChatService()
try:
while True:
# Receive message
data = await websocket.receive_json()
message = WebSocketMessage(**data)
logger.info(f"WebSocket message received: type={message.type}, session={session_id}")
# Handle different message types
if message.type == "chat":
# Process chat message
user_message = message.data.get("message", "")
# Send typing indicator
await manager.broadcast_to_session(
WebSocketMessage(
type="typing",
data={"agent": "processing"}
),
session_id
)
# Get response from chat service
try:
response = await chat_service.process_message(
message=user_message,
session_id=session_id,
stream=True
)
# Stream response chunks
async for chunk in response:
await manager.broadcast_to_session(
WebSocketMessage(
type="chat",
data={
"role": "assistant",
"content": chunk.get("content", ""),
"agent_id": chunk.get("agent_id"),
"agent_name": chunk.get("agent_name"),
"chunk": True
}
),
session_id
)
# Send completion
await manager.broadcast_to_session(
WebSocketMessage(
type="chat_complete",
data={"status": "completed"}
),
session_id
)
except Exception as e:
logger.error(f"Error processing chat message: {e}")
await manager.broadcast_to_session(
WebSocketMessage(
type="error",
data={"message": "Erro ao processar mensagem"}
),
session_id
)
elif message.type == "subscribe":
# Subscribe to investigation updates
investigation_id = message.data.get("investigation_id")
if investigation_id:
manager.subscribe_to_investigation(session_id, investigation_id)
await manager.send_personal_message(
WebSocketMessage(
type="subscribed",
data={
"investigation_id": investigation_id,
"message": f"Inscrito para atualizações da investigação {investigation_id}"
}
),
websocket
)
elif message.type == "unsubscribe":
# Unsubscribe from investigation
investigation_id = message.data.get("investigation_id")
if investigation_id:
manager.unsubscribe_from_investigation(session_id, investigation_id)
await manager.send_personal_message(
WebSocketMessage(
type="unsubscribed",
data={"investigation_id": investigation_id}
),
websocket
)
elif message.type == "ping":
# Keep-alive ping
await manager.send_personal_message(
WebSocketMessage(type="pong", data={}),
websocket
)
except WebSocketDisconnect:
await manager.disconnect(websocket, session_id)
except Exception as e:
logger.error(f"WebSocket error: {e}")
await manager.disconnect(websocket, session_id)
@router.websocket("/ws/investigations/{investigation_id}")
async def websocket_investigation_endpoint(
websocket: WebSocket,
investigation_id: str,
token: Optional[str] = Query(None)
):
"""
WebSocket endpoint for investigation-specific updates.
Receives real-time updates for:
- Investigation progress
- Anomaly detections
- Agent findings
- Report generation status
"""
session_id = f"investigation-{investigation_id}-{uuid4()}"
await manager.connect(websocket, session_id)
manager.subscribe_to_investigation(session_id, investigation_id)
try:
while True:
# Keep connection alive
await asyncio.sleep(30)
await manager.send_personal_message(
WebSocketMessage(
type="heartbeat",
data={"investigation_id": investigation_id}
),
websocket
)
except WebSocketDisconnect:
await manager.disconnect(websocket, session_id)
except Exception as e:
logger.error(f"Investigation WebSocket error: {e}")
await manager.disconnect(websocket, session_id)
# Helper functions for sending notifications from other parts of the app
async def notify_investigation_update(investigation_id: str, update_type: str, data: Dict[str, Any]):
"""Send investigation update to all subscribers."""
await manager.broadcast_to_investigation(
WebSocketMessage(
type=f"investigation_{update_type}",
data={
"investigation_id": investigation_id,
"update_type": update_type,
**data
}
),
investigation_id
)
async def notify_anomaly_detected(investigation_id: str, anomaly_data: Dict[str, Any]):
"""Notify subscribers of detected anomaly."""
await manager.broadcast_to_investigation(
WebSocketMessage(
type="anomaly_detected",
data={
"investigation_id": investigation_id,
"severity": anomaly_data.get("severity", "medium"),
"description": anomaly_data.get("description"),
"details": anomaly_data
}
),
investigation_id
)
async def notify_chat_session(session_id: str, notification: Dict[str, Any]):
"""Send notification to chat session."""
await manager.broadcast_to_session(
WebSocketMessage(
type="notification",
data=notification
),
session_id
)