temsa's picture
Publish ContextPII rc4 with decoder hardening and updated evals
3460734 verified
#!/usr/bin/env python3
from __future__ import annotations
import json
import re
import tempfile
import unicodedata
from pathlib import Path
from typing import Any
import numpy as np
from huggingface_hub import HfApi, hf_hub_download
from transformers import AutoConfig, AutoTokenizer
TOKENIZER_FILES = [
"tokenizer_config.json",
"tokenizer.json",
"special_tokens_map.json",
"vocab.txt",
"vocab.json",
"merges.txt",
"added_tokens.json",
"sentencepiece.bpe.model",
"spiece.model",
]
DEFAULT_LABEL_MAX_SPAN_TOKENS = {
"PPSN": 9,
"POSTCODE": 8,
"PHONE_NUMBER": 10,
"PASSPORT_NUMBER": 8,
"BANK_ROUTING_NUMBER": 6,
"ACCOUNT_NUMBER": 19,
"CREDIT_DEBIT_CARD": 12,
"SWIFT_BIC": 8,
"EMAIL": 16,
"FIRST_NAME": 6,
"LAST_NAME": 8,
"AGE": 8,
"CITY": 8,
"COUNTY": 8,
"DATE_OF_BIRTH": 8,
"STREET_ADDRESS": 10,
}
DEFAULT_LABEL_MIN_NONSPACE_CHARS = {
"PPSN": 8,
"POSTCODE": 6,
"PHONE_NUMBER": 7,
"PASSPORT_NUMBER": 7,
"BANK_ROUTING_NUMBER": 6,
"ACCOUNT_NUMBER": 6,
"CREDIT_DEBIT_CARD": 12,
"SWIFT_BIC": 8,
"EMAIL": 6,
"FIRST_NAME": 2,
"LAST_NAME": 2,
"AGE": 1,
"CITY": 1,
"COUNTY": 1,
"DATE_OF_BIRTH": 1,
"STREET_ADDRESS": 1,
}
OUTPUT_PRIORITY = {
"PPSN": 0,
"PASSPORT_NUMBER": 1,
"ACCOUNT_NUMBER": 2,
"BANK_ROUTING_NUMBER": 3,
"CREDIT_DEBIT_CARD": 4,
"PHONE_NUMBER": 5,
"SWIFT_BIC": 6,
"POSTCODE": 7,
"EMAIL": 8,
"DATE_OF_BIRTH": 9,
"AGE": 10,
"STREET_ADDRESS": 11,
"CITY": 12,
"COUNTY": 13,
"FIRST_NAME": 14,
"LAST_NAME": 15,
}
def normalize_entity_name(label: str) -> str:
label = (label or '').strip()
if label.startswith('B-') or label.startswith('I-'):
label = label[2:]
return label.upper()
def _sanitize_tokenizer_dir(tokenizer_path: Path) -> str:
tokenizer_cfg_path = tokenizer_path / 'tokenizer_config.json'
if not tokenizer_cfg_path.exists():
return str(tokenizer_path)
data = json.loads(tokenizer_cfg_path.read_text(encoding='utf-8'))
if 'fix_mistral_regex' not in data:
return str(tokenizer_path)
tmpdir = Path(tempfile.mkdtemp(prefix='openmed_gp_tokenizer_'))
keep = set(TOKENIZER_FILES)
for child in tokenizer_path.iterdir():
if child.is_file() and child.name in keep:
(tmpdir / child.name).write_bytes(child.read_bytes())
data.pop('fix_mistral_regex', None)
(tmpdir / 'tokenizer_config.json').write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding='utf-8')
return str(tmpdir)
def safe_auto_tokenizer(tokenizer_ref: str):
tokenizer_path = Path(tokenizer_ref)
if tokenizer_path.exists():
tokenizer_ref = _sanitize_tokenizer_dir(tokenizer_path)
else:
api = HfApi()
files = set(api.list_repo_files(repo_id=tokenizer_ref, repo_type='model'))
tmpdir = Path(tempfile.mkdtemp(prefix='openmed_gp_remote_tokenizer_'))
copied = False
for name in TOKENIZER_FILES:
if name not in files:
continue
src = hf_hub_download(repo_id=tokenizer_ref, filename=name, repo_type='model')
(tmpdir / Path(name).name).write_bytes(Path(src).read_bytes())
copied = True
if copied:
tokenizer_ref = _sanitize_tokenizer_dir(tmpdir)
try:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True)
except Exception:
pass
try:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False)
except TypeError:
pass
try:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True)
except Exception:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=False)
def label_names_from_config(config) -> list[str]:
names = list(getattr(config, 'span_label_names', []))
if not names:
raise ValueError('Missing span_label_names in config')
return [normalize_entity_name(name) for name in names]
def label_max_span_tokens_from_config(config) -> dict[str, int]:
raw = getattr(config, 'span_label_max_span_tokens', None) or {}
out = {normalize_entity_name(key): int(value) for key, value in raw.items()}
for label, value in DEFAULT_LABEL_MAX_SPAN_TOKENS.items():
out.setdefault(label, value)
for label in label_names_from_config(config):
out.setdefault(label, 8)
return out
def label_min_nonspace_chars_from_config(config) -> dict[str, int]:
raw = getattr(config, 'span_label_min_nonspace_chars', None) or {}
out = {normalize_entity_name(key): int(value) for key, value in raw.items()}
for label, value in DEFAULT_LABEL_MIN_NONSPACE_CHARS.items():
out.setdefault(label, value)
for label in label_names_from_config(config):
out.setdefault(label, 1)
return out
def overlaps(a: dict, b: dict) -> bool:
return not (a['end'] <= b['start'] or b['end'] <= a['start'])
def dedupe_spans(spans: list[dict]) -> list[dict]:
ordered = sorted(
spans,
key=lambda item: (-float(item.get('score', 0.0)), item['start'], item['end'], OUTPUT_PRIORITY.get(item['label'], 99)),
)
kept = []
for span in ordered:
if any(overlaps(span, other) for other in kept):
continue
kept.append(span)
kept.sort(key=lambda item: (item['start'], item['end'], OUTPUT_PRIORITY.get(item['label'], 99)))
return kept
def load_onnx_session(model_ref: str, onnx_file: str = 'model_quantized.onnx', onnx_subfolder: str = 'onnx'):
import onnxruntime as ort
model_path = Path(model_ref)
if model_path.exists():
candidates = []
if onnx_subfolder:
candidates.append(model_path / onnx_subfolder / onnx_file)
candidates.append(model_path / onnx_file)
onnx_path = next((path for path in candidates if path.exists()), candidates[0])
config = AutoConfig.from_pretrained(model_ref)
tokenizer = safe_auto_tokenizer(model_ref)
else:
remote_name = f"{onnx_subfolder}/{onnx_file}" if onnx_subfolder else onnx_file
onnx_path = Path(hf_hub_download(repo_id=model_ref, filename=remote_name, repo_type='model'))
config = AutoConfig.from_pretrained(model_ref)
tokenizer = safe_auto_tokenizer(model_ref)
session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
return session, tokenizer, config
def run_onnx_span(session, encoded: dict[str, Any]) -> np.ndarray:
feed = {}
input_names = {item.name for item in session.get_inputs()}
for key, value in encoded.items():
if key == 'offset_mapping':
continue
if key in input_names:
feed[key] = value
outputs = session.run(None, feed)
if not outputs:
raise ValueError('ONNX session returned no outputs')
return outputs[0]
def label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]:
raw = getattr(config, "span_label_thresholds", None) or {}
out = {normalize_entity_name(key): float(value) for key, value in raw.items()}
for label in label_names_from_config(config):
out.setdefault(label, float(default_threshold))
return out
def valid_offset(offset: tuple[int, int]) -> bool:
return bool(offset) and int(offset[1]) > int(offset[0])
def nonspace_length(text: str, start: int, end: int) -> int:
return sum(0 if ch.isspace() else 1 for ch in text[int(start) : int(end)])
def alnum_upper(text: str) -> str:
return "".join(ch for ch in text.upper() if ch.isalnum())
def normalize_surface(text: str) -> str:
value = unicodedata.normalize("NFKD", text)
value = "".join(ch for ch in value if not unicodedata.combining(ch))
value = value.replace("\u00A0", " ").replace("\u202F", " ")
value = re.sub(r"\s+", " ", value.strip().lower())
return value
IRISH_CITY_FORMS = (
"Dublin",
"Baile Átha Cliath",
"Galway",
"Gaillimh",
"Cork",
"Cork City",
"Corcaigh",
"Limerick",
"Luimneach",
"Waterford",
"Port Láirge",
"Kilkenny",
"Cill Chainnigh",
"Carlow",
"Ceatharlach",
)
IRISH_CITY_SURFACES = {normalize_surface(value) for value in IRISH_CITY_FORMS}
IRISH_COUNTY_SURFACES = {
normalize_surface(value)
for value in {
"Co. Dublin",
"Co. Bhaile Átha Cliath",
"Co. Galway",
"Co. na Gaillimhe",
"Co. Cork",
"Co. Chorcaí",
"Co. Limerick",
"Co. Luimnigh",
"Co. Waterford",
"Co. Phort Láirge",
"Co. Kilkenny",
"Co. Chill Chainnigh",
"Co. Carlow",
"Co. Cheatharlach",
}
}
STREET_SUFFIX_RE = re.compile(
r"(?i)\b(street|road|avenue|lane|park|view|square|terrace|drive|close|way|place|bóthar|bothar|sráid|sraid|lána|lana)\b"
)
PHONE_SURFACE_RE = re.compile(r"^[+()\d][+()\d \-\u00A0\u202F]*\d$")
ACCOUNT_DIGIT_SURFACE_RE = re.compile(r"^[\d \-\u00A0\u202F]+$")
DATE_OF_BIRTH_RE = re.compile(
r"(?i)^(?:\d{2}/\d{2}/\d{4}|\d{4}-\d{2}-\d{2}|\d{1,2}(?:st|nd|rd|th)?\s+[A-Za-zÁÉÍÓÚáéíóú]+(?:\s+[A-Za-zÁÉÍÓÚáéíóú]+)?\s+\d{4})$"
)
DATE_OF_BIRTH_VALUE_RE = re.compile(
r"(?<![A-Za-z0-9])(\d{2}/\d{2}/\d{4}|\d{4}-\d{2}-\d{2}|\d{1,2}(?:st|nd|rd|th)?\s+[A-Za-zÁÉÍÓÚáéíóú]+(?:\s+[A-Za-zÁÉÍÓÚáéíóú]+)?\s+\d{4})(?![A-Za-z0-9])"
)
AGE_CONTEXT_RE = re.compile(r"(?i)\b(age|years?\s+old|year\s+old|aois|bliana\s+d['’]aois)\b")
DOB_CONTEXT_RE = re.compile(
r"(?i)\b(dob|date\s+of\s+birth|born\s+on|data\s+breithe|dáta\s+breithe|dhata\s+breithe|dháta\s+breithe|rugadh)\b"
)
PPSN_CUE_RE = re.compile(
r"(?i)\b(ppsn|upsp|personal public service(?:\s+number)?|uimhir\s+(?:mo\s+)?upsp|uimhir\s+(?:mo\s+)?ppsn)\b"
)
NAME_STOP_SURFACES = {
normalize_surface(value)
for value in {
"Address",
"Name",
"Phone",
"Email",
"Seoladh",
"Ainm",
"Teagmháil",
"Teagmhail",
"Ríomhphost",
"Riomhphost",
"PPSN",
"UPSP",
"Call",
"Glao",
"Glaoigh",
"Rugadh",
"Ionad",
"Intreo",
"Cill",
"Sampla",
"Leithdháilte",
"Leithdhailte",
"Leithdháil",
"Leithdhail",
"Leithdh",
"Fón",
"Fon",
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
"Sunday",
"Eanáir",
"Feabhra",
"Márta",
"Aibreán",
"Aibrean",
"Bealtaine",
"Meitheamh",
"Iúil",
"Iuil",
"Lúnasa",
"Lunasa",
"Meán Fómhair",
"Mean Fomhair",
"Deireadh Fómhair",
"Deireadh Fomhair",
"Samhain",
"Nollaig",
"Luan",
"Máirt",
"Mairt",
"Céadaoin",
"Ceadaoin",
"Déardaoin",
"Deardaoin",
"Aoine",
"Satharn",
"Domhnach",
}
}
NAME_PARTICLE_SURFACES = {
normalize_surface(value)
for value in {"Ó", "O", "Ní", "Ni", "Nic", "Mac", "Mc", "de", "van", "von"}
}
STREET_TRAILING_BLOCK_SURFACES = {
normalize_surface(value)
for value in {
"are",
"public",
"contact",
"details",
"website",
"open",
"before",
"visiting",
"roimh",
"chuairt",
"agus",
"and",
"the",
"is",
"ta",
}
}
def is_plausible_last_name_sequence(value: str) -> bool:
tokens = [token for token in re.split(r"\s+", value.strip()) if token]
if not tokens:
return False
for token in tokens:
if not any(ch.isalpha() for ch in token):
return False
if not all(is_name_token_char(ch) for ch in token):
return False
alpha_chars = [ch for ch in token if ch.isalpha()]
first_alpha = alpha_chars[0] if alpha_chars else ""
if first_alpha.isupper():
continue
if len(alpha_chars) >= 2 and alpha_chars[0].islower() and alpha_chars[1].isupper():
continue
if normalize_surface(token) in NAME_PARTICLE_SURFACES:
continue
return False
return True
def is_reasonable_span_text(label: str, text: str, start: int, end: int) -> bool:
value = text[int(start) : int(end)].strip()
if not value:
return False
upper = alnum_upper(value)
if label in {"FIRST_NAME", "LAST_NAME"}:
if not any(ch.isalpha() for ch in value):
return False
if any(ch.isdigit() for ch in value):
return False
if normalize_surface(value) in NAME_STOP_SURFACES:
return False
if label == "FIRST_NAME" and any(ch.isspace() for ch in value):
return False
if any(ch in ".,;:/@()" for ch in value):
return False
if label == "FIRST_NAME":
first_alpha = next((ch for ch in value if ch.isalpha()), "")
if not first_alpha or not first_alpha.isupper():
return False
if label == "LAST_NAME" and not is_plausible_last_name_sequence(value):
return False
if start > 0 and text[int(start) - 1].isdigit():
return False
return True
if label == "EMAIL":
if "@" not in value:
return False
local, _, domain = value.partition("@")
return bool(local) and "." in domain
if label == "PHONE_NUMBER":
normalized = value.replace("\u00A0", " ").replace("\u202F", " ").strip()
if any(ch.isalpha() for ch in normalized):
return False
if any(ch in "/@" for ch in normalized):
return False
if int(start) > 0 and text[int(start) - 1].isalnum():
return False
if int(end) < len(text) and text[int(end)].isalnum():
return False
if not PHONE_SURFACE_RE.match(normalized):
return False
digits = "".join(ch for ch in value if ch.isdigit())
if normalized.startswith("+353"):
return 11 <= len(digits) <= 12
if not digits.startswith("0"):
return False
if digits.startswith("0818") or digits.startswith("1800"):
return len(digits) == 10
if digits.startswith("08"):
return len(digits) == 10
if digits.startswith("01"):
return len(digits) == 9
return 9 <= len(digits) <= 10
if label == "PPSN":
return bool(len(upper) in {8, 9} and upper[:7].isdigit() and upper[7:].isalpha())
if label == "POSTCODE":
compact = value.replace(" ", "").replace("\u00A0", "").replace("\u202F", "")
if any(not (ch.isalnum() or ch.isspace()) for ch in value):
return False
if len(compact) != 7:
return False
routing = compact[:3]
unique = compact[3:]
routing_ok = bool(
(routing[0].isalpha() and routing[1:].isdigit())
or routing == "D6W"
)
unique_ok = bool(
len(unique) == 4
and unique[0].isalpha()
and unique[1:].isalnum()
)
return routing_ok and unique_ok
if label == "PASSPORT_NUMBER":
return bool(re.fullmatch(r"[A-Z]{1,2}\s?\d{7}", value.strip()))
if label == "BANK_ROUTING_NUMBER":
digits = "".join(ch for ch in value if ch.isdigit())
if len(digits) != 6:
return False
context = text[max(0, int(start) - 32) : min(len(text), int(end) + 24)]
return bool(BANK_ROUTING_CONTEXT_RE.search(context))
if label == "SWIFT_BIC":
return len(upper) in {8, 11} and upper.isalnum()
if label == "CREDIT_DEBIT_CARD":
digits = "".join(ch for ch in value if ch.isdigit())
return 12 <= len(digits) <= 19
if label == "ACCOUNT_NUMBER":
if upper.startswith("IE"):
return bool(re.fullmatch(r"IE\d{2}[A-Z0-9]{18}", upper))
if not ACCOUNT_DIGIT_SURFACE_RE.fullmatch(value.strip()):
return False
digits = "".join(ch for ch in value if ch.isdigit())
return 6 <= len(digits) <= 34
if label == "AGE":
digits = "".join(ch for ch in value if ch.isdigit())
if digits != value.strip():
return False
if not digits:
return False
if int(start) > 0 and text[int(start) - 1].isalnum():
return False
if int(end) < len(text) and text[int(end)].isalnum():
return False
if int(start) > 0 and text[int(start) - 1] in "/-":
return False
if int(end) < len(text) and text[int(end)] in "/-":
return False
age = int(digits)
if not (0 < age <= 120):
return False
context = text[max(0, int(start) - 24) : min(len(text), int(end) + 24)]
return bool(AGE_CONTEXT_RE.search(context))
if label == "DATE_OF_BIRTH":
if not any(ch.isdigit() for ch in value):
return False
if not DATE_OF_BIRTH_RE.match(value.strip()):
return False
context = text[max(0, int(start) - 32) : min(len(text), int(end) + 32)]
return bool(DOB_CONTEXT_RE.search(context))
if label == "CITY":
if any(ch.isdigit() for ch in value):
return False
return normalize_surface(value) in IRISH_CITY_SURFACES
if label == "COUNTY":
if any(ch.isdigit() for ch in value):
return False
return normalize_surface(value) in IRISH_COUNTY_SURFACES
if label == "STREET_ADDRESS":
cleaned = value.strip()
suffix_match = STREET_SUFFIX_RE.search(cleaned)
if not suffix_match:
return False
if any(ch in "@,:;" for ch in cleaned):
return False
trailing = cleaned[int(suffix_match.end()) :].strip()
trailing_tokens = [token for token in re.split(r"\s+", trailing) if token]
if len(trailing_tokens) > 3:
return False
if any(normalize_surface(token) in STREET_TRAILING_BLOCK_SURFACES for token in trailing_tokens):
return False
has_digit = any(ch.isdigit() for ch in cleaned)
if has_digit and not re.match(r"^\s*\d{1,4}\b", cleaned):
return False
title_tokens = [token for token in re.split(r"\s+", cleaned) if token]
return has_digit or len(title_tokens) >= 2
return True
def spans_overlap(a: dict, b: dict) -> bool:
return int(a["start"]) < int(b["end"]) and int(b["start"]) < int(a["end"])
def is_name_token_char(ch: str) -> bool:
return ch.isalpha() or ch in {"-", "'", "’"}
def is_plausible_first_name(value: str) -> bool:
if not value:
return False
if any(ch.isspace() for ch in value):
return False
if any(ch.isdigit() for ch in value):
return False
if any(ch in ",;:/@()" for ch in value):
return False
if not any(ch.isalpha() for ch in value):
return False
first_alpha = next((ch for ch in value if ch.isalpha()), "")
if not first_alpha or not first_alpha.isupper():
return False
return all(is_name_token_char(ch) for ch in value)
def repair_first_name_from_last_name(text: str, spans: list[dict]) -> list[dict]:
repaired = list(spans)
for last_name in [span for span in repaired if span["label"] == "LAST_NAME"]:
if any(
span["label"] == "FIRST_NAME"
and int(span["end"]) <= int(last_name["start"])
and int(last_name["start"]) - int(span["end"]) <= 2
for span in repaired
):
continue
cursor = int(last_name["start"]) - 1
if cursor < 0 or not text[cursor].isspace():
continue
while cursor >= 0 and text[cursor].isspace():
cursor -= 1
token_end = cursor + 1
while cursor >= 0 and is_name_token_char(text[cursor]):
cursor -= 1
token_start = cursor + 1
if token_end <= token_start:
continue
candidate = text[token_start:token_end]
if not is_plausible_first_name(candidate):
continue
candidate_span = {
"start": token_start,
"end": token_end,
"label": "FIRST_NAME",
"score": float(last_name.get("score", 0.5)) * 0.6,
"text": candidate,
}
if any(spans_overlap(candidate_span, other) for other in repaired if other["label"] == "FIRST_NAME"):
continue
repaired.append(candidate_span)
return repaired
PASSPORT_CUE_RE = re.compile(
r"(?i)(passport(?:\s+number)?|phas|uimhir\s+(?:mo\s+)?phas)"
)
PASSPORT_VALUE_RE = re.compile(r"(?<![A-Za-z0-9])([A-Z]{1,2}\s?\d{7})(?![A-Za-z0-9])")
EMAIL_EXTRACT_RE = re.compile(r"([^\s@,;:()<>]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,})", re.UNICODE)
PHONE_CUE_RE = re.compile(
r"(?i)\b(phone|call|contact|reach\s+me|glaoigh\s+ar|teagmh[aá]il|uimhir|m['’]uimhir)\b"
)
PUBLIC_CONTACT_DETAILS_RE = re.compile(r"(?i)\bpublic\s+contact\s+details\b")
CITY_CUE_RE = re.compile(
r"(?i)\b(address|seoladh|located|suite|centre|center|ionad|intreo|clinic|hospital|ospid[eé]al|hse|fss)\b"
)
BANK_ROUTING_CONTEXT_RE = re.compile(
r"(?i)\b(sort\s+code|routing\s+number|bank\s+of\s+ireland|aib|cod\s+sort[aá]la|sort[aá]la)\b"
)
PHONE_VALUE_RE = re.compile(
r"(?<![A-Za-z0-9])((?:\+353|0)\d(?:[\s\-]?\d){7,13})(?![A-Za-z0-9])"
)
PPSN_VALUE_RE = re.compile(r"(?<![A-Za-z0-9])(\d{7}(?:[\s-]*[A-Za-z]){1,2})(?![A-Za-z0-9])")
POSTCODE_VALUE_RE = re.compile(
r"(?<![A-Za-z0-9])((?:[A-Za-z]\d{2}|D6W)[\s\u00A0\u202F]?[A-Za-z][A-Za-z0-9]{3})(?![A-Za-z0-9])"
)
def repair_contextual_passport_numbers(text: str, spans: list[dict]) -> list[dict]:
repaired = list(spans)
for match in PASSPORT_VALUE_RE.finditer(text):
start, end = match.span(1)
candidate_span = {
"start": start,
"end": end,
"label": "PASSPORT_NUMBER",
"score": 0.67,
"text": text[start:end],
}
if any(
other["label"] == "PASSPORT_NUMBER"
and int(other["start"]) <= start
and int(other["end"]) >= end
for other in repaired
):
continue
cue_window = text[max(0, start - 32) : start]
if not PASSPORT_CUE_RE.search(cue_window):
continue
conflicting_labels = {"PHONE_NUMBER", "PPSN", "ACCOUNT_NUMBER", "AGE", "PASSPORT_NUMBER"}
repaired = [
other
for other in repaired
if not (
spans_overlap(candidate_span, other)
and other["label"] in conflicting_labels
)
]
repaired.append(candidate_span)
return repaired
def repair_ppsn_variants(text: str, spans: list[dict]) -> list[dict]:
repaired = list(spans)
for match in PPSN_VALUE_RE.finditer(text):
start, end = match.span(1)
value = text[start:end]
compact = alnum_upper(value)
if not (len(compact) in {8, 9} and compact[:7].isdigit() and compact[7:].isalpha()):
continue
cue_window = text[max(0, start - 32) : min(len(text), end + 24)]
has_cue = bool(PPSN_CUE_RE.search(cue_window))
candidate_span = {
"start": start,
"end": end,
"label": "PPSN",
"score": 0.72 if has_cue else 0.58,
"text": value,
}
conflicting_labels = {"PHONE_NUMBER", "PASSPORT_NUMBER", "ACCOUNT_NUMBER", "AGE", "FIRST_NAME", "LAST_NAME"}
repaired = [
other
for other in repaired
if not (
spans_overlap(candidate_span, other)
and other["label"] in conflicting_labels.union({"PPSN"})
)
]
repaired.append(candidate_span)
return repaired
def repair_contextual_date_of_birth(text: str, spans: list[dict]) -> list[dict]:
repaired = list(spans)
for match in DATE_OF_BIRTH_VALUE_RE.finditer(text):
start, end = match.span(1)
cue_window = text[max(0, start - 40) : min(len(text), end + 24)]
if not DOB_CONTEXT_RE.search(cue_window):
continue
candidate_span = {
"start": start,
"end": end,
"label": "DATE_OF_BIRTH",
"score": 0.66,
"text": text[start:end],
}
conflicting_labels = {"DATE_OF_BIRTH", "PHONE_NUMBER", "AGE", "FIRST_NAME", "LAST_NAME", "ACCOUNT_NUMBER"}
repaired = [
other
for other in repaired
if not (
spans_overlap(candidate_span, other)
and other["label"] in conflicting_labels
)
]
repaired.append(candidate_span)
return repaired
ACCOUNT_CUE_RE = re.compile(
r"(?i)(account\s+number|bank\s+account|uimhir\s+chuntais|cuntas\s+bainc)"
)
ACCOUNT_VALUE_RE = re.compile(r"(?<![A-Za-z0-9])(\d{6,12})(?![A-Za-z0-9])")
def repair_contextual_account_numbers(text: str, spans: list[dict]) -> list[dict]:
repaired = list(spans)
for match in ACCOUNT_VALUE_RE.finditer(text):
start, end = match.span(1)
candidate_span = {
"start": start,
"end": end,
"label": "ACCOUNT_NUMBER",
"score": 0.51,
"text": text[start:end],
}
if any(
other["label"] == "ACCOUNT_NUMBER"
and int(other["start"]) <= start
and int(other["end"]) >= end
for other in repaired
):
continue
cue_window = text[max(0, start - 40) : start]
if not ACCOUNT_CUE_RE.search(cue_window):
continue
if any(
spans_overlap(candidate_span, other)
and other["label"] in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "PPSN", "POSTCODE", "PASSPORT_NUMBER"}
for other in repaired
):
continue
repaired.append(candidate_span)
return repaired
def repair_emails(text: str, spans: list[dict]) -> list[dict]:
repaired = list(spans)
for match in EMAIL_EXTRACT_RE.finditer(text):
start, end = match.span(1)
candidate_span = {
"start": start,
"end": end,
"label": "EMAIL",
"score": 0.74,
"text": text[start:end],
}
conflicting_labels = {"EMAIL", "FIRST_NAME", "LAST_NAME"}
repaired = [
other
for other in repaired
if not (
spans_overlap(candidate_span, other)
and other["label"] in conflicting_labels
)
]
repaired.append(candidate_span)
return repaired
def repair_phone_numbers(text: str, spans: list[dict]) -> list[dict]:
repaired = list(spans)
for match in PHONE_VALUE_RE.finditer(text):
start, end = match.span(1)
candidate_span = {
"start": start,
"end": end,
"label": "PHONE_NUMBER",
"score": 0.69,
"text": text[start:end],
}
cue_window = text[max(0, start - 32) : min(len(text), end + 16)]
has_cue = bool(PHONE_CUE_RE.search(cue_window))
has_overlap = any(spans_overlap(candidate_span, other) and other["label"] == "PHONE_NUMBER" for other in repaired)
if not (has_cue or has_overlap):
continue
if not is_reasonable_span_text("PHONE_NUMBER", text, start, end):
continue
conflicting_labels = {"PHONE_NUMBER", "PPSN", "ACCOUNT_NUMBER", "BANK_ROUTING_NUMBER"}
repaired = [
other
for other in repaired
if not (
spans_overlap(candidate_span, other)
and other["label"] in conflicting_labels
)
]
repaired.append(candidate_span)
return repaired
def repair_postcodes(text: str, spans: list[dict]) -> list[dict]:
repaired = list(spans)
for match in POSTCODE_VALUE_RE.finditer(text):
start, end = match.span(1)
candidate_span = {
"start": start,
"end": end,
"label": "POSTCODE",
"score": 0.71,
"text": text[start:end],
}
conflicting_labels = {"POSTCODE", "PHONE_NUMBER", "ACCOUNT_NUMBER", "FIRST_NAME", "LAST_NAME"}
repaired = [
other
for other in repaired
if not (
spans_overlap(candidate_span, other)
and other["label"] in conflicting_labels
)
]
repaired.append(candidate_span)
return repaired
def repair_city_spans(text: str, spans: list[dict]) -> list[dict]:
repaired = list(spans)
seen: set[tuple[int, int]] = set()
ordered_forms = sorted(IRISH_CITY_FORMS, key=len, reverse=True)
for form in ordered_forms:
for match in re.finditer(re.escape(form), text, flags=re.IGNORECASE):
start, end = match.span()
key = (start, end)
if key in seen:
continue
seen.add(key)
candidate_span = {
"start": start,
"end": end,
"label": "CITY",
"score": 0.64,
"text": text[start:end],
}
has_context = False
for other in repaired:
other_start = int(other["start"])
other_end = int(other["end"])
if other["label"] == "STREET_ADDRESS" and 0 <= start - other_end <= 4:
has_context = True
break
if other["label"] in {"COUNTY", "POSTCODE"} and 0 <= other_start - end <= 6:
has_context = True
break
if not has_context and re.match(r"^\s*,\s*(?:Co\.\s+|[A-Z]\d{2}|D6W)", text[end:]):
has_context = True
if not has_context:
cue_window = text[max(0, start - 40) : min(len(text), end + 32)]
has_context = bool(CITY_CUE_RE.search(cue_window))
if not has_context:
continue
conflicting_labels = {"CITY", "FIRST_NAME", "LAST_NAME"}
repaired = [
other
for other in repaired
if not (
spans_overlap(candidate_span, other)
and other["label"] in conflicting_labels
)
]
repaired.append(candidate_span)
return repaired
def drop_public_contact_detail_spans(text: str, spans: list[dict]) -> list[dict]:
if not PUBLIC_CONTACT_DETAILS_RE.search(text):
return spans
return [
span
for span in spans
if not (span["label"] in {"STREET_ADDRESS", "CITY", "COUNTY"} and PUBLIC_CONTACT_DETAILS_RE.search(text))
]
def drop_city_org_prefix_spans(text: str, spans: list[dict]) -> list[dict]:
keep: list[dict] = []
for span in spans:
if span["label"] != "CITY":
keep.append(span)
continue
tail = text[int(span["end"]) : min(len(text), int(span["end"]) + 24)]
if re.match(r"^\s+Intreo\s+Centre\b", tail):
continue
keep.append(span)
return keep
def canonicalize_email_spans(text: str, spans: list[dict]) -> list[dict]:
repaired: list[dict] = []
for span in spans:
if span["label"] != "EMAIL":
repaired.append(span)
continue
segment = text[int(span["start"]) : int(span["end"])]
match = EMAIL_EXTRACT_RE.search(segment)
if not match:
repaired.append(span)
continue
start = int(span["start"]) + int(match.start(1))
end = int(span["start"]) + int(match.end(1))
repaired.append(
{
**span,
"start": start,
"end": end,
"text": text[start:end],
}
)
return repaired
def drop_stacked_first_names(spans: list[dict]) -> list[dict]:
if not spans:
return spans
first_names = [span for span in spans if span["label"] == "FIRST_NAME"]
last_names = [span for span in spans if span["label"] == "LAST_NAME"]
if not first_names or not last_names:
return spans
keep: list[dict] = []
for span in spans:
if span["label"] != "FIRST_NAME":
keep.append(span)
continue
shadowed = False
for other in first_names:
if other is span:
continue
if int(other["start"]) <= int(span["start"]):
continue
if int(other["start"]) - int(span["end"]) > 2:
continue
if not any(
int(last["start"]) >= int(other["end"]) and int(last["start"]) - int(other["end"]) <= 2
for last in last_names
):
continue
shadowed = True
break
if not shadowed:
keep.append(span)
return keep
def decode_span_matrix(
text: str,
offsets: list[tuple[int, int]],
span_scores: np.ndarray,
config,
min_score: float,
) -> list[dict]:
label_names = label_names_from_config(config)
thresholds = label_thresholds_from_config(config, min_score)
max_span_tokens = label_max_span_tokens_from_config(config)
min_nonspace_chars = label_min_nonspace_chars_from_config(config)
if span_scores.ndim != 3:
raise ValueError(f"Expected [num_labels, seq_len, seq_len] span scores, got shape {span_scores.shape}")
num_labels, seq_len, _ = span_scores.shape
spans: list[dict] = []
for label_index in range(min(num_labels, len(label_names))):
label = label_names[label_index]
threshold = thresholds.get(label, min_score)
max_width = max(1, int(max_span_tokens.get(label, 8)))
min_chars = max(1, int(min_nonspace_chars.get(label, 1)))
for start_idx in range(seq_len):
start_offset = offsets[start_idx]
if not valid_offset(start_offset):
continue
max_end = min(seq_len, start_idx + max_width)
for end_idx in range(start_idx, max_end):
end_offset = offsets[end_idx]
if not valid_offset(end_offset):
continue
score = float(span_scores[label_index, start_idx, end_idx])
if score < threshold:
continue
start_char = int(start_offset[0])
end_char = int(end_offset[1])
if end_char <= start_char:
continue
if nonspace_length(text, start_char, end_char) < min_chars:
continue
if not is_reasonable_span_text(label, text, start_char, end_char):
continue
spans.append(
{
"start": start_char,
"end": end_char,
"label": label,
"score": score,
"text": text[start_char:end_char],
}
)
spans = prefer_long_name_spans(spans, thresholds)
spans = prefer_long_structured_spans(spans, thresholds)
spans = repair_first_name_from_last_name(text, spans)
spans = repair_emails(text, spans)
spans = repair_phone_numbers(text, spans)
spans = repair_ppsn_variants(text, spans)
spans = repair_postcodes(text, spans)
spans = repair_city_spans(text, spans)
spans = repair_contextual_date_of_birth(text, spans)
spans = repair_contextual_passport_numbers(text, spans)
spans = repair_contextual_account_numbers(text, spans)
spans = drop_public_contact_detail_spans(text, spans)
spans = drop_city_org_prefix_spans(text, spans)
spans = drop_stacked_first_names(spans)
spans = canonicalize_email_spans(text, spans)
return dedupe_spans(spans)
def prefer_long_name_spans(spans: list[dict], thresholds: dict[str, float]) -> list[dict]:
if not spans:
return spans
preferred: list[dict] = []
consumed: set[int] = set()
for index, span in enumerate(spans):
if index in consumed:
continue
label = span["label"]
if label not in {"FIRST_NAME", "LAST_NAME"}:
preferred.append(span)
continue
same_start = [
(other_index, other)
for other_index, other in enumerate(spans)
if other_index not in consumed and other["label"] == label and other["start"] == span["start"]
]
if len(same_start) == 1:
preferred.append(span)
continue
for other_index, _ in same_start:
consumed.add(other_index)
best_by_score = max(same_start, key=lambda item: float(item[1].get("score", 0.0)))[1]
longest = max(same_start, key=lambda item: (item[1]["end"] - item[1]["start"], float(item[1].get("score", 0.0))))[1]
threshold = float(thresholds.get(label, 0.5))
if float(longest.get("score", 0.0)) >= max(threshold + 0.15, float(best_by_score.get("score", 0.0)) * 0.7):
preferred.append(longest)
else:
preferred.append(best_by_score)
return prefer_same_end_extensions(preferred, thresholds)
def prefer_same_end_extensions(spans: list[dict], thresholds: dict[str, float]) -> list[dict]:
if not spans:
return spans
preferred: list[dict] = []
consumed: set[int] = set()
for index, span in enumerate(spans):
if index in consumed:
continue
label = span["label"]
if label not in {"FIRST_NAME", "LAST_NAME", "EMAIL"}:
preferred.append(span)
continue
same_end = [
(other_index, other)
for other_index, other in enumerate(spans)
if other_index not in consumed and other["label"] == label and other["end"] == span["end"]
]
if len(same_end) == 1:
preferred.append(span)
continue
for other_index, _ in same_end:
consumed.add(other_index)
best_by_score = max(same_end, key=lambda item: float(item[1].get("score", 0.0)))[1]
longest = max(same_end, key=lambda item: (item[1]["end"] - item[1]["start"], float(item[1].get("score", 0.0))))[1]
longest_score = float(longest.get("score", 0.0))
best_score = float(best_by_score.get("score", 0.0))
if label == "EMAIL":
if "@" in longest.get("text", "") or longest["end"] - longest["start"] > best_by_score["end"] - best_by_score["start"]:
if longest_score >= best_score - 0.02:
preferred.append(longest)
continue
else:
longest_text = longest.get("text", "")
if " " not in longest_text.strip() and longest_score >= max(float(thresholds.get(label, 0.5)) * 0.8, best_score * 0.55):
preferred.append(longest)
continue
preferred.append(best_by_score)
return preferred
def prefer_long_structured_spans(spans: list[dict], thresholds: dict[str, float]) -> list[dict]:
if not spans:
return spans
preferred: list[dict] = []
consumed: set[int] = set()
target_labels = {"STREET_ADDRESS", "DATE_OF_BIRTH"}
for index, span in enumerate(spans):
if index in consumed:
continue
label = span["label"]
if label not in target_labels:
preferred.append(span)
continue
overlapping = [
(other_index, other)
for other_index, other in enumerate(spans)
if other_index not in consumed and other["label"] == label and spans_overlap(span, other)
]
if len(overlapping) == 1:
preferred.append(span)
continue
for other_index, _ in overlapping:
consumed.add(other_index)
best_by_score = max(overlapping, key=lambda item: float(item[1].get("score", 0.0)))[1]
longest = max(
overlapping,
key=lambda item: (item[1]["end"] - item[1]["start"], float(item[1].get("score", 0.0))),
)[1]
longest_score = float(longest.get("score", 0.0))
best_score = float(best_by_score.get("score", 0.0))
threshold = float(thresholds.get(label, 0.5))
if longest_score >= max(threshold, best_score * 0.75):
preferred.append(longest)
else:
preferred.append(best_by_score)
return preferred
def sigmoid_np(values: np.ndarray) -> np.ndarray:
clipped = np.clip(values, -60.0, 60.0)
return 1.0 / (1.0 + np.exp(-clipped))
def run_onnx_span(session, encoded: dict[str, Any]) -> np.ndarray:
feed = {}
input_names = {item.name for item in session.get_inputs()}
for key, value in encoded.items():
if key == "offset_mapping":
continue
if key in input_names:
feed[key] = value
outputs = session.run(None, feed)
if not outputs:
raise ValueError("ONNX session returned no outputs")
return outputs[0]