File size: 11,774 Bytes
1af9523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
"""
Batch API endpoints for processing multiple requests efficiently.

This module provides endpoints for batching multiple operations,
reducing network overhead and improving throughput.
"""

from typing import List, Dict, Any, Optional, Union
from datetime import datetime
import asyncio
import uuid

from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from pydantic import BaseModel, Field, validator

from src.core import get_logger
from src.api.dependencies import get_current_user
from src.agents import get_agent_pool, MasterAgent
from src.agents.parallel_processor import (
    ParallelAgentProcessor,
    ParallelTask,
    ParallelStrategy
)
from src.services.chat_service_with_cache import chat_service

logger = get_logger(__name__)
router = APIRouter(prefix="/api/v1/batch", tags=["batch"])


class BatchOperation(BaseModel):
    """Single operation in a batch request."""
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    operation: str = Field(..., description="Operation type")
    data: Dict[str, Any] = Field(..., description="Operation data")
    priority: int = Field(default=5, ge=1, le=10)
    timeout: Optional[float] = Field(default=30.0, ge=1.0, le=300.0)
    
    @validator('operation')
    def validate_operation(cls, v):
        allowed = ["chat", "investigate", "analyze", "search"]
        if v not in allowed:
            raise ValueError(f"Operation must be one of {allowed}")
        return v


class BatchRequest(BaseModel):
    """Batch request containing multiple operations."""
    operations: List[BatchOperation] = Field(..., max_items=100)
    strategy: ParallelStrategy = Field(
        default=ParallelStrategy.BEST_EFFORT,
        description="Execution strategy"
    )
    max_concurrent: int = Field(default=5, ge=1, le=20)
    return_partial: bool = Field(
        default=True,
        description="Return partial results if some operations fail"
    )


class BatchOperationResult(BaseModel):
    """Result of a single batch operation."""
    id: str
    operation: str
    success: bool
    result: Optional[Dict[str, Any]] = None
    error: Optional[str] = None
    execution_time: float
    timestamp: datetime = Field(default_factory=datetime.utcnow)


class BatchResponse(BaseModel):
    """Response from batch processing."""
    batch_id: str
    total_operations: int
    successful_operations: int
    failed_operations: int
    results: List[BatchOperationResult]
    total_execution_time: float
    metadata: Dict[str, Any] = Field(default_factory=dict)


# Batch processor instance
batch_processor = ParallelAgentProcessor(max_concurrent=10)


@router.post("/process", response_model=BatchResponse)
async def process_batch(
    request: BatchRequest,
    background_tasks: BackgroundTasks,
    current_user=Depends(get_current_user)
) -> BatchResponse:
    """
    Process multiple operations in a single batch request.
    
    Supports operations:
    - chat: Chat completions
    - investigate: Full investigations
    - analyze: Data analysis
    - search: Search operations
    
    Operations are executed in parallel when possible.
    """
    batch_id = str(uuid.uuid4())
    start_time = datetime.utcnow()
    
    logger.info(
        f"Processing batch {batch_id} with {len(request.operations)} operations "
        f"for user {current_user.id}"
    )
    
    # Sort operations by priority
    sorted_ops = sorted(
        request.operations,
        key=lambda x: x.priority,
        reverse=True
    )
    
    # Process operations in parallel
    tasks = []
    for op in sorted_ops:
        task = asyncio.create_task(
            _process_single_operation(
                op,
                current_user,
                batch_processor.max_concurrent
            )
        )
        tasks.append(task)
    
    # Execute with specified concurrency
    results = []
    if request.strategy == ParallelStrategy.FIRST_SUCCESS:
        # Process until first success
        for task in asyncio.as_completed(tasks):
            result = await task
            results.append(result)
            if result.success:
                # Cancel remaining tasks
                for t in tasks:
                    if not t.done():
                        t.cancel()
                break
    else:
        # Process all tasks
        batch_results = await asyncio.gather(*tasks, return_exceptions=True)
        
        for i, result in enumerate(batch_results):
            if isinstance(result, Exception):
                results.append(BatchOperationResult(
                    id=sorted_ops[i].id,
                    operation=sorted_ops[i].operation,
                    success=False,
                    error=str(result),
                    execution_time=0.0
                ))
            else:
                results.append(result)
    
    # Calculate statistics
    total_time = (datetime.utcnow() - start_time).total_seconds()
    successful = sum(1 for r in results if r.success)
    failed = len(results) - successful
    
    # Background cleanup if needed
    background_tasks.add_task(_cleanup_batch_resources, batch_id)
    
    return BatchResponse(
        batch_id=batch_id,
        total_operations=len(request.operations),
        successful_operations=successful,
        failed_operations=failed,
        results=results,
        total_execution_time=total_time,
        metadata={
            "strategy": request.strategy,
            "user_id": current_user.id,
            "avg_execution_time": total_time / len(results) if results else 0
        }
    )


