hy commited on
Commit
7134b06
ยท
1 Parent(s): 7f4feea

Fix server code and sync with remote

Browse files
Files changed (6) hide show
  1. .gitignore +3 -1
  2. aggro_model.py +188 -0
  3. database.py +5 -1
  4. main.py +24 -19
  5. mismatch_model.py +127 -0
  6. requirements.txt +0 -0
.gitignore CHANGED
@@ -3,4 +3,6 @@ __pycache__/
3
  *.pyc
4
  .venv/
5
  venv/
6
- info.md
 
 
 
3
  *.pyc
4
  .venv/
5
  venv/
6
+ info.md
7
+ *.pth
8
+ *.pkl
aggro_model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertTokenizer, BertModel
4
+ import pickle
5
+ import re
6
+ import os
7
+ import numpy as np
8
+
9
+ # =============================================================================
10
+ # 1. ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜
11
+ # =============================================================================
12
+
13
+ # (1) ๊ทœ์น™ ๊ธฐ๋ฐ˜ ์Šค์ฝ”์–ด๋Ÿฌ ํด๋ž˜์Šค ๋ณต์›
14
+ class RuleBasedScorer:
15
+ def __init__(self):
16
+ # ํŒŒ์ผ์—์„œ ๋กœ๋“œ๋  ๋•Œ ์ด ๋ณ€์ˆ˜๋“ค์ด ๋ฎ์–ด์”Œ์›Œ์ง‘๋‹ˆ๋‹ค.
17
+ self.patterns = {}
18
+ self.pattern_names = {}
19
+ self.symbol_patterns = {
20
+ 'repeated': re.compile(r'([!?โ€ฆ~])\1+'),
21
+ 'ellipsis': re.compile(r'\.\.\.|โ€ฆ')
22
+ }
23
+
24
+ def get_score(self, title):
25
+ words = re.findall(r'[๊ฐ€-ํžฃA-Za-z0-9]+', str(title))
26
+ scores = {}
27
+
28
+ # ๋ถ€ํ˜ธ ์ ์ˆ˜
29
+ rep = len(self.symbol_patterns['repeated'].findall(title))
30
+ ell = len(self.symbol_patterns['ellipsis'].findall(title))
31
+ symbol_score = (rep * 30) + (ell * 10)
32
+
33
+ for p in [11, 12, 13, 14]:
34
+ word_score = 0
35
+ # ๋‹จ์–ด ๋งค์นญ ์ ์ˆ˜
36
+ for word in words:
37
+ if word in self.patterns[p]:
38
+ # ๊ฐ€์ค‘์น˜ ์ ์šฉ
39
+ word_score += np.log1p(self.patterns[p][word]) * 2
40
+
41
+ total = 0
42
+ if p == 11: total = symbol_score
43
+ elif p == 12: total = word_score + (symbol_score * 0.5)
44
+ else: total = word_score
45
+
46
+ scores[p] = total
47
+
48
+ # ์ตœ์ข… ์ ์ˆ˜ ์‚ฐ์ถœ
49
+ max_pattern = max(scores, key=scores.get)
50
+ max_score = min(100, scores[max_pattern])
51
+
52
+ return {
53
+ 'score': max_score,
54
+ 'pattern': max_pattern,
55
+ 'pattern_name': self.pattern_names.get(max_pattern, '์•Œ ์ˆ˜ ์—†์Œ')
56
+ }
57
+
58
+ # (2) KoBERT ๋ชจ๋ธ ํด๋ž˜์Šค ๋ณต์›
59
+ class FishingClassifier(nn.Module):
60
+ def __init__(self, bert, num_classes=2):
61
+ super().__init__()
62
+ self.bert = bert
63
+ self.dropout = nn.Dropout(0.3)
64
+ self.fc = nn.Linear(768, num_classes)
65
+ def forward(self, input_ids, mask):
66
+ _, pooled = self.bert(input_ids=input_ids, attention_mask=mask, return_dict=False)
67
+ return self.fc(self.dropout(pooled))
68
+
69
+ # =============================================================================
70
+ # 2. ๋ชจ๋ธ ๋กœ๋“œ
71
+ # =============================================================================
72
+ print("[AggroModel] ์‹œ์Šคํ…œ ๋กœ๋”ฉ ์‹œ์ž‘...")
73
+
74
+ aggro_model = None
75
+ tokenizer = None
76
+ rule_scorer = None
77
+ device = torch.device("cpu") # ์„œ๋ฒ„๋Š” CPU ์‚ฌ์šฉ
78
+
79
+ # --- A. ๊ทœ์น™ ๊ธฐ๋ฐ˜ ๋ชจ๋ธ ๋กœ๋“œ (.pkl) ---
80
+ try:
81
+ pkl_path = "rule_based_scorer.pkl"
82
+ if os.path.exists(pkl_path):
83
+ with open(pkl_path, "rb") as f:
84
+ rule_scorer = pickle.load(f)
85
+ print("โœ… [Aggro] ๊ทœ์น™ ๊ธฐ๋ฐ˜ ๋ชจ๋ธ(PKL) ๋กœ๋“œ ์„ฑ๊ณต!")
86
+ else:
87
+ print(f"โš ๏ธ [Aggro] ๊ทœ์น™ ํŒŒ์ผ ์—†์Œ: {pkl_path}")
88
+ except Exception as e:
89
+ print(f"๐Ÿšจ [Aggro] ๊ทœ์น™ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
90
+
91
+ # --- B. KoBERT ๋ชจ๋ธ ๋กœ๋“œ (.pth) ---
92
+ try:
93
+ # 1. ํ† ํฌ๋‚˜์ด์ € & ๊ธฐ๋ณธ ๋ชจ๋ธ ๋กœ๋“œ
94
+ # (์—๋Ÿฌ ๋‚˜๋ฉด 'monologg/kobert'๋กœ ๋ณ€๊ฒฝ ์‹œ๋„)
95
+ MODEL_NAME = 'skt/kobert-base-v1'
96
+ tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
97
+ bert_base = BertModel.from_pretrained(MODEL_NAME)
98
+
99
+ # 2. ๋ผˆ๋Œ€ ์ƒ์„ฑ
100
+ aggro_model = FishingClassifier(bert_base).to(device)
101
+
102
+ # 3. ๊ฐ€์ค‘์น˜ ํŒŒ์ผ ๋กœ๋“œ
103
+ pth_path = "bert_fishing_model_best.pth"
104
+
105
+ if os.path.exists(pth_path):
106
+ state_dict = torch.load(pth_path, map_location=device)
107
+ aggro_model.load_state_dict(state_dict)
108
+ aggro_model.eval()
109
+ print("โœ… [Aggro] KoBERT ๋ชจ๋ธ(PTH) ๋กœ๋“œ ์„ฑ๊ณต!")
110
+ else:
111
+ print(f"โš ๏ธ [Aggro] KoBERT ํŒŒ์ผ ์—†์Œ: {pth_path}")
112
+ aggro_model = None
113
+
114
+ except Exception as e:
115
+ print(f"๐Ÿšจ [Aggro] KoBERT ๋กœ๋“œ ์‹คํŒจ: {e}")
116
+ aggro_model = None
117
+
118
+
119
+ # =============================================================================
120
+ # 3. ๋ฉ”์ธ ํ•จ์ˆ˜ (main.py์—์„œ ํ˜ธ์ถœ)
121
+ # =============================================================================
122
+ def get_aggro_score(title: str) -> dict:
123
+ """
124
+ ์ œ๋ชฉ์„ ๋ฐ›์•„ ๋‚š์‹œ์„ฑ ์ ์ˆ˜์™€ ๊ทผ๊ฑฐ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
125
+ """
126
+
127
+ # 1. ๊ทœ์น™ ๊ธฐ๋ฐ˜ ์ ์ˆ˜
128
+ rule_score = 0.0
129
+ rule_pattern = "๋ถ„์„ ๋ถˆ๊ฐ€"
130
+
131
+ if rule_scorer:
132
+ try:
133
+ res = rule_scorer.get_score(title)
134
+ rule_score = res['score'] # 0~100์ 
135
+ rule_pattern = res['pattern_name']
136
+ except Exception as e:
137
+ print(f"[Aggro] ๊ทœ์น™ ๊ณ„์‚ฐ ์˜ค๋ฅ˜: {e}")
138
+
139
+ # 2. KoBERT ์ ์ˆ˜
140
+ bert_score = 0.0
141
+ if aggro_model and tokenizer:
142
+ try:
143
+ inputs = tokenizer(
144
+ title,
145
+ return_tensors='pt',
146
+ padding="max_length",
147
+ truncation=True,
148
+ max_length=64
149
+ )
150
+ input_ids = inputs['input_ids'].to(device)
151
+ mask = inputs['attention_mask'].to(device)
152
+
153
+ with torch.no_grad():
154
+ outputs = aggro_model(input_ids, mask)
155
+ probs = torch.softmax(outputs, dim=1)
156
+ bert_score = probs[0][1].item() * 100 # 0~100
157
+ except Exception as e:
158
+ print(f"[Aggro] KoBERT ์˜ˆ์ธก ์˜ค๋ฅ˜: {e}")
159
+ bert_score = 50.0
160
+
161
+ # 3. ์ตœ์ข… ํ•ฉ์‚ฐ (๊ทœ์น™ 40% + BERT 60%)
162
+ if rule_scorer and aggro_model:
163
+ final_score = (rule_score * 0.4) + (bert_score * 0.6)
164
+ elif aggro_model:
165
+ final_score = bert_score
166
+ elif rule_scorer:
167
+ final_score = rule_score
168
+ else:
169
+ final_score = 0.0
170
+
171
+ # 4. ๊ฒฐ๊ณผ ํ…์ŠคํŠธ
172
+ normalized_score = min(final_score / 100.0, 1.0) # 0~1 ๋ณ€ํ™˜
173
+
174
+ reason = "์ œ๋ชฉ์ด ํ‰์ดํ•ฉ๋‹ˆ๋‹ค."
175
+ recommendation = "์–‘ํ˜ธํ•ฉ๋‹ˆ๋‹ค."
176
+
177
+ if final_score >= 70:
178
+ reason = f"AI ์˜ˆ์ธก({bert_score:.0f}์ ) ๋ฐ ๊ทœ์น™({rule_pattern})์— ์˜ํ•ด ๋‚š์‹œ์„ฑ์œผ๋กœ ํŒ๋‹จ๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
179
+ recommendation = "๊ฐ๊ด€์ ์ธ ์‚ฌ์‹ค ์œ„์ฃผ์˜ ์ œ๋ชฉ์œผ๋กœ ์ˆ˜์ •ํ•ด์ฃผ์„ธ์š”."
180
+ elif final_score >= 40:
181
+ reason = f"์ผ๋ถ€ ๋‚š์‹œ์„ฑ ์š”์†Œ({rule_pattern})๊ฐ€ ๊ฐ์ง€๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
182
+ recommendation = "ํ‘œํ˜„์„ ์กฐ๊ธˆ ๋” ๋‹ค๋“ฌ๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค."
183
+
184
+ return {
185
+ "score": round(normalized_score, 2),
186
+ "reason": reason,
187
+ "recommendation": recommendation
188
+ }
database.py CHANGED
@@ -8,7 +8,11 @@ SQLALCHEMY_DATABASE_URL = os.environ.get("SQLALCHEMY_DATABASE_URL")
8
 
9
  # 2. DB ์—ฐ๊ฒฐ ์—”์ง„ ์ƒ์„ฑ
10
  engine = create_engine(
11
- SQLALCHEMY_DATABASE_URL
 
 
 
 
12
  )
13
 
14
  # 3. DB์™€ ํ†ต์‹ ํ•  ์„ธ์…˜(Session) ์ƒ์„ฑ
 
8
 
9
  # 2. DB ์—ฐ๊ฒฐ ์—”์ง„ ์ƒ์„ฑ
10
  engine = create_engine(
11
+ SQLALCHEMY_DATABASE_URL,
12
+ pool_pre_ping=True, # <--- (1) ์—ฐ๊ฒฐํ•˜๊ธฐ ์ „์— Ping
13
+ pool_recycle=300, # <--- (2) 5๋ถ„(300์ดˆ)๋งˆ๋‹ค ์—ฐ๊ฒฐ์„ ์ƒˆ๊ฒƒ์œผ๋กœ ๊ต์ฒด (์˜ค๋ž˜๋œ ์—ฐ๊ฒฐ ๋Š๊ธฐ ๋ฐฉ์ง€)
14
+ pool_size=5, # <--- (3) ๋™์‹œ์— ์œ ์ง€ํ•  ์—ฐ๊ฒฐ ๊ฐœ์ˆ˜
15
+ max_overflow=10 # <--- (4) ๊ฐ‘์ž๊ธฐ ๋ชฐ๋ฆด ๋•Œ ์ถ”๊ฐ€๋กœ ํ—ˆ์šฉํ•  ์—ฐ๊ฒฐ ๊ฐœ์ˆ˜
16
  )
17
 
18
  # 3. DB์™€ ํ†ต์‹ ํ•  ์„ธ์…˜(Session) ์ƒ์„ฑ
main.py CHANGED
@@ -6,7 +6,8 @@ from database import engine, SessionLocal
6
  from sqlalchemy.orm import Session # ํ•„์ˆ˜ import
7
  from models import Article, AnalysisResult
8
  from crossref_model import get_crossref_score_and_reason
9
-
 
10
 
11
  models.Base.metadata.create_all(bind=engine)
12
  from database import SessionLocal # database.py์—์„œ ์ •์˜ํ•œ SessionLocal ๊ฐ€์ ธ์˜ค๊ธฐ
@@ -51,21 +52,25 @@ app = FastAPI()
51
 
52
  @app.post("/api/v1/analyze", response_model=AnalysisResponse)
53
  def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
54
- """๊ธฐ์‚ฌ ๋ถ„์„ API (Aggro/Mismatch๋Š” '๊ฐ€์งœ' ๋ฐ์ดํ„ฐ ๋ฐ˜ํ™˜)"""
 
 
 
55
 
56
- # 1. aggro(๊ฐ€์งœ ์ ์ˆ˜)
57
- dummy_aggro = ScoreBreakdown(
58
- score=0.95,
59
- reason="'๊ฒฐ๊ตญ', '๊ฒฝ์•…' ๋“ฑ ์ž๊ทน์ ์ธ ํ‘œํ˜„์ด ์‚ฌ์šฉ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.",
60
- recommendation="์ œ๋ชฉ์˜ ์ž๊ทน์ ์ธ ํ‘œํ˜„์„ ์ˆ˜์ •ํ•˜์„ธ์š”.",
61
  found_urls=None
62
  )
63
 
64
- # 2. mismatch (๊ฐ€์งœ ์ ์ˆ˜)
65
- dummy_mismatch = ScoreBreakdown(
66
- score=0.15,
67
- reason="์ œ๋ชฉ๊ณผ ๋ณธ๋ฌธ์˜ ํ•ต์‹ฌ ๋‚ด์šฉ์ด ์˜๋ฏธ์ ์œผ๋กœ ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค.",
68
- recommendation="์–‘ํ˜ธํ•ฉ๋‹ˆ๋‹ค.",
 
 
69
  found_urls=None
70
  )
71
 
@@ -97,8 +102,8 @@ def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
97
  w_mismatch = 0.4
98
  w_crossref = 0.2
99
 
100
- final_score = (dummy_aggro.score * w_aggro) + \
101
- (dummy_mismatch.score * w_mismatch) + \
102
  (final_crossref.score * w_crossref)
103
 
104
  final_level = "์•ˆ์ „"
@@ -124,8 +129,8 @@ def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
124
  # 3. 'analysis_results' ํ…Œ์ด๋ธ”์— ๋ถ„์„ ๊ฒฐ๊ณผ ์ €์žฅ
125
  new_result = AnalysisResult(
126
  article_id=new_article.article_id, # ์™ธ๋ž˜ ํ‚ค ์—ฐ๊ฒฐ
127
- aggro_score=dummy_aggro.score,
128
- mismatch_score=dummy_mismatch.score,
129
  crossref_score=final_crossref.score,
130
  final_risk=final_score,
131
  )
@@ -140,12 +145,12 @@ def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
140
  final_risk_score=round(final_score, 2), # ์†Œ์ˆ˜์  2์ž๋ฆฌ๋กœ ๋ฐ˜์˜ฌ๋ฆผ
141
  final_risk_level=final_level,
142
  breakdown={
143
- "aggro_score": dummy_aggro,
144
- "mismatch_score": dummy_mismatch,
145
  "crossref_score": final_crossref
146
  }
147
  )
