Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| import google.generativeai as genai | |
| import logging | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List | |
| # --- 配置 --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 从环境变量获取 API 密钥和后端 URL | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| SEARCH_API_BASE_URL = os.getenv("SEARCH_API_BASE_URL") | |
| # 配置 Google Gemini | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| # 使用最新的 Flash 模型,性价比高 | |
| gemini_model = genai.GenerativeModel('gemini-2.5-flash') | |
| # --- FastAPI 应用设置 --- | |
| app = FastAPI( | |
| title="AI Search Agent", | |
| description="一个使用 Gemini-2.5-Flash 将自然语言转换为学术搜索查询的智能中间层。", | |
| version="1.0.0" | |
| ) | |
| # 配置 CORS,允许您的前端 Space 访问 | |
| # TODO: 为了安全,您应该将 "*" 替换为您的前端 Space 的 URL | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # 允许所有来源 | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- 数据模型 --- | |
| class SearchRequest(BaseModel): | |
| platform: str | |
| query: str | |
| max_results: int = 10 | |
| # --- 核心 AI 功能 --- | |
| async def get_ai_keywords(natural_language_query: str) -> str: | |
| """ | |
| 使用 Gemini 将自然语言查询转换为优化的布尔逻辑搜索关键词。 | |
| """ | |
| if not GEMINI_API_KEY: | |
| logger.warning("GEMINI_API_KEY 未设置,将使用原始查询。") | |
| return natural_language_query | |
| # 精心设计的 Prompt | |
| prompt = f""" | |
| You are an expert academic researcher. Your task is to convert a user's natural language query into a highly effective, concise, boolean-logic keyword string for searching academic databases like PubMed. | |
| - Use boolean operators like AND, OR. | |
| - Use parentheses for grouping. | |
| - Focus on core concepts. | |
| - Keep the string concise and in English. | |
| - Do not add any explanation, markdown, or quotation marks. Just return the pure keyword string. | |
| User Query: "{natural_language_query}" | |
| Keyword String: | |
| """ | |
| try: | |
| logger.info(f"向 Gemini 发送请求,查询: '{natural_language_query}'") | |
| response = await gemini_model.generate_content_async(prompt) | |
| optimized_query = response.text.strip() | |
| logger.info(f"原始查询: '{natural_language_query}' -> Gemini 优化关键词: '{optimized_query}'") | |
| # 如果AI返回空,则回退到原始查询 | |
| if not optimized_query: | |
| logger.warning("Gemini 返回空结果,回退到原始查询。") | |
| return natural_language_query | |
| return optimized_query | |
| except Exception as e: | |
| logger.error(f"调用 Gemini API 失败: {e}") | |
| # 如果AI调用失败,就回退到使用原始查询 | |
| return natural_language_query | |
| # --- API 端点 --- | |
| def read_root(): | |
| return {"status": "AI Search Agent is running"} | |
| async def intelligent_search(request: SearchRequest): | |
| """ | |
| 接收前端请求,进行 AI 优化,然后代理到搜索后端。 | |
| """ | |
| if not SEARCH_API_BASE_URL: | |
| raise HTTPException(status_code=500, detail="SEARCH_API_BASE_URL 未配置") | |
| # 1. 使用 Gemini 优化查询 | |
| optimized_query = await get_ai_keywords(request.query) | |
| # 2. 准备发往 `paper-mcp-agent` 的请求体 | |
| search_payload = { | |
| "platform": request.platform, | |
| "query": optimized_query, # 使用优化后的查询 | |
| "max_results": request.max_results | |
| } | |
| # 3. 调用 `paper-mcp-agent` 搜索后端 | |
| try: | |
| logger.info(f"向搜索后端发送请求: {search_payload}") | |
| search_url = f"{SEARCH_API_BASE_URL}/search" | |
| response = requests.post(search_url, json=search_payload, timeout=30) | |
| response.raise_for_status() # 如果状态码不是 2xx,则抛出异常 | |
| search_results = response.json() | |
| # 在最终结果中包含原始查询和优化后的查询,便于调试 | |
| search_results['original_query'] = request.query | |
| search_results['optimized_query'] = optimized_query | |
| return search_results | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"调用搜索后端失败: {e}") | |
| raise HTTPException(status_code=503, detail=f"无法连接到搜索服务: {str(e)}") | |
| except Exception as e: | |
| logger.error(f"处理搜索时发生未知错误: {e}") | |
| raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}") |