PatoFlamejanteTV commited on
Commit
ce99828
·
verified ·
1 Parent(s): 1173745

Create classifier.py

Browse files
Files changed (1) hide show
  1. classifier.py +217 -0
classifier.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ classifier.py
3
+ Core pipeline: normalization, heuristics, multi-model inference, aggregation & explanations.
4
+
5
+ Designed to be defensive: flags suspicious content and explains why.
6
+ """
7
+
8
+ from typing import List, Dict, Any, Optional, Tuple
9
+ import re
10
+ import math
11
+ import logging
12
+
13
+ # Model imports
14
+ import torch
15
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
16
+
17
+ logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+
21
+ ###########################
22
+ # Configuration / models
23
+ ###########################
24
+
25
+ # Candidate model names (change to the exact models you prefer)
26
+ HARM_MODELS = [
27
+ "unitary/toxic-bert", # English toxic classifier
28
+ "unitary/multilingual-toxic-xlm-roberta" # multilingual toxic classifier
29
+ ]
30
+ URL_MODEL = "r3ddkahili/final-complete-malicious-url-model" # malicious URL detector
31
+
32
+ # thresholds (tunable)
33
+ THRESHOLDS = {
34
+ "harm": 0.5, # generic threshold for harmful label(s) (individual mapping below)
35
+ "url": 0.7, # suspicious/malicious probability threshold
36
+ "ascii_entropy": 3.0 # lower entropy -> suspicious
37
+ }
38
+
39
+ # model handles (populated by load_models())
40
+ MODEL_HANDLES = {
41
+ "harm": [], # list of tuples (name, tokenizer, model, label_map)
42
+ "url": None # tuple (name, tokenizer, model, label_map)
43
+ }
44
+
45
+
46
+ ###########################
47
+ # Utilities: normalization
48
+ ###########################
49
+
50
+ # Minimal homoglyph map (extend this in production)
51
+ HOMOGLYPH_MAP = {
52
+ '\u0430': 'a', # cyrillic a -> a
53
+ '\u0435': 'e', # cyrillic e -> e
54
+ '\u03BF': 'o', # greek omicron -> o
55
+ '0': 'o',
56
+ '1': 'l',
57
+ '@': 'a',
58
+ }
59
+
60
+ ZERO_WIDTH_PATTERN = re.compile('[\u200B-\u200F\uFEFF]')
61
+
62
+ def normalize_obfuscation(text: str) -> str:
63
+ """Normalize text: collapse whitespace, remove zero-width, apply basic homoglyph map."""
64
+ t = ZERO_WIDTH_PATTERN.sub('', text)
65
+ t = re.sub(r'\s+', ' ', t)
66
+ out_chars = []
67
+ for ch in t:
68
+ out_chars.append(HOMOGLYPH_MAP.get(ch, ch))
69
+ return ''.join(out_chars).strip()
70
+
71
+
72
+ def shannon_entropy(s: str) -> float:
73
+ """Return Shannon character entropy of string s."""
74
+ if not s:
75
+ return 0.0
76
+ s = s.replace(" ", "")
77
+ freq = {}
78
+ for c in s:
79
+ freq[c] = freq.get(c, 0) + 1
80
+ ent = 0.0
81
+ L = len(s)
82
+ for v in freq.values():
83
+ p = v / L
84
+ ent -= p * math.log2(p)
85
+ return ent
86
+
87
+
88
+ ###########################
89
+ # Heuristic detectors
90
+ ###########################
91
+
92
+ # suspicious URL-like tokens (shortlist of TLDs frequently used for obfuscation)
93
+ URL_OBFUSCATION_RE = re.compile(
94
+ r'([a-z0-9\-]{1,20}\s*[\.\[\(]? ?(?:link|site|xyz|to|ly|pw|click)\b)|' # e.g. site.link or site . link
95
+ r'(https?://)?[^\s]{1,64}\.(?:link|site|xyz|to|ly|pw|click)\b',
96
+ re.I
97
+ )
98
+
99
+ JAILBREAK_PATTERNS = [
100
+ re.compile(r"ignore (?:previous|all) instructions", re.I),
101
+ re.compile(r"(?:bypass|disable) (?:filters|moderation|safety)", re.I),
102
+ re.compile(r"rewire the (?:system|assistant) prompt", re.I),
103
+ re.compile(r"output (?:the|the full) system prompt", re.I),
104
+ ]
105
+
106
+ ASCII_ART_RE = re.compile(r'[\u2500-\u259F]|[_\-\|]{6,}|(?:\bASCII\b)', re.I)
107
+
108
+ # catches long runs of punctuation / separators (often used to hide tokens)
109
+ OBFUSCATION_SEP_RE = re.compile(r'([^\w\s]{2,}\s*){2,}')
110
+
111
+
112
+ def heuristic_scan(raw: str, normalized: str) -> List[Dict[str, Any]]:
113
+ flags = []
114
+ # URL heuristics
115
+ if URL_OBFUSCATION_RE.search(raw) or URL_OBFUSCATION_RE.search(normalized):
116
+ flags.append({"type": "hidden_link_heuristic", "explain": "Suspicious or obfuscated URL-like token detected by regex."})
117
+
118
+ # ascii-art / block text / low entropy
119
+ ent = shannon_entropy(re.sub(r'\s+', '', normalized))
120
+ if ASCII_ART_RE.search(raw) or ent < THRESHOLDS["ascii_entropy"]:
121
+ flags.append({"type": "ascii_art_heuristic", "explain": f"ASCII-art-like characters or low entropy text (entropy={ent:.2f})."})
122
+
123
+ # jailbreak heuristics
124
+ jail_matches = [p.pattern for p in JAILBREAK_PATTERNS if p.search(normalized)]
125
+ if jail_matches:
126
+ flags.append({"type": "ai_jailbreak_heuristic", "explain": "Patterns commonly used to override model safety detected.", "matches": jail_matches})
127
+
128
+ # obfuscation separators
129
+ if OBFUSCATION_SEP_RE.search(normalized):
130
+ flags.append({"type": "filter_obfuscation_heuristic", "explain": "Many non-alphanumeric separators or repeated punctuation — possible obfuscation."})
131
+
132
+ return flags
133
+
134
+
135
+ ###########################
136
+ # Model loading & helpers
137
+ ###########################
138
+
139
+ def safe_load_tokenizer_and_model(name: str) -> Optional[Tuple]:
140
+ """Try to load tokenizer and model; return None on failure gracefully."""
141
+ try:
142
+ tokenizer = AutoTokenizer.from_pretrained(name, use_fast=True)
143
+ model = AutoModelForSequenceClassification.from_pretrained(name)
144
+ model.eval()
145
+ if torch.cuda.is_available():
146
+ try:
147
+ model.to("cuda")
148
+ except Exception:
149
+ logger.warning("Could not move model to cuda.")
150
+ logger.info(f"Loaded model {name}")
151
+ return tokenizer, model
152
+ except Exception as e:
153
+ logger.warning(f"Failed to load {name}: {e}")
154
+ return None
155
+
156
+
157
+ def load_models():
158
+ """Populate MODEL_HANDLES with tokenizer+model pairs. Called once at import or app init."""
159
+ # load harm models list
160
+ for mname in HARM_MODELS:
161
+ res = safe_load_tokenizer_and_model(mname)
162
+ if res:
163
+ tokenizer, model = res
164
+ # attempt to extract label mapping (if model has config.id2label)
165
+ label_map = getattr(model.config, "id2label", None) or {}
166
+ MODEL_HANDLES["harm"].append((mname, tokenizer, model, label_map))
167
+
168
+ # load URL model
169
+ res = safe_load_tokenizer_and_model(URL_MODEL)
170
+ if res:
171
+ tokenizer, model = res
172
+ label_map = getattr(model.config, "id2label", None) or {}
173
+ MODEL_HANDLES["url"] = (URL_MODEL, tokenizer, model, label_map)
174
+
175
+ # Call at import
176
+ try:
177
+ load_models()
178
+ except Exception as e:
179
+ # keep running: models may be loaded later or in Space with more resources
180
+ logger.warning(f"Model loading raised: {e}")
181
+
182
+
183
+ ###########################
184
+ # Model runners
185
+ ###########################
186
+
187
+ def run_sequence_model(tokenizer, model, text: str, max_length=512) -> Dict[str, float]:
188
+ """Run a sequence classification model and return label->prob mapping (softmax)."""
189
+ inputs = tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt")
190
+ if torch.cuda.is_available() and next(model.parameters(), None) is not None:
191
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
192
+ with torch.no_grad():
193
+ outputs = model(**inputs)
194
+ logits = outputs.logits
195
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
196
+ # build mapping
197
+ id2label = getattr(model.config, "id2label", {})
198
+ if id2label:
199
+ return {id2label.get(i, str(i)): float(probs[i]) for i in range(len(probs))}
200
+ else:
201
+ # fallback: numeric labels
202
+ return {str(i): float(probs[i]) for i in range(len(probs))}
203
+
204
+
205
+ def aggregate_harm_predictions(preds: List[Dict[str, float]]) -> Dict[str, Any]:
206
+ """
207
+ Combine multiple harm model outputs.
208
+ We compute per-label averages and maxes, and decide whether to flag.
209
+ """
210
+ if not preds:
211
+ return {"combined": {}, "note": "no harm models loaded"}
212
+ label_set = set()
213
+ for p in preds:
214
+ label_set.update(p.keys())
215
+ combined = {}
216
+ for lbl in label_set:
217
+ vals = [p.get(lbl, 0.0)