148
 
149
  @app.get("/")
150
  def read_root():
151
- return {"message": "AI ๊ธฐ์‚ฌ ๋ถ„์„ ์„œ๋ฒ„ v1.0 (Dummy Server)"}
 
6
  from sqlalchemy.orm import Session # ํ•„์ˆ˜ import
7
  from models import Article, AnalysisResult
8
  from crossref_model import get_crossref_score_and_reason
9
+ from mismatch_model import calculate_mismatch_score # <-- ํŒ€์› ํ•จ์ˆ˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
10
+ from aggro_model import get_aggro_score
11
 
12
  models.Base.metadata.create_all(bind=engine)
13
  from database import SessionLocal # database.py์—์„œ ์ •์˜ํ•œ SessionLocal ๊ฐ€์ ธ์˜ค๊ธฐ
 
52
 
53
  @app.post("/api/v1/analyze", response_model=AnalysisResponse)
54
  def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
55
+ """๊ธฐ์‚ฌ ๋ถ„์„ API"""
56
+
57
+ # 1. AggroScore
58
+ aggro_result = get_aggro_score(request.article_title)
59
 
60
+ real_aggro = ScoreBreakdown(
61
+ score=aggro_result["score"],
62
+ reason=aggro_result["reason"],
63
+ recommendation=aggro_result["recommendation"],
 
64
  found_urls=None
65
  )