async def _process_single_operation(
    operation: BatchOperation,
    user: Any,
    semaphore_limit: int
) -> BatchOperationResult:
    """Process a single operation with error handling."""
    start_time = datetime.utcnow()
    
    try:
        # Route to appropriate handler
        if operation.operation == "chat":
            result = await _handle_chat_operation(operation.data, user)
        elif operation.operation == "investigate":
            result = await _handle_investigate_operation(operation.data, user)
        elif operation.operation == "analyze":
            result = await _handle_analyze_operation(operation.data, user)
        elif operation.operation == "search":
            result = await _handle_search_operation(operation.data, user)
        else:
            raise ValueError(f"Unknown operation: {operation.operation}")
        
        execution_time = (datetime.utcnow() - start_time).total_seconds()
        
        return BatchOperationResult(
            id=operation.id,
            operation=operation.operation,
            success=True,
            result=result,
            execution_time=execution_time
        )
        
    except Exception as e:
        logger.error(f"Batch operation {operation.id} failed: {str(e)}")
        
        execution_time = (datetime.utcnow() - start_time).total_seconds()
        
        return BatchOperationResult(
            id=operation.id,
            operation=operation.operation,
            success=False,
            error=str(e),
            execution_time=execution_time
        )


async def _handle_chat_operation(data: Dict[str, Any], user: Any) -> Dict[str, Any]:
    """Handle chat operation."""
    message = data.get("message", "")
    session_id = data.get("session_id", str(uuid.uuid4()))
    
    # Get or create session
    session = await chat_service.get_or_create_session(session_id, user_id=user.id)
    
    # Process message
    response = await chat_service.process_message(
        session_id=session_id,
        message=message,
        user_id=user.id
    )
    
    return {
        "session_id": session_id,
        "response": response.message,
        "agent": response.agent_name,
        "confidence": response.confidence
    }


async def _handle_investigate_operation(data: Dict[str, Any], user: Any) -> Dict[str, Any]:
    """Handle investigation operation."""
    query = data.get("query", "")
    
    # Get agent pool and master agent
    pool = await get_agent_pool()
    
    # Create investigation context
    from src.agents.deodoro import AgentContext
    context = AgentContext(
        investigation_id=str(uuid.uuid4()),
        user_id=user.id,
        data_sources=data.get("data_sources", [])
    )
    
    # Execute investigation
    async with pool.acquire(MasterAgent, context) as master:
        result = await master._investigate({"query": query}, context)
    
    return {
        "investigation_id": result.investigation_id,
        "findings": result.findings,
        "confidence": result.confidence_score,
        "sources": result.sources,
        "explanation": result.explanation
    }


async def _handle_analyze_operation(data: Dict[str, Any], user: Any) -> Dict[str, Any]:
    """Handle analysis operation."""
    # Simplified for now - extend based on your analysis needs
    return {
        "status": "completed",
        "analysis_type": data.get("type", "general"),
        "results": {
            "summary": "Analysis completed successfully",
            "data": data.get("data", {})
        }
    }


async def _handle_search_operation(data: Dict[str, Any], user: Any) -> Dict[str, Any]:
    """Handle search operation."""
    query = data.get("query", "")
    filters = data.get("filters", {})
    
    # Simplified search - integrate with your search service
    return {
        "query": query,
        "results": [],
        "total": 0,
        "filters_applied": filters
    }


async def _cleanup_batch_resources(batch_id: str):
    """Cleanup any resources used by the batch."""
    # Add cleanup logic if needed
    logger.debug(f"Cleaning up resources for batch {batch_id}")


@router.get("/status/{batch_id}")
async def get_batch_status(
    batch_id: str,
    current_user=Depends(get_current_user)
) -> Dict[str, Any]:
    """
    Get the status of a batch operation.
    
    Note: This is a placeholder for async batch processing.
    Currently all batches are processed synchronously.
    """
    return {
        "batch_id": batch_id,
        "status": "completed",
        "message": "Batch operations are currently processed synchronously"
    }


@router.post("/validate", response_model=Dict[str, Any])
async def validate_batch(
    request: BatchRequest,
    current_user=Depends(get_current_user)
) -> Dict[str, Any]:
    """
    Validate a batch request without executing it.
    
    Useful for checking if operations are valid before submission.
    """
    validation_results = []
    
    for op in request.operations:
        is_valid = True
        errors = []
        
        # Validate operation type
        if op.operation not in ["chat", "investigate", "analyze", "search"]:
            is_valid = False
            errors.append(f"Unknown operation: {op.operation}")
        
        # Validate operation data
        if op.operation == "chat" and "message" not in op.data:
            is_valid = False
            errors.append("Chat operation requires 'message' field")
        elif op.operation == "investigate" and "query" not in op.data:
            is_valid = False
            errors.append("Investigate operation requires 'query' field")
        
        validation_results.append({
            "id": op.id,
            "operation": op.operation,
            "valid": is_valid,
            "errors": errors
        })
    
    total_valid = sum(1 for v in validation_results if v["valid"])
    
    return {
        "valid": total_valid == len(request.operations),
        "total_operations": len(request.operations),
        "valid_operations": total_valid,
        "invalid_operations": len(request.operations) - total_valid,
        "results": validation_results
    }