anderson-ufrj
refactor(performance): replace all json imports with json_utils
9730fbc
"""
Message queue service for async processing.
This module implements a distributed task queue using Redis
for background processing and async operations.
"""
import asyncio
from typing import Dict, Any, Optional, Callable, List, Union
from datetime import datetime, timedelta
import uuid
from enum import Enum
from src.core import json_utils
from dataclasses import dataclass, asdict
import time
import redis.asyncio as redis
from pydantic import BaseModel
from src.core import get_logger, settings
from src.core.json_utils import dumps, loads
logger = get_logger(__name__)
class TaskStatus(str, Enum):
"""Task execution status."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
RETRY = "retry"
CANCELLED = "cancelled"
class TaskPriority(str, Enum):
"""Task priority levels."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
@dataclass
class Task:
"""Task definition."""
id: str
queue: str
task_type: str
payload: Dict[str, Any]
priority: TaskPriority
status: TaskStatus
created_at: datetime
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
max_retries: int = 3
retry_count: int = 0
error: Optional[str] = None
result: Optional[Any] = None
metadata: Optional[Dict[str, Any]] = None
@classmethod
def create(
cls,
queue: str,
task_type: str,
payload: Dict[str, Any],
priority: TaskPriority = TaskPriority.MEDIUM,
scheduled_at: Optional[datetime] = None,
max_retries: int = 3,
metadata: Optional[Dict[str, Any]] = None
) -> "Task":
"""Create a new task."""
return cls(
id=str(uuid.uuid4()),
queue=queue,
task_type=task_type,
payload=payload,
priority=priority,
status=TaskStatus.PENDING,
created_at=datetime.utcnow(),
scheduled_at=scheduled_at,
max_retries=max_retries,
metadata=metadata or {}
)
def to_dict(self) -> Dict[str, Any]:
"""Convert task to dictionary."""
return {
"id": self.id,
"queue": self.queue,
"task_type": self.task_type,
"payload": self.payload,
"priority": self.priority.value,
"status": self.status.value,
"created_at": self.created_at.isoformat(),
"scheduled_at": self.scheduled_at.isoformat() if self.scheduled_at else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"max_retries": self.max_retries,
"retry_count": self.retry_count,
"error": self.error,
"result": self.result,
"metadata": self.metadata
}
class TaskHandler:
"""Base class for task handlers."""
def __init__(self, task_types: List[str]):
"""
Initialize task handler.
Args:
task_types: List of task types this handler can process
"""
self.task_types = task_types
self.logger = get_logger(self.__class__.__name__)
async def handle(self, task: Task) -> Any:
"""
Handle a task.
Args:
task: Task to handle
Returns:
Task result
"""
raise NotImplementedError("Subclasses must implement handle()")
def can_handle(self, task_type: str) -> bool:
"""Check if this handler can handle the task type."""
return task_type in self.task_types
class QueueService:
"""
Distributed task queue service using Redis.
Features:
- Multiple queue support
- Priority-based processing
- Scheduled tasks
- Retry mechanism with exponential backoff
- Dead letter queue
- Task monitoring and metrics
"""
def __init__(
self,
redis_client: redis.Redis,
queue_prefix: str = "queue",
worker_name: Optional[str] = None,
max_concurrent_tasks: int = 10
):
"""
Initialize queue service.
Args:
redis_client: Redis async client
queue_prefix: Prefix for queue names
worker_name: Unique worker name
max_concurrent_tasks: Maximum concurrent tasks per worker
"""
self.redis = redis_client
self.queue_prefix = queue_prefix
self.worker_name = worker_name or f"worker-{uuid.uuid4().hex[:8]}"
self.max_concurrent_tasks = max_concurrent_tasks
# Task handlers
self._handlers: Dict[str, TaskHandler] = {}
# Running tasks
self._running_tasks: Dict[str, asyncio.Task] = {}
# Worker state
self._running = False
self._worker_task: Optional[asyncio.Task] = None
# Statistics
self._stats = {
"tasks_processed": 0,
"tasks_succeeded": 0,
"tasks_failed": 0,
"tasks_retried": 0,
"total_processing_time_ms": 0.0
}
def _get_queue_name(self, queue: str) -> str:
"""Get Redis queue name."""
return f"{self.queue_prefix}:{queue}"
def _get_priority_score(self, priority: TaskPriority) -> float:
"""Get priority score for Redis sorted set."""
scores = {
TaskPriority.LOW: 1.0,
TaskPriority.MEDIUM: 2.0,
TaskPriority.HIGH: 3.0,
TaskPriority.CRITICAL: 4.0
}
return scores.get(priority, 1.0)
async def enqueue(
self,
queue: str,
task_type: str,
payload: Dict[str, Any],
priority: TaskPriority = TaskPriority.MEDIUM,
delay: Optional[timedelta] = None,
max_retries: int = 3,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""
Enqueue a task for processing.
Args:
queue: Queue name
task_type: Type of task
payload: Task payload
priority: Task priority
delay: Delay before execution
max_retries: Maximum retry attempts
metadata: Additional metadata
Returns:
Task ID
"""
# Create task
scheduled_at = datetime.utcnow() + delay if delay else None
task = Task.create(
queue=queue,
task_type=task_type,
payload=payload,
priority=priority,
scheduled_at=scheduled_at,
max_retries=max_retries,
metadata=metadata
)
# Store task data
await self.redis.hset(
f"task:{task.id}",
mapping={
"data": dumps(task.to_dict()),
"created_at": task.created_at.isoformat()
}
)
# Add to queue
queue_name = self._get_queue_name(queue)
if scheduled_at:
# Add to delayed queue (sorted by timestamp)
await self.redis.zadd(
f"{queue_name}:delayed",
{task.id: scheduled_at.timestamp()}
)
else:
# Add to priority queue
priority_score = self._get_priority_score(priority)
timestamp_score = time.time() / 1000000 # microsecond precision
# Combine priority and timestamp (priority * 1M + timestamp)
final_score = priority_score * 1000000 + timestamp_score
await self.redis.zadd(
queue_name,
{task.id: final_score}
)
logger.info(f"Enqueued task {task.id} in queue {queue}")
return task.id
async def get_task(self, task_id: str) -> Optional[Task]:
"""Get task by ID."""
task_data = await self.redis.hget(f"task:{task_id}", "data")
if not task_data:
return None
data = loads(task_data)
# Reconstruct task
task = Task(
id=data["id"],
queue=data["queue"],
task_type=data["task_type"],
payload=data["payload"],
priority=TaskPriority(data["priority"]),
status=TaskStatus(data["status"]),
created_at=datetime.fromisoformat(data["created_at"]),
scheduled_at=datetime.fromisoformat(data["scheduled_at"]) if data["scheduled_at"] else None,
started_at=datetime.fromisoformat(data["started_at"]) if data["started_at"] else None,
completed_at=datetime.fromisoformat(data["completed_at"]) if data["completed_at"] else None,
max_retries=data["max_retries"],
retry_count=data["retry_count"],
error=data["error"],
result=data["result"],
metadata=data["metadata"]
)
return task
async def cancel_task(self, task_id: str) -> bool:
"""Cancel a pending task."""
task = await self.get_task(task_id)
if not task or task.status not in [TaskStatus.PENDING, TaskStatus.RUNNING]:
return False
# Update task status
task.status = TaskStatus.CANCELLED
task.completed_at = datetime.utcnow()
await self._update_task(task)
# Remove from queues
await self.redis.zrem(self._get_queue_name(task.queue), task_id)
await self.redis.zrem(f"{self._get_queue_name(task.queue)}:delayed", task_id)
logger.info(f"Cancelled task {task_id}")
return True
def register_handler(self, handler: TaskHandler):
"""
Register a task handler.
Args:
handler: Task handler to register
"""
for task_type in handler.task_types:
self._handlers[task_type] = handler
logger.info(f"Registered handler {handler.__class__.__name__} for {task_type}")
async def start_worker(self, queues: List[str]):
"""
Start worker to process tasks.
Args:
queues: List of queues to process
"""
if self._running:
logger.warning("Worker already running")
return
self._running = True
self._worker_task = asyncio.create_task(
self._worker_loop(queues)
)
logger.info(f"Worker {self.worker_name} started for queues: {queues}")
async def stop_worker(self):
"""Stop worker."""
self._running = False
if self._worker_task:
self._worker_task.cancel()
try:
await self._worker_task
except asyncio.CancelledError:
pass
# Cancel running tasks
for task in self._running_tasks.values():
task.cancel()
await asyncio.gather(*self._running_tasks.values(), return_exceptions=True)
self._running_tasks.clear()
logger.info(f"Worker {self.worker_name} stopped")
async def _worker_loop(self, queues: List[str]):
"""Main worker loop."""
while self._running:
try:
# Check for delayed tasks that are ready
await self._process_delayed_tasks(queues)
# Process pending tasks
if len(self._running_tasks) < self.max_concurrent_tasks:
task = await self._get_next_task(queues)
if task:
# Start processing task
task_coro = asyncio.create_task(
self._process_task(task)
)
self._running_tasks[task.id] = task_coro
# Clean up completed tasks
await self._cleanup_completed_tasks()
else:
# No tasks available, wait a bit
await asyncio.sleep(0.1)
else:
# Max concurrent tasks reached, wait for completion
await asyncio.sleep(0.1)
await self._cleanup_completed_tasks()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Worker loop error: {e}")
await asyncio.sleep(1)
async def _process_delayed_tasks(self, queues: List[str]):
"""Move delayed tasks that are ready to main queues."""
now = datetime.utcnow().timestamp()
for queue in queues:
queue_name = self._get_queue_name(queue)
delayed_queue = f"{queue_name}:delayed"
# Get tasks ready for execution
ready_tasks = await self.redis.zrangebyscore(
delayed_queue,
0,
now,
withscores=True
)
for task_id, _ in ready_tasks:
# Move to main queue
task = await self.get_task(task_id)
if task:
priority_score = self._get_priority_score(task.priority)
timestamp_score = time.time() / 1000000
final_score = priority_score * 1000000 + timestamp_score
await self.redis.zadd(queue_name, {task_id: final_score})
await self.redis.zrem(delayed_queue, task_id)
async def _get_next_task(self, queues: List[str]) -> Optional[Task]:
"""Get next task from queues (highest priority first)."""
for queue in queues:
queue_name = self._get_queue_name(queue)
# Get highest priority task
result = await self.redis.zpopmax(queue_name, count=1)
if result:
task_id, _ = result[0]
task = await self.get_task(task_id)
if task and task.status == TaskStatus.PENDING:
return task
return None
async def _process_task(self, task: Task):
"""Process a single task."""
start_time = datetime.utcnow()
try:
# Update task status
task.status = TaskStatus.RUNNING
task.started_at = start_time
await self._update_task(task)
# Find handler
handler = self._handlers.get(task.task_type)
if not handler:
raise ValueError(f"No handler found for task type: {task.task_type}")
# Execute task
result = await handler.handle(task)
# Update task with result
task.status = TaskStatus.COMPLETED
task.completed_at = datetime.utcnow()
task.result = result
await self._update_task(task)
# Update statistics
processing_time = (task.completed_at - start_time).total_seconds() * 1000
self._stats["tasks_processed"] += 1
self._stats["tasks_succeeded"] += 1
self._stats["total_processing_time_ms"] += processing_time
logger.info(f"Task {task.id} completed successfully")
except Exception as e:
logger.error(f"Task {task.id} failed: {e}")
# Update task with error
task.error = str(e)
task.completed_at = datetime.utcnow()
# Check if we should retry
if task.retry_count < task.max_retries:
# Schedule retry with exponential backoff
delay_seconds = 2 ** task.retry_count
retry_at = datetime.utcnow() + timedelta(seconds=delay_seconds)
task.status = TaskStatus.RETRY
task.retry_count += 1
task.scheduled_at = retry_at
# Add to delayed queue
queue_name = self._get_queue_name(task.queue)
await self.redis.zadd(
f"{queue_name}:delayed",
{task.id: retry_at.timestamp()}
)
self._stats["tasks_retried"] += 1
logger.info(f"Task {task.id} scheduled for retry {task.retry_count}")
else:
# Max retries exceeded, move to dead letter queue
task.status = TaskStatus.FAILED
await self.redis.zadd(
f"{self.queue_prefix}:dlq",
{task.id: time.time()}
)
self._stats["tasks_failed"] += 1
logger.error(f"Task {task.id} moved to DLQ after {task.max_retries} retries")
await self._update_task(task)
finally:
# Remove from running tasks
if task.id in self._running_tasks:
del self._running_tasks[task.id]
async def _cleanup_completed_tasks(self):
"""Clean up completed task coroutines."""
completed = []
for task_id, task_coro in self._running_tasks.items():
if task_coro.done():
completed.append(task_id)
for task_id in completed:
del self._running_tasks[task_id]
async def _update_task(self, task: Task):
"""Update task in Redis."""
await self.redis.hset(
f"task:{task.id}",
mapping={
"data": dumps(task.to_dict()),
"updated_at": datetime.utcnow().isoformat()
}
)
def get_stats(self) -> Dict[str, Any]:
"""Get queue service statistics."""
return {
**self._stats,
"worker_name": self.worker_name,
"running_tasks": len(self._running_tasks),
"handlers_registered": len(self._handlers),
"avg_processing_time_ms": (
self._stats["total_processing_time_ms"] / self._stats["tasks_succeeded"]
if self._stats["tasks_succeeded"] > 0 else 0
)
}
# Example task handlers
class InvestigationTaskHandler(TaskHandler):
"""Handler for investigation tasks."""
def __init__(self):
super().__init__(["create_investigation", "analyze_contract", "detect_anomaly"])
async def handle(self, task: Task) -> Any:
"""Handle investigation tasks."""
if task.task_type == "create_investigation":
# Simulate investigation creation
await asyncio.sleep(2) # Simulate processing time
return {
"investigation_id": task.payload.get("investigation_id"),
"status": "completed",
"findings": ["Sample finding 1", "Sample finding 2"]
}
elif task.task_type == "analyze_contract":
# Simulate contract analysis
await asyncio.sleep(1)
return {
"contract_id": task.payload.get("contract_id"),
"analysis": "Contract appears normal",
"score": 0.85
}
elif task.task_type == "detect_anomaly":
# Simulate anomaly detection
await asyncio.sleep(0.5)
return {
"anomalies_found": 2,
"severity": "medium",
"details": ["Price anomaly", "Vendor concentration"]
}
# Global queue service instance
_queue_service: Optional[QueueService] = None
async def get_queue_service() -> QueueService:
"""Get or create the global queue service instance."""
global _queue_service
if _queue_service is None:
# Initialize Redis client
redis_client = redis.from_url(
settings.redis_url,
decode_responses=True
)
_queue_service = QueueService(redis_client)
# Register default handlers
_queue_service.register_handler(InvestigationTaskHandler())
return _queue_service