#!/usr/bin/env python3 from __future__ import annotations import json import re import tempfile import unicodedata from pathlib import Path from typing import Any import numpy as np from huggingface_hub import HfApi, hf_hub_download from transformers import AutoConfig, AutoTokenizer TOKENIZER_FILES = [ "tokenizer_config.json", "tokenizer.json", "special_tokens_map.json", "vocab.txt", "vocab.json", "merges.txt", "added_tokens.json", "sentencepiece.bpe.model", "spiece.model", ] DEFAULT_LABEL_MAX_SPAN_TOKENS = { "PPSN": 9, "POSTCODE": 8, "PHONE_NUMBER": 10, "PASSPORT_NUMBER": 8, "BANK_ROUTING_NUMBER": 6, "ACCOUNT_NUMBER": 19, "CREDIT_DEBIT_CARD": 12, "SWIFT_BIC": 8, "EMAIL": 16, "FIRST_NAME": 6, "LAST_NAME": 8, "AGE": 8, "CITY": 8, "COUNTY": 8, "DATE_OF_BIRTH": 8, "STREET_ADDRESS": 10, } DEFAULT_LABEL_MIN_NONSPACE_CHARS = { "PPSN": 8, "POSTCODE": 6, "PHONE_NUMBER": 7, "PASSPORT_NUMBER": 7, "BANK_ROUTING_NUMBER": 6, "ACCOUNT_NUMBER": 6, "CREDIT_DEBIT_CARD": 12, "SWIFT_BIC": 8, "EMAIL": 6, "FIRST_NAME": 2, "LAST_NAME": 2, "AGE": 1, "CITY": 1, "COUNTY": 1, "DATE_OF_BIRTH": 1, "STREET_ADDRESS": 1, } OUTPUT_PRIORITY = { "PPSN": 0, "PASSPORT_NUMBER": 1, "ACCOUNT_NUMBER": 2, "BANK_ROUTING_NUMBER": 3, "CREDIT_DEBIT_CARD": 4, "PHONE_NUMBER": 5, "SWIFT_BIC": 6, "POSTCODE": 7, "EMAIL": 8, "DATE_OF_BIRTH": 9, "AGE": 10, "STREET_ADDRESS": 11, "CITY": 12, "COUNTY": 13, "FIRST_NAME": 14, "LAST_NAME": 15, } def normalize_entity_name(label: str) -> str: label = (label or '').strip() if label.startswith('B-') or label.startswith('I-'): label = label[2:] return label.upper() def _sanitize_tokenizer_dir(tokenizer_path: Path) -> str: tokenizer_cfg_path = tokenizer_path / 'tokenizer_config.json' if not tokenizer_cfg_path.exists(): return str(tokenizer_path) data = json.loads(tokenizer_cfg_path.read_text(encoding='utf-8')) if 'fix_mistral_regex' not in data: return str(tokenizer_path) tmpdir = Path(tempfile.mkdtemp(prefix='openmed_gp_tokenizer_')) keep = set(TOKENIZER_FILES) for child in tokenizer_path.iterdir(): if child.is_file() and child.name in keep: (tmpdir / child.name).write_bytes(child.read_bytes()) data.pop('fix_mistral_regex', None) (tmpdir / 'tokenizer_config.json').write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding='utf-8') return str(tmpdir) def safe_auto_tokenizer(tokenizer_ref: str): tokenizer_path = Path(tokenizer_ref) if tokenizer_path.exists(): tokenizer_ref = _sanitize_tokenizer_dir(tokenizer_path) else: api = HfApi() files = set(api.list_repo_files(repo_id=tokenizer_ref, repo_type='model')) tmpdir = Path(tempfile.mkdtemp(prefix='openmed_gp_remote_tokenizer_')) copied = False for name in TOKENIZER_FILES: if name not in files: continue src = hf_hub_download(repo_id=tokenizer_ref, filename=name, repo_type='model') (tmpdir / Path(name).name).write_bytes(Path(src).read_bytes()) copied = True if copied: tokenizer_ref = _sanitize_tokenizer_dir(tmpdir) try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True) except Exception: pass try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False) except TypeError: pass try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True) except Exception: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=False) def label_names_from_config(config) -> list[str]: names = list(getattr(config, 'span_label_names', [])) if not names: raise ValueError('Missing span_label_names in config') return [normalize_entity_name(name) for name in names] def label_max_span_tokens_from_config(config) -> dict[str, int]: raw = getattr(config, 'span_label_max_span_tokens', None) or {} out = {normalize_entity_name(key): int(value) for key, value in raw.items()} for label, value in DEFAULT_LABEL_MAX_SPAN_TOKENS.items(): out.setdefault(label, value) for label in label_names_from_config(config): out.setdefault(label, 8) return out def label_min_nonspace_chars_from_config(config) -> dict[str, int]: raw = getattr(config, 'span_label_min_nonspace_chars', None) or {} out = {normalize_entity_name(key): int(value) for key, value in raw.items()} for label, value in DEFAULT_LABEL_MIN_NONSPACE_CHARS.items(): out.setdefault(label, value) for label in label_names_from_config(config): out.setdefault(label, 1) return out def overlaps(a: dict, b: dict) -> bool: return not (a['end'] <= b['start'] or b['end'] <= a['start']) def dedupe_spans(spans: list[dict]) -> list[dict]: ordered = sorted( spans, key=lambda item: (-float(item.get('score', 0.0)), item['start'], item['end'], OUTPUT_PRIORITY.get(item['label'], 99)), ) kept = [] for span in ordered: if any(overlaps(span, other) for other in kept): continue kept.append(span) kept.sort(key=lambda item: (item['start'], item['end'], OUTPUT_PRIORITY.get(item['label'], 99))) return kept def load_onnx_session(model_ref: str, onnx_file: str = 'model_quantized.onnx', onnx_subfolder: str = 'onnx'): import onnxruntime as ort model_path = Path(model_ref) if model_path.exists(): candidates = [] if onnx_subfolder: candidates.append(model_path / onnx_subfolder / onnx_file) candidates.append(model_path / onnx_file) onnx_path = next((path for path in candidates if path.exists()), candidates[0]) config = AutoConfig.from_pretrained(model_ref) tokenizer = safe_auto_tokenizer(model_ref) else: remote_name = f"{onnx_subfolder}/{onnx_file}" if onnx_subfolder else onnx_file onnx_path = Path(hf_hub_download(repo_id=model_ref, filename=remote_name, repo_type='model')) config = AutoConfig.from_pretrained(model_ref) tokenizer = safe_auto_tokenizer(model_ref) session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider']) return session, tokenizer, config def run_onnx_span(session, encoded: dict[str, Any]) -> np.ndarray: feed = {} input_names = {item.name for item in session.get_inputs()} for key, value in encoded.items(): if key == 'offset_mapping': continue if key in input_names: feed[key] = value outputs = session.run(None, feed) if not outputs: raise ValueError('ONNX session returned no outputs') return outputs[0] def label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]: raw = getattr(config, "span_label_thresholds", None) or {} out = {normalize_entity_name(key): float(value) for key, value in raw.items()} for label in label_names_from_config(config): out.setdefault(label, float(default_threshold)) return out def valid_offset(offset: tuple[int, int]) -> bool: return bool(offset) and int(offset[1]) > int(offset[0]) def nonspace_length(text: str, start: int, end: int) -> int: return sum(0 if ch.isspace() else 1 for ch in text[int(start) : int(end)]) def alnum_upper(text: str) -> str: return "".join(ch for ch in text.upper() if ch.isalnum()) def normalize_surface(text: str) -> str: value = unicodedata.normalize("NFKD", text) value = "".join(ch for ch in value if not unicodedata.combining(ch)) value = value.replace("\u00A0", " ").replace("\u202F", " ") value = re.sub(r"\s+", " ", value.strip().lower()) return value IRISH_CITY_FORMS = ( "Dublin", "Baile Átha Cliath", "Galway", "Gaillimh", "Cork", "Cork City", "Corcaigh", "Limerick", "Luimneach", "Waterford", "Port Láirge", "Kilkenny", "Cill Chainnigh", "Carlow", "Ceatharlach", ) IRISH_CITY_SURFACES = {normalize_surface(value) for value in IRISH_CITY_FORMS} IRISH_COUNTY_SURFACES = { normalize_surface(value) for value in { "Co. Dublin", "Co. Bhaile Átha Cliath", "Co. Galway", "Co. na Gaillimhe", "Co. Cork", "Co. Chorcaí", "Co. Limerick", "Co. Luimnigh", "Co. Waterford", "Co. Phort Láirge", "Co. Kilkenny", "Co. Chill Chainnigh", "Co. Carlow", "Co. Cheatharlach", } } STREET_SUFFIX_RE = re.compile( r"(?i)\b(street|road|avenue|lane|park|view|square|terrace|drive|close|way|place|bóthar|bothar|sráid|sraid|lána|lana)\b" ) PHONE_SURFACE_RE = re.compile(r"^[+()\d][+()\d \-\u00A0\u202F]*\d$") ACCOUNT_DIGIT_SURFACE_RE = re.compile(r"^[\d \-\u00A0\u202F]+$") DATE_OF_BIRTH_RE = re.compile( r"(?i)^(?:\d{2}/\d{2}/\d{4}|\d{4}-\d{2}-\d{2}|\d{1,2}(?:st|nd|rd|th)?\s+[A-Za-zÁÉÍÓÚáéíóú]+(?:\s+[A-Za-zÁÉÍÓÚáéíóú]+)?\s+\d{4})$" ) DATE_OF_BIRTH_VALUE_RE = re.compile( r"(? bool: tokens = [token for token in re.split(r"\s+", value.strip()) if token] if not tokens: return False for token in tokens: if not any(ch.isalpha() for ch in token): return False if not all(is_name_token_char(ch) for ch in token): return False alpha_chars = [ch for ch in token if ch.isalpha()] first_alpha = alpha_chars[0] if alpha_chars else "" if first_alpha.isupper(): continue if len(alpha_chars) >= 2 and alpha_chars[0].islower() and alpha_chars[1].isupper(): continue if normalize_surface(token) in NAME_PARTICLE_SURFACES: continue return False return True def is_reasonable_span_text(label: str, text: str, start: int, end: int) -> bool: value = text[int(start) : int(end)].strip() if not value: return False upper = alnum_upper(value) if label in {"FIRST_NAME", "LAST_NAME"}: if not any(ch.isalpha() for ch in value): return False if any(ch.isdigit() for ch in value): return False if normalize_surface(value) in NAME_STOP_SURFACES: return False if label == "FIRST_NAME" and any(ch.isspace() for ch in value): return False if any(ch in ".,;:/@()" for ch in value): return False if label == "FIRST_NAME": first_alpha = next((ch for ch in value if ch.isalpha()), "") if not first_alpha or not first_alpha.isupper(): return False if label == "LAST_NAME" and not is_plausible_last_name_sequence(value): return False if start > 0 and text[int(start) - 1].isdigit(): return False return True if label == "EMAIL": if "@" not in value: return False local, _, domain = value.partition("@") return bool(local) and "." in domain if label == "PHONE_NUMBER": normalized = value.replace("\u00A0", " ").replace("\u202F", " ").strip() if any(ch.isalpha() for ch in normalized): return False if any(ch in "/@" for ch in normalized): return False if int(start) > 0 and text[int(start) - 1].isalnum(): return False if int(end) < len(text) and text[int(end)].isalnum(): return False if not PHONE_SURFACE_RE.match(normalized): return False digits = "".join(ch for ch in value if ch.isdigit()) if normalized.startswith("+353"): return 11 <= len(digits) <= 12 if not digits.startswith("0"): return False if digits.startswith("0818") or digits.startswith("1800"): return len(digits) == 10 if digits.startswith("08"): return len(digits) == 10 if digits.startswith("01"): return len(digits) == 9 return 9 <= len(digits) <= 10 if label == "PPSN": return bool(len(upper) in {8, 9} and upper[:7].isdigit() and upper[7:].isalpha()) if label == "POSTCODE": compact = value.replace(" ", "").replace("\u00A0", "").replace("\u202F", "") if any(not (ch.isalnum() or ch.isspace()) for ch in value): return False if len(compact) != 7: return False routing = compact[:3] unique = compact[3:] routing_ok = bool( (routing[0].isalpha() and routing[1:].isdigit()) or routing == "D6W" ) unique_ok = bool( len(unique) == 4 and unique[0].isalpha() and unique[1:].isalnum() ) return routing_ok and unique_ok if label == "PASSPORT_NUMBER": return bool(re.fullmatch(r"[A-Z]{1,2}\s?\d{7}", value.strip())) if label == "BANK_ROUTING_NUMBER": digits = "".join(ch for ch in value if ch.isdigit()) if len(digits) != 6: return False context = text[max(0, int(start) - 32) : min(len(text), int(end) + 24)] return bool(BANK_ROUTING_CONTEXT_RE.search(context)) if label == "SWIFT_BIC": return len(upper) in {8, 11} and upper.isalnum() if label == "CREDIT_DEBIT_CARD": digits = "".join(ch for ch in value if ch.isdigit()) return 12 <= len(digits) <= 19 if label == "ACCOUNT_NUMBER": if upper.startswith("IE"): return bool(re.fullmatch(r"IE\d{2}[A-Z0-9]{18}", upper)) if not ACCOUNT_DIGIT_SURFACE_RE.fullmatch(value.strip()): return False digits = "".join(ch for ch in value if ch.isdigit()) return 6 <= len(digits) <= 34 if label == "AGE": digits = "".join(ch for ch in value if ch.isdigit()) if digits != value.strip(): return False if not digits: return False if int(start) > 0 and text[int(start) - 1].isalnum(): return False if int(end) < len(text) and text[int(end)].isalnum(): return False if int(start) > 0 and text[int(start) - 1] in "/-": return False if int(end) < len(text) and text[int(end)] in "/-": return False age = int(digits) if not (0 < age <= 120): return False context = text[max(0, int(start) - 24) : min(len(text), int(end) + 24)] return bool(AGE_CONTEXT_RE.search(context)) if label == "DATE_OF_BIRTH": if not any(ch.isdigit() for ch in value): return False if not DATE_OF_BIRTH_RE.match(value.strip()): return False context = text[max(0, int(start) - 32) : min(len(text), int(end) + 32)] return bool(DOB_CONTEXT_RE.search(context)) if label == "CITY": if any(ch.isdigit() for ch in value): return False return normalize_surface(value) in IRISH_CITY_SURFACES if label == "COUNTY": if any(ch.isdigit() for ch in value): return False return normalize_surface(value) in IRISH_COUNTY_SURFACES if label == "STREET_ADDRESS": cleaned = value.strip() suffix_match = STREET_SUFFIX_RE.search(cleaned) if not suffix_match: return False if any(ch in "@,:;" for ch in cleaned): return False trailing = cleaned[int(suffix_match.end()) :].strip() trailing_tokens = [token for token in re.split(r"\s+", trailing) if token] if len(trailing_tokens) > 3: return False if any(normalize_surface(token) in STREET_TRAILING_BLOCK_SURFACES for token in trailing_tokens): return False has_digit = any(ch.isdigit() for ch in cleaned) if has_digit and not re.match(r"^\s*\d{1,4}\b", cleaned): return False title_tokens = [token for token in re.split(r"\s+", cleaned) if token] return has_digit or len(title_tokens) >= 2 return True def spans_overlap(a: dict, b: dict) -> bool: return int(a["start"]) < int(b["end"]) and int(b["start"]) < int(a["end"]) def is_name_token_char(ch: str) -> bool: return ch.isalpha() or ch in {"-", "'", "’"} def is_plausible_first_name(value: str) -> bool: if not value: return False if any(ch.isspace() for ch in value): return False if any(ch.isdigit() for ch in value): return False if any(ch in ",;:/@()" for ch in value): return False if not any(ch.isalpha() for ch in value): return False first_alpha = next((ch for ch in value if ch.isalpha()), "") if not first_alpha or not first_alpha.isupper(): return False return all(is_name_token_char(ch) for ch in value) def repair_first_name_from_last_name(text: str, spans: list[dict]) -> list[dict]: repaired = list(spans) for last_name in [span for span in repaired if span["label"] == "LAST_NAME"]: if any( span["label"] == "FIRST_NAME" and int(span["end"]) <= int(last_name["start"]) and int(last_name["start"]) - int(span["end"]) <= 2 for span in repaired ): continue cursor = int(last_name["start"]) - 1 if cursor < 0 or not text[cursor].isspace(): continue while cursor >= 0 and text[cursor].isspace(): cursor -= 1 token_end = cursor + 1 while cursor >= 0 and is_name_token_char(text[cursor]): cursor -= 1 token_start = cursor + 1 if token_end <= token_start: continue candidate = text[token_start:token_end] if not is_plausible_first_name(candidate): continue candidate_span = { "start": token_start, "end": token_end, "label": "FIRST_NAME", "score": float(last_name.get("score", 0.5)) * 0.6, "text": candidate, } if any(spans_overlap(candidate_span, other) for other in repaired if other["label"] == "FIRST_NAME"): continue repaired.append(candidate_span) return repaired PASSPORT_CUE_RE = re.compile( r"(?i)(passport(?:\s+number)?|phas|uimhir\s+(?:mo\s+)?phas)" ) PASSPORT_VALUE_RE = re.compile(r"(?]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,})", re.UNICODE) PHONE_CUE_RE = re.compile( r"(?i)\b(phone|call|contact|reach\s+me|glaoigh\s+ar|teagmh[aá]il|uimhir|m['’]uimhir)\b" ) PUBLIC_CONTACT_DETAILS_RE = re.compile(r"(?i)\bpublic\s+contact\s+details\b") CITY_CUE_RE = re.compile( r"(?i)\b(address|seoladh|located|suite|centre|center|ionad|intreo|clinic|hospital|ospid[eé]al|hse|fss)\b" ) BANK_ROUTING_CONTEXT_RE = re.compile( r"(?i)\b(sort\s+code|routing\s+number|bank\s+of\s+ireland|aib|cod\s+sort[aá]la|sort[aá]la)\b" ) PHONE_VALUE_RE = re.compile( r"(? list[dict]: repaired = list(spans) for match in PASSPORT_VALUE_RE.finditer(text): start, end = match.span(1) candidate_span = { "start": start, "end": end, "label": "PASSPORT_NUMBER", "score": 0.67, "text": text[start:end], } if any( other["label"] == "PASSPORT_NUMBER" and int(other["start"]) <= start and int(other["end"]) >= end for other in repaired ): continue cue_window = text[max(0, start - 32) : start] if not PASSPORT_CUE_RE.search(cue_window): continue conflicting_labels = {"PHONE_NUMBER", "PPSN", "ACCOUNT_NUMBER", "AGE", "PASSPORT_NUMBER"} repaired = [ other for other in repaired if not ( spans_overlap(candidate_span, other) and other["label"] in conflicting_labels ) ] repaired.append(candidate_span) return repaired def repair_ppsn_variants(text: str, spans: list[dict]) -> list[dict]: repaired = list(spans) for match in PPSN_VALUE_RE.finditer(text): start, end = match.span(1) value = text[start:end] compact = alnum_upper(value) if not (len(compact) in {8, 9} and compact[:7].isdigit() and compact[7:].isalpha()): continue cue_window = text[max(0, start - 32) : min(len(text), end + 24)] has_cue = bool(PPSN_CUE_RE.search(cue_window)) candidate_span = { "start": start, "end": end, "label": "PPSN", "score": 0.72 if has_cue else 0.58, "text": value, } conflicting_labels = {"PHONE_NUMBER", "PASSPORT_NUMBER", "ACCOUNT_NUMBER", "AGE", "FIRST_NAME", "LAST_NAME"} repaired = [ other for other in repaired if not ( spans_overlap(candidate_span, other) and other["label"] in conflicting_labels.union({"PPSN"}) ) ] repaired.append(candidate_span) return repaired def repair_contextual_date_of_birth(text: str, spans: list[dict]) -> list[dict]: repaired = list(spans) for match in DATE_OF_BIRTH_VALUE_RE.finditer(text): start, end = match.span(1) cue_window = text[max(0, start - 40) : min(len(text), end + 24)] if not DOB_CONTEXT_RE.search(cue_window): continue candidate_span = { "start": start, "end": end, "label": "DATE_OF_BIRTH", "score": 0.66, "text": text[start:end], } conflicting_labels = {"DATE_OF_BIRTH", "PHONE_NUMBER", "AGE", "FIRST_NAME", "LAST_NAME", "ACCOUNT_NUMBER"} repaired = [ other for other in repaired if not ( spans_overlap(candidate_span, other) and other["label"] in conflicting_labels ) ] repaired.append(candidate_span) return repaired ACCOUNT_CUE_RE = re.compile( r"(?i)(account\s+number|bank\s+account|uimhir\s+chuntais|cuntas\s+bainc)" ) ACCOUNT_VALUE_RE = re.compile(r"(? list[dict]: repaired = list(spans) for match in ACCOUNT_VALUE_RE.finditer(text): start, end = match.span(1) candidate_span = { "start": start, "end": end, "label": "ACCOUNT_NUMBER", "score": 0.51, "text": text[start:end], } if any( other["label"] == "ACCOUNT_NUMBER" and int(other["start"]) <= start and int(other["end"]) >= end for other in repaired ): continue cue_window = text[max(0, start - 40) : start] if not ACCOUNT_CUE_RE.search(cue_window): continue if any( spans_overlap(candidate_span, other) and other["label"] in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "PPSN", "POSTCODE", "PASSPORT_NUMBER"} for other in repaired ): continue repaired.append(candidate_span) return repaired def repair_emails(text: str, spans: list[dict]) -> list[dict]: repaired = list(spans) for match in EMAIL_EXTRACT_RE.finditer(text): start, end = match.span(1) candidate_span = { "start": start, "end": end, "label": "EMAIL", "score": 0.74, "text": text[start:end], } conflicting_labels = {"EMAIL", "FIRST_NAME", "LAST_NAME"} repaired = [ other for other in repaired if not ( spans_overlap(candidate_span, other) and other["label"] in conflicting_labels ) ] repaired.append(candidate_span) return repaired def repair_phone_numbers(text: str, spans: list[dict]) -> list[dict]: repaired = list(spans) for match in PHONE_VALUE_RE.finditer(text): start, end = match.span(1) candidate_span = { "start": start, "end": end, "label": "PHONE_NUMBER", "score": 0.69, "text": text[start:end], } cue_window = text[max(0, start - 32) : min(len(text), end + 16)] has_cue = bool(PHONE_CUE_RE.search(cue_window)) has_overlap = any(spans_overlap(candidate_span, other) and other["label"] == "PHONE_NUMBER" for other in repaired) if not (has_cue or has_overlap): continue if not is_reasonable_span_text("PHONE_NUMBER", text, start, end): continue conflicting_labels = {"PHONE_NUMBER", "PPSN", "ACCOUNT_NUMBER", "BANK_ROUTING_NUMBER"} repaired = [ other for other in repaired if not ( spans_overlap(candidate_span, other) and other["label"] in conflicting_labels ) ] repaired.append(candidate_span) return repaired def repair_postcodes(text: str, spans: list[dict]) -> list[dict]: repaired = list(spans) for match in POSTCODE_VALUE_RE.finditer(text): start, end = match.span(1) candidate_span = { "start": start, "end": end, "label": "POSTCODE", "score": 0.71, "text": text[start:end], } conflicting_labels = {"POSTCODE", "PHONE_NUMBER", "ACCOUNT_NUMBER", "FIRST_NAME", "LAST_NAME"} repaired = [ other for other in repaired if not ( spans_overlap(candidate_span, other) and other["label"] in conflicting_labels ) ] repaired.append(candidate_span) return repaired def repair_city_spans(text: str, spans: list[dict]) -> list[dict]: repaired = list(spans) seen: set[tuple[int, int]] = set() ordered_forms = sorted(IRISH_CITY_FORMS, key=len, reverse=True) for form in ordered_forms: for match in re.finditer(re.escape(form), text, flags=re.IGNORECASE): start, end = match.span() key = (start, end) if key in seen: continue seen.add(key) candidate_span = { "start": start, "end": end, "label": "CITY", "score": 0.64, "text": text[start:end], } has_context = False for other in repaired: other_start = int(other["start"]) other_end = int(other["end"]) if other["label"] == "STREET_ADDRESS" and 0 <= start - other_end <= 4: has_context = True break if other["label"] in {"COUNTY", "POSTCODE"} and 0 <= other_start - end <= 6: has_context = True break if not has_context and re.match(r"^\s*,\s*(?:Co\.\s+|[A-Z]\d{2}|D6W)", text[end:]): has_context = True if not has_context: cue_window = text[max(0, start - 40) : min(len(text), end + 32)] has_context = bool(CITY_CUE_RE.search(cue_window)) if not has_context: continue conflicting_labels = {"CITY", "FIRST_NAME", "LAST_NAME"} repaired = [ other for other in repaired if not ( spans_overlap(candidate_span, other) and other["label"] in conflicting_labels ) ] repaired.append(candidate_span) return repaired def drop_public_contact_detail_spans(text: str, spans: list[dict]) -> list[dict]: if not PUBLIC_CONTACT_DETAILS_RE.search(text): return spans return [ span for span in spans if not (span["label"] in {"STREET_ADDRESS", "CITY", "COUNTY"} and PUBLIC_CONTACT_DETAILS_RE.search(text)) ] def drop_city_org_prefix_spans(text: str, spans: list[dict]) -> list[dict]: keep: list[dict] = [] for span in spans: if span["label"] != "CITY": keep.append(span) continue tail = text[int(span["end"]) : min(len(text), int(span["end"]) + 24)] if re.match(r"^\s+Intreo\s+Centre\b", tail): continue keep.append(span) return keep def canonicalize_email_spans(text: str, spans: list[dict]) -> list[dict]: repaired: list[dict] = [] for span in spans: if span["label"] != "EMAIL": repaired.append(span) continue segment = text[int(span["start"]) : int(span["end"])] match = EMAIL_EXTRACT_RE.search(segment) if not match: repaired.append(span) continue start = int(span["start"]) + int(match.start(1)) end = int(span["start"]) + int(match.end(1)) repaired.append( { **span, "start": start, "end": end, "text": text[start:end], } ) return repaired def drop_stacked_first_names(spans: list[dict]) -> list[dict]: if not spans: return spans first_names = [span for span in spans if span["label"] == "FIRST_NAME"] last_names = [span for span in spans if span["label"] == "LAST_NAME"] if not first_names or not last_names: return spans keep: list[dict] = [] for span in spans: if span["label"] != "FIRST_NAME": keep.append(span) continue shadowed = False for other in first_names: if other is span: continue if int(other["start"]) <= int(span["start"]): continue if int(other["start"]) - int(span["end"]) > 2: continue if not any( int(last["start"]) >= int(other["end"]) and int(last["start"]) - int(other["end"]) <= 2 for last in last_names ): continue shadowed = True break if not shadowed: keep.append(span) return keep def decode_span_matrix( text: str, offsets: list[tuple[int, int]], span_scores: np.ndarray, config, min_score: float, ) -> list[dict]: label_names = label_names_from_config(config) thresholds = label_thresholds_from_config(config, min_score) max_span_tokens = label_max_span_tokens_from_config(config) min_nonspace_chars = label_min_nonspace_chars_from_config(config) if span_scores.ndim != 3: raise ValueError(f"Expected [num_labels, seq_len, seq_len] span scores, got shape {span_scores.shape}") num_labels, seq_len, _ = span_scores.shape spans: list[dict] = [] for label_index in range(min(num_labels, len(label_names))): label = label_names[label_index] threshold = thresholds.get(label, min_score) max_width = max(1, int(max_span_tokens.get(label, 8))) min_chars = max(1, int(min_nonspace_chars.get(label, 1))) for start_idx in range(seq_len): start_offset = offsets[start_idx] if not valid_offset(start_offset): continue max_end = min(seq_len, start_idx + max_width) for end_idx in range(start_idx, max_end): end_offset = offsets[end_idx] if not valid_offset(end_offset): continue score = float(span_scores[label_index, start_idx, end_idx]) if score < threshold: continue start_char = int(start_offset[0]) end_char = int(end_offset[1]) if end_char <= start_char: continue if nonspace_length(text, start_char, end_char) < min_chars: continue if not is_reasonable_span_text(label, text, start_char, end_char): continue spans.append( { "start": start_char, "end": end_char, "label": label, "score": score, "text": text[start_char:end_char], } ) spans = prefer_long_name_spans(spans, thresholds) spans = prefer_long_structured_spans(spans, thresholds) spans = repair_first_name_from_last_name(text, spans) spans = repair_emails(text, spans) spans = repair_phone_numbers(text, spans) spans = repair_ppsn_variants(text, spans) spans = repair_postcodes(text, spans) spans = repair_city_spans(text, spans) spans = repair_contextual_date_of_birth(text, spans) spans = repair_contextual_passport_numbers(text, spans) spans = repair_contextual_account_numbers(text, spans) spans = drop_public_contact_detail_spans(text, spans) spans = drop_city_org_prefix_spans(text, spans) spans = drop_stacked_first_names(spans) spans = canonicalize_email_spans(text, spans) return dedupe_spans(spans) def prefer_long_name_spans(spans: list[dict], thresholds: dict[str, float]) -> list[dict]: if not spans: return spans preferred: list[dict] = [] consumed: set[int] = set() for index, span in enumerate(spans): if index in consumed: continue label = span["label"] if label not in {"FIRST_NAME", "LAST_NAME"}: preferred.append(span) continue same_start = [ (other_index, other) for other_index, other in enumerate(spans) if other_index not in consumed and other["label"] == label and other["start"] == span["start"] ] if len(same_start) == 1: preferred.append(span) continue for other_index, _ in same_start: consumed.add(other_index) best_by_score = max(same_start, key=lambda item: float(item[1].get("score", 0.0)))[1] longest = max(same_start, key=lambda item: (item[1]["end"] - item[1]["start"], float(item[1].get("score", 0.0))))[1] threshold = float(thresholds.get(label, 0.5)) if float(longest.get("score", 0.0)) >= max(threshold + 0.15, float(best_by_score.get("score", 0.0)) * 0.7): preferred.append(longest) else: preferred.append(best_by_score) return prefer_same_end_extensions(preferred, thresholds) def prefer_same_end_extensions(spans: list[dict], thresholds: dict[str, float]) -> list[dict]: if not spans: return spans preferred: list[dict] = [] consumed: set[int] = set() for index, span in enumerate(spans): if index in consumed: continue label = span["label"] if label not in {"FIRST_NAME", "LAST_NAME", "EMAIL"}: preferred.append(span) continue same_end = [ (other_index, other) for other_index, other in enumerate(spans) if other_index not in consumed and other["label"] == label and other["end"] == span["end"] ] if len(same_end) == 1: preferred.append(span) continue for other_index, _ in same_end: consumed.add(other_index) best_by_score = max(same_end, key=lambda item: float(item[1].get("score", 0.0)))[1] longest = max(same_end, key=lambda item: (item[1]["end"] - item[1]["start"], float(item[1].get("score", 0.0))))[1] longest_score = float(longest.get("score", 0.0)) best_score = float(best_by_score.get("score", 0.0)) if label == "EMAIL": if "@" in longest.get("text", "") or longest["end"] - longest["start"] > best_by_score["end"] - best_by_score["start"]: if longest_score >= best_score - 0.02: preferred.append(longest) continue else: longest_text = longest.get("text", "") if " " not in longest_text.strip() and longest_score >= max(float(thresholds.get(label, 0.5)) * 0.8, best_score * 0.55): preferred.append(longest) continue preferred.append(best_by_score) return preferred def prefer_long_structured_spans(spans: list[dict], thresholds: dict[str, float]) -> list[dict]: if not spans: return spans preferred: list[dict] = [] consumed: set[int] = set() target_labels = {"STREET_ADDRESS", "DATE_OF_BIRTH"} for index, span in enumerate(spans): if index in consumed: continue label = span["label"] if label not in target_labels: preferred.append(span) continue overlapping = [ (other_index, other) for other_index, other in enumerate(spans) if other_index not in consumed and other["label"] == label and spans_overlap(span, other) ] if len(overlapping) == 1: preferred.append(span) continue for other_index, _ in overlapping: consumed.add(other_index) best_by_score = max(overlapping, key=lambda item: float(item[1].get("score", 0.0)))[1] longest = max( overlapping, key=lambda item: (item[1]["end"] - item[1]["start"], float(item[1].get("score", 0.0))), )[1] longest_score = float(longest.get("score", 0.0)) best_score = float(best_by_score.get("score", 0.0)) threshold = float(thresholds.get(label, 0.5)) if longest_score >= max(threshold, best_score * 0.75): preferred.append(longest) else: preferred.append(best_by_score) return preferred def sigmoid_np(values: np.ndarray) -> np.ndarray: clipped = np.clip(values, -60.0, 60.0) return 1.0 / (1.0 + np.exp(-clipped)) def run_onnx_span(session, encoded: dict[str, Any]) -> np.ndarray: feed = {} input_names = {item.name for item in session.get_inputs()} for key, value in encoded.items(): if key == "offset_mapping": continue if key in input_names: feed[key] = value outputs = session.run(None, feed) if not outputs: raise ValueError("ONNX session returned no outputs") return outputs[0]