66
 
67
+ # 2. mismatch
68
+ mismatch_result = calculate_mismatch_score(request.article_title, request.article_body)
69
+
70
+ real_mismatch = ScoreBreakdown(
71
+ score=mismatch_result["score"],
72
+ reason=mismatch_result["reason"],
73
+ recommendation=mismatch_result["recommendation"],
74
  found_urls=None
75
  )
76
 
 
102
  w_mismatch = 0.4
103
  w_crossref = 0.2
104
 
105
+ final_score = (real_aggro.score * w_aggro) + \
106
+ (real_mismatch.score * w_mismatch) + \
107
  (final_crossref.score * w_crossref)
108
 
109
  final_level = "์•ˆ์ „"
 
129
  # 3. 'analysis_results' ํ…Œ์ด๋ธ”์— ๋ถ„์„ ๊ฒฐ๊ณผ ์ €์žฅ
130
  new_result = AnalysisResult(
131
  article_id=new_article.article_id, # ์™ธ๋ž˜ ํ‚ค ์—ฐ๊ฒฐ
132
+ aggro_score=real_aggro.score,
133
+ mismatch_score=real_mismatch.score,
134
  crossref_score=final_crossref.score,
135
  final_risk=final_score,
136
  )
 
145
  final_risk_score=round(final_score, 2), # ์†Œ์ˆ˜์  2์ž๋ฆฌ๋กœ ๋ฐ˜์˜ฌ๋ฆผ
