Commit
·
2e86bf2
1
Parent(s):
0ffcee6
enhance research mcp server to use more tools and refine output
Browse files- app.py +7 -0
- core/orchestrator.py +21 -5
- mcp_servers/research/config.py +6 -1
- mcp_servers/research/provider.py +1352 -43
- mcp_servers/research/tool_schemas.py +77 -6
- requirements.txt +2 -0
- ui/chat.py +80 -15
app.py
CHANGED
|
@@ -2,6 +2,13 @@ import gradio as gr
|
|
| 2 |
from ui.chat import build_app
|
| 3 |
import subprocess, pathlib, time, os
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
def maybe_start_ideation_server():
|
| 6 |
if os.environ.get("CELEBRATE_AI_IDEATION_ENDPOINT"):
|
| 7 |
return # using SSE/HTTP endpoint elsewhere
|
|
|
|
| 2 |
from ui.chat import build_app
|
| 3 |
import subprocess, pathlib, time, os
|
| 4 |
|
| 5 |
+
# Load .env file if present
|
| 6 |
+
try:
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
load_dotenv()
|
| 9 |
+
except ImportError:
|
| 10 |
+
pass # python-dotenv not installed
|
| 11 |
+
|
| 12 |
def maybe_start_ideation_server():
|
| 13 |
if os.environ.get("CELEBRATE_AI_IDEATION_ENDPOINT"):
|
| 14 |
return # using SSE/HTTP endpoint elsewhere
|
core/orchestrator.py
CHANGED
|
@@ -249,10 +249,13 @@ class MCPResearchClient:
|
|
| 249 |
def available(self) -> bool:
|
| 250 |
return bool(self._endpoint or self._command or DEFAULT_MCP_RESEARCH_COMMAND)
|
| 251 |
|
| 252 |
-
def generate(self, payload: RequestPayload) -> ResearchResponse:
|
| 253 |
-
return anyio.run(self._generate_async,
|
| 254 |
|
| 255 |
-
async def _generate_async(self, payload: RequestPayload) -> ResearchResponse:
|
|
|
|
|
|
|
|
|
|
| 256 |
arguments = {
|
| 257 |
"payload": {
|
| 258 |
"session_id": payload.session_id,
|
|
@@ -261,6 +264,19 @@ class MCPResearchClient:
|
|
| 261 |
"location": payload.celebration_profile.location,
|
| 262 |
"budget_tier": payload.celebration_profile.budget_tier,
|
| 263 |
"constraints": payload.celebration_profile.constraints,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
}
|
| 265 |
}
|
| 266 |
|
|
@@ -340,12 +356,12 @@ class ResearchOrchestrator:
|
|
| 340 |
)
|
| 341 |
self._mock_provider = MockResearchProvider()
|
| 342 |
|
| 343 |
-
def generate_plan(self, payload_dict: Dict[str, Any]) -> ResearchResponse:
|
| 344 |
payload = RequestPayload.model_validate(payload_dict)
|
| 345 |
|
| 346 |
if self._mcp_client.available:
|
| 347 |
try:
|
| 348 |
-
return self._mcp_client.generate(payload)
|
| 349 |
except Exception as exc:
|
| 350 |
return self._mock_provider.generate(payload, reason=f"MCP research call failed: {exc}")
|
| 351 |
|
|
|
|
| 249 |
def available(self) -> bool:
|
| 250 |
return bool(self._endpoint or self._command or DEFAULT_MCP_RESEARCH_COMMAND)
|
| 251 |
|
| 252 |
+
def generate(self, payload: RequestPayload, idea_context: dict | None = None) -> ResearchResponse:
|
| 253 |
+
return anyio.run(lambda: self._generate_async(payload, idea_context))
|
| 254 |
|
| 255 |
+
async def _generate_async(self, payload: RequestPayload, idea_context: dict | None = None) -> ResearchResponse:
|
| 256 |
+
# Build rich context for research agent
|
| 257 |
+
honoree = payload.celebration_profile.honoree_profile
|
| 258 |
+
|
| 259 |
arguments = {
|
| 260 |
"payload": {
|
| 261 |
"session_id": payload.session_id,
|
|
|
|
| 264 |
"location": payload.celebration_profile.location,
|
| 265 |
"budget_tier": payload.celebration_profile.budget_tier,
|
| 266 |
"constraints": payload.celebration_profile.constraints,
|
| 267 |
+
# New rich context fields
|
| 268 |
+
"occasion": payload.celebration_profile.occasion,
|
| 269 |
+
"event_date": payload.celebration_profile.date,
|
| 270 |
+
"honoree": {
|
| 271 |
+
"age_range": honoree.age_range,
|
| 272 |
+
"interests": honoree.interests,
|
| 273 |
+
"preferences": honoree.preferences,
|
| 274 |
+
"guest_count": honoree.guest_count,
|
| 275 |
+
},
|
| 276 |
+
# Pass full idea context if available
|
| 277 |
+
"idea": idea_context,
|
| 278 |
+
# Pass conversation history for context
|
| 279 |
+
"conversation_history": payload.interaction_context.history[-10:], # Last 10 messages
|
| 280 |
}
|
| 281 |
}
|
| 282 |
|
|
|
|
| 356 |
)
|
| 357 |
self._mock_provider = MockResearchProvider()
|
| 358 |
|
| 359 |
+
def generate_plan(self, payload_dict: Dict[str, Any], idea_context: dict | None = None) -> ResearchResponse:
|
| 360 |
payload = RequestPayload.model_validate(payload_dict)
|
| 361 |
|
| 362 |
if self._mcp_client.available:
|
| 363 |
try:
|
| 364 |
+
return self._mcp_client.generate(payload, idea_context)
|
| 365 |
except Exception as exc:
|
| 366 |
return self._mock_provider.generate(payload, reason=f"MCP research call failed: {exc}")
|
| 367 |
|
mcp_servers/research/config.py
CHANGED
|
@@ -1,9 +1,14 @@
|
|
| 1 |
from pydantic_settings import BaseSettings
|
| 2 |
|
| 3 |
class ResearchSettings(BaseSettings):
|
| 4 |
-
provider: str = "
|
| 5 |
max_items: int = 3
|
| 6 |
search_safe: str = "moderate" # strict/moderate/off
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class Config:
|
| 9 |
env_prefix = "CELEBRATE_AI_RESEARCH_"
|
|
|
|
| 1 |
from pydantic_settings import BaseSettings
|
| 2 |
|
| 3 |
class ResearchSettings(BaseSettings):
|
| 4 |
+
provider: str = "agent"
|
| 5 |
max_items: int = 3
|
| 6 |
search_safe: str = "moderate" # strict/moderate/off
|
| 7 |
+
nebius_base_url: str | None = None
|
| 8 |
+
nebius_api_key: str | None = None
|
| 9 |
+
nebius_model: str = "meta-llama/llama-3.1-8b-instruct"
|
| 10 |
+
nebius_temperature: float = 0.6
|
| 11 |
+
max_tool_calls: int = 4
|
| 12 |
|
| 13 |
class Config:
|
| 14 |
env_prefix = "CELEBRATE_AI_RESEARCH_"
|
mcp_servers/research/provider.py
CHANGED
|
@@ -1,33 +1,1306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
| 3 |
import logging
|
| 4 |
-
import
|
|
|
|
|
|
|
| 5 |
from datetime import datetime, timezone
|
| 6 |
-
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
try:
|
| 9 |
from ddgs import DDGS
|
| 10 |
-
except Exception:
|
| 11 |
-
DDGS = None
|
| 12 |
|
| 13 |
from .config import settings
|
| 14 |
-
from .tool_schemas import
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class BaseResearchProvider:
|
|
|
|
|
|
|
| 20 |
def generate(self, payload: ResearchRequest) -> ResearchResponse:
|
| 21 |
raise NotImplementedError
|
| 22 |
|
| 23 |
|
| 24 |
class MockResearchProvider(BaseResearchProvider):
|
|
|
|
|
|
|
| 25 |
def generate(self, payload: ResearchRequest) -> ResearchResponse:
|
| 26 |
-
plans
|
| 27 |
PlanItem(
|
| 28 |
title="Vendor shortlist",
|
| 29 |
summary="List 2-3 vendors that fit the brief.",
|
| 30 |
-
steps=["Identify 3 local options", "Compare pricing
|
| 31 |
estimated_budget="Varies",
|
| 32 |
duration="1-2 days",
|
| 33 |
sources=[],
|
|
@@ -42,7 +1315,8 @@ class MockResearchProvider(BaseResearchProvider):
|
|
| 42 |
sources=[],
|
| 43 |
links=[],
|
| 44 |
),
|
| 45 |
-
][:
|
|
|
|
| 46 |
return ResearchResponse(
|
| 47 |
session_id=payload.session_id,
|
| 48 |
plans=plans,
|
|
@@ -55,42 +1329,64 @@ class MockResearchProvider(BaseResearchProvider):
|
|
| 55 |
|
| 56 |
|
| 57 |
class DuckDuckGoResearchProvider(BaseResearchProvider):
|
| 58 |
-
"""
|
| 59 |
-
|
| 60 |
-
def __init__(self, max_items: int, safe: str = "moderate")
|
| 61 |
self.max_items = max_items
|
| 62 |
self.safe = safe
|
| 63 |
if DDGS is None:
|
| 64 |
-
raise RuntimeError("
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
def generate(self, payload: ResearchRequest) -> ResearchResponse:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
try:
|
| 71 |
with DDGS() as ddgs:
|
| 72 |
results = ddgs.text(query, safesearch=self.safe, max_results=self.max_items)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
except Exception as exc:
|
| 79 |
logger.error("DuckDuckGo search failed: %s", exc)
|
| 80 |
-
|
| 81 |
plan = PlanItem(
|
| 82 |
-
title="
|
| 83 |
-
summary=f"
|
| 84 |
steps=[
|
| 85 |
-
"Review the links for
|
| 86 |
-
"
|
| 87 |
-
"
|
|
|
|
| 88 |
],
|
| 89 |
estimated_budget="Varies",
|
| 90 |
-
duration="
|
| 91 |
-
sources=
|
| 92 |
-
links=links[:
|
| 93 |
)
|
|
|
|
| 94 |
return ResearchResponse(
|
| 95 |
session_id=payload.session_id,
|
| 96 |
plans=[plan],
|
|
@@ -98,28 +1394,41 @@ class DuckDuckGoResearchProvider(BaseResearchProvider):
|
|
| 98 |
"source": "duckduckgo",
|
| 99 |
"provider": "duckduckgo",
|
| 100 |
"query": query,
|
|
|
|
| 101 |
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
|
| 102 |
},
|
| 103 |
)
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
def get_provider() -> BaseResearchProvider:
|
|
|
|
| 116 |
name = settings.provider.lower()
|
| 117 |
logger.info("Research provider selected: %s", name)
|
| 118 |
-
|
| 119 |
if name == "mock":
|
| 120 |
return MockResearchProvider()
|
| 121 |
-
|
| 122 |
if name == "duckduckgo":
|
| 123 |
-
return DuckDuckGoResearchProvider(
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
raise ValueError(f"Unsupported research provider: {settings.provider}")
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Research Agent Provider - ReAct-style agent for celebration planning research.
|
| 3 |
+
|
| 4 |
+
This module implements a research agent that uses a ReAct (Reasoning + Acting) cycle
|
| 5 |
+
to plan celebrations, search for products, generate invites, and create execution plans.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
+
import json
|
| 11 |
import logging
|
| 12 |
+
import re
|
| 13 |
+
from abc import ABC, abstractmethod
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
from datetime import datetime, timezone
|
| 16 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 17 |
+
from urllib.parse import urljoin, urlparse
|
| 18 |
+
|
| 19 |
+
import httpx
|
| 20 |
+
from bs4 import BeautifulSoup
|
| 21 |
|
| 22 |
try:
|
| 23 |
from ddgs import DDGS
|
| 24 |
+
except Exception:
|
| 25 |
+
DDGS = None
|
| 26 |
|
| 27 |
from .config import settings
|
| 28 |
+
from .tool_schemas import (
|
| 29 |
+
PlanItem, ResearchRequest, ResearchResponse,
|
| 30 |
+
SourceLink, ShoppingItem, IdeaContext
|
| 31 |
+
)
|
| 32 |
|
| 33 |
logger = logging.getLogger(__name__)
|
| 34 |
|
| 35 |
|
| 36 |
+
# =============================================================================
|
| 37 |
+
# Tool Infrastructure
|
| 38 |
+
# =============================================================================
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class ToolResult:
|
| 42 |
+
"""Result from a tool execution."""
|
| 43 |
+
success: bool
|
| 44 |
+
data: Dict[str, Any] = field(default_factory=dict)
|
| 45 |
+
error: Optional[str] = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class BaseTool(ABC):
|
| 49 |
+
"""Base class for all research tools."""
|
| 50 |
+
|
| 51 |
+
name: str
|
| 52 |
+
description: str
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def execute(self, **kwargs) -> ToolResult:
|
| 56 |
+
"""Execute the tool with given parameters."""
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
def get_schema(self) -> Dict[str, Any]:
|
| 60 |
+
"""Return JSON schema for tool parameters."""
|
| 61 |
+
return {}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class WebSearchTool(BaseTool):
|
| 65 |
+
"""Search the web using DuckDuckGo."""
|
| 66 |
+
|
| 67 |
+
name = "web_search"
|
| 68 |
+
description = (
|
| 69 |
+
"Search the web for information about celebration planning, venues, activities, "
|
| 70 |
+
"decorations, food ideas, or any general information. Returns search results with "
|
| 71 |
+
"titles, snippets, and URLs."
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Domains to filter out as they rarely have relevant party planning content
|
| 75 |
+
IRRELEVANT_DOMAINS = [
|
| 76 |
+
"microsoft.com", "apple.com", "support.", "answers.microsoft",
|
| 77 |
+
"stackoverflow.com", "github.com", "gitlab.com", "bitbucket.org",
|
| 78 |
+
"techsupport", "forum.nvidia", "developer.mozilla",
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
# URL patterns that indicate ads or tracking (not actual content)
|
| 82 |
+
AD_TRACKING_PATTERNS = [
|
| 83 |
+
"/aclick?", "/adclick?", "doubleclick.", "googleadservices.",
|
| 84 |
+
"go.redirectingat.", "tracking.", "click.linksynergy.",
|
| 85 |
+
"bing.com/aclick", "r.search.yahoo.com/cbclk",
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
def __init__(self, max_results: int = 5, safe: str = "moderate"):
|
| 89 |
+
self.max_results = max_results
|
| 90 |
+
self.safe = safe
|
| 91 |
+
if DDGS is None:
|
| 92 |
+
raise RuntimeError("ddgs (DuckDuckGo search) is not installed")
|
| 93 |
+
|
| 94 |
+
def _is_relevant_result(self, url: str, title: str, snippet: str) -> bool:
|
| 95 |
+
"""Check if a search result is likely relevant to party planning."""
|
| 96 |
+
url_lower = url.lower()
|
| 97 |
+
|
| 98 |
+
# Filter out known irrelevant domains
|
| 99 |
+
for domain in self.IRRELEVANT_DOMAINS:
|
| 100 |
+
if domain in url_lower:
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
# Filter out ad/tracking URLs
|
| 104 |
+
for pattern in self.AD_TRACKING_PATTERNS:
|
| 105 |
+
if pattern in url_lower:
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
# Check for relevance indicators
|
| 109 |
+
combined = (title + " " + snippet).lower()
|
| 110 |
+
relevant_keywords = [
|
| 111 |
+
"party", "celebration", "ideas", "decor", "activity", "game",
|
| 112 |
+
"host", "plan", "diy", "craft", "food", "drink", "theme",
|
| 113 |
+
"birthday", "event", "invite", "guest", "fun", "entertainment",
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
# At least one relevant keyword should be present
|
| 117 |
+
return any(kw in combined for kw in relevant_keywords)
|
| 118 |
+
|
| 119 |
+
def execute(self, query: str, **kwargs) -> ToolResult:
|
| 120 |
+
"""Execute web search."""
|
| 121 |
+
if not query or not query.strip():
|
| 122 |
+
return ToolResult(success=False, error="Query cannot be empty")
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
results = []
|
| 126 |
+
# Request more results than needed to allow for filtering
|
| 127 |
+
with DDGS() as ddgs:
|
| 128 |
+
search_results = ddgs.text(
|
| 129 |
+
query.strip(),
|
| 130 |
+
safesearch=self.safe,
|
| 131 |
+
max_results=self.max_results * 3
|
| 132 |
+
)
|
| 133 |
+
for item in search_results or []:
|
| 134 |
+
url = item.get("href", "")
|
| 135 |
+
title = item.get("title", "")
|
| 136 |
+
snippet = item.get("body", "")
|
| 137 |
+
|
| 138 |
+
# Filter for relevance
|
| 139 |
+
if self._is_relevant_result(url, title, snippet):
|
| 140 |
+
results.append({
|
| 141 |
+
"title": title,
|
| 142 |
+
"snippet": snippet,
|
| 143 |
+
"url": url,
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
if len(results) >= self.max_results:
|
| 147 |
+
break
|
| 148 |
+
|
| 149 |
+
return ToolResult(
|
| 150 |
+
success=True,
|
| 151 |
+
data={"query": query, "results": results, "count": len(results)}
|
| 152 |
+
)
|
| 153 |
+
except Exception as exc:
|
| 154 |
+
logger.error("Web search failed: %s", exc)
|
| 155 |
+
return ToolResult(success=False, error=str(exc))
|
| 156 |
+
|
| 157 |
+
def get_schema(self) -> Dict[str, Any]:
|
| 158 |
+
return {
|
| 159 |
+
"type": "object",
|
| 160 |
+
"properties": {
|
| 161 |
+
"query": {
|
| 162 |
+
"type": "string",
|
| 163 |
+
"description": "Search query for finding information"
|
| 164 |
+
}
|
| 165 |
+
},
|
| 166 |
+
"required": ["query"]
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class PageVisitorTool(BaseTool):
|
| 171 |
+
"""Visit a web page and extract its content."""
|
| 172 |
+
|
| 173 |
+
name = "visit_page"
|
| 174 |
+
description = (
|
| 175 |
+
"Visit a webpage URL and extract its main text content. Use this to get detailed "
|
| 176 |
+
"information from a search result. Returns the page title, description, and main text."
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def __init__(self, timeout: int = 15):
|
| 180 |
+
self.timeout = timeout
|
| 181 |
+
self._client = httpx.Client(
|
| 182 |
+
timeout=timeout,
|
| 183 |
+
follow_redirects=True,
|
| 184 |
+
headers={
|
| 185 |
+
"User-Agent": (
|
| 186 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
| 187 |
+
"(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
| 188 |
+
)
|
| 189 |
+
}
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def execute(self, url: str, **kwargs) -> ToolResult:
|
| 193 |
+
"""Visit a page and extract content."""
|
| 194 |
+
if not url or not url.strip():
|
| 195 |
+
return ToolResult(success=False, error="URL cannot be empty")
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
response = self._client.get(url.strip())
|
| 199 |
+
response.raise_for_status()
|
| 200 |
+
|
| 201 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
| 202 |
+
|
| 203 |
+
# Remove script and style elements
|
| 204 |
+
for element in soup(["script", "style", "nav", "footer", "header", "aside"]):
|
| 205 |
+
element.decompose()
|
| 206 |
+
|
| 207 |
+
# Extract title
|
| 208 |
+
title = ""
|
| 209 |
+
title_tag = soup.find("title")
|
| 210 |
+
if title_tag:
|
| 211 |
+
title = title_tag.get_text(strip=True)
|
| 212 |
+
|
| 213 |
+
# Extract description
|
| 214 |
+
description = ""
|
| 215 |
+
meta_desc = soup.find("meta", attrs={"name": "description"})
|
| 216 |
+
if meta_desc:
|
| 217 |
+
description = meta_desc.get("content", "")
|
| 218 |
+
|
| 219 |
+
# Extract main text content
|
| 220 |
+
text_parts = []
|
| 221 |
+
for tag in soup.find_all(["h1", "h2", "h3", "p", "li"]):
|
| 222 |
+
text = tag.get_text(strip=True)
|
| 223 |
+
if text and len(text) > 20:
|
| 224 |
+
text_parts.append(text)
|
| 225 |
+
|
| 226 |
+
# Limit content length
|
| 227 |
+
content = "\n".join(text_parts[:30])
|
| 228 |
+
if len(content) > 3000:
|
| 229 |
+
content = content[:3000] + "..."
|
| 230 |
+
|
| 231 |
+
return ToolResult(
|
| 232 |
+
success=True,
|
| 233 |
+
data={
|
| 234 |
+
"url": url,
|
| 235 |
+
"title": title,
|
| 236 |
+
"description": description,
|
| 237 |
+
"content": content,
|
| 238 |
+
}
|
| 239 |
+
)
|
| 240 |
+
except Exception as exc:
|
| 241 |
+
logger.error("Page visit failed for %s: %s", url, exc)
|
| 242 |
+
return ToolResult(success=False, error=f"Failed to fetch page: {str(exc)}")
|
| 243 |
+
|
| 244 |
+
def get_schema(self) -> Dict[str, Any]:
|
| 245 |
+
return {
|
| 246 |
+
"type": "object",
|
| 247 |
+
"properties": {
|
| 248 |
+
"url": {
|
| 249 |
+
"type": "string",
|
| 250 |
+
"description": "URL of the webpage to visit"
|
| 251 |
+
}
|
| 252 |
+
},
|
| 253 |
+
"required": ["url"]
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class AmazonSearchTool(BaseTool):
|
| 258 |
+
"""Search for products on Amazon."""
|
| 259 |
+
|
| 260 |
+
name = "amazon_search"
|
| 261 |
+
description = (
|
| 262 |
+
"Search for products on Amazon. Use this for finding gifts, party supplies, "
|
| 263 |
+
"decorations, balloons, banners, cakes, or any items to buy for the celebration."
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Patterns that indicate a product page vs search/category page
|
| 267 |
+
PRODUCT_URL_PATTERNS = ["/dp/", "/gp/product/", "/gp/aw/d/"]
|
| 268 |
+
|
| 269 |
+
def __init__(self, max_results: int = 5, safe: str = "moderate"):
|
| 270 |
+
self.max_results = max_results
|
| 271 |
+
self.safe = safe
|
| 272 |
+
if DDGS is None:
|
| 273 |
+
raise RuntimeError("ddgs is not installed")
|
| 274 |
+
|
| 275 |
+
def _is_product_url(self, url: str) -> bool:
|
| 276 |
+
"""Check if URL is an actual product page, not a search or category page."""
|
| 277 |
+
url_lower = url.lower()
|
| 278 |
+
# Must be Amazon domain
|
| 279 |
+
if "amazon" not in url_lower:
|
| 280 |
+
return False
|
| 281 |
+
# Check for product URL patterns
|
| 282 |
+
return any(pattern in url_lower for pattern in self.PRODUCT_URL_PATTERNS)
|
| 283 |
+
|
| 284 |
+
def _is_search_or_category_url(self, url: str) -> bool:
|
| 285 |
+
"""Check if URL is a search or category page."""
|
| 286 |
+
url_lower = url.lower()
|
| 287 |
+
search_patterns = ["/s?", "/s/", "/b?", "/b/", "node=", "keywords="]
|
| 288 |
+
return any(pattern in url_lower for pattern in search_patterns)
|
| 289 |
+
|
| 290 |
+
def execute(self, query: str, **kwargs) -> ToolResult:
|
| 291 |
+
"""Search Amazon via DuckDuckGo site search."""
|
| 292 |
+
if not query or not query.strip():
|
| 293 |
+
return ToolResult(success=False, error="Query cannot be empty")
|
| 294 |
+
|
| 295 |
+
try:
|
| 296 |
+
# Search for specific products
|
| 297 |
+
search_query = f"site:amazon.com OR site:amazon.in {query.strip()}"
|
| 298 |
+
results = []
|
| 299 |
+
product_results = []
|
| 300 |
+
category_results = []
|
| 301 |
+
|
| 302 |
+
with DDGS() as ddgs:
|
| 303 |
+
# Request more results to filter for product pages
|
| 304 |
+
search_results = ddgs.text(
|
| 305 |
+
search_query,
|
| 306 |
+
safesearch=self.safe,
|
| 307 |
+
max_results=self.max_results * 4 # Get more to filter
|
| 308 |
+
)
|
| 309 |
+
for item in search_results or []:
|
| 310 |
+
href = item.get("href", "")
|
| 311 |
+
if "amazon" in href.lower():
|
| 312 |
+
# Extract price if present in snippet or title
|
| 313 |
+
snippet = item.get("body", "")
|
| 314 |
+
title = item.get("title", "")
|
| 315 |
+
price = self._extract_price(title + " " + snippet)
|
| 316 |
+
|
| 317 |
+
product_info = {
|
| 318 |
+
"title": title,
|
| 319 |
+
"snippet": snippet,
|
| 320 |
+
"url": href,
|
| 321 |
+
"price": price,
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
# Prioritize actual product pages over search/category pages
|
| 325 |
+
if self._is_product_url(href):
|
| 326 |
+
product_results.append(product_info)
|
| 327 |
+
elif not self._is_search_or_category_url(href):
|
| 328 |
+
category_results.append(product_info)
|
| 329 |
+
|
| 330 |
+
# Prefer product pages, then category pages
|
| 331 |
+
results = product_results[:self.max_results]
|
| 332 |
+
if len(results) < self.max_results:
|
| 333 |
+
results.extend(category_results[:self.max_results - len(results)])
|
| 334 |
+
|
| 335 |
+
return ToolResult(
|
| 336 |
+
success=True,
|
| 337 |
+
data={"query": query, "products": results, "count": len(results)}
|
| 338 |
+
)
|
| 339 |
+
except Exception as exc:
|
| 340 |
+
logger.error("Amazon search failed: %s", exc)
|
| 341 |
+
return ToolResult(success=False, error=str(exc))
|
| 342 |
+
|
| 343 |
+
def _extract_price(self, text: str) -> Optional[str]:
|
| 344 |
+
"""Extract price from text."""
|
| 345 |
+
patterns = [
|
| 346 |
+
r'₹\s*[\d,]+(?:\.\d{2})?',
|
| 347 |
+
r'\$\s*[\d,]+(?:\.\d{2})?',
|
| 348 |
+
r'Rs\.?\s*[\d,]+(?:\.\d{2})?',
|
| 349 |
+
]
|
| 350 |
+
for pattern in patterns:
|
| 351 |
+
match = re.search(pattern, text)
|
| 352 |
+
if match:
|
| 353 |
+
return match.group(0)
|
| 354 |
+
return None
|
| 355 |
+
|
| 356 |
+
def get_schema(self) -> Dict[str, Any]:
|
| 357 |
+
return {
|
| 358 |
+
"type": "object",
|
| 359 |
+
"properties": {
|
| 360 |
+
"query": {
|
| 361 |
+
"type": "string",
|
| 362 |
+
"description": "Product search query (e.g., 'birthday decorations', 'party balloons')"
|
| 363 |
+
}
|
| 364 |
+
},
|
| 365 |
+
"required": ["query"]
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class MessageGeneratorTool(BaseTool):
|
| 370 |
+
"""Generate invitation messages, birthday wishes, or greeting cards."""
|
| 371 |
+
|
| 372 |
+
name = "generate_message"
|
| 373 |
+
description = (
|
| 374 |
+
"Generate personalized messages like party invitations, birthday wishes, "
|
| 375 |
+
"thank you notes, or greeting card content. Provide the occasion, recipient, "
|
| 376 |
+
"and any specific details to include."
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
def __init__(self, llm_client: Optional[httpx.Client] = None):
|
| 380 |
+
self._client = llm_client or httpx.Client(timeout=60)
|
| 381 |
+
|
| 382 |
+
def execute(
|
| 383 |
+
self,
|
| 384 |
+
occasion: str,
|
| 385 |
+
recipient: str = "Guest",
|
| 386 |
+
tone: str = "warm and friendly",
|
| 387 |
+
details: str = "",
|
| 388 |
+
**kwargs
|
| 389 |
+
) -> ToolResult:
|
| 390 |
+
"""Generate a message using LLM or templates."""
|
| 391 |
+
if not occasion:
|
| 392 |
+
return ToolResult(success=False, error="Occasion is required")
|
| 393 |
+
|
| 394 |
+
# Try LLM generation first
|
| 395 |
+
if settings.nebius_base_url and settings.nebius_api_key:
|
| 396 |
+
try:
|
| 397 |
+
message = self._generate_with_llm(occasion, recipient, tone, details)
|
| 398 |
+
if message:
|
| 399 |
+
return ToolResult(
|
| 400 |
+
success=True,
|
| 401 |
+
data={
|
| 402 |
+
"message": message,
|
| 403 |
+
"occasion": occasion,
|
| 404 |
+
"recipient": recipient,
|
| 405 |
+
"generated_by": "llm"
|
| 406 |
+
}
|
| 407 |
+
)
|
| 408 |
+
except Exception as exc:
|
| 409 |
+
logger.error("LLM message generation failed: %s", exc)
|
| 410 |
+
|
| 411 |
+
# Fallback to template-based generation
|
| 412 |
+
message = self._generate_template_message(occasion, recipient, tone, details)
|
| 413 |
+
return ToolResult(
|
| 414 |
+
success=True,
|
| 415 |
+
data={
|
| 416 |
+
"message": message,
|
| 417 |
+
"occasion": occasion,
|
| 418 |
+
"recipient": recipient,
|
| 419 |
+
"generated_by": "template"
|
| 420 |
+
}
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
def _generate_with_llm(
|
| 424 |
+
self, occasion: str, recipient: str, tone: str, details: str
|
| 425 |
+
) -> Optional[str]:
|
| 426 |
+
"""Generate message using LLM."""
|
| 427 |
+
system_prompt = (
|
| 428 |
+
"You are a creative writer specializing in celebration messages. "
|
| 429 |
+
"Write warm, engaging messages that are personal and memorable. "
|
| 430 |
+
"Keep messages concise (60-100 words) but heartfelt."
|
| 431 |
+
)
|
| 432 |
+
user_prompt = (
|
| 433 |
+
f"Write a {tone} message for a {occasion}.\n"
|
| 434 |
+
f"Recipient: {recipient}\n"
|
| 435 |
+
f"Additional details: {details or 'None'}\n"
|
| 436 |
+
"Write only the message, no explanations."
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
body = {
|
| 440 |
+
"model": settings.nebius_model,
|
| 441 |
+
"temperature": settings.nebius_temperature,
|
| 442 |
+
"messages": [
|
| 443 |
+
{"role": "system", "content": system_prompt},
|
| 444 |
+
{"role": "user", "content": user_prompt},
|
| 445 |
+
],
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
resp = self._client.post(
|
| 449 |
+
f"{settings.nebius_base_url.rstrip('/')}/chat/completions",
|
| 450 |
+
headers={
|
| 451 |
+
"Authorization": f"Bearer {settings.nebius_api_key}",
|
| 452 |
+
"Content-Type": "application/json",
|
| 453 |
+
},
|
| 454 |
+
json=body,
|
| 455 |
+
)
|
| 456 |
+
resp.raise_for_status()
|
| 457 |
+
data = resp.json()
|
| 458 |
+
return data["choices"][0]["message"]["content"]
|
| 459 |
+
|
| 460 |
+
def _generate_template_message(
|
| 461 |
+
self, occasion: str, recipient: str, tone: str, details: str
|
| 462 |
+
) -> str:
|
| 463 |
+
"""Generate message using templates."""
|
| 464 |
+
templates = {
|
| 465 |
+
"birthday": [
|
| 466 |
+
f"🎂 Happy Birthday, {recipient}! 🎉\n\nWishing you a day filled with joy, laughter, and all your favorite things. May this year bring you endless happiness and amazing adventures!\n\nCheers to you! 🥳",
|
| 467 |
+
f"🎈 Happy Birthday, {recipient}!\n\nAnother year older, another year wiser, and definitely another year more awesome! Hope your special day is as wonderful as you are.\n\nEnjoy your day! 🎁",
|
| 468 |
+
],
|
| 469 |
+
"party invitation": [
|
| 470 |
+
f"🎉 You're Invited!\n\nHey {recipient}!\n\nWe're throwing a celebration and it wouldn't be the same without you! Join us for an unforgettable time.\n\n{details if details else 'Details to follow!'}\n\nHope to see you there! 🥳",
|
| 471 |
+
],
|
| 472 |
+
"thank you": [
|
| 473 |
+
f"Dear {recipient},\n\nThank you so much for being part of our celebration! Your presence made the day even more special.\n\nWith gratitude,\n❤️",
|
| 474 |
+
],
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
occasion_lower = occasion.lower()
|
| 478 |
+
for key, msgs in templates.items():
|
| 479 |
+
if key in occasion_lower:
|
| 480 |
+
return msgs[0]
|
| 481 |
+
|
| 482 |
+
# Default template
|
| 483 |
+
return (
|
| 484 |
+
f"Dear {recipient},\n\n"
|
| 485 |
+
f"You're invited to celebrate with us! "
|
| 486 |
+
f"{'Details: ' + details if details else 'More details coming soon.'}\n\n"
|
| 487 |
+
f"Looking forward to seeing you!\n🎉"
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
def get_schema(self) -> Dict[str, Any]:
|
| 491 |
+
return {
|
| 492 |
+
"type": "object",
|
| 493 |
+
"properties": {
|
| 494 |
+
"occasion": {
|
| 495 |
+
"type": "string",
|
| 496 |
+
"description": "Type of occasion (birthday, party invitation, thank you, etc.)"
|
| 497 |
+
},
|
| 498 |
+
"recipient": {
|
| 499 |
+
"type": "string",
|
| 500 |
+
"description": "Name of the recipient"
|
| 501 |
+
},
|
| 502 |
+
"tone": {
|
| 503 |
+
"type": "string",
|
| 504 |
+
"description": "Tone of the message (warm, formal, funny, casual)"
|
| 505 |
+
},
|
| 506 |
+
"details": {
|
| 507 |
+
"type": "string",
|
| 508 |
+
"description": "Additional details to include (date, time, venue, etc.)"
|
| 509 |
+
}
|
| 510 |
+
},
|
| 511 |
+
"required": ["occasion"]
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
# =============================================================================
|
| 516 |
+
# Agent State and ReAct Implementation
|
| 517 |
+
# =============================================================================
|
| 518 |
+
|
| 519 |
+
@dataclass
|
| 520 |
+
class AgentState:
|
| 521 |
+
"""State maintained during agent execution."""
|
| 522 |
+
idea: str # The idea title
|
| 523 |
+
context: Dict[str, Any] # location, budget, constraints, etc.
|
| 524 |
+
|
| 525 |
+
# Full idea details from ideation (summary, highlights, next_steps, etc.)
|
| 526 |
+
idea_details: Optional[Dict[str, Any]] = None
|
| 527 |
+
|
| 528 |
+
# User profile and preferences
|
| 529 |
+
honoree: Optional[Dict[str, Any]] = None
|
| 530 |
+
occasion: Optional[str] = None
|
| 531 |
+
|
| 532 |
+
messages: List[Dict[str, str]] = field(default_factory=list)
|
| 533 |
+
tool_results: List[Dict[str, Any]] = field(default_factory=list)
|
| 534 |
+
iteration: int = 0
|
| 535 |
+
max_iterations: int = 5
|
| 536 |
+
final_response: Optional[str] = None
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
class ReActResearchAgent:
|
| 540 |
+
"""
|
| 541 |
+
ReAct-style research agent for celebration planning.
|
| 542 |
+
|
| 543 |
+
Uses a Reasoning + Acting cycle:
|
| 544 |
+
1. Reason about what information is needed
|
| 545 |
+
2. Select and execute appropriate tools
|
| 546 |
+
3. Observe results and reason about next steps
|
| 547 |
+
4. Repeat until sufficient information gathered
|
| 548 |
+
5. Synthesize final response
|
| 549 |
+
"""
|
| 550 |
+
|
| 551 |
+
def __init__(self):
|
| 552 |
+
self.tools: Dict[str, BaseTool] = {}
|
| 553 |
+
self._httpx_client = httpx.Client(timeout=60)
|
| 554 |
+
self._initialize_tools()
|
| 555 |
+
|
| 556 |
+
def _initialize_tools(self):
|
| 557 |
+
"""Initialize available tools."""
|
| 558 |
+
try:
|
| 559 |
+
self.tools["web_search"] = WebSearchTool(
|
| 560 |
+
max_results=settings.max_items,
|
| 561 |
+
safe=settings.search_safe
|
| 562 |
+
)
|
| 563 |
+
except Exception as exc:
|
| 564 |
+
logger.warning("WebSearchTool initialization failed: %s", exc)
|
| 565 |
+
|
| 566 |
+
try:
|
| 567 |
+
self.tools["visit_page"] = PageVisitorTool()
|
| 568 |
+
except Exception as exc:
|
| 569 |
+
logger.warning("PageVisitorTool initialization failed: %s", exc)
|
| 570 |
+
|
| 571 |
+
try:
|
| 572 |
+
self.tools["amazon_search"] = AmazonSearchTool(
|
| 573 |
+
max_results=settings.max_items,
|
| 574 |
+
safe=settings.search_safe
|
| 575 |
+
)
|
| 576 |
+
except Exception as exc:
|
| 577 |
+
logger.warning("AmazonSearchTool initialization failed: %s", exc)
|
| 578 |
+
|
| 579 |
+
self.tools["generate_message"] = MessageGeneratorTool(self._httpx_client)
|
| 580 |
+
|
| 581 |
+
logger.info("Initialized %d tools: %s", len(self.tools), list(self.tools.keys()))
|
| 582 |
+
|
| 583 |
+
def _get_tools_description(self) -> str:
|
| 584 |
+
"""Get formatted description of all available tools."""
|
| 585 |
+
lines = []
|
| 586 |
+
for name, tool in self.tools.items():
|
| 587 |
+
schema = tool.get_schema()
|
| 588 |
+
params = schema.get("properties", {})
|
| 589 |
+
param_str = ", ".join(
|
| 590 |
+
f"{k}: {v.get('description', 'no description')}"
|
| 591 |
+
for k, v in params.items()
|
| 592 |
+
)
|
| 593 |
+
lines.append(f"- {name}: {tool.description}\n Parameters: {param_str}")
|
| 594 |
+
return "\n".join(lines)
|
| 595 |
+
|
| 596 |
+
def _parse_idea_from_message(self, message: str) -> Tuple[str, str]:
|
| 597 |
+
"""
|
| 598 |
+
Extract the celebration idea and additional context from user message.
|
| 599 |
+
|
| 600 |
+
Handles formats like:
|
| 601 |
+
- "Plan this idea in detail: At-Home Studio Night. Additional context: ..."
|
| 602 |
+
- Direct idea names
|
| 603 |
+
"""
|
| 604 |
+
idea = message
|
| 605 |
+
context = ""
|
| 606 |
+
|
| 607 |
+
# Pattern 1: "Plan this idea in detail: X. Additional context: Y"
|
| 608 |
+
if "Plan this idea in detail:" in message:
|
| 609 |
+
try:
|
| 610 |
+
parts = message.split("Plan this idea in detail:")
|
| 611 |
+
if len(parts) > 1:
|
| 612 |
+
after_plan = parts[1]
|
| 613 |
+
if "Additional context:" in after_plan:
|
| 614 |
+
idea_parts = after_plan.split("Additional context:")
|
| 615 |
+
idea = idea_parts[0].strip().rstrip(".")
|
| 616 |
+
context = idea_parts[1].strip() if len(idea_parts) > 1 else ""
|
| 617 |
+
else:
|
| 618 |
+
idea = after_plan.strip().rstrip(".")
|
| 619 |
+
except Exception:
|
| 620 |
+
pass
|
| 621 |
+
|
| 622 |
+
# Clean up the idea
|
| 623 |
+
idea = idea.strip()
|
| 624 |
+
if not idea:
|
| 625 |
+
idea = message[:100] # Fallback to first 100 chars
|
| 626 |
+
|
| 627 |
+
return idea, context
|
| 628 |
+
|
| 629 |
+
def _call_llm(self, messages: List[Dict[str, str]]) -> Optional[str]:
|
| 630 |
+
"""Call LLM API for reasoning."""
|
| 631 |
+
if not settings.nebius_base_url or not settings.nebius_api_key:
|
| 632 |
+
return None
|
| 633 |
+
|
| 634 |
+
try:
|
| 635 |
+
body = {
|
| 636 |
+
"model": settings.nebius_model,
|
| 637 |
+
"temperature": settings.nebius_temperature,
|
| 638 |
+
"messages": messages,
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
resp = self._httpx_client.post(
|
| 642 |
+
f"{settings.nebius_base_url.rstrip('/')}/chat/completions",
|
| 643 |
+
headers={
|
| 644 |
+
"Authorization": f"Bearer {settings.nebius_api_key}",
|
| 645 |
+
"Content-Type": "application/json",
|
| 646 |
+
},
|
| 647 |
+
json=body,
|
| 648 |
+
)
|
| 649 |
+
resp.raise_for_status()
|
| 650 |
+
data = resp.json()
|
| 651 |
+
return data["choices"][0]["message"]["content"]
|
| 652 |
+
except Exception as exc:
|
| 653 |
+
logger.error("LLM call failed: %s", exc)
|
| 654 |
+
return None
|
| 655 |
+
|
| 656 |
+
def _plan_tool_calls(self, state: AgentState) -> List[Dict[str, Any]]:
|
| 657 |
+
"""
|
| 658 |
+
Use LLM to plan which tools to call, or use heuristics as fallback.
|
| 659 |
+
|
| 660 |
+
Returns list of tool calls like:
|
| 661 |
+
[{"tool": "web_search", "params": {"query": "..."}}]
|
| 662 |
+
"""
|
| 663 |
+
# Build system prompt
|
| 664 |
+
system_prompt = f"""You are a celebration planning research assistant. Your task is to help plan "{state.idea}".
|
| 665 |
+
|
| 666 |
+
Available tools:
|
| 667 |
+
{self._get_tools_description()}
|
| 668 |
+
|
| 669 |
+
Based on what you know so far, decide which tools to use to gather information.
|
| 670 |
+
Return a JSON object with "calls" array containing tool calls.
|
| 671 |
+
|
| 672 |
+
Example response:
|
| 673 |
+
{{"calls": [
|
| 674 |
+
{{"tool": "web_search", "params": {{"query": "birthday party at home ideas activities"}}}},
|
| 675 |
+
{{"tool": "amazon_search", "params": {{"query": "birthday party decorations supplies"}}}}
|
| 676 |
+
]}}
|
| 677 |
+
|
| 678 |
+
Keep queries specific and relevant to the celebration idea.
|
| 679 |
+
Maximum 3 tool calls per response."""
|
| 680 |
+
|
| 681 |
+
# Build user prompt with context
|
| 682 |
+
user_content = f"Celebration idea: {state.idea}\n"
|
| 683 |
+
if state.context:
|
| 684 |
+
user_content += f"Additional context: {json.dumps(state.context)}\n"
|
| 685 |
+
|
| 686 |
+
if state.tool_results:
|
| 687 |
+
user_content += "\nPrevious tool results summary:\n"
|
| 688 |
+
for result in state.tool_results[-3:]: # Last 3 results
|
| 689 |
+
user_content += f"- {result.get('tool')}: {result.get('summary', 'completed')}\n"
|
| 690 |
+
|
| 691 |
+
user_content += "\nWhat tools should we use next to gather useful information?"
|
| 692 |
+
|
| 693 |
+
# Try LLM planning
|
| 694 |
+
llm_response = self._call_llm([
|
| 695 |
+
{"role": "system", "content": system_prompt},
|
| 696 |
+
{"role": "user", "content": user_content},
|
| 697 |
+
])
|
| 698 |
+
|
| 699 |
+
if llm_response:
|
| 700 |
+
try:
|
| 701 |
+
# Try to parse JSON from response
|
| 702 |
+
# Handle markdown code blocks
|
| 703 |
+
json_match = re.search(r'\{[\s\S]*\}', llm_response)
|
| 704 |
+
if json_match:
|
| 705 |
+
parsed = json.loads(json_match.group())
|
| 706 |
+
calls = parsed.get("calls", [])
|
| 707 |
+
if calls:
|
| 708 |
+
return calls[:3] # Limit to 3 calls
|
| 709 |
+
except Exception as exc:
|
| 710 |
+
logger.warning("Failed to parse LLM tool planning response: %s", exc)
|
| 711 |
+
|
| 712 |
+
# Fallback: Use heuristics based on the idea
|
| 713 |
+
return self._heuristic_tool_planning(state)
|
| 714 |
+
|
| 715 |
+
def _heuristic_tool_planning(self, state: AgentState) -> List[Dict[str, Any]]:
|
| 716 |
+
"""Fallback heuristic-based tool planning when LLM is unavailable."""
|
| 717 |
+
idea_lower = state.idea.lower()
|
| 718 |
+
calls = []
|
| 719 |
+
|
| 720 |
+
# Extract additional context from idea details if available
|
| 721 |
+
idea_summary = ""
|
| 722 |
+
idea_highlights = []
|
| 723 |
+
if state.idea_details:
|
| 724 |
+
idea_summary = state.idea_details.get("summary", "")
|
| 725 |
+
idea_highlights = state.idea_details.get("highlights", [])
|
| 726 |
+
|
| 727 |
+
# Get interests from honoree if available
|
| 728 |
+
interests = []
|
| 729 |
+
if state.honoree:
|
| 730 |
+
interests = state.honoree.get("interests", [])
|
| 731 |
+
|
| 732 |
+
# First iteration: Do multiple targeted searches
|
| 733 |
+
if state.iteration == 0:
|
| 734 |
+
# Build a more specific search query using full context
|
| 735 |
+
idea_themes = self._extract_themes_from_idea(state.idea)
|
| 736 |
+
|
| 737 |
+
# Build search query with idea summary context
|
| 738 |
+
search_context = state.idea
|
| 739 |
+
if idea_summary:
|
| 740 |
+
# Extract key action words from summary
|
| 741 |
+
search_context = f"{state.idea} {idea_summary[:100]}"
|
| 742 |
+
|
| 743 |
+
# Web search for party/celebration ideas with better keywords
|
| 744 |
+
calls.append({
|
| 745 |
+
"tool": "web_search",
|
| 746 |
+
"params": {"query": f"how to plan {state.idea} party ideas activities guide"}
|
| 747 |
+
})
|
| 748 |
+
|
| 749 |
+
# Search for specific supplies based on highlights
|
| 750 |
+
amazon_terms = [idea_themes]
|
| 751 |
+
if idea_highlights:
|
| 752 |
+
# Extract shopping-related terms from highlights
|
| 753 |
+
for highlight in idea_highlights[:2]:
|
| 754 |
+
highlight_lower = highlight.lower()
|
| 755 |
+
if any(word in highlight_lower for word in ["station", "kit", "supplies", "decor", "activity"]):
|
| 756 |
+
amazon_terms.append(highlight[:30])
|
| 757 |
+
|
| 758 |
+
amazon_query = " ".join(amazon_terms[:3]) + " party supplies"
|
| 759 |
+
calls.append({
|
| 760 |
+
"tool": "amazon_search",
|
| 761 |
+
"params": {"query": amazon_query}
|
| 762 |
+
})
|
| 763 |
+
|
| 764 |
+
# Generate invitation message with more context
|
| 765 |
+
invite_details = []
|
| 766 |
+
if state.context.get("location"):
|
| 767 |
+
invite_details.append(f"Location: {state.context['location']}")
|
| 768 |
+
if state.context.get("event_date"):
|
| 769 |
+
invite_details.append(f"Date: {state.context['event_date']}")
|
| 770 |
+
if state.occasion:
|
| 771 |
+
invite_details.append(f"Occasion: {state.occasion}")
|
| 772 |
+
|
| 773 |
+
calls.append({
|
| 774 |
+
"tool": "generate_message",
|
| 775 |
+
"params": {
|
| 776 |
+
"occasion": f"{state.occasion or 'party'} invitation for {state.idea}",
|
| 777 |
+
"recipient": "Guest",
|
| 778 |
+
"tone": "fun and exciting",
|
| 779 |
+
"details": ". ".join(invite_details) if invite_details else "at my place"
|
| 780 |
+
}
|
| 781 |
+
})
|
| 782 |
+
|
| 783 |
+
# Second iteration: Do more specific searches based on first results
|
| 784 |
+
elif state.iteration == 1:
|
| 785 |
+
# Do a second, more specific search
|
| 786 |
+
calls.append({
|
| 787 |
+
"tool": "web_search",
|
| 788 |
+
"params": {"query": f"{state.idea} step by step planning checklist timeline"}
|
| 789 |
+
})
|
| 790 |
+
|
| 791 |
+
# Check if we have relevant URLs to visit from previous searches
|
| 792 |
+
for result in state.tool_results:
|
| 793 |
+
if result.get("tool") == "web_search" and result.get("success"):
|
| 794 |
+
data = result.get("data", {})
|
| 795 |
+
results = data.get("results", [])
|
| 796 |
+
# Visit pages that look relevant (filter out obvious junk)
|
| 797 |
+
for r in results[:3]:
|
| 798 |
+
url = r.get("url", "")
|
| 799 |
+
title = r.get("title", "").lower()
|
| 800 |
+
url_lower = url.lower()
|
| 801 |
+
|
| 802 |
+
# Skip irrelevant domains and tracking URLs
|
| 803 |
+
skip_patterns = [
|
| 804 |
+
"microsoft.com", "apple.com", "support.", "answers.",
|
| 805 |
+
"forum.", "stackoverflow", "github.com",
|
| 806 |
+
"/aclick?", "/adclick?", "doubleclick.", "tracking.",
|
| 807 |
+
"bing.com/aclick", "amazon.com/s?", "amazon.in/s?",
|
| 808 |
+
]
|
| 809 |
+
if url and not any(p in url_lower for p in skip_patterns):
|
| 810 |
+
# Check if title seems relevant
|
| 811 |
+
relevant_keywords = ["party", "celebration", "ideas", "how to",
|
| 812 |
+
"guide", "tips", "plan", "host", "decor",
|
| 813 |
+
"activity", "game", "diy", "craft"]
|
| 814 |
+
if any(kw in title for kw in relevant_keywords):
|
| 815 |
+
calls.append({
|
| 816 |
+
"tool": "visit_page",
|
| 817 |
+
"params": {"url": url}
|
| 818 |
+
})
|
| 819 |
+
break
|
| 820 |
+
|
| 821 |
+
return calls[:3]
|
| 822 |
+
|
| 823 |
+
def _extract_themes_from_idea(self, idea: str) -> str:
|
| 824 |
+
"""Extract key themes from the idea for better search queries."""
|
| 825 |
+
idea_lower = idea.lower()
|
| 826 |
+
|
| 827 |
+
# Map common celebration types to search themes
|
| 828 |
+
theme_mappings = {
|
| 829 |
+
"studio": "art craft painting",
|
| 830 |
+
"art": "art craft painting canvas",
|
| 831 |
+
"paint": "painting art canvas",
|
| 832 |
+
"movie": "movie night cinema film",
|
| 833 |
+
"game": "gaming video games board games",
|
| 834 |
+
"spa": "spa relaxation wellness",
|
| 835 |
+
"picnic": "outdoor picnic garden",
|
| 836 |
+
"bbq": "barbecue grill outdoor",
|
| 837 |
+
"karaoke": "karaoke music singing",
|
| 838 |
+
"dance": "dance music disco",
|
| 839 |
+
"birthday": "birthday celebration",
|
| 840 |
+
"anniversary": "anniversary romantic",
|
| 841 |
+
"graduation": "graduation celebration",
|
| 842 |
+
"baby shower": "baby shower celebration",
|
| 843 |
+
"bridal": "bridal shower wedding",
|
| 844 |
+
"cocktail": "cocktail drinks mixology",
|
| 845 |
+
"dinner": "dinner party elegant",
|
| 846 |
+
"brunch": "brunch morning",
|
| 847 |
+
"tea": "tea party afternoon",
|
| 848 |
+
"sports": "sports game viewing",
|
| 849 |
+
"trivia": "trivia quiz game night",
|
| 850 |
+
}
|
| 851 |
+
|
| 852 |
+
themes = []
|
| 853 |
+
for keyword, theme in theme_mappings.items():
|
| 854 |
+
if keyword in idea_lower:
|
| 855 |
+
themes.append(theme)
|
| 856 |
+
|
| 857 |
+
if themes:
|
| 858 |
+
return " ".join(themes[:2])
|
| 859 |
+
|
| 860 |
+
# Default to the idea itself cleaned up
|
| 861 |
+
return idea
|
| 862 |
+
|
| 863 |
+
def _execute_tool(self, tool_name: str, params: Dict[str, Any]) -> ToolResult:
|
| 864 |
+
"""Execute a single tool."""
|
| 865 |
+
tool = self.tools.get(tool_name)
|
| 866 |
+
if not tool:
|
| 867 |
+
return ToolResult(success=False, error=f"Unknown tool: {tool_name}")
|
| 868 |
+
|
| 869 |
+
try:
|
| 870 |
+
return tool.execute(**params)
|
| 871 |
+
except Exception as exc:
|
| 872 |
+
logger.error("Tool %s execution failed: %s", tool_name, exc)
|
| 873 |
+
return ToolResult(success=False, error=str(exc))
|
| 874 |
+
|
| 875 |
+
def _synthesize_response(self, state: AgentState) -> str:
|
| 876 |
+
"""
|
| 877 |
+
Synthesize a final user-friendly response from all gathered information.
|
| 878 |
+
Uses LLM if available, otherwise uses template-based synthesis.
|
| 879 |
+
"""
|
| 880 |
+
# Build context from tool results
|
| 881 |
+
context_parts = []
|
| 882 |
+
web_results = []
|
| 883 |
+
products = []
|
| 884 |
+
messages = []
|
| 885 |
+
page_contents = []
|
| 886 |
+
|
| 887 |
+
for result in state.tool_results:
|
| 888 |
+
tool = result.get("tool")
|
| 889 |
+
data = result.get("data", {})
|
| 890 |
+
|
| 891 |
+
if tool == "web_search" and data.get("results"):
|
| 892 |
+
for r in data["results"]:
|
| 893 |
+
web_results.append({
|
| 894 |
+
"title": r.get("title", ""),
|
| 895 |
+
"snippet": r.get("snippet", ""),
|
| 896 |
+
"url": r.get("url", "")
|
| 897 |
+
})
|
| 898 |
+
elif tool == "amazon_search" and data.get("products"):
|
| 899 |
+
products.extend(data["products"])
|
| 900 |
+
elif tool == "generate_message" and data.get("message"):
|
| 901 |
+
messages.append(data["message"])
|
| 902 |
+
elif tool == "visit_page" and data.get("content"):
|
| 903 |
+
page_contents.append({
|
| 904 |
+
"title": data.get("title", ""),
|
| 905 |
+
"content": data.get("content", "")[:500]
|
| 906 |
+
})
|
| 907 |
+
|
| 908 |
+
# Try LLM synthesis
|
| 909 |
+
if settings.nebius_base_url and settings.nebius_api_key:
|
| 910 |
+
synthesis_result = self._llm_synthesize(state, web_results, products, messages, page_contents)
|
| 911 |
+
if synthesis_result:
|
| 912 |
+
return synthesis_result
|
| 913 |
+
|
| 914 |
+
# Fallback to template-based synthesis
|
| 915 |
+
return self._template_synthesize(state, web_results, products, messages, page_contents)
|
| 916 |
+
|
| 917 |
+
def _llm_synthesize(
|
| 918 |
+
self,
|
| 919 |
+
state: AgentState,
|
| 920 |
+
web_results: List[Dict],
|
| 921 |
+
products: List[Dict],
|
| 922 |
+
messages: List[str],
|
| 923 |
+
page_contents: List[Dict]
|
| 924 |
+
) -> Optional[str]:
|
| 925 |
+
"""Use LLM to synthesize a comprehensive response."""
|
| 926 |
+
system_prompt = """You are a celebration planning assistant. Based on the research gathered,
|
| 927 |
+
create a comprehensive, actionable plan for the user.
|
| 928 |
+
|
| 929 |
+
Structure your response as:
|
| 930 |
+
1. **Overview** - Brief summary of the celebration idea
|
| 931 |
+
2. **Key Activities** - 3-5 suggested activities with details
|
| 932 |
+
3. **Shopping List** - Items to buy with links if available
|
| 933 |
+
4. **Timeline** - Suggested schedule for preparation and event day
|
| 934 |
+
5. **Sample Invitation** - If applicable
|
| 935 |
+
6. **Tips** - 2-3 practical tips
|
| 936 |
+
|
| 937 |
+
Be specific, practical, and enthusiastic!"""
|
| 938 |
+
|
| 939 |
+
user_content = f"Celebration idea: {state.idea}\n\n"
|
| 940 |
+
|
| 941 |
+
if web_results:
|
| 942 |
+
user_content += "Research findings:\n"
|
| 943 |
+
for r in web_results[:5]:
|
| 944 |
+
user_content += f"- {r['title']}: {r['snippet'][:200]}\n"
|
| 945 |
+
|
| 946 |
+
if page_contents:
|
| 947 |
+
user_content += "\nDetailed information from visited pages:\n"
|
| 948 |
+
for p in page_contents[:2]:
|
| 949 |
+
user_content += f"- {p['title']}: {p['content'][:300]}...\n"
|
| 950 |
+
|
| 951 |
+
if products:
|
| 952 |
+
user_content += "\nRecommended products:\n"
|
| 953 |
+
for p in products[:5]:
|
| 954 |
+
user_content += f"- {p['title'][:80]}: {p.get('price', 'Check price')} - {p['url']}\n"
|
| 955 |
+
|
| 956 |
+
if messages:
|
| 957 |
+
user_content += f"\nSample message/invitation:\n{messages[0][:300]}\n"
|
| 958 |
+
|
| 959 |
+
user_content += "\nPlease create a comprehensive celebration plan based on this information."
|
| 960 |
+
|
| 961 |
+
return self._call_llm([
|
| 962 |
+
{"role": "system", "content": system_prompt},
|
| 963 |
+
{"role": "user", "content": user_content},
|
| 964 |
+
])
|
| 965 |
+
|
| 966 |
+
def _template_synthesize(
|
| 967 |
+
self,
|
| 968 |
+
state: AgentState,
|
| 969 |
+
web_results: List[Dict],
|
| 970 |
+
products: List[Dict],
|
| 971 |
+
messages: List[str],
|
| 972 |
+
page_contents: List[Dict]
|
| 973 |
+
) -> str:
|
| 974 |
+
"""Template-based response synthesis when LLM is unavailable."""
|
| 975 |
+
parts = []
|
| 976 |
+
|
| 977 |
+
# Header
|
| 978 |
+
parts.append(f"# 🎉 Celebration Plan: {state.idea}\n")
|
| 979 |
+
|
| 980 |
+
# Overview from page content or web results
|
| 981 |
+
if page_contents:
|
| 982 |
+
parts.append("## Overview\n")
|
| 983 |
+
parts.append(page_contents[0].get("content", "")[:500] + "\n")
|
| 984 |
+
elif web_results:
|
| 985 |
+
parts.append("## Key Insights\n")
|
| 986 |
+
for r in web_results[:3]:
|
| 987 |
+
if r.get("snippet"):
|
| 988 |
+
parts.append(f"- {r['snippet'][:150]}\n")
|
| 989 |
+
|
| 990 |
+
# Suggested activities
|
| 991 |
+
parts.append("\n## Suggested Activities\n")
|
| 992 |
+
activity_suggestions = [
|
| 993 |
+
"Set up a themed decoration area",
|
| 994 |
+
"Plan interactive games or activities",
|
| 995 |
+
"Prepare a special playlist",
|
| 996 |
+
"Arrange for food and refreshments",
|
| 997 |
+
"Create a photo booth corner"
|
| 998 |
+
]
|
| 999 |
+
for i, activity in enumerate(activity_suggestions, 1):
|
| 1000 |
+
parts.append(f"{i}. {activity}\n")
|
| 1001 |
+
|
| 1002 |
+
# Shopping list with products
|
| 1003 |
+
if products:
|
| 1004 |
+
parts.append("\n## Shopping List\n")
|
| 1005 |
+
for p in products[:5]:
|
| 1006 |
+
price = p.get("price", "Check price")
|
| 1007 |
+
parts.append(f"- [{p['title'][:60]}]({p['url']}) - {price}\n")
|
| 1008 |
+
|
| 1009 |
+
# Sample invitation
|
| 1010 |
+
if messages:
|
| 1011 |
+
parts.append("\n## Sample Invitation\n")
|
| 1012 |
+
parts.append(f"```\n{messages[0]}\n```\n")
|
| 1013 |
+
|
| 1014 |
+
# Reference links
|
| 1015 |
+
if web_results:
|
| 1016 |
+
parts.append("\n## Helpful Resources\n")
|
| 1017 |
+
for r in web_results[:3]:
|
| 1018 |
+
if r.get("url"):
|
| 1019 |
+
parts.append(f"- [{r['title'][:50]}]({r['url']})\n")
|
| 1020 |
+
|
| 1021 |
+
# Tips
|
| 1022 |
+
parts.append("\n## Tips\n")
|
| 1023 |
+
parts.append("- Start preparations at least a week in advance\n")
|
| 1024 |
+
parts.append("- Create a checklist to track tasks\n")
|
| 1025 |
+
parts.append("- Don't forget to enjoy the celebration yourself!\n")
|
| 1026 |
+
|
| 1027 |
+
return "".join(parts)
|
| 1028 |
+
|
| 1029 |
+
def run(self, payload: ResearchRequest) -> ResearchResponse:
|
| 1030 |
+
"""
|
| 1031 |
+
Execute the ReAct agent loop.
|
| 1032 |
+
|
| 1033 |
+
1. Parse the idea from the request
|
| 1034 |
+
2. Plan tool calls (using LLM or heuristics)
|
| 1035 |
+
3. Execute tools
|
| 1036 |
+
4. Check if more information needed
|
| 1037 |
+
5. Synthesize final response
|
| 1038 |
+
"""
|
| 1039 |
+
# Parse the idea - use idea context if available, otherwise parse from message
|
| 1040 |
+
if payload.idea and payload.idea.title:
|
| 1041 |
+
idea = payload.idea.title
|
| 1042 |
+
idea_details = {
|
| 1043 |
+
"title": payload.idea.title,
|
| 1044 |
+
"summary": payload.idea.summary,
|
| 1045 |
+
"highlights": payload.idea.highlights,
|
| 1046 |
+
"estimated_budget": payload.idea.estimated_budget,
|
| 1047 |
+
"effort": payload.idea.effort,
|
| 1048 |
+
"next_steps": payload.idea.next_steps,
|
| 1049 |
+
"persona_fit": payload.idea.persona_fit,
|
| 1050 |
+
}
|
| 1051 |
+
additional_context = ""
|
| 1052 |
+
else:
|
| 1053 |
+
idea, additional_context = self._parse_idea_from_message(
|
| 1054 |
+
payload.user_message or payload.topic or ""
|
| 1055 |
+
)
|
| 1056 |
+
idea_details = None
|
| 1057 |
+
|
| 1058 |
+
logger.info("Research agent starting for idea: %s", idea)
|
| 1059 |
+
if idea_details:
|
| 1060 |
+
logger.info("Full idea context available: %s", idea_details.get("summary", "")[:100])
|
| 1061 |
+
|
| 1062 |
+
# Extract honoree info if available
|
| 1063 |
+
honoree_info = None
|
| 1064 |
+
if payload.honoree:
|
| 1065 |
+
honoree_info = {
|
| 1066 |
+
"age_range": payload.honoree.age_range,
|
| 1067 |
+
"interests": payload.honoree.interests,
|
| 1068 |
+
"preferences": payload.honoree.preferences,
|
| 1069 |
+
"guest_count": payload.honoree.guest_count,
|
| 1070 |
+
}
|
| 1071 |
+
|
| 1072 |
+
# Initialize state with full context
|
| 1073 |
+
state = AgentState(
|
| 1074 |
+
idea=idea,
|
| 1075 |
+
context={
|
| 1076 |
+
"location": payload.location,
|
| 1077 |
+
"budget_tier": payload.budget_tier,
|
| 1078 |
+
"constraints": payload.constraints,
|
| 1079 |
+
"additional": additional_context,
|
| 1080 |
+
"event_date": payload.event_date,
|
| 1081 |
+
},
|
| 1082 |
+
idea_details=idea_details,
|
| 1083 |
+
honoree=honoree_info,
|
| 1084 |
+
occasion=payload.occasion,
|
| 1085 |
+
max_iterations=settings.max_tool_calls,
|
| 1086 |
+
)
|
| 1087 |
+
|
| 1088 |
+
# ReAct loop
|
| 1089 |
+
while state.iteration < state.max_iterations:
|
| 1090 |
+
logger.info("ReAct iteration %d", state.iteration + 1)
|
| 1091 |
+
|
| 1092 |
+
# Plan tool calls
|
| 1093 |
+
tool_calls = self._plan_tool_calls(state)
|
| 1094 |
+
|
| 1095 |
+
if not tool_calls:
|
| 1096 |
+
logger.info("No more tool calls planned, moving to synthesis")
|
| 1097 |
+
break
|
| 1098 |
+
|
| 1099 |
+
# Execute tools
|
| 1100 |
+
for call in tool_calls:
|
| 1101 |
+
tool_name = call.get("tool")
|
| 1102 |
+
params = call.get("params", {})
|
| 1103 |
+
|
| 1104 |
+
logger.info("Executing tool: %s with params: %s", tool_name, params)
|
| 1105 |
+
result = self._execute_tool(tool_name, params)
|
| 1106 |
+
|
| 1107 |
+
# Store result
|
| 1108 |
+
state.tool_results.append({
|
| 1109 |
+
"tool": tool_name,
|
| 1110 |
+
"params": params,
|
| 1111 |
+
"success": result.success,
|
| 1112 |
+
"data": result.data,
|
| 1113 |
+
"error": result.error,
|
| 1114 |
+
"summary": self._summarize_result(tool_name, result),
|
| 1115 |
+
})
|
| 1116 |
+
|
| 1117 |
+
state.iteration += 1
|
| 1118 |
+
|
| 1119 |
+
# Synthesize response
|
| 1120 |
+
final_text = self._synthesize_response(state)
|
| 1121 |
+
|
| 1122 |
+
# Extract structured data from tool results
|
| 1123 |
+
web_sources = self._get_web_sources(state)
|
| 1124 |
+
products = self._get_products(state)
|
| 1125 |
+
invitation = self._get_invitation(state)
|
| 1126 |
+
|
| 1127 |
+
# Build overview from idea details or synthesized text
|
| 1128 |
+
overview = ""
|
| 1129 |
+
if state.idea_details:
|
| 1130 |
+
overview = state.idea_details.get("summary", "")
|
| 1131 |
+
if state.idea_details.get("highlights"):
|
| 1132 |
+
overview += "\n\n**Highlights:**\n"
|
| 1133 |
+
for h in state.idea_details["highlights"]:
|
| 1134 |
+
overview += f"• {h}\n"
|
| 1135 |
+
else:
|
| 1136 |
+
# Extract overview from synthesized text
|
| 1137 |
+
overview = final_text[:800] if len(final_text) > 800 else final_text
|
| 1138 |
+
|
| 1139 |
+
# Build steps - prefer idea's next_steps if available
|
| 1140 |
+
steps = []
|
| 1141 |
+
if state.idea_details and state.idea_details.get("next_steps"):
|
| 1142 |
+
steps = state.idea_details["next_steps"]
|
| 1143 |
+
else:
|
| 1144 |
+
steps = self._extract_steps_from_text(final_text)
|
| 1145 |
+
|
| 1146 |
+
# Build references (combined sources and links)
|
| 1147 |
+
references = [
|
| 1148 |
+
SourceLink(
|
| 1149 |
+
title=r.get("title", "Reference")[:60],
|
| 1150 |
+
url=r.get("url", ""),
|
| 1151 |
+
snippet=r.get("snippet", "")[:150]
|
| 1152 |
+
)
|
| 1153 |
+
for r in web_sources if r.get("url")
|
| 1154 |
+
]
|
| 1155 |
+
|
| 1156 |
+
# Build shopping list from products
|
| 1157 |
+
shopping_list = [
|
| 1158 |
+
ShoppingItem(
|
| 1159 |
+
name=p.get("title", "Product")[:80],
|
| 1160 |
+
price=p.get("price"),
|
| 1161 |
+
url=p.get("url", ""),
|
| 1162 |
+
source="Amazon"
|
| 1163 |
+
)
|
| 1164 |
+
for p in products if p.get("url")
|
| 1165 |
+
]
|
| 1166 |
+
|
| 1167 |
+
# Convert to PlanItem format with new schema
|
| 1168 |
+
plans = [
|
| 1169 |
+
PlanItem(
|
| 1170 |
+
title=f"Celebration Plan: {state.idea}",
|
| 1171 |
+
summary=overview[:500] if len(overview) > 500 else overview,
|
| 1172 |
+
overview=overview,
|
| 1173 |
+
activities=state.idea_details.get("highlights", []) if state.idea_details else [],
|
| 1174 |
+
steps=steps,
|
| 1175 |
+
timeline=f"Preparation: 1-2 weeks before the event",
|
| 1176 |
+
estimated_budget=state.idea_details.get("estimated_budget") if state.idea_details else (payload.budget_tier or "Varies"),
|
| 1177 |
+
duration=f"Effort: {state.idea_details.get('effort', 'medium')}" if state.idea_details else "Plan ahead: 1-2 weeks",
|
| 1178 |
+
references=references,
|
| 1179 |
+
shopping_list=shopping_list,
|
| 1180 |
+
sample_invitation=invitation,
|
| 1181 |
+
# Legacy fields for backward compatibility
|
| 1182 |
+
sources=[r.get("title", "") for r in web_sources],
|
| 1183 |
+
links=[r.get("url", "") for r in web_sources],
|
| 1184 |
+
)
|
| 1185 |
+
]
|
| 1186 |
+
|
| 1187 |
+
return ResearchResponse(
|
| 1188 |
+
session_id=payload.session_id,
|
| 1189 |
+
plans=plans,
|
| 1190 |
+
metadata={
|
| 1191 |
+
"source": "react_agent",
|
| 1192 |
+
"provider": "agent",
|
| 1193 |
+
"idea": state.idea,
|
| 1194 |
+
"idea_details": state.idea_details,
|
| 1195 |
+
"iterations": state.iteration,
|
| 1196 |
+
"tools_used": list(set(r.get("tool") for r in state.tool_results)),
|
| 1197 |
+
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
|
| 1198 |
+
"full_response": final_text,
|
| 1199 |
+
},
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
def _summarize_result(self, tool_name: str, result: ToolResult) -> str:
|
| 1203 |
+
"""Create a brief summary of a tool result."""
|
| 1204 |
+
if not result.success:
|
| 1205 |
+
return f"Failed: {result.error}"
|
| 1206 |
+
|
| 1207 |
+
if tool_name == "web_search":
|
| 1208 |
+
count = result.data.get("count", 0)
|
| 1209 |
+
return f"Found {count} results"
|
| 1210 |
+
elif tool_name == "amazon_search":
|
| 1211 |
+
count = result.data.get("count", 0)
|
| 1212 |
+
return f"Found {count} products"
|
| 1213 |
+
elif tool_name == "visit_page":
|
| 1214 |
+
title = result.data.get("title", "page")
|
| 1215 |
+
return f"Visited: {title[:50]}"
|
| 1216 |
+
elif tool_name == "generate_message":
|
| 1217 |
+
return "Generated message"
|
| 1218 |
+
|
| 1219 |
+
return "Completed"
|
| 1220 |
+
|
| 1221 |
+
def _extract_steps_from_text(self, text: str) -> List[str]:
|
| 1222 |
+
"""Extract actionable steps from synthesized text."""
|
| 1223 |
+
steps = []
|
| 1224 |
+
|
| 1225 |
+
# Look for numbered items or bullet points
|
| 1226 |
+
lines = text.split("\n")
|
| 1227 |
+
for line in lines:
|
| 1228 |
+
line = line.strip()
|
| 1229 |
+
# Match numbered steps or bullet points
|
| 1230 |
+
if re.match(r'^[\d]+\.|\-|\*|\→', line):
|
| 1231 |
+
clean_line = re.sub(r'^[\d]+\.|\-|\*|\→\s*', '', line).strip()
|
| 1232 |
+
if clean_line and len(clean_line) > 10:
|
| 1233 |
+
steps.append(clean_line[:100])
|
| 1234 |
+
|
| 1235 |
+
# If no steps found, create default ones
|
| 1236 |
+
if not steps:
|
| 1237 |
+
steps = [
|
| 1238 |
+
"Review the plan and customize for your needs",
|
| 1239 |
+
"Purchase necessary supplies and decorations",
|
| 1240 |
+
"Send invitations to guests",
|
| 1241 |
+
"Prepare the venue and activities",
|
| 1242 |
+
"Enjoy the celebration!"
|
| 1243 |
+
]
|
| 1244 |
+
|
| 1245 |
+
return steps[:8] # Limit to 8 steps
|
| 1246 |
+
|
| 1247 |
+
def _get_web_sources(self, state: AgentState) -> List[Dict[str, str]]:
|
| 1248 |
+
"""Extract web sources from tool results."""
|
| 1249 |
+
sources = []
|
| 1250 |
+
for result in state.tool_results:
|
| 1251 |
+
if result.get("tool") == "web_search":
|
| 1252 |
+
for r in result.get("data", {}).get("results", []):
|
| 1253 |
+
if r.get("url"):
|
| 1254 |
+
sources.append({
|
| 1255 |
+
"title": r.get("title", "")[:60],
|
| 1256 |
+
"url": r.get("url", ""),
|
| 1257 |
+
"snippet": r.get("snippet", "")[:150]
|
| 1258 |
+
})
|
| 1259 |
+
return sources[:5]
|
| 1260 |
+
|
| 1261 |
+
def _get_products(self, state: AgentState) -> List[Dict[str, str]]:
|
| 1262 |
+
"""Extract product results from Amazon search."""
|
| 1263 |
+
products = []
|
| 1264 |
+
for result in state.tool_results:
|
| 1265 |
+
if result.get("tool") == "amazon_search" and result.get("success"):
|
| 1266 |
+
for p in result.get("data", {}).get("products", []):
|
| 1267 |
+
if p.get("url"):
|
| 1268 |
+
products.append({
|
| 1269 |
+
"title": p.get("title", "Product")[:80],
|
| 1270 |
+
"price": p.get("price"),
|
| 1271 |
+
"url": p.get("url", ""),
|
| 1272 |
+
"snippet": p.get("snippet", "")[:100]
|
| 1273 |
+
})
|
| 1274 |
+
return products[:5]
|
| 1275 |
+
|
| 1276 |
+
def _get_invitation(self, state: AgentState) -> str:
|
| 1277 |
+
"""Extract generated invitation message from tool results."""
|
| 1278 |
+
for result in state.tool_results:
|
| 1279 |
+
if result.get("tool") == "generate_message" and result.get("success"):
|
| 1280 |
+
return result.get("data", {}).get("message", "")
|
| 1281 |
+
return ""
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
# =============================================================================
|
| 1285 |
+
# Provider Factory
|
| 1286 |
+
# =============================================================================
|
| 1287 |
+
|
| 1288 |
class BaseResearchProvider:
|
| 1289 |
+
"""Base class for research providers."""
|
| 1290 |
+
|
| 1291 |
def generate(self, payload: ResearchRequest) -> ResearchResponse:
|
| 1292 |
raise NotImplementedError
|
| 1293 |
|
| 1294 |
|
| 1295 |
class MockResearchProvider(BaseResearchProvider):
|
| 1296 |
+
"""Mock provider for testing."""
|
| 1297 |
+
|
| 1298 |
def generate(self, payload: ResearchRequest) -> ResearchResponse:
|
| 1299 |
+
plans = [
|
| 1300 |
PlanItem(
|
| 1301 |
title="Vendor shortlist",
|
| 1302 |
summary="List 2-3 vendors that fit the brief.",
|
| 1303 |
+
steps=["Identify 3 local options", "Compare pricing", "Prepare contact info"],
|
| 1304 |
estimated_budget="Varies",
|
| 1305 |
duration="1-2 days",
|
| 1306 |
sources=[],
|
|
|
|
| 1315 |
sources=[],
|
| 1316 |
links=[],
|
| 1317 |
),
|
| 1318 |
+
][:settings.max_items]
|
| 1319 |
+
|
| 1320 |
return ResearchResponse(
|
| 1321 |
session_id=payload.session_id,
|
| 1322 |
plans=plans,
|
|
|
|
| 1329 |
|
| 1330 |
|
| 1331 |
class DuckDuckGoResearchProvider(BaseResearchProvider):
|
| 1332 |
+
"""Simple DuckDuckGo search provider (no agent loop)."""
|
| 1333 |
+
|
| 1334 |
+
def __init__(self, max_items: int, safe: str = "moderate"):
|
| 1335 |
self.max_items = max_items
|
| 1336 |
self.safe = safe
|
| 1337 |
if DDGS is None:
|
| 1338 |
+
raise RuntimeError("ddgs is not installed")
|
| 1339 |
+
|
| 1340 |
+
def _parse_idea(self, payload: ResearchRequest) -> str:
|
| 1341 |
+
"""Extract the actual idea from the payload."""
|
| 1342 |
+
message = payload.user_message or payload.topic or ""
|
| 1343 |
+
|
| 1344 |
+
# Handle "Plan this idea in detail: X" format
|
| 1345 |
+
if "Plan this idea in detail:" in message:
|
| 1346 |
+
try:
|
| 1347 |
+
parts = message.split("Plan this idea in detail:")
|
| 1348 |
+
if len(parts) > 1:
|
| 1349 |
+
idea = parts[1].split("Additional context:")[0].strip().rstrip(".")
|
| 1350 |
+
if idea:
|
| 1351 |
+
return idea
|
| 1352 |
+
except Exception:
|
| 1353 |
+
pass
|
| 1354 |
+
|
| 1355 |
+
return message[:100] if message else "celebration"
|
| 1356 |
+
|
| 1357 |
def generate(self, payload: ResearchRequest) -> ResearchResponse:
|
| 1358 |
+
idea = self._parse_idea(payload)
|
| 1359 |
+
query = f"{idea} celebration party ideas activities how to plan"
|
| 1360 |
+
|
| 1361 |
+
logger.info("DuckDuckGo search for: %s", query)
|
| 1362 |
+
|
| 1363 |
+
links, titles, snippets = [], [], []
|
| 1364 |
try:
|
| 1365 |
with DDGS() as ddgs:
|
| 1366 |
results = ddgs.text(query, safesearch=self.safe, max_results=self.max_items)
|
| 1367 |
+
for item in results or []:
|
| 1368 |
+
if item.get("href"):
|
| 1369 |
+
links.append(item["href"])
|
| 1370 |
+
titles.append(item.get("title", ""))
|
| 1371 |
+
snippets.append(item.get("body", ""))
|
| 1372 |
+
except Exception as exc:
|
| 1373 |
logger.error("DuckDuckGo search failed: %s", exc)
|
| 1374 |
+
|
| 1375 |
plan = PlanItem(
|
| 1376 |
+
title=f"Research: {idea}",
|
| 1377 |
+
summary=f"Web search results for planning {idea}",
|
| 1378 |
steps=[
|
| 1379 |
+
"Review the links for relevant ideas and inspiration",
|
| 1380 |
+
"Note down activities and supplies that fit your budget",
|
| 1381 |
+
"Create a checklist of items to purchase",
|
| 1382 |
+
"Plan your timeline for preparation",
|
| 1383 |
],
|
| 1384 |
estimated_budget="Varies",
|
| 1385 |
+
duration="30-60 minutes research",
|
| 1386 |
+
sources=titles[:self.max_items],
|
| 1387 |
+
links=links[:self.max_items],
|
| 1388 |
)
|
| 1389 |
+
|
| 1390 |
return ResearchResponse(
|
| 1391 |
session_id=payload.session_id,
|
| 1392 |
plans=[plan],
|
|
|
|
| 1394 |
"source": "duckduckgo",
|
| 1395 |
"provider": "duckduckgo",
|
| 1396 |
"query": query,
|
| 1397 |
+
"idea": idea,
|
| 1398 |
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
|
| 1399 |
},
|
| 1400 |
)
|
| 1401 |
|
| 1402 |
+
|
| 1403 |
+
class AgentResearchProvider(BaseResearchProvider):
|
| 1404 |
+
"""ReAct agent-based research provider."""
|
| 1405 |
+
|
| 1406 |
+
def __init__(self):
|
| 1407 |
+
self._agent = ReActResearchAgent()
|
| 1408 |
+
|
| 1409 |
+
def generate(self, payload: ResearchRequest) -> ResearchResponse:
|
| 1410 |
+
return self._agent.run(payload)
|
| 1411 |
|
| 1412 |
|
| 1413 |
def get_provider() -> BaseResearchProvider:
|
| 1414 |
+
"""Get the configured research provider."""
|
| 1415 |
name = settings.provider.lower()
|
| 1416 |
logger.info("Research provider selected: %s", name)
|
| 1417 |
+
|
| 1418 |
if name == "mock":
|
| 1419 |
return MockResearchProvider()
|
| 1420 |
+
|
| 1421 |
if name == "duckduckgo":
|
| 1422 |
+
return DuckDuckGoResearchProvider(
|
| 1423 |
+
max_items=settings.max_items,
|
| 1424 |
+
safe=settings.search_safe
|
| 1425 |
+
)
|
| 1426 |
+
|
| 1427 |
+
if name == "agent":
|
| 1428 |
+
try:
|
| 1429 |
+
return AgentResearchProvider()
|
| 1430 |
+
except Exception as exc:
|
| 1431 |
+
logger.error("Agent provider failed to initialize: %s", exc)
|
| 1432 |
+
return MockResearchProvider()
|
| 1433 |
+
|
| 1434 |
raise ValueError(f"Unsupported research provider: {settings.provider}")
|
mcp_servers/research/tool_schemas.py
CHANGED
|
@@ -1,23 +1,94 @@
|
|
| 1 |
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
class ResearchRequest(BaseModel):
|
|
|
|
| 4 |
session_id: str
|
| 5 |
user_message: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
topic: str | None = None
|
| 7 |
location: str | None = None
|
| 8 |
budget_tier: str | None = None
|
| 9 |
-
constraints:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class PlanItem(BaseModel):
|
|
|
|
| 12 |
title: str
|
| 13 |
summary: str
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
estimated_budget: str | None = None
|
| 16 |
duration: str | None = None
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
class ResearchResponse(BaseModel):
|
| 21 |
session_id: str
|
| 22 |
-
plans:
|
| 23 |
-
metadata:
|
|
|
|
| 1 |
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Optional, Dict, Any
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class IdeaContext(BaseModel):
|
| 6 |
+
"""Full context of the idea being researched."""
|
| 7 |
+
title: str
|
| 8 |
+
summary: str = ""
|
| 9 |
+
highlights: List[str] = Field(default_factory=list)
|
| 10 |
+
estimated_budget: str | None = None
|
| 11 |
+
effort: str | None = None
|
| 12 |
+
next_steps: List[str] = Field(default_factory=list)
|
| 13 |
+
persona_fit: str | None = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class HonoreeProfile(BaseModel):
|
| 17 |
+
"""Profile of the person being celebrated."""
|
| 18 |
+
age_range: str | None = None
|
| 19 |
+
interests: List[str] = Field(default_factory=list)
|
| 20 |
+
preferences: str | None = None
|
| 21 |
+
guest_count: int | None = None
|
| 22 |
+
|
| 23 |
|
| 24 |
class ResearchRequest(BaseModel):
|
| 25 |
+
"""Request payload for research agent with full context."""
|
| 26 |
session_id: str
|
| 27 |
user_message: str
|
| 28 |
+
|
| 29 |
+
# The idea being researched (full context from ideation)
|
| 30 |
+
idea: IdeaContext | None = None
|
| 31 |
+
|
| 32 |
+
# Legacy fields (kept for backward compatibility)
|
| 33 |
topic: str | None = None
|
| 34 |
location: str | None = None
|
| 35 |
budget_tier: str | None = None
|
| 36 |
+
constraints: List[str] = Field(default_factory=list)
|
| 37 |
+
|
| 38 |
+
# New fields for richer context
|
| 39 |
+
occasion: str | None = None
|
| 40 |
+
event_date: str | None = None
|
| 41 |
+
honoree: HonoreeProfile | None = None
|
| 42 |
+
|
| 43 |
+
# Conversation history for context
|
| 44 |
+
conversation_history: List[Dict[str, Any]] = Field(default_factory=list)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class SourceLink(BaseModel):
|
| 48 |
+
"""Combined source with title and URL."""
|
| 49 |
+
title: str
|
| 50 |
+
url: str
|
| 51 |
+
snippet: str = ""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ShoppingItem(BaseModel):
|
| 55 |
+
"""A specific product recommendation."""
|
| 56 |
+
name: str
|
| 57 |
+
price: str | None = None
|
| 58 |
+
url: str
|
| 59 |
+
source: str = "Amazon" # Amazon, Flipkart, etc.
|
| 60 |
+
|
| 61 |
|
| 62 |
class PlanItem(BaseModel):
|
| 63 |
+
"""A detailed plan item with combined sources."""
|
| 64 |
title: str
|
| 65 |
summary: str
|
| 66 |
+
|
| 67 |
+
# Detailed sections
|
| 68 |
+
overview: str = "" # Detailed overview of what this plan entails
|
| 69 |
+
activities: List[str] = Field(default_factory=list) # Suggested activities
|
| 70 |
+
steps: List[str] = Field(default_factory=list) # Actionable steps
|
| 71 |
+
timeline: str = "" # Suggested timeline
|
| 72 |
+
|
| 73 |
+
# Budget and duration
|
| 74 |
estimated_budget: str | None = None
|
| 75 |
duration: str | None = None
|
| 76 |
+
|
| 77 |
+
# Combined sources (title + URL together)
|
| 78 |
+
references: List[SourceLink] = Field(default_factory=list)
|
| 79 |
+
|
| 80 |
+
# Shopping recommendations (specific products)
|
| 81 |
+
shopping_list: List[ShoppingItem] = Field(default_factory=list)
|
| 82 |
+
|
| 83 |
+
# Generated content
|
| 84 |
+
sample_invitation: str = ""
|
| 85 |
+
|
| 86 |
+
# Legacy fields (for backward compatibility)
|
| 87 |
+
sources: List[str] = Field(default_factory=list)
|
| 88 |
+
links: List[str] = Field(default_factory=list)
|
| 89 |
+
|
| 90 |
|
| 91 |
class ResearchResponse(BaseModel):
|
| 92 |
session_id: str
|
| 93 |
+
plans: List[PlanItem] = Field(default_factory=list)
|
| 94 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
requirements.txt
CHANGED
|
@@ -4,3 +4,5 @@ pydantic-settings
|
|
| 4 |
pydantic
|
| 5 |
httpx
|
| 6 |
ddgs
|
|
|
|
|
|
|
|
|
| 4 |
pydantic
|
| 5 |
httpx
|
| 6 |
ddgs
|
| 7 |
+
langgraph>=1.0.0
|
| 8 |
+
langchain-core>=1.1.0
|
ui/chat.py
CHANGED
|
@@ -39,7 +39,8 @@ def _make_initial_state() -> Dict[str, Any]:
|
|
| 39 |
"source": "gradio",
|
| 40 |
"last_interaction": datetime.now(tz=timezone.utc).isoformat(),
|
| 41 |
},
|
| 42 |
-
"last_ideas": [],
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
|
|
@@ -148,20 +149,67 @@ def _format_research_response(response: ResearchResponse) -> str:
|
|
| 148 |
|
| 149 |
sections: List[str] = []
|
| 150 |
for idx, plan in enumerate(response.plans, start=1):
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
+ (f"**Links**\n{links}\n\n" if links else "")
|
| 163 |
-
)
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
note = response.metadata.get("reason") if response.metadata else None
|
| 167 |
if note:
|
|
@@ -230,7 +278,20 @@ def handle_message(
|
|
| 230 |
payload_for_call = _build_request_payload(state, user_message=message or "")
|
| 231 |
try:
|
| 232 |
ideation_response = orchestrator.generate_ideas(payload_for_call)
|
|
|
|
| 233 |
state["last_ideas"] = [idea.title for idea in ideation_response.ideas]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
bot_content = _format_ideation_response(ideation_response)
|
| 235 |
except Exception as exc: # pragma: no cover
|
| 236 |
bot_content = (
|
|
@@ -288,6 +349,9 @@ def handle_plan_selection(
|
|
| 288 |
choices = state.get("last_ideas") or []
|
| 289 |
return chat_history, state, _build_request_payload(state), state["session_id"], gr.update(choices=choices, value=None)
|
| 290 |
|
|
|
|
|
|
|
|
|
|
| 291 |
# craft a planning prompt using the selected idea
|
| 292 |
plan_prompt = f"Plan this idea in detail: {plan_choice}. Additional context: {message}"
|
| 293 |
user_entry = {"role": "user", "content": plan_prompt}
|
|
@@ -296,7 +360,8 @@ def handle_plan_selection(
|
|
| 296 |
|
| 297 |
payload_for_call = _build_request_payload(state, user_message=plan_prompt)
|
| 298 |
try:
|
| 299 |
-
|
|
|
|
| 300 |
bot_content = _format_research_response(research_response)
|
| 301 |
except Exception as exc: # pragma: no cover
|
| 302 |
bot_content = (
|
|
|
|
| 39 |
"source": "gradio",
|
| 40 |
"last_interaction": datetime.now(tz=timezone.utc).isoformat(),
|
| 41 |
},
|
| 42 |
+
"last_ideas": [], # List of idea titles for dropdown
|
| 43 |
+
"last_ideas_full": {}, # Dict mapping title -> full idea details
|
| 44 |
}
|
| 45 |
|
| 46 |
|
|
|
|
| 149 |
|
| 150 |
sections: List[str] = []
|
| 151 |
for idx, plan in enumerate(response.plans, start=1):
|
| 152 |
+
parts = [f"### Plan {idx}: {plan.title}\n"]
|
| 153 |
+
|
| 154 |
+
# Overview section (use overview if available, else summary)
|
| 155 |
+
overview = getattr(plan, 'overview', '') or plan.summary
|
| 156 |
+
if overview:
|
| 157 |
+
parts.append(f"{overview}\n")
|
| 158 |
+
|
| 159 |
+
# Budget and duration
|
| 160 |
+
parts.append(
|
| 161 |
+
f"\n**Budget:** {plan.estimated_budget or 'Varies'} · "
|
| 162 |
+
f"**Duration:** {plan.duration or 'TBD'}\n"
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
+
|
| 165 |
+
# Activities (if available)
|
| 166 |
+
activities = getattr(plan, 'activities', [])
|
| 167 |
+
if activities:
|
| 168 |
+
parts.append("\n**Suggested Activities**")
|
| 169 |
+
for activity in activities:
|
| 170 |
+
parts.append(f"\n• {activity}")
|
| 171 |
+
parts.append("\n")
|
| 172 |
+
|
| 173 |
+
# Steps
|
| 174 |
+
if plan.steps:
|
| 175 |
+
parts.append("\n**Steps**")
|
| 176 |
+
for step in plan.steps:
|
| 177 |
+
parts.append(f"\n→ {step}")
|
| 178 |
+
parts.append("\n")
|
| 179 |
+
|
| 180 |
+
# Shopping list (specific products)
|
| 181 |
+
shopping = getattr(plan, 'shopping_list', [])
|
| 182 |
+
if shopping:
|
| 183 |
+
parts.append("\n**Shopping List**")
|
| 184 |
+
for item in shopping:
|
| 185 |
+
price_str = f" - {item.price}" if getattr(item, 'price', None) else ""
|
| 186 |
+
parts.append(f"\n• [{item.name}]({item.url}){price_str}")
|
| 187 |
+
parts.append("\n")
|
| 188 |
+
|
| 189 |
+
# Sample invitation
|
| 190 |
+
invitation = getattr(plan, 'sample_invitation', '')
|
| 191 |
+
if invitation:
|
| 192 |
+
parts.append(f"\n**Sample Invitation**\n```\n{invitation}\n```\n")
|
| 193 |
+
|
| 194 |
+
# Combined references (title + URL)
|
| 195 |
+
references = getattr(plan, 'references', [])
|
| 196 |
+
if references:
|
| 197 |
+
parts.append("\n**References**")
|
| 198 |
+
for ref in references:
|
| 199 |
+
title = getattr(ref, 'title', 'Link')
|
| 200 |
+
url = getattr(ref, 'url', '')
|
| 201 |
+
if url:
|
| 202 |
+
parts.append(f"\n• [{title}]({url})")
|
| 203 |
+
parts.append("\n")
|
| 204 |
+
elif plan.sources or plan.links:
|
| 205 |
+
# Fallback to legacy format - combine sources and links
|
| 206 |
+
parts.append("\n**References**")
|
| 207 |
+
for i, link in enumerate(plan.links or []):
|
| 208 |
+
title = plan.sources[i] if i < len(plan.sources or []) else "Link"
|
| 209 |
+
parts.append(f"\n• [{title}]({link})")
|
| 210 |
+
parts.append("\n")
|
| 211 |
+
|
| 212 |
+
sections.append("".join(parts))
|
| 213 |
|
| 214 |
note = response.metadata.get("reason") if response.metadata else None
|
| 215 |
if note:
|
|
|
|
| 278 |
payload_for_call = _build_request_payload(state, user_message=message or "")
|
| 279 |
try:
|
| 280 |
ideation_response = orchestrator.generate_ideas(payload_for_call)
|
| 281 |
+
# Store both titles (for dropdown) and full details (for research)
|
| 282 |
state["last_ideas"] = [idea.title for idea in ideation_response.ideas]
|
| 283 |
+
state["last_ideas_full"] = {
|
| 284 |
+
idea.title: {
|
| 285 |
+
"title": idea.title,
|
| 286 |
+
"summary": idea.summary,
|
| 287 |
+
"highlights": idea.highlights,
|
| 288 |
+
"estimated_budget": idea.estimated_budget,
|
| 289 |
+
"effort": idea.effort,
|
| 290 |
+
"next_steps": idea.next_steps,
|
| 291 |
+
"persona_fit": idea.persona_fit,
|
| 292 |
+
}
|
| 293 |
+
for idea in ideation_response.ideas
|
| 294 |
+
}
|
| 295 |
bot_content = _format_ideation_response(ideation_response)
|
| 296 |
except Exception as exc: # pragma: no cover
|
| 297 |
bot_content = (
|
|
|
|
| 349 |
choices = state.get("last_ideas") or []
|
| 350 |
return chat_history, state, _build_request_payload(state), state["session_id"], gr.update(choices=choices, value=None)
|
| 351 |
|
| 352 |
+
# Get full idea context if available
|
| 353 |
+
idea_context = state.get("last_ideas_full", {}).get(plan_choice)
|
| 354 |
+
|
| 355 |
# craft a planning prompt using the selected idea
|
| 356 |
plan_prompt = f"Plan this idea in detail: {plan_choice}. Additional context: {message}"
|
| 357 |
user_entry = {"role": "user", "content": plan_prompt}
|
|
|
|
| 360 |
|
| 361 |
payload_for_call = _build_request_payload(state, user_message=plan_prompt)
|
| 362 |
try:
|
| 363 |
+
# Pass full idea context to research agent
|
| 364 |
+
research_response = research_orchestrator.generate_plan(payload_for_call, idea_context=idea_context)
|
| 365 |
bot_content = _format_research_response(research_response)
|
| 366 |
except Exception as exc: # pragma: no cover
|
| 367 |
bot_content = (
|