|
|
""" |
|
|
Parallel processing utilities for multi-agent system. |
|
|
|
|
|
This module provides utilities for executing multiple agent tasks |
|
|
in parallel, significantly improving investigation speed. |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
from typing import List, Dict, Any, Optional, Tuple, Callable, Union |
|
|
from datetime import datetime |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
import traceback |
|
|
|
|
|
from src.core import get_logger, AgentStatus |
|
|
from src.agents.deodoro import BaseAgent, AgentContext, AgentMessage, AgentResponse |
|
|
from src.agents.agent_pool import get_agent_pool |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class ParallelStrategy(str, Enum): |
|
|
"""Strategies for parallel execution.""" |
|
|
ALL_SUCCEED = "all_succeed" |
|
|
BEST_EFFORT = "best_effort" |
|
|
FIRST_SUCCESS = "first_success" |
|
|
MAJORITY_VOTE = "majority_vote" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ParallelTask: |
|
|
"""Task to be executed in parallel.""" |
|
|
agent_type: type[BaseAgent] |
|
|
message: AgentMessage |
|
|
timeout: Optional[float] = None |
|
|
weight: float = 1.0 |
|
|
fallback: Optional[Callable] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ParallelResult: |
|
|
"""Result from parallel execution.""" |
|
|
task_id: str |
|
|
agent_name: str |
|
|
success: bool |
|
|
result: Optional[AgentResponse] = None |
|
|
error: Optional[str] = None |
|
|
execution_time: float = 0.0 |
|
|
metadata: Dict[str, Any] = None |
|
|
|
|
|
|
|
|
class ParallelAgentProcessor: |
|
|
""" |
|
|
Processor for executing multiple agent tasks in parallel. |
|
|
|
|
|
Features: |
|
|
- Concurrent execution with configurable strategies |
|
|
- Automatic retry and fallback handling |
|
|
- Performance monitoring and optimization |
|
|
- Result aggregation and voting |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
max_concurrent: int = 5, |
|
|
default_timeout: float = 30.0, |
|
|
enable_pooling: bool = True |
|
|
): |
|
|
""" |
|
|
Initialize parallel processor. |
|
|
|
|
|
Args: |
|
|
max_concurrent: Maximum concurrent tasks |
|
|
default_timeout: Default timeout per task |
|
|
enable_pooling: Use agent pooling |
|
|
""" |
|
|
self.max_concurrent = max_concurrent |
|
|
self.default_timeout = default_timeout |
|
|
self.enable_pooling = enable_pooling |
|
|
self._semaphore = asyncio.Semaphore(max_concurrent) |
|
|
self._stats = { |
|
|
"total_tasks": 0, |
|
|
"successful_tasks": 0, |
|
|
"failed_tasks": 0, |
|
|
"total_time": 0.0 |
|
|
} |
|
|
|
|
|
async def execute_parallel( |
|
|
self, |
|
|
tasks: List[ParallelTask], |
|
|
context: AgentContext, |
|
|
strategy: ParallelStrategy = ParallelStrategy.BEST_EFFORT |
|
|
) -> List[ParallelResult]: |
|
|
""" |
|
|
Execute multiple agent tasks in parallel. |
|
|
|
|
|
Args: |
|
|
tasks: List of tasks to execute |
|
|
context: Agent execution context |
|
|
strategy: Execution strategy |
|
|
|
|
|
Returns: |
|
|
List of results |
|
|
""" |
|
|
start_time = datetime.now() |
|
|
self._stats["total_tasks"] += len(tasks) |
|
|
|
|
|
logger.info(f"Starting parallel execution of {len(tasks)} tasks with strategy {strategy}") |
|
|
|
|
|
|
|
|
coroutines = [] |
|
|
for i, task in enumerate(tasks): |
|
|
task_id = f"{context.investigation_id}_{i}" |
|
|
coro = self._execute_single_task(task_id, task, context) |
|
|
coroutines.append(coro) |
|
|
|
|
|
|
|
|
if strategy == ParallelStrategy.FIRST_SUCCESS: |
|
|
results = await self._execute_first_success(coroutines) |
|
|
else: |
|
|
results = await self._execute_all_tasks(coroutines) |
|
|
|
|
|
|
|
|
final_results = self._process_results(results, strategy) |
|
|
|
|
|
|
|
|
execution_time = (datetime.now() - start_time).total_seconds() |
|
|
self._stats["total_time"] += execution_time |
|
|
self._stats["successful_tasks"] += sum(1 for r in final_results if r.success) |
|
|
self._stats["failed_tasks"] += sum(1 for r in final_results if not r.success) |
|
|
|
|
|
logger.info( |
|
|
f"Parallel execution completed: {len(final_results)} results, " |
|
|
f"{sum(1 for r in final_results if r.success)} successful, " |
|
|
f"time: {execution_time:.2f}s" |
|
|
) |
|
|
|
|
|
return final_results |
|
|
|
|
|
async def _execute_single_task( |
|
|
self, |
|
|
task_id: str, |
|
|
task: ParallelTask, |
|
|
context: AgentContext |
|
|
) -> ParallelResult: |
|
|
"""Execute a single task with error handling.""" |
|
|
async with self._semaphore: |
|
|
start_time = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
if self.enable_pooling: |
|
|
pool = await get_agent_pool() |
|
|
async with pool.acquire(task.agent_type, context) as agent: |
|
|
result = await self._run_agent_task(agent, task, context) |
|
|
else: |
|
|
agent = task.agent_type() |
|
|
result = await self._run_agent_task(agent, task, context) |
|
|
|
|
|
execution_time = (datetime.now() - start_time).total_seconds() |
|
|
|
|
|
return ParallelResult( |
|
|
task_id=task_id, |
|
|
agent_name=agent.name, |
|
|
success=result.status == AgentStatus.COMPLETED, |
|
|
result=result, |
|
|
execution_time=execution_time, |
|
|
metadata={"task_type": task.agent_type.__name__} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Task {task_id} failed: {str(e)}\n{traceback.format_exc()}") |
|
|
|
|
|
|
|
|
if task.fallback: |
|
|
try: |
|
|
fallback_result = await task.fallback() |
|
|
return ParallelResult( |
|
|
task_id=task_id, |
|
|
agent_name="fallback", |
|
|
success=True, |
|
|
result=fallback_result, |
|
|
execution_time=(datetime.now() - start_time).total_seconds(), |
|
|
metadata={"used_fallback": True} |
|
|
) |
|
|
except Exception as fb_error: |
|
|
logger.error(f"Fallback also failed: {fb_error}") |
|
|
|
|
|
return ParallelResult( |
|
|
task_id=task_id, |
|
|
agent_name=task.agent_type.__name__, |
|
|
success=False, |
|
|
error=str(e), |
|
|
execution_time=(datetime.now() - start_time).total_seconds() |
|
|
) |
|
|
|
|
|
async def _run_agent_task( |
|
|
self, |
|
|
agent: BaseAgent, |
|
|
task: ParallelTask, |
|
|
context: AgentContext |
|
|
) -> AgentResponse: |
|
|
"""Run agent task with timeout.""" |
|
|
timeout = task.timeout or self.default_timeout |
|
|
|
|
|
try: |
|
|
return await asyncio.wait_for( |
|
|
agent.process(task.message, context), |
|
|
timeout=timeout |
|
|
) |
|
|
except asyncio.TimeoutError: |
|
|
logger.error(f"Agent {agent.name} timed out after {timeout}s") |
|
|
return AgentResponse( |
|
|
agent_name=agent.name, |
|
|
status=AgentStatus.ERROR, |
|
|
error=f"Timeout after {timeout} seconds" |
|
|
) |
|
|
|
|
|
async def _execute_all_tasks( |
|
|
self, |
|
|
coroutines: List[asyncio.Task] |
|
|
) -> List[ParallelResult]: |
|
|
"""Execute all tasks and gather results.""" |
|
|
results = await asyncio.gather(*coroutines, return_exceptions=True) |
|
|
|
|
|
|
|
|
final_results = [] |
|
|
for i, result in enumerate(results): |
|
|
if isinstance(result, Exception): |
|
|
final_results.append(ParallelResult( |
|
|
task_id=f"task_{i}", |
|
|
agent_name="unknown", |
|
|
success=False, |
|
|
error=str(result) |
|
|
)) |
|
|
else: |
|
|
final_results.append(result) |
|
|
|
|
|
return final_results |
|
|
|
|
|
async def _execute_first_success( |
|
|
self, |
|
|
coroutines: List[asyncio.Task] |
|
|
) -> List[ParallelResult]: |
|
|
"""Execute tasks until first success.""" |
|
|
pending = set(coroutines) |
|
|
results = [] |
|
|
|
|
|
while pending: |
|
|
done, pending = await asyncio.wait( |
|
|
pending, |
|
|
return_when=asyncio.FIRST_COMPLETED |
|
|
) |
|
|
|
|
|
for task in done: |
|
|
try: |
|
|
result = await task |
|
|
results.append(result) |
|
|
|
|
|
if result.success: |
|
|
|
|
|
for p in pending: |
|
|
p.cancel() |
|
|
return results |
|
|
except Exception as e: |
|
|
logger.error(f"Task failed: {e}") |
|
|
|
|
|
return results |
|
|
|
|
|
def _process_results( |
|
|
self, |
|
|
results: List[ParallelResult], |
|
|
strategy: ParallelStrategy |
|
|
) -> List[ParallelResult]: |
|
|
"""Process results based on strategy.""" |
|
|
if strategy == ParallelStrategy.ALL_SUCCEED: |
|
|
|
|
|
if not all(r.success for r in results): |
|
|
logger.warning("Not all tasks succeeded with ALL_SUCCEED strategy") |
|
|
|
|
|
elif strategy == ParallelStrategy.MAJORITY_VOTE: |
|
|
|
|
|
successes = sum(1 for r in results if r.success) |
|
|
if successes < len(results) / 2: |
|
|
logger.warning("Majority vote failed") |
|
|
|
|
|
return results |
|
|
|
|
|
def aggregate_results( |
|
|
self, |
|
|
results: List[ParallelResult], |
|
|
aggregation_key: str = "findings" |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Aggregate results from multiple agents. |
|
|
|
|
|
Args: |
|
|
results: List of parallel results |
|
|
aggregation_key: Key to aggregate from results |
|
|
|
|
|
Returns: |
|
|
Aggregated data |
|
|
""" |
|
|
aggregated = { |
|
|
"total_tasks": len(results), |
|
|
"successful_tasks": sum(1 for r in results if r.success), |
|
|
"failed_tasks": sum(1 for r in results if not r.success), |
|
|
"total_execution_time": sum(r.execution_time for r in results), |
|
|
"results_by_agent": {}, |
|
|
aggregation_key: [] |
|
|
} |
|
|
|
|
|
|
|
|
for result in results: |
|
|
if result.success and result.result: |
|
|
agent_name = result.agent_name |
|
|
|
|
|
|
|
|
if agent_name not in aggregated["results_by_agent"]: |
|
|
aggregated["results_by_agent"][agent_name] = [] |
|
|
|
|
|
aggregated["results_by_agent"][agent_name].append(result.result) |
|
|
|
|
|
|
|
|
if hasattr(result.result, 'result') and isinstance(result.result.result, dict): |
|
|
data = result.result.result.get(aggregation_key, []) |
|
|
if isinstance(data, list): |
|
|
aggregated[aggregation_key].extend(data) |
|
|
else: |
|
|
aggregated[aggregation_key].append(data) |
|
|
|
|
|
return aggregated |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get processor statistics.""" |
|
|
return { |
|
|
**self._stats, |
|
|
"avg_success_rate": ( |
|
|
self._stats["successful_tasks"] / self._stats["total_tasks"] |
|
|
if self._stats["total_tasks"] > 0 else 0 |
|
|
), |
|
|
"avg_execution_time": ( |
|
|
self._stats["total_time"] / self._stats["total_tasks"] |
|
|
if self._stats["total_tasks"] > 0 else 0 |
|
|
) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
parallel_processor = ParallelAgentProcessor() |