cidadao.ai-backend / src /api /websocket.py
anderson-ufrj
refactor(performance): replace all json imports with json_utils
9730fbc
raw
history blame
12.7 kB
"""
WebSocket manager for real-time communication in Cidadão.AI
Handles investigation streaming, analysis updates, and notifications
"""
from src.core import json_utils
import asyncio
import logging
from typing import Dict, List, Set, Optional
from datetime import datetime
from fastapi import WebSocket, WebSocketDisconnect
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class WebSocketMessage(BaseModel):
"""Standard WebSocket message format"""
type: str
data: dict
timestamp: datetime = None
user_id: str = None
def __init__(self, **data):
if 'timestamp' not in data:
data['timestamp'] = datetime.utcnow()
super().__init__(**data)
class ConnectionManager:
"""Manages WebSocket connections and message broadcasting"""
def __init__(self):
# Active connections by user ID
self.user_connections: Dict[str, Set[WebSocket]] = {}
# Connections by investigation ID
self.investigation_connections: Dict[str, Set[WebSocket]] = {}
# Connections by analysis ID
self.analysis_connections: Dict[str, Set[WebSocket]] = {}
# Global notification connections
self.notification_connections: Set[WebSocket] = set()
# Connection metadata
self.connection_metadata: Dict[WebSocket, dict] = {}
async def connect(self, websocket: WebSocket, user_id: str, connection_type: str = "general"):
"""Accept new WebSocket connection"""
await websocket.accept()
# Store connection metadata
self.connection_metadata[websocket] = {
'user_id': user_id,
'connection_type': connection_type,
'connected_at': datetime.utcnow(),
'last_ping': datetime.utcnow()
}
# Add to user connections
if user_id not in self.user_connections:
self.user_connections[user_id] = set()
self.user_connections[user_id].add(websocket)
# Add to notification connections
self.notification_connections.add(websocket)
logger.info(f"WebSocket connected: user_id={user_id}, type={connection_type}")
# Send welcome message
await self.send_personal_message(websocket, WebSocketMessage(
type="connection_established",
data={
"message": "WebSocket connection established",
"user_id": user_id,
"connection_type": connection_type
}
))
def disconnect(self, websocket: WebSocket):
"""Remove WebSocket connection"""
if websocket not in self.connection_metadata:
return
metadata = self.connection_metadata[websocket]
user_id = metadata['user_id']
# Remove from all connection sets
if user_id in self.user_connections:
self.user_connections[user_id].discard(websocket)
if not self.user_connections[user_id]:
del self.user_connections[user_id]
self.notification_connections.discard(websocket)
# Remove from investigation/analysis connections
for connections in self.investigation_connections.values():
connections.discard(websocket)
for connections in self.analysis_connections.values():
connections.discard(websocket)
# Clean up metadata
del self.connection_metadata[websocket]
logger.info(f"WebSocket disconnected: user_id={user_id}")
async def send_personal_message(self, websocket: WebSocket, message: WebSocketMessage):
"""Send message to specific WebSocket connection"""
try:
await websocket.send_text(message.json())
except Exception as e:
logger.error(f"Failed to send message to WebSocket: {e}")
self.disconnect(websocket)
async def send_to_user(self, user_id: str, message: WebSocketMessage):
"""Send message to all connections of a specific user"""
if user_id not in self.user_connections:
return
message.user_id = user_id
disconnected = set()
for websocket in self.user_connections[user_id].copy():
try:
await websocket.send_text(message.json())
except Exception as e:
logger.error(f"Failed to send message to user {user_id}: {e}")
disconnected.add(websocket)
# Clean up disconnected sockets
for websocket in disconnected:
self.disconnect(websocket)
async def broadcast_to_all(self, message: WebSocketMessage):
"""Broadcast message to all connected users"""
disconnected = set()
for websocket in self.notification_connections.copy():
try:
await websocket.send_text(message.json())
except Exception as e:
logger.error(f"Failed to broadcast message: {e}")
disconnected.add(websocket)
# Clean up disconnected sockets
for websocket in disconnected:
self.disconnect(websocket)
async def subscribe_to_investigation(self, websocket: WebSocket, investigation_id: str):
"""Subscribe WebSocket to investigation updates"""
if investigation_id not in self.investigation_connections:
self.investigation_connections[investigation_id] = set()
self.investigation_connections[investigation_id].add(websocket)
await self.send_personal_message(websocket, WebSocketMessage(
type="subscribed_to_investigation",
data={
"investigation_id": investigation_id,
"message": f"Subscribed to investigation {investigation_id}"
}
))
async def unsubscribe_from_investigation(self, websocket: WebSocket, investigation_id: str):
"""Unsubscribe WebSocket from investigation updates"""
if investigation_id in self.investigation_connections:
self.investigation_connections[investigation_id].discard(websocket)
if not self.investigation_connections[investigation_id]:
del self.investigation_connections[investigation_id]
async def send_to_investigation(self, investigation_id: str, message: WebSocketMessage):
"""Send message to all subscribers of an investigation"""
if investigation_id not in self.investigation_connections:
return
disconnected = set()
for websocket in self.investigation_connections[investigation_id].copy():
try:
await websocket.send_text(message.json())
except Exception as e:
logger.error(f"Failed to send investigation update: {e}")
disconnected.add(websocket)
# Clean up disconnected sockets
for websocket in disconnected:
self.disconnect(websocket)
async def subscribe_to_analysis(self, websocket: WebSocket, analysis_id: str):
"""Subscribe WebSocket to analysis updates"""
if analysis_id not in self.analysis_connections:
self.analysis_connections[analysis_id] = set()
self.analysis_connections[analysis_id].add(websocket)
await self.send_personal_message(websocket, WebSocketMessage(
type="subscribed_to_analysis",
data={
"analysis_id": analysis_id,
"message": f"Subscribed to analysis {analysis_id}"
}
))
async def send_to_analysis(self, analysis_id: str, message: WebSocketMessage):
"""Send message to all subscribers of an analysis"""
if analysis_id not in self.analysis_connections:
return
disconnected = set()
for websocket in self.analysis_connections[analysis_id].copy():
try:
await websocket.send_text(message.json())
except Exception as e:
logger.error(f"Failed to send analysis update: {e}")
disconnected.add(websocket)
# Clean up disconnected sockets
for websocket in disconnected:
self.disconnect(websocket)
async def send_system_notification(self, notification_type: str, data: dict):
"""Send system-wide notification"""
message = WebSocketMessage(
type="system_notification",
data={
"notification_type": notification_type,
**data
}
)
await self.broadcast_to_all(message)
def get_connection_stats(self) -> dict:
"""Get WebSocket connection statistics"""
return {
"total_connections": len(self.connection_metadata),
"users_connected": len(self.user_connections),
"active_investigations": len(self.investigation_connections),
"active_analyses": len(self.analysis_connections),
"notification_subscribers": len(self.notification_connections)
}
async def ping_all_connections(self):
"""Send ping to all connections to keep them alive"""
ping_message = WebSocketMessage(
type="ping",
data={"timestamp": datetime.utcnow().isoformat()}
)
disconnected = set()
for websocket in list(self.connection_metadata.keys()):
try:
await websocket.send_text(ping_message.json())
self.connection_metadata[websocket]['last_ping'] = datetime.utcnow()
except Exception:
disconnected.add(websocket)
# Clean up disconnected sockets
for websocket in disconnected:
self.disconnect(websocket)
# Global connection manager instance
connection_manager = ConnectionManager()
class WebSocketHandler:
"""Handles WebSocket message processing"""
def __init__(self, connection_manager: ConnectionManager):
self.connection_manager = connection_manager
async def handle_message(self, websocket: WebSocket, message: dict):
"""Process incoming WebSocket message"""
message_type = message.get('type')
data = message.get('data', {})
try:
if message_type == "subscribe_investigation":
investigation_id = data.get('investigation_id')
if investigation_id:
await self.connection_manager.subscribe_to_investigation(websocket, investigation_id)
elif message_type == "unsubscribe_investigation":
investigation_id = data.get('investigation_id')
if investigation_id:
await self.connection_manager.unsubscribe_from_investigation(websocket, investigation_id)
elif message_type == "subscribe_analysis":
analysis_id = data.get('analysis_id')
if analysis_id:
await self.connection_manager.subscribe_to_analysis(websocket, analysis_id)
elif message_type == "pong":
# Handle pong response
if websocket in self.connection_manager.connection_metadata:
self.connection_manager.connection_metadata[websocket]['last_ping'] = datetime.utcnow()
else:
logger.warning(f"Unknown WebSocket message type: {message_type}")
except Exception as e:
logger.error(f"Error handling WebSocket message: {e}")
error_message = WebSocketMessage(
type="error",
data={
"message": f"Failed to process message: {str(e)}",
"original_type": message_type
}
)
await self.connection_manager.send_personal_message(websocket, error_message)
# Global WebSocket handler
websocket_handler = WebSocketHandler(connection_manager)
# Background task for connection maintenance
async def connection_maintenance_task():
"""Background task to maintain WebSocket connections"""
while True:
try:
await connection_manager.ping_all_connections()
await asyncio.sleep(30) # Ping every 30 seconds
except Exception as e:
logger.error(f"Error in connection maintenance: {e}")
await asyncio.sleep(60) # Wait longer on error