project-tdm / main.py
hy
Fix server code and sync with remote
7134b06
raw
history blame
5.35 kB
from fastapi import FastAPI, Depends
from pydantic import BaseModel, Field
from typing import Dict, List, Optional
import models
from database import engine, SessionLocal
from sqlalchemy.orm import Session # ํ•„์ˆ˜ import
from models import Article, AnalysisResult
from crossref_model import get_crossref_score_and_reason
from mismatch_model import calculate_mismatch_score # <-- ํŒ€์› ํ•จ์ˆ˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
from aggro_model import get_aggro_score
models.Base.metadata.create_all(bind=engine)
from database import SessionLocal # database.py์—์„œ ์ •์˜ํ•œ SessionLocal ๊ฐ€์ ธ์˜ค๊ธฐ
# DB ์„ธ์…˜์„ ์—ด๊ณ  ๋‹ซ๋Š” Dependency ํ•จ์ˆ˜ ์ •์˜
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# --- 1. API ๋ช…์„ธ์„œ (Request/Response) ---
# Pydantic ๋ชจ๋ธ, API ๋ช…์„ธ ์ดˆ์•ˆ
class ArticleRequest(BaseModel):
"""ํ”„๋ก ํŠธ๊ฐ€ ๋ณด๋‚ผ ์š”์ฒญ ํ˜•์‹"""
article_title: str = Field(..., example="๊ธฐ์‚ฌ ์ œ๋ชฉ์ด ์—ฌ๊ธฐ ๋“ค์–ด๊ฐ‘๋‹ˆ๋‹ค")
article_body: str = Field(..., example="๊ธฐ์‚ฌ ๋ณธ๋ฌธ ํ…์ŠคํŠธ...")
class FoundURL(BaseModel):
"""SBERT๊ฐ€ ๊ฒ€์ฆํ•œ URL ๊ฐ์ฒด"""
url: str
similarity: float = Field(..., example=0.85)
class ScoreBreakdown(BaseModel):
"""๊ฐœ๋ณ„ ์ ์ˆ˜ ์ƒ์„ธ๋‚ด์—ญ ํ˜•์‹"""
score: float = Field(..., example=0.95)
reason: str = Field(..., example="'์ถฉ๊ฒฉ' ํ‚ค์›Œ๋“œ ์‚ฌ์šฉ")
recommendation: str = Field(..., example="์ค‘๋ฆฝ์ ์ธ ๋‹จ์–ด๋กœ ์ˆ˜์ •ํ•˜์„ธ์š”.")
found_urls: Optional[List[FoundURL]] = None
class AnalysisResponse(BaseModel):
"""๋ฐฑ์—”๋“œ->ํ”„๋ก ํŠธ ์ตœ์ข… ์‘๋‹ต ํ˜•์‹"""
final_risk_score: float = Field(..., example=0.82)
final_risk_level: str = Field(..., example="์œ„ํ—˜")
breakdown: Dict[str, ScoreBreakdown]
# --- 2. FastAPI ์•ฑ ์ƒ์„ฑ ---
app = FastAPI()
# --- 3. API ์—”๋“œํฌ์ธํŠธ ---
@app.post("/api/v1/analyze", response_model=AnalysisResponse)
def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
"""๊ธฐ์‚ฌ ๋ถ„์„ API"""
# 1. AggroScore
aggro_result = get_aggro_score(request.article_title)
real_aggro = ScoreBreakdown(
score=aggro_result["score"],
reason=aggro_result["reason"],
recommendation=aggro_result["recommendation"],
found_urls=None
)
# 2. mismatch
mismatch_result = calculate_mismatch_score(request.article_title, request.article_body)
real_mismatch = ScoreBreakdown(
score=mismatch_result["score"],
reason=mismatch_result["reason"],
recommendation=mismatch_result["recommendation"],
found_urls=None
)
# 3. crossref
real_crossref_data = get_crossref_score_and_reason(request.article_body)
SIMILARITY_THRESHOLD = 0.7 # 70% ์ด์ƒ ์ผ์น˜ํ•˜๋Š” ๊ฒƒ๋งŒ ๋ณด์—ฌ์ฃผ๊ธฐ
# ์œ ์‚ฌ๋„๊ฐ€ ๋†’์€ ์ˆœ์œผ๋กœ ์ •๋ ฌ
sorted_urls = sorted(
real_crossref_data["paired_results"],
key=lambda x: x["similarity"],
reverse=True
)
# ์ž„๊ณ„๊ฐ’(THRESHOLD) ์ด์ƒ์˜ URL๋งŒ ํ•„ํ„ฐ๋ง
filtered_urls = [
FoundURL(url=item["url"], similarity=item["similarity"])
for item in sorted_urls
if item["similarity"] >= SIMILARITY_THRESHOLD
]
# ์ตœ์ข… CrossrefScore ๊ฐ์ฒด ์ƒ์„ฑ (ํ•„ํ„ฐ๋ง๋œ URL ํฌํ•จ)
final_crossref = ScoreBreakdown(
score=real_crossref_data["score"],
reason=real_crossref_data["reason"],
recommendation=real_crossref_data["recommendation"],
found_urls=filtered_urls
)
# ์ตœ์ข… ์œ„ํ—˜๋„ ๊ณ„์‚ฐ
w_aggro = 0.4
w_mismatch = 0.4
w_crossref = 0.2
final_score = (real_aggro.score * w_aggro) + \
(real_mismatch.score * w_mismatch) + \
(final_crossref.score * w_crossref)
final_level = "์•ˆ์ „"
if final_score >= 0.7:
final_level = "์œ„ํ—˜"
elif final_score >= 0.4:
final_level = "์ฃผ์˜"
# ------------------------------------------------
# ๐Ÿ›‘ [ํ•ต์‹ฌ ์ถ”๊ฐ€] DB ์ €์žฅ ๋กœ์ง ์‹œ์ž‘
# ------------------------------------------------
# 1. 'articles' ํ…Œ์ด๋ธ”์— ๊ธฐ์‚ฌ ์ €์žฅ
new_article = Article(
title=request.article_title,
body=request.article_body,
source="Swagger UI Test" # ํ…Œ์ŠคํŠธ์šฉ ์ถœ์ฒ˜ ์ž…๋ ฅ
)
db.add(new_article)
# 2. article_id๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด Flush (์•„์ง commit์€ ํ•˜์ง€ ์•Š์Œ)
db.flush()
# 3. 'analysis_results' ํ…Œ์ด๋ธ”์— ๋ถ„์„ ๊ฒฐ๊ณผ ์ €์žฅ
new_result = AnalysisResult(
article_id=new_article.article_id, # ์™ธ๋ž˜ ํ‚ค ์—ฐ๊ฒฐ
aggro_score=real_aggro.score,
mismatch_score=real_mismatch.score,
crossref_score=final_crossref.score,
final_risk=final_score,
)
db.add(new_result)
# 4. ๋ชจ๋“  ๋ณ€๊ฒฝ ์‚ฌํ•ญ์„ DB์— ์˜๊ตฌ ์ €์žฅ (ํŠธ๋žœ์žญ์…˜ ์™„๋ฃŒ)
db.commit()
# ------------------------------------------------
# API ๋ช…์„ธ์„œ(AnalysisResponse) ํ˜•์‹์— ๋งž์ถฐ์„œ ๋ฐ˜ํ™˜
return AnalysisResponse(
final_risk_score=round(final_score, 2), # ์†Œ์ˆ˜์  2์ž๋ฆฌ๋กœ ๋ฐ˜์˜ฌ๋ฆผ
final_risk_level=final_level,
breakdown={
"aggro_score": real_aggro,
"mismatch_score": real_mismatch,
"crossref_score": final_crossref
}
)
@app.get("/")
def read_root():
return {"message": "AI ๊ธฐ์‚ฌ ๋ถ„์„ ์„œ๋ฒ„"}