project-tdm / aggro_model.py
hy
round 4
474acae
raw
history blame
7.48 kB
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import pickle
import re
import os
import sys
import numpy as np
from collections import defaultdict
# =============================================================================
# 1. ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜
# =============================================================================
# (1) ๊ทœ์น™ ๊ธฐ๋ฐ˜ ์Šค์ฝ”์–ด๋Ÿฌ ํด๋ž˜์Šค
class RuleBasedScorer:
def __init__(self):
# ํŒจํ„ด๋ณ„ ๋‹จ์–ด ์‚ฌ์ „
self.patterns = {
11: defaultdict(float), 12: defaultdict(float),
13: defaultdict(float), 14: defaultdict(float)
}
self.pattern_names = {
11: '์˜๋ฌธ ์œ ๋ฐœํ˜•(๋ถ€ํ˜ธ)', 12: '์˜๋ฌธ ์œ ๋ฐœํ˜•(์€๋‹‰)',
13: '์„ ์ •ํ‘œํ˜„ ์‚ฌ์šฉํ˜•', 14: '์†์–ด/์ค„์ž„๋ง ์‚ฌ์šฉํ˜•'
}
# ๋ถ€ํ˜ธ ํŒจํ„ด (๋‹จ์ˆœ ๋ฌผ์Œํ‘œ ์ œ์™ธ, ๊ณผ์žฅ๋œ ๋ถ€ํ˜ธ๋งŒ)
self.symbol_patterns = {
'repeated': re.compile(r'([!?โ€ฆ~])\1+'), # ๋ฐ˜๋ณต ๋ถ€ํ˜ธ (??, !!)
'ellipsis': re.compile(r'\.\.\.|โ€ฆ') # ๋ง์ค„์ž„ํ‘œ
}
def get_score(self, title):
# 1. ํ…์ŠคํŠธ ํ† ํฐํ™” (๋‹จ์ˆœ ๋„์–ด์“ฐ๊ธฐ ๋ฐ ๋ฌธ์ž ์ถ”์ถœ)
words = re.findall(r'[๊ฐ€-ํžฃA-Za-z0-9]+', str(title))
scores = {}
# 2. ๋ถ€ํ˜ธ ์ ์ˆ˜ ๊ณ„์‚ฐ
rep = len(self.symbol_patterns['repeated'].findall(title))
ell = len(self.symbol_patterns['ellipsis'].findall(title))
symbol_score = (rep * 30) + (ell * 10)
# 3. ํŒจํ„ด๋ณ„(11~14) ์ ์ˆ˜ ๊ณ„์‚ฐ
for p in [11, 12, 13, 14]:
word_score = 0
# ๋‹จ์–ด ๋งค์นญ ์ ์ˆ˜ (์‚ฌ์ „์— ์žˆ๋Š” ๋‹จ์–ด์ธ์ง€ ํ™•์ธ)
if p in self.patterns: # ์•ˆ์ „์žฅ์น˜
for word in words:
if word in self.patterns[p]:
# ๊ฐ€์ค‘์น˜ ์ ์šฉ (๋กœ๊ทธ ์Šค์ผ€์ผ)
word_score += np.log1p(self.patterns[p][word]) * 2
total = 0
# ํŒจํ„ด๋ณ„ ์ ์ˆ˜ ํ•ฉ์‚ฐ ๋กœ์ง
if p == 11: # ์˜๋ฌธ๋ถ€ํ˜ธํ˜•
total = symbol_score # ์˜ค์ง ๋ถ€ํ˜ธ๋งŒ ๋ด„
elif p == 12: # ์˜๋ฌธ์€๋‹‰ํ˜• ("...์ด์œ ๋Š”")
total = word_score + (symbol_score * 0.5)
else: # 13(์„ ์ •), 14(์†์–ด)
total = word_score # ์˜ค์ง ๋‹จ์–ด๋งŒ ๋ด„
scores[p] = total
# 4. ์ตœ์ข… ์ ์ˆ˜ ์‚ฐ์ถœ (๊ฐ€์žฅ ๋†’์€ ์ ์ˆ˜ ์„ ํƒ)
if not scores:
return {'score': 0, 'pattern': 0, 'pattern_name': '์ •์ƒ'}
max_pattern = max(scores, key=scores.get)
max_score = min(100, scores[max_pattern]) # 100์  ๋งŒ์ 
return {
'score': max_score,
'pattern': max_pattern,
'pattern_name': self.pattern_names.get(max_pattern, '์•Œ ์ˆ˜ ์—†์Œ')
}
# ๐Ÿšจ Pickle ๋กœ๋”ฉ ์—๋Ÿฌ ๋ฐฉ์ง€์šฉ
import __main__
setattr(__main__, "RuleBasedScorer", RuleBasedScorer)
# (2) KoBERT ๋ชจ๋ธ ํด๋ž˜์Šค
class FishingClassifier(nn.Module):
def __init__(self, bert, num_classes=2):
super().__init__()
self.bert = bert
self.dropout = nn.Dropout(0.3)
self.fc = nn.Linear(768, num_classes)
def forward(self, input_ids, mask):
_, pooled = self.bert(input_ids=input_ids, attention_mask=mask, return_dict=False)
return self.fc(self.dropout(pooled))
# =============================================================================
# 2. ๋ชจ๋ธ ๋กœ๋“œ
# =============================================================================
print("[AggroModel] ์‹œ์Šคํ…œ ๋กœ๋”ฉ ์‹œ์ž‘...")
aggro_model = None
tokenizer = None
rule_scorer = None
device = torch.device("cpu")
# ์ ˆ๋Œ€ ๊ฒฝ๋กœ ์„ค์ •
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# --- A. ๊ทœ์น™ ๊ธฐ๋ฐ˜ ๋ชจ๋ธ ๋กœ๋“œ (.pkl) ---
try:
pkl_path = os.path.join(BASE_DIR, "rule_based_scorer.pkl")
if os.path.exists(pkl_path):
with open(pkl_path, "rb") as f:
rule_scorer = pickle.load(f)
print(f"โœ… [Aggro] ๊ทœ์น™ ๊ธฐ๋ฐ˜ ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต!")
else:
print(f"โš ๏ธ [Aggro] ๊ทœ์น™ ํŒŒ์ผ ์—†์Œ: {pkl_path}")
except Exception as e:
print(f"๐Ÿšจ [Aggro] ๊ทœ์น™ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
# --- B. KoBERT ๋ชจ๋ธ ๋กœ๋“œ (.pth) ---
try:
MODEL_NAME = 'skt/kobert-base-v1'
try:
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
bert_base = BertModel.from_pretrained(MODEL_NAME)
except:
print("โš ๏ธ skt ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ, monologg๋กœ ์žฌ์‹œ๋„...")
tokenizer = BertTokenizer.from_pretrained('monologg/kobert')
bert_base = BertModel.from_pretrained('monologg/kobert')
aggro_model = FishingClassifier(bert_base).to(device)
pth_path = os.path.join(BASE_DIR, "bert_fishing_model_best.pth")
if os.path.exists(pth_path):
state_dict = torch.load(pth_path, map_location=device)
aggro_model.load_state_dict(state_dict)
aggro_model.eval()
print(f"โœ… [Aggro] KoBERT ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต!")
else:
print(f"โš ๏ธ [Aggro] KoBERT ํŒŒ์ผ ์—†์Œ: {pth_path}")
aggro_model = None
except Exception as e:
print(f"๐Ÿšจ [Aggro] KoBERT ๋กœ๋“œ ์‹คํŒจ: {e}")
aggro_model = None
# =============================================================================
# 3. ๋ฉ”์ธ ํ•จ์ˆ˜
# =============================================================================
def get_aggro_score(title: str) -> dict:
# 1. ๊ทœ์น™ ๊ธฐ๋ฐ˜ ์ ์ˆ˜
rule_score = 0.0
rule_pattern = "๋ถ„์„ ๋ถˆ๊ฐ€"
if rule_scorer:
try:
res = rule_scorer.get_score(title)
rule_score = res['score']
rule_pattern = res.get('pattern_name', '์•Œ ์ˆ˜ ์—†์Œ')
except Exception as e:
print(f"๊ทœ์น™ ๊ณ„์‚ฐ ์—๋Ÿฌ: {e}")
rule_score = 50.0
# 2. KoBERT ์ ์ˆ˜
bert_score = 0.0
if aggro_model and tokenizer:
try:
inputs = tokenizer(
title, return_tensors='pt', padding="max_length", truncation=True, max_length=64
)
input_ids = inputs['input_ids'].to(device)
mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = aggro_model(input_ids, mask)
probs = torch.softmax(outputs, dim=1)
bert_score = probs[0][1].item() * 100
except:
bert_score = 50.0
# 3. ํ•ฉ์‚ฐ (๊ทœ์น™ 40% + BERT 60%)
if rule_scorer and aggro_model:
final_score = (rule_score * 0.4) + (bert_score * 0.6)
elif aggro_model:
final_score = bert_score
elif rule_scorer:
final_score = rule_score
else:
final_score = 0.0
# 4. ๊ฒฐ๊ณผ
normalized_score = min(final_score / 100.0, 1.0)
reason = "์ œ๋ชฉ์ด ํ‰์ดํ•ฉ๋‹ˆ๋‹ค."
recommendation = "์–‘ํ˜ธํ•ฉ๋‹ˆ๋‹ค."
if final_score >= 70:
reason = f"AI ์˜ˆ์ธก({bert_score:.0f}์ ) ๋ฐ ๊ทœ์น™({rule_pattern})์— ์˜ํ•ด ๋‚š์‹œ์„ฑ์œผ๋กœ ํŒ๋‹จ๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
recommendation = "๊ฐ๊ด€์ ์ธ ์‚ฌ์‹ค ์œ„์ฃผ์˜ ์ œ๋ชฉ์œผ๋กœ ์ˆ˜์ •ํ•ด์ฃผ์„ธ์š”."
elif final_score >= 40:
reason = f"์ผ๋ถ€ ๋‚š์‹œ์„ฑ ์š”์†Œ({rule_pattern})๊ฐ€ ๊ฐ์ง€๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
recommendation = "ํ‘œํ˜„์„ ์กฐ๊ธˆ ๋” ๋‹ค๋“ฌ๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค."
return {
"score": round(normalized_score, 4),
"reason": reason,
"recommendation": recommendation
}