Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from transformers import BertTokenizer, BertModel | |
| import pickle | |
| import re | |
| import os | |
| import sys | |
| import numpy as np | |
| from collections import defaultdict | |
| # ============================================================================= | |
| # 1. ๋ชจ๋ธ ํด๋์ค ์ ์ | |
| # ============================================================================= | |
| # (1) ๊ท์น ๊ธฐ๋ฐ ์ค์ฝ์ด๋ฌ ํด๋์ค | |
| class RuleBasedScorer: | |
| def __init__(self): | |
| # ํจํด๋ณ ๋จ์ด ์ฌ์ | |
| self.patterns = { | |
| 11: defaultdict(float), 12: defaultdict(float), | |
| 13: defaultdict(float), 14: defaultdict(float) | |
| } | |
| self.pattern_names = { | |
| 11: '์๋ฌธ ์ ๋ฐํ(๋ถํธ)', 12: '์๋ฌธ ์ ๋ฐํ(์๋)', | |
| 13: '์ ์ ํํ ์ฌ์ฉํ', 14: '์์ด/์ค์๋ง ์ฌ์ฉํ' | |
| } | |
| # ๋ถํธ ํจํด (๋จ์ ๋ฌผ์ํ ์ ์ธ, ๊ณผ์ฅ๋ ๋ถํธ๋ง) | |
| self.symbol_patterns = { | |
| 'repeated': re.compile(r'([!?โฆ~])\1+'), # ๋ฐ๋ณต ๋ถํธ (??, !!) | |
| 'ellipsis': re.compile(r'\.\.\.|โฆ') # ๋ง์ค์ํ | |
| } | |
| def get_score(self, title): | |
| # 1. ํ ์คํธ ํ ํฐํ (๋จ์ ๋์ด์ฐ๊ธฐ ๋ฐ ๋ฌธ์ ์ถ์ถ) | |
| words = re.findall(r'[๊ฐ-ํฃA-Za-z0-9]+', str(title)) | |
| scores = {} | |
| # 2. ๋ถํธ ์ ์ ๊ณ์ฐ | |
| rep = len(self.symbol_patterns['repeated'].findall(title)) | |
| ell = len(self.symbol_patterns['ellipsis'].findall(title)) | |
| symbol_score = (rep * 30) + (ell * 10) | |
| # 3. ํจํด๋ณ(11~14) ์ ์ ๊ณ์ฐ | |
| for p in [11, 12, 13, 14]: | |
| word_score = 0 | |
| # ๋จ์ด ๋งค์นญ ์ ์ (์ฌ์ ์ ์๋ ๋จ์ด์ธ์ง ํ์ธ) | |
| if p in self.patterns: # ์์ ์ฅ์น | |
| for word in words: | |
| if word in self.patterns[p]: | |
| # ๊ฐ์ค์น ์ ์ฉ (๋ก๊ทธ ์ค์ผ์ผ) | |
| word_score += np.log1p(self.patterns[p][word]) * 2 | |
| total = 0 | |
| # ํจํด๋ณ ์ ์ ํฉ์ฐ ๋ก์ง | |
| if p == 11: # ์๋ฌธ๋ถํธํ | |
| total = symbol_score # ์ค์ง ๋ถํธ๋ง ๋ด | |
| elif p == 12: # ์๋ฌธ์๋ํ ("...์ด์ ๋") | |
| total = word_score + (symbol_score * 0.5) | |
| else: # 13(์ ์ ), 14(์์ด) | |
| total = word_score # ์ค์ง ๋จ์ด๋ง ๋ด | |
| scores[p] = total | |
| # 4. ์ต์ข ์ ์ ์ฐ์ถ (๊ฐ์ฅ ๋์ ์ ์ ์ ํ) | |
| if not scores: | |
| return {'score': 0, 'pattern': 0, 'pattern_name': '์ ์'} | |
| max_pattern = max(scores, key=scores.get) | |
| max_score = min(100, scores[max_pattern]) # 100์ ๋ง์ | |
| return { | |
| 'score': max_score, | |
| 'pattern': max_pattern, | |
| 'pattern_name': self.pattern_names.get(max_pattern, '์ ์ ์์') | |
| } | |
| # ๐จ Pickle ๋ก๋ฉ ์๋ฌ ๋ฐฉ์ง์ฉ | |
| import __main__ | |
| setattr(__main__, "RuleBasedScorer", RuleBasedScorer) | |
| # (2) KoBERT ๋ชจ๋ธ ํด๋์ค | |
| class FishingClassifier(nn.Module): | |
| def __init__(self, bert, num_classes=2): | |
| super().__init__() | |
| self.bert = bert | |
| self.dropout = nn.Dropout(0.3) | |
| self.fc = nn.Linear(768, num_classes) | |
| def forward(self, input_ids, mask): | |
| _, pooled = self.bert(input_ids=input_ids, attention_mask=mask, return_dict=False) | |
| return self.fc(self.dropout(pooled)) | |
| # ============================================================================= | |
| # 2. ๋ชจ๋ธ ๋ก๋ | |
| # ============================================================================= | |
| print("[AggroModel] ์์คํ ๋ก๋ฉ ์์...") | |
| aggro_model = None | |
| tokenizer = None | |
| rule_scorer = None | |
| device = torch.device("cpu") | |
| # ์ ๋ ๊ฒฝ๋ก ์ค์ | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # --- A. ๊ท์น ๊ธฐ๋ฐ ๋ชจ๋ธ ๋ก๋ (.pkl) --- | |
| try: | |
| pkl_path = os.path.join(BASE_DIR, "rule_based_scorer.pkl") | |
| if os.path.exists(pkl_path): | |
| with open(pkl_path, "rb") as f: | |
| rule_scorer = pickle.load(f) | |
| print(f"โ [Aggro] ๊ท์น ๊ธฐ๋ฐ ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต!") | |
| else: | |
| print(f"โ ๏ธ [Aggro] ๊ท์น ํ์ผ ์์: {pkl_path}") | |
| except Exception as e: | |
| print(f"๐จ [Aggro] ๊ท์น ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| # --- B. KoBERT ๋ชจ๋ธ ๋ก๋ (.pth) --- | |
| try: | |
| MODEL_NAME = 'skt/kobert-base-v1' | |
| try: | |
| tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) | |
| bert_base = BertModel.from_pretrained(MODEL_NAME) | |
| except: | |
| print("โ ๏ธ skt ๋ชจ๋ธ ๋ก๋ ์คํจ, monologg๋ก ์ฌ์๋...") | |
| tokenizer = BertTokenizer.from_pretrained('monologg/kobert') | |
| bert_base = BertModel.from_pretrained('monologg/kobert') | |
| aggro_model = FishingClassifier(bert_base).to(device) | |
| pth_path = os.path.join(BASE_DIR, "bert_fishing_model_best.pth") | |
| if os.path.exists(pth_path): | |
| state_dict = torch.load(pth_path, map_location=device) | |
| aggro_model.load_state_dict(state_dict) | |
| aggro_model.eval() | |
| print(f"โ [Aggro] KoBERT ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต!") | |
| else: | |
| print(f"โ ๏ธ [Aggro] KoBERT ํ์ผ ์์: {pth_path}") | |
| aggro_model = None | |
| except Exception as e: | |
| print(f"๐จ [Aggro] KoBERT ๋ก๋ ์คํจ: {e}") | |
| aggro_model = None | |
| # ============================================================================= | |
| # 3. ๋ฉ์ธ ํจ์ | |
| # ============================================================================= | |
| def get_aggro_score(title: str) -> dict: | |
| # 1. ๊ท์น ๊ธฐ๋ฐ ์ ์ | |
| rule_score = 0.0 | |
| rule_pattern = "๋ถ์ ๋ถ๊ฐ" | |
| if rule_scorer: | |
| try: | |
| res = rule_scorer.get_score(title) | |
| rule_score = res['score'] | |
| rule_pattern = res.get('pattern_name', '์ ์ ์์') | |
| except Exception as e: | |
| print(f"๊ท์น ๊ณ์ฐ ์๋ฌ: {e}") | |
| rule_score = 50.0 | |
| # 2. KoBERT ์ ์ | |
| bert_score = 0.0 | |
| if aggro_model and tokenizer: | |
| try: | |
| inputs = tokenizer( | |
| title, return_tensors='pt', padding="max_length", truncation=True, max_length=64 | |
| ) | |
| input_ids = inputs['input_ids'].to(device) | |
| mask = inputs['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| outputs = aggro_model(input_ids, mask) | |
| probs = torch.softmax(outputs, dim=1) | |
| bert_score = probs[0][1].item() * 100 | |
| except: | |
| bert_score = 50.0 | |
| # 3. ํฉ์ฐ (๊ท์น 40% + BERT 60%) | |
| if rule_scorer and aggro_model: | |
| final_score = (rule_score * 0.4) + (bert_score * 0.6) | |
| elif aggro_model: | |
| final_score = bert_score | |
| elif rule_scorer: | |
| final_score = rule_score | |
| else: | |
| final_score = 0.0 | |
| # 4. ๊ฒฐ๊ณผ | |
| normalized_score = min(final_score / 100.0, 1.0) | |
| reason = "์ ๋ชฉ์ด ํ์ดํฉ๋๋ค." | |
| recommendation = "์ํธํฉ๋๋ค." | |
| if final_score >= 70: | |
| reason = f"AI ์์ธก({bert_score:.0f}์ ) ๋ฐ ๊ท์น({rule_pattern})์ ์ํด ๋์์ฑ์ผ๋ก ํ๋จ๋์์ต๋๋ค." | |
| recommendation = "๊ฐ๊ด์ ์ธ ์ฌ์ค ์์ฃผ์ ์ ๋ชฉ์ผ๋ก ์์ ํด์ฃผ์ธ์." | |
| elif final_score >= 40: | |
| reason = f"์ผ๋ถ ๋์์ฑ ์์({rule_pattern})๊ฐ ๊ฐ์ง๋์์ต๋๋ค." | |
| recommendation = "ํํ์ ์กฐ๊ธ ๋ ๋ค๋ฌ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค." | |
| return { | |
| "score": round(normalized_score, 4), | |
| "reason": reason, | |
| "recommendation": recommendation | |
| } |