146
  final_risk_level=final_level,
147
  breakdown={
148
+ "aggro_score": real_aggro,
149
+ "mismatch_score": real_mismatch,
150
  "crossref_score": final_crossref
151
  }
152
  )
153
 
154
  @app.get("/")
155
  def read_root():
156
+ return {"message": "AI ๊ธฐ์‚ฌ ๋ถ„์„ ์„œ๋ฒ„"}
mismatch_model.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
4
+ from sentence_transformers import SentenceTransformer, util
5
+
6
+ # ๋””๋ฐ”์ด์Šค ์„ค์ • (GPU ์šฐ์„ , ์—†์œผ๋ฉด CPU)
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ print(f"โœ… ํ˜„์žฌ ์‹คํ–‰ ํ™˜๊ฒฝ: {device}")
9
+
10
+ # =============================================================================
11
+ # 2. ๋ชจ๋ธ ๋กœ๋“œ (์‹œ๊ฐ„์ด ์กฐ๊ธˆ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค)
12
+ # =============================================================================
13
+ print("\nโณ [1/3] KoBART ์š”์•ฝ ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
14
+ kobart_summarizer = pipeline(
15
+ "summarization",
16
+ model="gogamza/kobart-summarization",
17
+ device=0 if torch.cuda.is_available() else -1
18
+ )
19
+
20
+ print("โณ [2/3] SBERT ์œ ์‚ฌ๋„ ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
21
+ sbert_model = SentenceTransformer('jhgan/ko-sroberta-multitask')
22
+
23
+ print("โณ [3/3] NLI(๋ชจ์ˆœ ํƒ์ง€) ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
24
+ nli_model_name = "Huffon/klue-roberta-base-nli"
25
+ nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
26
+ nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device)
27
+
28
+ print("๐ŸŽ‰ ๋ชจ๋“  ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ!\n")
29
+
30
+ # =============================================================================
31
+ # 3. ๋„์šฐ๋ฏธ ํ•จ์ˆ˜ ์ •์˜ (Worker Functions)
32
+ # =============================================================================
33
+
34
+ def summarize_kobart_strict(text):
35
+ """KoBART๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณธ๋ฌธ์„ ์š”์•ฝํ•ฉ๋‹ˆ๋‹ค."""
36
+ # ๋ณธ๋ฌธ์ด ๋„ˆ๋ฌด ์งง์œผ๋ฉด ์š”์•ฝ ์ƒ๋žต (์˜ค๋ฅ˜ ๋ฐฉ์ง€)
37
+ if len(text) < 50:
38
+ return text
39
+
40
+ try:
41
+ result = kobart_summarizer(
42
+ text,
43
+ min_length=15,
44
+ max_length=128,
45
+ num_beams=4,
46
+ no_repeat_ngram_size=3,
47
+ early_stopping=True
48
+ )[0]['summary_text']
49
+ return result.strip()
50
+ except Exception as e:
51
+ return text[:100] # ์‹คํŒจ ์‹œ ์•ž๋ถ€๋ถ„ ๋ฐ˜ํ™˜
52
+
53
+ def get_cosine_similarity(title, summary):
54
+ """SBERT๋กœ ์ œ๋ชฉ๊ณผ ์š”์•ฝ๋ฌธ์˜ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค."""
55
+ emb1 = sbert_model.encode(title, convert_to_tensor=True)
56
+ emb2 = sbert_model.encode(summary, convert_to_tensor=True)
57
+ return util.cos_sim(emb1, emb2).item()
58
+
59
+ def get_mismatch_score(summary, title):
60
+ """NLI ๋ชจ๋ธ๋กœ ์š”์•ฝ๋ฌธ(์ „์ œ)๊ณผ ์ œ๋ชฉ(๊ฐ€์„ค) ์‚ฌ์ด์˜ ๋ชจ์ˆœ ํ™•๋ฅ ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค."""
61
+ inputs = nli_tokenizer(
62
+ summary, title,
63
+ return_tensors='pt',
64
+ truncation=True,
65
+ max_length=512
66
+ ).to(device)
67
+
68
+ # RoBERTa ๋ชจ๋ธ ์—๋Ÿฌ ๋ฐฉ์ง€ (token_type_ids ์ œ๊ฑฐ)
69
+ if "token_type_ids" in inputs:
70
+ del inputs["token_type_ids"]
71
+
72
+ with torch.no_grad():
73
+ outputs = nli_model(**inputs)
74
+ probs = F.softmax(outputs.logits, dim=-1)[0]
75
+
76
+ # Huffon/klue-roberta-base-nli ๋ผ๋ฒจ ์ˆœ์„œ: [Entailment, Neutral, Contradiction]
77
+ # ๋ชจ์ˆœ(Contradiction) ํ™•๋ฅ  ๋ฐ˜ํ™˜ (Index 2)
78
+ return round(probs[2].item(), 4)
79
+
80
+ # =============================================================================
81
+ # 4. ์ตœ์ข… ๋ฉ”์ธ ํ•จ์ˆ˜ (Main Logic)
82
+ # =============================================================================
83
+
84
+ def calculate_mismatch_score(article_title, article_body):
85
+ """
86
+ Grid Search ๊ฒฐ๊ณผ ์ตœ์  ๊ฐ€์ค‘์น˜ ์ ์šฉ:
87
+ - w1 (SBERT, ์˜๋ฏธ์  ๊ฑฐ๋ฆฌ): 0.8
88
+ - w2 (NLI, ๋…ผ๋ฆฌ์  ๋ชจ์ˆœ): 0.2
89
+ - Threshold (์ž„๊ณ„๊ฐ’): 0.45 ์ด์ƒ์ด๋ฉด '์œ„ํ—˜'
90
+ """
91
+ #if not (kobart_summarizer and sbert_model and nli_model):
92
+ # return {"score": 0.0, "reason": "๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ", "recommendation": "์„œ๋ฒ„ ํ™•์ธ ํ•„์š”"}
93
+
94
+ # 1. ๋ณธ๋ฌธ ์š”์•ฝ
95
+ summary = summarize_kobart_strict(article_body)
96
+
97
+ # 2. SBERT ์˜๋ฏธ์  ๊ฑฐ๋ฆฌ (1 - ์œ ์‚ฌ๋„)
98
+ sbert_sim = get_cosine_similarity(article_title, summary)
99
+ semantic_distance = 1 - sbert_sim
100
+
101
+ # 3. NLI ๋…ผ๋ฆฌ์  ๋ชจ์ˆœ ํ™•๋ฅ 
102
+ nli_contradiction = get_mismatch_score(summary, article_title)
103
+
104
+ # 4. ์ตœ์ข… ์ ์ˆ˜ ์‚ฐ์ถœ
105
+ w1, w2 = 0.8, 0.2
106
+ final_score = (w1 * semantic_distance) + (w2 * nli_contradiction)
107
+ reason = (
108
+ f"[๋””๋ฒ„๊ทธ ๋ชจ๋“œ]\n"
109
+ f"1. ์š”์•ฝ๋ฌธ: {summary}\n"
110
+ f"2. SBERT ๊ฑฐ๋ฆฌ: {semantic_distance:.4f}\n"
111
+ f"3. NLI ๋ชจ์ˆœ: {nli_contradiction:.4f}"
112
+ )
113
+
114
+ #reason = f"์ œ๋ชฉ๊ณผ ๋ณธ๋ฌธ์˜ ์˜๋ฏธ์  ๊ฑฐ๋ฆฌ({semantic_distance:.4f})์™€ ๋ชจ์ˆœ ํ™•๋ฅ ({nli_contradiction:.4f})์ด ๋ฐ˜์˜๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
115
+
116
+ # 5. ๊ฒฐ๊ณผ ํŒ์ • (Threshold 0.45 ๊ธฐ์ค€)
117
+ if final_score >= 0.45:
118
+ recommendation = "์ œ๋ชฉ์ด ๋ณธ๋ฌธ์˜ ๋‚ด์šฉ์„ ์™œ๊ณกํ•˜๊ฑฐ๋‚˜ ๋ชจ์ˆœ๋  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์Šต๋‹ˆ๋‹ค."
119
+ else:
120
+ recommendation = "์ œ๋ชฉ๊ณผ ๋ณธ๋ฌธ์˜ ๋‚ด์šฉ์ด ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค."
121
+
122
+ # main.py๋กœ ์ „๋‹ฌํ•  ๋ฐ์ดํ„ฐ
123
+ return {
124
+ "score": round(final_score, 4),
125
+ "reason": reason,
126
+ "recommendation": recommendation
127
+ }
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