Spaces:
Sleeping
Sleeping
hy
commited on
Commit
ยท
7134b06
1
Parent(s):
7f4feea
Fix server code and sync with remote
Browse files- .gitignore +3 -1
- aggro_model.py +188 -0
- database.py +5 -1
- main.py +24 -19
- mismatch_model.py +127 -0
- requirements.txt +0 -0
.gitignore
CHANGED
|
@@ -3,4 +3,6 @@ __pycache__/
|
|
| 3 |
*.pyc
|
| 4 |
.venv/
|
| 5 |
venv/
|
| 6 |
-
info.md
|
|
|
|
|
|
|
|
|
| 3 |
*.pyc
|
| 4 |
.venv/
|
| 5 |
venv/
|
| 6 |
+
info.md
|
| 7 |
+
*.pth
|
| 8 |
+
*.pkl
|
aggro_model.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import BertTokenizer, BertModel
|
| 4 |
+
import pickle
|
| 5 |
+
import re
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
# =============================================================================
|
| 10 |
+
# 1. ๋ชจ๋ธ ํด๋์ค ์ ์
|
| 11 |
+
# =============================================================================
|
| 12 |
+
|
| 13 |
+
# (1) ๊ท์น ๊ธฐ๋ฐ ์ค์ฝ์ด๋ฌ ํด๋์ค ๋ณต์
|
| 14 |
+
class RuleBasedScorer:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
# ํ์ผ์์ ๋ก๋๋ ๋ ์ด ๋ณ์๋ค์ด ๋ฎ์ด์์์ง๋๋ค.
|
| 17 |
+
self.patterns = {}
|
| 18 |
+
self.pattern_names = {}
|
| 19 |
+
self.symbol_patterns = {
|
| 20 |
+
'repeated': re.compile(r'([!?โฆ~])\1+'),
|
| 21 |
+
'ellipsis': re.compile(r'\.\.\.|โฆ')
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
def get_score(self, title):
|
| 25 |
+
words = re.findall(r'[๊ฐ-ํฃA-Za-z0-9]+', str(title))
|
| 26 |
+
scores = {}
|
| 27 |
+
|
| 28 |
+
# ๋ถํธ ์ ์
|
| 29 |
+
rep = len(self.symbol_patterns['repeated'].findall(title))
|
| 30 |
+
ell = len(self.symbol_patterns['ellipsis'].findall(title))
|
| 31 |
+
symbol_score = (rep * 30) + (ell * 10)
|
| 32 |
+
|
| 33 |
+
for p in [11, 12, 13, 14]:
|
| 34 |
+
word_score = 0
|
| 35 |
+
# ๋จ์ด ๋งค์นญ ์ ์
|
| 36 |
+
for word in words:
|
| 37 |
+
if word in self.patterns[p]:
|
| 38 |
+
# ๊ฐ์ค์น ์ ์ฉ
|
| 39 |
+
word_score += np.log1p(self.patterns[p][word]) * 2
|
| 40 |
+
|
| 41 |
+
total = 0
|
| 42 |
+
if p == 11: total = symbol_score
|
| 43 |
+
elif p == 12: total = word_score + (symbol_score * 0.5)
|
| 44 |
+
else: total = word_score
|
| 45 |
+
|
| 46 |
+
scores[p] = total
|
| 47 |
+
|
| 48 |
+
# ์ต์ข
์ ์ ์ฐ์ถ
|
| 49 |
+
max_pattern = max(scores, key=scores.get)
|
| 50 |
+
max_score = min(100, scores[max_pattern])
|
| 51 |
+
|
| 52 |
+
return {
|
| 53 |
+
'score': max_score,
|
| 54 |
+
'pattern': max_pattern,
|
| 55 |
+
'pattern_name': self.pattern_names.get(max_pattern, '์ ์ ์์')
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# (2) KoBERT ๋ชจ๋ธ ํด๋์ค ๋ณต์
|
| 59 |
+
class FishingClassifier(nn.Module):
|
| 60 |
+
def __init__(self, bert, num_classes=2):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.bert = bert
|
| 63 |
+
self.dropout = nn.Dropout(0.3)
|
| 64 |
+
self.fc = nn.Linear(768, num_classes)
|
| 65 |
+
def forward(self, input_ids, mask):
|
| 66 |
+
_, pooled = self.bert(input_ids=input_ids, attention_mask=mask, return_dict=False)
|
| 67 |
+
return self.fc(self.dropout(pooled))
|
| 68 |
+
|
| 69 |
+
# =============================================================================
|
| 70 |
+
# 2. ๋ชจ๋ธ ๋ก๋
|
| 71 |
+
# =============================================================================
|
| 72 |
+
print("[AggroModel] ์์คํ
๋ก๋ฉ ์์...")
|
| 73 |
+
|
| 74 |
+
aggro_model = None
|
| 75 |
+
tokenizer = None
|
| 76 |
+
rule_scorer = None
|
| 77 |
+
device = torch.device("cpu") # ์๋ฒ๋ CPU ์ฌ์ฉ
|
| 78 |
+
|
| 79 |
+
# --- A. ๊ท์น ๊ธฐ๋ฐ ๋ชจ๋ธ ๋ก๋ (.pkl) ---
|
| 80 |
+
try:
|
| 81 |
+
pkl_path = "rule_based_scorer.pkl"
|
| 82 |
+
if os.path.exists(pkl_path):
|
| 83 |
+
with open(pkl_path, "rb") as f:
|
| 84 |
+
rule_scorer = pickle.load(f)
|
| 85 |
+
print("โ
[Aggro] ๊ท์น ๊ธฐ๋ฐ ๋ชจ๋ธ(PKL) ๋ก๋ ์ฑ๊ณต!")
|
| 86 |
+
else:
|
| 87 |
+
print(f"โ ๏ธ [Aggro] ๊ท์น ํ์ผ ์์: {pkl_path}")
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"๐จ [Aggro] ๊ท์น ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}")
|
| 90 |
+
|
| 91 |
+
# --- B. KoBERT ๋ชจ๋ธ ๋ก๋ (.pth) ---
|
| 92 |
+
try:
|
| 93 |
+
# 1. ํ ํฌ๋์ด์ & ๊ธฐ๋ณธ ๋ชจ๋ธ ๋ก๋
|
| 94 |
+
# (์๋ฌ ๋๋ฉด 'monologg/kobert'๋ก ๋ณ๊ฒฝ ์๋)
|
| 95 |
+
MODEL_NAME = 'skt/kobert-base-v1'
|
| 96 |
+
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
|
| 97 |
+
bert_base = BertModel.from_pretrained(MODEL_NAME)
|
| 98 |
+
|
| 99 |
+
# 2. ๋ผ๋ ์์ฑ
|
| 100 |
+
aggro_model = FishingClassifier(bert_base).to(device)
|
| 101 |
+
|
| 102 |
+
# 3. ๊ฐ์ค์น ํ์ผ ๋ก๋
|
| 103 |
+
pth_path = "bert_fishing_model_best.pth"
|
| 104 |
+
|
| 105 |
+
if os.path.exists(pth_path):
|
| 106 |
+
state_dict = torch.load(pth_path, map_location=device)
|
| 107 |
+
aggro_model.load_state_dict(state_dict)
|
| 108 |
+
aggro_model.eval()
|
| 109 |
+
print("โ
[Aggro] KoBERT ๋ชจ๋ธ(PTH) ๋ก๋ ์ฑ๊ณต!")
|
| 110 |
+
else:
|
| 111 |
+
print(f"โ ๏ธ [Aggro] KoBERT ํ์ผ ์์: {pth_path}")
|
| 112 |
+
aggro_model = None
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"๐จ [Aggro] KoBERT ๋ก๋ ์คํจ: {e}")
|
| 116 |
+
aggro_model = None
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# =============================================================================
|
| 120 |
+
# 3. ๋ฉ์ธ ํจ์ (main.py์์ ํธ์ถ)
|
| 121 |
+
# =============================================================================
|
| 122 |
+
def get_aggro_score(title: str) -> dict:
|
| 123 |
+
"""
|
| 124 |
+
์ ๋ชฉ์ ๋ฐ์ ๋์์ฑ ์ ์์ ๊ทผ๊ฑฐ๋ฅผ ๋ฐํํ๋ ํจ์
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
# 1. ๊ท์น ๊ธฐ๋ฐ ์ ์
|
| 128 |
+
rule_score = 0.0
|
| 129 |
+
rule_pattern = "๋ถ์ ๋ถ๊ฐ"
|
| 130 |
+
|
| 131 |
+
if rule_scorer:
|
| 132 |
+
try:
|
| 133 |
+
res = rule_scorer.get_score(title)
|
| 134 |
+
rule_score = res['score'] # 0~100์
|
| 135 |
+
rule_pattern = res['pattern_name']
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print(f"[Aggro] ๊ท์น ๊ณ์ฐ ์ค๋ฅ: {e}")
|
| 138 |
+
|
| 139 |
+
# 2. KoBERT ์ ์
|
| 140 |
+
bert_score = 0.0
|
| 141 |
+
if aggro_model and tokenizer:
|
| 142 |
+
try:
|
| 143 |
+
inputs = tokenizer(
|
| 144 |
+
title,
|
| 145 |
+
return_tensors='pt',
|
| 146 |
+
padding="max_length",
|
| 147 |
+
truncation=True,
|
| 148 |
+
max_length=64
|
| 149 |
+
)
|
| 150 |
+
input_ids = inputs['input_ids'].to(device)
|
| 151 |
+
mask = inputs['attention_mask'].to(device)
|
| 152 |
+
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
outputs = aggro_model(input_ids, mask)
|
| 155 |
+
probs = torch.softmax(outputs, dim=1)
|
| 156 |
+
bert_score = probs[0][1].item() * 100 # 0~100
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"[Aggro] KoBERT ์์ธก ์ค๋ฅ: {e}")
|
| 159 |
+
bert_score = 50.0
|
| 160 |
+
|
| 161 |
+
# 3. ์ต์ข
ํฉ์ฐ (๊ท์น 40% + BERT 60%)
|
| 162 |
+
if rule_scorer and aggro_model:
|
| 163 |
+
final_score = (rule_score * 0.4) + (bert_score * 0.6)
|
| 164 |
+
elif aggro_model:
|
| 165 |
+
final_score = bert_score
|
| 166 |
+
elif rule_scorer:
|
| 167 |
+
final_score = rule_score
|
| 168 |
+
else:
|
| 169 |
+
final_score = 0.0
|
| 170 |
+
|
| 171 |
+
# 4. ๊ฒฐ๊ณผ ํ
์คํธ
|
| 172 |
+
normalized_score = min(final_score / 100.0, 1.0) # 0~1 ๋ณํ
|
| 173 |
+
|
| 174 |
+
reason = "์ ๋ชฉ์ด ํ์ดํฉ๋๋ค."
|
| 175 |
+
recommendation = "์ํธํฉ๋๋ค."
|
| 176 |
+
|
| 177 |
+
if final_score >= 70:
|
| 178 |
+
reason = f"AI ์์ธก({bert_score:.0f}์ ) ๋ฐ ๊ท์น({rule_pattern})์ ์ํด ๋์์ฑ์ผ๋ก ํ๋จ๋์์ต๋๋ค."
|
| 179 |
+
recommendation = "๊ฐ๊ด์ ์ธ ์ฌ์ค ์์ฃผ์ ์ ๋ชฉ์ผ๋ก ์์ ํด์ฃผ์ธ์."
|
| 180 |
+
elif final_score >= 40:
|
| 181 |
+
reason = f"์ผ๋ถ ๋์์ฑ ์์({rule_pattern})๊ฐ ๊ฐ์ง๋์์ต๋๋ค."
|
| 182 |
+
recommendation = "ํํ์ ์กฐ๊ธ ๋ ๋ค๋ฌ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค."
|
| 183 |
+
|
| 184 |
+
return {
|
| 185 |
+
"score": round(normalized_score, 2),
|
| 186 |
+
"reason": reason,
|
| 187 |
+
"recommendation": recommendation
|
| 188 |
+
}
|
database.py
CHANGED
|
@@ -8,7 +8,11 @@ SQLALCHEMY_DATABASE_URL = os.environ.get("SQLALCHEMY_DATABASE_URL")
|
|
| 8 |
|
| 9 |
# 2. DB ์ฐ๊ฒฐ ์์ง ์์ฑ
|
| 10 |
engine = create_engine(
|
| 11 |
-
SQLALCHEMY_DATABASE_URL
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
)
|
| 13 |
|
| 14 |
# 3. DB์ ํต์ ํ ์ธ์
(Session) ์์ฑ
|
|
|
|
| 8 |
|
| 9 |
# 2. DB ์ฐ๊ฒฐ ์์ง ์์ฑ
|
| 10 |
engine = create_engine(
|
| 11 |
+
SQLALCHEMY_DATABASE_URL,
|
| 12 |
+
pool_pre_ping=True, # <--- (1) ์ฐ๊ฒฐํ๊ธฐ ์ ์ Ping
|
| 13 |
+
pool_recycle=300, # <--- (2) 5๋ถ(300์ด)๋ง๋ค ์ฐ๊ฒฐ์ ์๊ฒ์ผ๋ก ๊ต์ฒด (์ค๋๋ ์ฐ๊ฒฐ ๋๊ธฐ ๋ฐฉ์ง)
|
| 14 |
+
pool_size=5, # <--- (3) ๋์์ ์ ์งํ ์ฐ๊ฒฐ ๊ฐ์
|
| 15 |
+
max_overflow=10 # <--- (4) ๊ฐ์๊ธฐ ๋ชฐ๋ฆด ๋ ์ถ๊ฐ๋ก ํ์ฉํ ์ฐ๊ฒฐ ๊ฐ์
|
| 16 |
)
|
| 17 |
|
| 18 |
# 3. DB์ ํต์ ํ ์ธ์
(Session) ์์ฑ
|
main.py
CHANGED
|
@@ -6,7 +6,8 @@ from database import engine, SessionLocal
|
|
| 6 |
from sqlalchemy.orm import Session # ํ์ import
|
| 7 |
from models import Article, AnalysisResult
|
| 8 |
from crossref_model import get_crossref_score_and_reason
|
| 9 |
-
|
|
|
|
| 10 |
|
| 11 |
models.Base.metadata.create_all(bind=engine)
|
| 12 |
from database import SessionLocal # database.py์์ ์ ์ํ SessionLocal ๊ฐ์ ธ์ค๊ธฐ
|
|
@@ -51,21 +52,25 @@ app = FastAPI()
|
|
| 51 |
|
| 52 |
@app.post("/api/v1/analyze", response_model=AnalysisResponse)
|
| 53 |
def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
|
| 54 |
-
"""๊ธฐ์ฌ ๋ถ์ API
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
recommendation="์ ๋ชฉ์ ์๊ทน์ ์ธ ํํ์ ์์ ํ์ธ์.",
|
| 61 |
found_urls=None
|
| 62 |
)
|
| 63 |
|
| 64 |
-
# 2. mismatch
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
found_urls=None
|
| 70 |
)
|
| 71 |
|
|
@@ -97,8 +102,8 @@ def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
|
|
| 97 |
w_mismatch = 0.4
|
| 98 |
w_crossref = 0.2
|
| 99 |
|
| 100 |
-
final_score = (
|
| 101 |
-
(
|
| 102 |
(final_crossref.score * w_crossref)
|
| 103 |
|
| 104 |
final_level = "์์ "
|
|
@@ -124,8 +129,8 @@ def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
|
|
| 124 |
# 3. 'analysis_results' ํ
์ด๋ธ์ ๋ถ์ ๊ฒฐ๊ณผ ์ ์ฅ
|
| 125 |
new_result = AnalysisResult(
|
| 126 |
article_id=new_article.article_id, # ์ธ๋ ํค ์ฐ๊ฒฐ
|
| 127 |
-
aggro_score=
|
| 128 |
-
mismatch_score=
|
| 129 |
crossref_score=final_crossref.score,
|
| 130 |
final_risk=final_score,
|
| 131 |
)
|
|
@@ -140,12 +145,12 @@ def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
|
|
| 140 |
final_risk_score=round(final_score, 2), # ์์์ 2์๋ฆฌ๋ก ๋ฐ์ฌ๋ฆผ
|
| 141 |
final_risk_level=final_level,
|
| 142 |
breakdown={
|
| 143 |
-
"aggro_score":
|
| 144 |
-
"mismatch_score":
|
| 145 |
"crossref_score": final_crossref
|
| 146 |
}
|
| 147 |
)
|
| 148 |
|
| 149 |
@app.get("/")
|
| 150 |
def read_root():
|
| 151 |
-
return {"message": "AI ๊ธฐ์ฌ ๋ถ์ ์๋ฒ
|
|
|
|
| 6 |
from sqlalchemy.orm import Session # ํ์ import
|
| 7 |
from models import Article, AnalysisResult
|
| 8 |
from crossref_model import get_crossref_score_and_reason
|
| 9 |
+
from mismatch_model import calculate_mismatch_score # <-- ํ์ ํจ์ ๋ถ๋ฌ์ค๊ธฐ
|
| 10 |
+
from aggro_model import get_aggro_score
|
| 11 |
|
| 12 |
models.Base.metadata.create_all(bind=engine)
|
| 13 |
from database import SessionLocal # database.py์์ ์ ์ํ SessionLocal ๊ฐ์ ธ์ค๊ธฐ
|
|
|
|
| 52 |
|
| 53 |
@app.post("/api/v1/analyze", response_model=AnalysisResponse)
|
| 54 |
def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
|
| 55 |
+
"""๊ธฐ์ฌ ๋ถ์ API"""
|
| 56 |
+
|
| 57 |
+
# 1. AggroScore
|
| 58 |
+
aggro_result = get_aggro_score(request.article_title)
|
| 59 |
|
| 60 |
+
real_aggro = ScoreBreakdown(
|
| 61 |
+
score=aggro_result["score"],
|
| 62 |
+
reason=aggro_result["reason"],
|
| 63 |
+
recommendation=aggro_result["recommendation"],
|
|
|
|
| 64 |
found_urls=None
|
| 65 |
)
|
| 66 |
|
| 67 |
+
# 2. mismatch
|
| 68 |
+
mismatch_result = calculate_mismatch_score(request.article_title, request.article_body)
|
| 69 |
+
|
| 70 |
+
real_mismatch = ScoreBreakdown(
|
| 71 |
+
score=mismatch_result["score"],
|
| 72 |
+
reason=mismatch_result["reason"],
|
| 73 |
+
recommendation=mismatch_result["recommendation"],
|
| 74 |
found_urls=None
|
| 75 |
)
|
| 76 |
|
|
|
|
| 102 |
w_mismatch = 0.4
|
| 103 |
w_crossref = 0.2
|
| 104 |
|
| 105 |
+
final_score = (real_aggro.score * w_aggro) + \
|
| 106 |
+
(real_mismatch.score * w_mismatch) + \
|
| 107 |
(final_crossref.score * w_crossref)
|
| 108 |
|
| 109 |
final_level = "์์ "
|
|
|
|
| 129 |
# 3. 'analysis_results' ํ
์ด๋ธ์ ๋ถ์ ๊ฒฐ๊ณผ ์ ์ฅ
|
| 130 |
new_result = AnalysisResult(
|
| 131 |
article_id=new_article.article_id, # ์ธ๋ ํค ์ฐ๊ฒฐ
|
| 132 |
+
aggro_score=real_aggro.score,
|
| 133 |
+
mismatch_score=real_mismatch.score,
|
| 134 |
crossref_score=final_crossref.score,
|
| 135 |
final_risk=final_score,
|
| 136 |
)
|
|
|
|
| 145 |
final_risk_score=round(final_score, 2), # ์์์ 2์๋ฆฌ๋ก ๋ฐ์ฌ๋ฆผ
|
| 146 |
final_risk_level=final_level,
|
| 147 |
breakdown={
|
| 148 |
+
"aggro_score": real_aggro,
|
| 149 |
+
"mismatch_score": real_mismatch,
|
| 150 |
"crossref_score": final_crossref
|
| 151 |
}
|
| 152 |
)
|
| 153 |
|
| 154 |
@app.get("/")
|
| 155 |
def read_root():
|
| 156 |
+
return {"message": "AI ๊ธฐ์ฌ ๋ถ์ ์๋ฒ"}
|
mismatch_model.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
| 4 |
+
from sentence_transformers import SentenceTransformer, util
|
| 5 |
+
|
| 6 |
+
# ๋๋ฐ์ด์ค ์ค์ (GPU ์ฐ์ , ์์ผ๋ฉด CPU)
|
| 7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
+
print(f"โ
ํ์ฌ ์คํ ํ๊ฒฝ: {device}")
|
| 9 |
+
|
| 10 |
+
# =============================================================================
|
| 11 |
+
# 2. ๋ชจ๋ธ ๋ก๋ (์๊ฐ์ด ์กฐ๊ธ ๊ฑธ๋ฆด ์ ์์ต๋๋ค)
|
| 12 |
+
# =============================================================================
|
| 13 |
+
print("\nโณ [1/3] KoBART ์์ฝ ๋ชจ๋ธ ๋ก๋ฉ ์ค...")
|
| 14 |
+
kobart_summarizer = pipeline(
|
| 15 |
+
"summarization",
|
| 16 |
+
model="gogamza/kobart-summarization",
|
| 17 |
+
device=0 if torch.cuda.is_available() else -1
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
print("โณ [2/3] SBERT ์ ์ฌ๋ ๋ชจ๋ธ ๋ก๋ฉ ์ค...")
|
| 21 |
+
sbert_model = SentenceTransformer('jhgan/ko-sroberta-multitask')
|
| 22 |
+
|
| 23 |
+
print("โณ [3/3] NLI(๋ชจ์ ํ์ง) ๋ชจ๋ธ ๋ก๋ฉ ์ค...")
|
| 24 |
+
nli_model_name = "Huffon/klue-roberta-base-nli"
|
| 25 |
+
nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
|
| 26 |
+
nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device)
|
| 27 |
+
|
| 28 |
+
print("๐ ๋ชจ๋ ๋ชจ๋ธ ๋ก๋ ์๋ฃ!\n")
|
| 29 |
+
|
| 30 |
+
# =============================================================================
|
| 31 |
+
# 3. ๋์ฐ๋ฏธ ํจ์ ์ ์ (Worker Functions)
|
| 32 |
+
# =============================================================================
|
| 33 |
+
|
| 34 |
+
def summarize_kobart_strict(text):
|
| 35 |
+
"""KoBART๋ฅผ ์ฌ์ฉํ์ฌ ๋ณธ๋ฌธ์ ์์ฝํฉ๋๋ค."""
|
| 36 |
+
# ๋ณธ๋ฌธ์ด ๋๋ฌด ์งง์ผ๋ฉด ์์ฝ ์๋ต (์ค๋ฅ ๋ฐฉ์ง)
|
| 37 |
+
if len(text) < 50:
|
| 38 |
+
return text
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
result = kobart_summarizer(
|
| 42 |
+
text,
|
| 43 |
+
min_length=15,
|
| 44 |
+
max_length=128,
|
| 45 |
+
num_beams=4,
|
| 46 |
+
no_repeat_ngram_size=3,
|
| 47 |
+
early_stopping=True
|
| 48 |
+
)[0]['summary_text']
|
| 49 |
+
return result.strip()
|
| 50 |
+
except Exception as e:
|
| 51 |
+
return text[:100] # ์คํจ ์ ์๋ถ๋ถ ๋ฐํ
|
| 52 |
+
|
| 53 |
+
def get_cosine_similarity(title, summary):
|
| 54 |
+
"""SBERT๋ก ์ ๋ชฉ๊ณผ ์์ฝ๋ฌธ์ ์ฝ์ฌ์ธ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค."""
|
| 55 |
+
emb1 = sbert_model.encode(title, convert_to_tensor=True)
|
| 56 |
+
emb2 = sbert_model.encode(summary, convert_to_tensor=True)
|
| 57 |
+
return util.cos_sim(emb1, emb2).item()
|
| 58 |
+
|
| 59 |
+
def get_mismatch_score(summary, title):
|
| 60 |
+
"""NLI ๋ชจ๋ธ๋ก ์์ฝ๋ฌธ(์ ์ )๊ณผ ์ ๋ชฉ(๊ฐ์ค) ์ฌ์ด์ ๋ชจ์ ํ๋ฅ ์ ๊ณ์ฐํฉ๋๋ค."""
|
| 61 |
+
inputs = nli_tokenizer(
|
| 62 |
+
summary, title,
|
| 63 |
+
return_tensors='pt',
|
| 64 |
+
truncation=True,
|
| 65 |
+
max_length=512
|
| 66 |
+
).to(device)
|
| 67 |
+
|
| 68 |
+
# RoBERTa ๋ชจ๋ธ ์๋ฌ ๋ฐฉ์ง (token_type_ids ์ ๊ฑฐ)
|
| 69 |
+
if "token_type_ids" in inputs:
|
| 70 |
+
del inputs["token_type_ids"]
|
| 71 |
+
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
outputs = nli_model(**inputs)
|
| 74 |
+
probs = F.softmax(outputs.logits, dim=-1)[0]
|
| 75 |
+
|
| 76 |
+
# Huffon/klue-roberta-base-nli ๋ผ๋ฒจ ์์: [Entailment, Neutral, Contradiction]
|
| 77 |
+
# ๋ชจ์(Contradiction) ํ๋ฅ ๋ฐํ (Index 2)
|
| 78 |
+
return round(probs[2].item(), 4)
|
| 79 |
+
|
| 80 |
+
# =============================================================================
|
| 81 |
+
# 4. ์ต์ข
๋ฉ์ธ ํจ์ (Main Logic)
|
| 82 |
+
# =============================================================================
|
| 83 |
+
|
| 84 |
+
def calculate_mismatch_score(article_title, article_body):
|
| 85 |
+
"""
|
| 86 |
+
Grid Search ๊ฒฐ๊ณผ ์ต์ ๊ฐ์ค์น ์ ์ฉ:
|
| 87 |
+
- w1 (SBERT, ์๋ฏธ์ ๊ฑฐ๋ฆฌ): 0.8
|
| 88 |
+
- w2 (NLI, ๋
ผ๋ฆฌ์ ๋ชจ์): 0.2
|
| 89 |
+
- Threshold (์๊ณ๊ฐ): 0.45 ์ด์์ด๋ฉด '์ํ'
|
| 90 |
+
"""
|
| 91 |
+
#if not (kobart_summarizer and sbert_model and nli_model):
|
| 92 |
+
# return {"score": 0.0, "reason": "๋ชจ๋ธ ๋ก๋ฉ ์คํจ", "recommendation": "์๋ฒ ํ์ธ ํ์"}
|
| 93 |
+
|
| 94 |
+
# 1. ๋ณธ๋ฌธ ์์ฝ
|
| 95 |
+
summary = summarize_kobart_strict(article_body)
|
| 96 |
+
|
| 97 |
+
# 2. SBERT ์๋ฏธ์ ๊ฑฐ๋ฆฌ (1 - ์ ์ฌ๋)
|
| 98 |
+
sbert_sim = get_cosine_similarity(article_title, summary)
|
| 99 |
+
semantic_distance = 1 - sbert_sim
|
| 100 |
+
|
| 101 |
+
# 3. NLI ๋
ผ๋ฆฌ์ ๋ชจ์ ํ๋ฅ
|
| 102 |
+
nli_contradiction = get_mismatch_score(summary, article_title)
|
| 103 |
+
|
| 104 |
+
# 4. ์ต์ข
์ ์ ์ฐ์ถ
|
| 105 |
+
w1, w2 = 0.8, 0.2
|
| 106 |
+
final_score = (w1 * semantic_distance) + (w2 * nli_contradiction)
|
| 107 |
+
reason = (
|
| 108 |
+
f"[๋๋ฒ๊ทธ ๋ชจ๋]\n"
|
| 109 |
+
f"1. ์์ฝ๋ฌธ: {summary}\n"
|
| 110 |
+
f"2. SBERT ๊ฑฐ๋ฆฌ: {semantic_distance:.4f}\n"
|
| 111 |
+
f"3. NLI ๋ชจ์: {nli_contradiction:.4f}"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
#reason = f"์ ๋ชฉ๊ณผ ๋ณธ๋ฌธ์ ์๋ฏธ์ ๊ฑฐ๋ฆฌ({semantic_distance:.4f})์ ๋ชจ์ ํ๋ฅ ({nli_contradiction:.4f})์ด ๋ฐ์๋์์ต๋๋ค."
|
| 115 |
+
|
| 116 |
+
# 5. ๊ฒฐ๊ณผ ํ์ (Threshold 0.45 ๊ธฐ์ค)
|
| 117 |
+
if final_score >= 0.45:
|
| 118 |
+
recommendation = "์ ๋ชฉ์ด ๋ณธ๋ฌธ์ ๋ด์ฉ์ ์๊ณกํ๊ฑฐ๋ ๋ชจ์๋ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค."
|
| 119 |
+
else:
|
| 120 |
+
recommendation = "์ ๋ชฉ๊ณผ ๋ณธ๋ฌธ์ ๋ด์ฉ์ด ์ผ์นํฉ๋๋ค."
|
| 121 |
+
|
| 122 |
+
# main.py๋ก ์ ๋ฌํ ๋ฐ์ดํฐ
|
| 123 |
+
return {
|
| 124 |
+
"score": round(final_score, 4),
|
| 125 |
+
"reason": reason,
|
| 126 |
+
"recommendation": recommendation
|
| 127 |
+
}
|
requirements.txt
CHANGED
|
Binary files a/requirements.txt and b/requirements.txt differ
|
|
|