Spaces:
Running
Running
| from fastapi import FastAPI, Depends | |
| from pydantic import BaseModel, Field | |
| from typing import Dict, List, Optional | |
| import models | |
| from database import engine, SessionLocal | |
| from sqlalchemy.orm import Session # ํ์ import | |
| from models import Article, AnalysisResult | |
| from crossref_model import get_crossref_score_and_reason | |
| from mismatch_model import calculate_mismatch_score # <-- ํ์ ํจ์ ๋ถ๋ฌ์ค๊ธฐ | |
| from aggro_model import get_aggro_score | |
| models.Base.metadata.create_all(bind=engine) | |
| from database import SessionLocal # database.py์์ ์ ์ํ SessionLocal ๊ฐ์ ธ์ค๊ธฐ | |
| # DB ์ธ์ ์ ์ด๊ณ ๋ซ๋ Dependency ํจ์ ์ ์ | |
| def get_db(): | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| # --- 1. API ๋ช ์ธ์ (Request/Response) --- | |
| # Pydantic ๋ชจ๋ธ, API ๋ช ์ธ ์ด์ | |
| class ArticleRequest(BaseModel): | |
| """ํ๋ก ํธ๊ฐ ๋ณด๋ผ ์์ฒญ ํ์""" | |
| article_title: str = Field(..., example="๊ธฐ์ฌ ์ ๋ชฉ์ด ์ฌ๊ธฐ ๋ค์ด๊ฐ๋๋ค") | |
| article_body: str = Field(..., example="๊ธฐ์ฌ ๋ณธ๋ฌธ ํ ์คํธ...") | |
| class FoundURL(BaseModel): | |
| """SBERT๊ฐ ๊ฒ์ฆํ URL ๊ฐ์ฒด""" | |
| url: str | |
| similarity: float = Field(..., example=0.85) | |
| class ScoreBreakdown(BaseModel): | |
| """๊ฐ๋ณ ์ ์ ์์ธ๋ด์ญ ํ์""" | |
| score: float = Field(..., example=0.95) | |
| reason: str = Field(..., example="'์ถฉ๊ฒฉ' ํค์๋ ์ฌ์ฉ") | |
| recommendation: str = Field(..., example="์ค๋ฆฝ์ ์ธ ๋จ์ด๋ก ์์ ํ์ธ์.") | |
| found_urls: Optional[List[FoundURL]] = None | |
| class AnalysisResponse(BaseModel): | |
| """๋ฐฑ์๋->ํ๋ก ํธ ์ต์ข ์๋ต ํ์""" | |
| final_risk_score: float = Field(..., example=0.82) | |
| final_risk_level: str = Field(..., example="์ํ") | |
| breakdown: Dict[str, ScoreBreakdown] | |
| # --- 2. FastAPI ์ฑ ์์ฑ --- | |
| app = FastAPI() | |
| # --- 3. API ์๋ํฌ์ธํธ --- | |
| def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)): | |
| """๊ธฐ์ฌ ๋ถ์ API""" | |
| # 1. AggroScore | |
| aggro_result = get_aggro_score(request.article_title) | |
| real_aggro = ScoreBreakdown( | |
| score=aggro_result["score"], | |
| reason=aggro_result["reason"], | |
| recommendation=aggro_result["recommendation"], | |
| found_urls=None | |
| ) | |
| # 2. mismatch | |
| mismatch_result = calculate_mismatch_score(request.article_title, request.article_body) | |
| real_mismatch = ScoreBreakdown( | |
| score=mismatch_result["score"], | |
| reason=mismatch_result["reason"], | |
| recommendation=mismatch_result["recommendation"], | |
| found_urls=None | |
| ) | |
| # 3. crossref | |
| real_crossref_data = get_crossref_score_and_reason(request.article_body) | |
| SIMILARITY_THRESHOLD = 0.7 # 70% ์ด์ ์ผ์นํ๋ ๊ฒ๋ง ๋ณด์ฌ์ฃผ๊ธฐ | |
| # ์ ์ฌ๋๊ฐ ๋์ ์์ผ๋ก ์ ๋ ฌ | |
| sorted_urls = sorted( | |
| real_crossref_data["paired_results"], | |
| key=lambda x: x["similarity"], | |
| reverse=True | |
| ) | |
| # ์๊ณ๊ฐ(THRESHOLD) ์ด์์ URL๋ง ํํฐ๋ง | |
| filtered_urls = [ | |
| FoundURL(url=item["url"], similarity=item["similarity"]) | |
| for item in sorted_urls | |
| if item["similarity"] >= SIMILARITY_THRESHOLD | |
| ] | |
| # ์ต์ข CrossrefScore ๊ฐ์ฒด ์์ฑ (ํํฐ๋ง๋ URL ํฌํจ) | |
| final_crossref = ScoreBreakdown( | |
| score=real_crossref_data["score"], | |
| reason=real_crossref_data["reason"], | |
| recommendation=real_crossref_data["recommendation"], | |
| found_urls=filtered_urls | |
| ) | |
| # ์ต์ข ์ํ๋ ๊ณ์ฐ | |
| w_aggro = 0.4 | |
| w_mismatch = 0.4 | |
| w_crossref = 0.2 | |
| final_score = (real_aggro.score * w_aggro) + \ | |
| (real_mismatch.score * w_mismatch) + \ | |
| (final_crossref.score * w_crossref) | |
| final_level = "์์ " | |
| if final_score >= 0.7: | |
| final_level = "์ํ" | |
| elif final_score >= 0.4: | |
| final_level = "์ฃผ์" | |
| # ------------------------------------------------ | |
| # ๐ [ํต์ฌ ์ถ๊ฐ] DB ์ ์ฅ ๋ก์ง ์์ | |
| # ------------------------------------------------ | |
| # 1. 'articles' ํ ์ด๋ธ์ ๊ธฐ์ฌ ์ ์ฅ | |
| new_article = Article( | |
| title=request.article_title, | |
| body=request.article_body, | |
| source="Swagger UI Test" # ํ ์คํธ์ฉ ์ถ์ฒ ์ ๋ ฅ | |
| ) | |
| db.add(new_article) | |
| # 2. article_id๋ฅผ ์ป๊ธฐ ์ํด Flush (์์ง commit์ ํ์ง ์์) | |
| db.flush() | |
| # 3. 'analysis_results' ํ ์ด๋ธ์ ๋ถ์ ๊ฒฐ๊ณผ ์ ์ฅ | |
| new_result = AnalysisResult( | |
| article_id=new_article.article_id, # ์ธ๋ ํค ์ฐ๊ฒฐ | |
| aggro_score=real_aggro.score, | |
| mismatch_score=real_mismatch.score, | |
| crossref_score=final_crossref.score, | |
| final_risk=final_score, | |
| ) | |
| db.add(new_result) | |
| # 4. ๋ชจ๋ ๋ณ๊ฒฝ ์ฌํญ์ DB์ ์๊ตฌ ์ ์ฅ (ํธ๋์ญ์ ์๋ฃ) | |
| db.commit() | |
| # ------------------------------------------------ | |
| # API ๋ช ์ธ์(AnalysisResponse) ํ์์ ๋ง์ถฐ์ ๋ฐํ | |
| return AnalysisResponse( | |
| final_risk_score=round(final_score, 2), # ์์์ 2์๋ฆฌ๋ก ๋ฐ์ฌ๋ฆผ | |
| final_risk_level=final_level, | |
| breakdown={ | |
| "aggro_score": real_aggro, | |
| "mismatch_score": real_mismatch, | |
| "crossref_score": final_crossref | |
| } | |
| ) | |
| def read_root(): | |
| return {"message": "AI ๊ธฐ์ฌ ๋ถ์ ์๋ฒ"} | |