| """ |
| WebSocket manager for real-time communication in Cidadão.AI |
| Handles investigation streaming, analysis updates, and notifications |
| """ |
|
|
| from src.core import json_utils |
| import asyncio |
| import logging |
| from typing import Dict, List, Set, Optional |
| from datetime import datetime |
| from fastapi import WebSocket, WebSocketDisconnect |
| from pydantic import BaseModel |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class WebSocketMessage(BaseModel): |
| """Standard WebSocket message format""" |
| type: str |
| data: dict |
| timestamp: datetime = None |
| user_id: str = None |
| |
| def __init__(self, **data): |
| if 'timestamp' not in data: |
| data['timestamp'] = datetime.utcnow() |
| super().__init__(**data) |
|
|
| class ConnectionManager: |
| """Manages WebSocket connections and message broadcasting""" |
| |
| def __init__(self): |
| |
| self.user_connections: Dict[str, Set[WebSocket]] = {} |
| |
| |
| self.investigation_connections: Dict[str, Set[WebSocket]] = {} |
| |
| |
| self.analysis_connections: Dict[str, Set[WebSocket]] = {} |
| |
| |
| self.notification_connections: Set[WebSocket] = set() |
| |
| |
| self.connection_metadata: Dict[WebSocket, dict] = {} |
| |
| async def connect(self, websocket: WebSocket, user_id: str, connection_type: str = "general"): |
| """Accept new WebSocket connection""" |
| await websocket.accept() |
| |
| |
| self.connection_metadata[websocket] = { |
| 'user_id': user_id, |
| 'connection_type': connection_type, |
| 'connected_at': datetime.utcnow(), |
| 'last_ping': datetime.utcnow() |
| } |
| |
| |
| if user_id not in self.user_connections: |
| self.user_connections[user_id] = set() |
| self.user_connections[user_id].add(websocket) |
| |
| |
| self.notification_connections.add(websocket) |
| |
| logger.info(f"WebSocket connected: user_id={user_id}, type={connection_type}") |
| |
| |
| await self.send_personal_message(websocket, WebSocketMessage( |
| type="connection_established", |
| data={ |
| "message": "WebSocket connection established", |
| "user_id": user_id, |
| "connection_type": connection_type |
| } |
| )) |
| |
| def disconnect(self, websocket: WebSocket): |
| """Remove WebSocket connection""" |
| if websocket not in self.connection_metadata: |
| return |
| |
| metadata = self.connection_metadata[websocket] |
| user_id = metadata['user_id'] |
| |
| |
| if user_id in self.user_connections: |
| self.user_connections[user_id].discard(websocket) |
| if not self.user_connections[user_id]: |
| del self.user_connections[user_id] |
| |
| self.notification_connections.discard(websocket) |
| |
| |
| for connections in self.investigation_connections.values(): |
| connections.discard(websocket) |
| |
| for connections in self.analysis_connections.values(): |
| connections.discard(websocket) |
| |
| |
| del self.connection_metadata[websocket] |
| |
| logger.info(f"WebSocket disconnected: user_id={user_id}") |
| |
| async def send_personal_message(self, websocket: WebSocket, message: WebSocketMessage): |
| """Send message to specific WebSocket connection""" |
| try: |
| await websocket.send_text(message.json()) |
| except Exception as e: |
| logger.error(f"Failed to send message to WebSocket: {e}") |
| self.disconnect(websocket) |
| |
| async def send_to_user(self, user_id: str, message: WebSocketMessage): |
| """Send message to all connections of a specific user""" |
| if user_id not in self.user_connections: |
| return |
| |
| message.user_id = user_id |
| disconnected = set() |
| |
| for websocket in self.user_connections[user_id].copy(): |
| try: |
| await websocket.send_text(message.json()) |
| except Exception as e: |
| logger.error(f"Failed to send message to user {user_id}: {e}") |
| disconnected.add(websocket) |
| |
| |
| for websocket in disconnected: |
| self.disconnect(websocket) |
| |
| async def broadcast_to_all(self, message: WebSocketMessage): |
| """Broadcast message to all connected users""" |
| disconnected = set() |
| |
| for websocket in self.notification_connections.copy(): |
| try: |
| await websocket.send_text(message.json()) |
| except Exception as e: |
| logger.error(f"Failed to broadcast message: {e}") |
| disconnected.add(websocket) |
| |
| |
| for websocket in disconnected: |
| self.disconnect(websocket) |
| |
| async def subscribe_to_investigation(self, websocket: WebSocket, investigation_id: str): |
| """Subscribe WebSocket to investigation updates""" |
| if investigation_id not in self.investigation_connections: |
| self.investigation_connections[investigation_id] = set() |
| |
| self.investigation_connections[investigation_id].add(websocket) |
| |
| await self.send_personal_message(websocket, WebSocketMessage( |
| type="subscribed_to_investigation", |
| data={ |
| "investigation_id": investigation_id, |
| "message": f"Subscribed to investigation {investigation_id}" |
| } |
| )) |
| |
| async def unsubscribe_from_investigation(self, websocket: WebSocket, investigation_id: str): |
| """Unsubscribe WebSocket from investigation updates""" |
| if investigation_id in self.investigation_connections: |
| self.investigation_connections[investigation_id].discard(websocket) |
| |
| if not self.investigation_connections[investigation_id]: |
| del self.investigation_connections[investigation_id] |
| |
| async def send_to_investigation(self, investigation_id: str, message: WebSocketMessage): |
| """Send message to all subscribers of an investigation""" |
| if investigation_id not in self.investigation_connections: |
| return |
| |
| disconnected = set() |
| |
| for websocket in self.investigation_connections[investigation_id].copy(): |
| try: |
| await websocket.send_text(message.json()) |
| except Exception as e: |
| logger.error(f"Failed to send investigation update: {e}") |
| disconnected.add(websocket) |
| |
| |
| for websocket in disconnected: |
| self.disconnect(websocket) |
| |
| async def subscribe_to_analysis(self, websocket: WebSocket, analysis_id: str): |
| """Subscribe WebSocket to analysis updates""" |
| if analysis_id not in self.analysis_connections: |
| self.analysis_connections[analysis_id] = set() |
| |
| self.analysis_connections[analysis_id].add(websocket) |
| |
| await self.send_personal_message(websocket, WebSocketMessage( |
| type="subscribed_to_analysis", |
| data={ |
| "analysis_id": analysis_id, |
| "message": f"Subscribed to analysis {analysis_id}" |
| } |
| )) |
| |
| async def send_to_analysis(self, analysis_id: str, message: WebSocketMessage): |
| """Send message to all subscribers of an analysis""" |
| if analysis_id not in self.analysis_connections: |
| return |
| |
| disconnected = set() |
| |
| for websocket in self.analysis_connections[analysis_id].copy(): |
| try: |
| await websocket.send_text(message.json()) |
| except Exception as e: |
| logger.error(f"Failed to send analysis update: {e}") |
| disconnected.add(websocket) |
| |
| |
| for websocket in disconnected: |
| self.disconnect(websocket) |
| |
| async def send_system_notification(self, notification_type: str, data: dict): |
| """Send system-wide notification""" |
| message = WebSocketMessage( |
| type="system_notification", |
| data={ |
| "notification_type": notification_type, |
| **data |
| } |
| ) |
| |
| await self.broadcast_to_all(message) |
| |
| def get_connection_stats(self) -> dict: |
| """Get WebSocket connection statistics""" |
| return { |
| "total_connections": len(self.connection_metadata), |
| "users_connected": len(self.user_connections), |
| "active_investigations": len(self.investigation_connections), |
| "active_analyses": len(self.analysis_connections), |
| "notification_subscribers": len(self.notification_connections) |
| } |
| |
| async def ping_all_connections(self): |
| """Send ping to all connections to keep them alive""" |
| ping_message = WebSocketMessage( |
| type="ping", |
| data={"timestamp": datetime.utcnow().isoformat()} |
| ) |
| |
| disconnected = set() |
| |
| for websocket in list(self.connection_metadata.keys()): |
| try: |
| await websocket.send_text(ping_message.json()) |
| self.connection_metadata[websocket]['last_ping'] = datetime.utcnow() |
| except Exception: |
| disconnected.add(websocket) |
| |
| |
| for websocket in disconnected: |
| self.disconnect(websocket) |
|
|
| |
| connection_manager = ConnectionManager() |
|
|
| class WebSocketHandler: |
| """Handles WebSocket message processing""" |
| |
| def __init__(self, connection_manager: ConnectionManager): |
| self.connection_manager = connection_manager |
| |
| async def handle_message(self, websocket: WebSocket, message: dict): |
| """Process incoming WebSocket message""" |
| message_type = message.get('type') |
| data = message.get('data', {}) |
| |
| try: |
| if message_type == "subscribe_investigation": |
| investigation_id = data.get('investigation_id') |
| if investigation_id: |
| await self.connection_manager.subscribe_to_investigation(websocket, investigation_id) |
| |
| elif message_type == "unsubscribe_investigation": |
| investigation_id = data.get('investigation_id') |
| if investigation_id: |
| await self.connection_manager.unsubscribe_from_investigation(websocket, investigation_id) |
| |
| elif message_type == "subscribe_analysis": |
| analysis_id = data.get('analysis_id') |
| if analysis_id: |
| await self.connection_manager.subscribe_to_analysis(websocket, analysis_id) |
| |
| elif message_type == "pong": |
| |
| if websocket in self.connection_manager.connection_metadata: |
| self.connection_manager.connection_metadata[websocket]['last_ping'] = datetime.utcnow() |
| |
| else: |
| logger.warning(f"Unknown WebSocket message type: {message_type}") |
| |
| except Exception as e: |
| logger.error(f"Error handling WebSocket message: {e}") |
| |
| error_message = WebSocketMessage( |
| type="error", |
| data={ |
| "message": f"Failed to process message: {str(e)}", |
| "original_type": message_type |
| } |
| ) |
| |
| await self.connection_manager.send_personal_message(websocket, error_message) |
|
|
| |
| websocket_handler = WebSocketHandler(connection_manager) |
|
|
| |
| async def connection_maintenance_task(): |
| """Background task to maintain WebSocket connections""" |
| while True: |
| try: |
| await connection_manager.ping_all_connections() |
| await asyncio.sleep(30) |
| except Exception as e: |
| logger.error(f"Error in connection maintenance: {e}") |
| await asyncio.sleep(60) |