| |
| from __future__ import annotations |
|
|
| import json |
| import re |
| import tempfile |
| import unicodedata |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| from huggingface_hub import HfApi, hf_hub_download |
| from transformers import AutoConfig, AutoTokenizer |
|
|
| TOKENIZER_FILES = [ |
| "tokenizer_config.json", |
| "tokenizer.json", |
| "special_tokens_map.json", |
| "vocab.txt", |
| "vocab.json", |
| "merges.txt", |
| "added_tokens.json", |
| "sentencepiece.bpe.model", |
| "spiece.model", |
| ] |
| DEFAULT_LABEL_MAX_SPAN_TOKENS = { |
| "PPSN": 9, |
| "POSTCODE": 8, |
| "PHONE_NUMBER": 10, |
| "PASSPORT_NUMBER": 8, |
| "BANK_ROUTING_NUMBER": 6, |
| "ACCOUNT_NUMBER": 19, |
| "CREDIT_DEBIT_CARD": 12, |
| "SWIFT_BIC": 8, |
| "EMAIL": 16, |
| "FIRST_NAME": 6, |
| "LAST_NAME": 8, |
| "AGE": 8, |
| "CITY": 8, |
| "COUNTY": 8, |
| "DATE_OF_BIRTH": 8, |
| "STREET_ADDRESS": 10, |
| } |
| DEFAULT_LABEL_MIN_NONSPACE_CHARS = { |
| "PPSN": 8, |
| "POSTCODE": 6, |
| "PHONE_NUMBER": 7, |
| "PASSPORT_NUMBER": 7, |
| "BANK_ROUTING_NUMBER": 6, |
| "ACCOUNT_NUMBER": 6, |
| "CREDIT_DEBIT_CARD": 12, |
| "SWIFT_BIC": 8, |
| "EMAIL": 6, |
| "FIRST_NAME": 2, |
| "LAST_NAME": 2, |
| "AGE": 1, |
| "CITY": 1, |
| "COUNTY": 1, |
| "DATE_OF_BIRTH": 1, |
| "STREET_ADDRESS": 1, |
| } |
| OUTPUT_PRIORITY = { |
| "PPSN": 0, |
| "PASSPORT_NUMBER": 1, |
| "ACCOUNT_NUMBER": 2, |
| "BANK_ROUTING_NUMBER": 3, |
| "CREDIT_DEBIT_CARD": 4, |
| "PHONE_NUMBER": 5, |
| "SWIFT_BIC": 6, |
| "POSTCODE": 7, |
| "EMAIL": 8, |
| "DATE_OF_BIRTH": 9, |
| "AGE": 10, |
| "STREET_ADDRESS": 11, |
| "CITY": 12, |
| "COUNTY": 13, |
| "FIRST_NAME": 14, |
| "LAST_NAME": 15, |
| } |
|
|
|
|
| def normalize_entity_name(label: str) -> str: |
| label = (label or '').strip() |
| if label.startswith('B-') or label.startswith('I-'): |
| label = label[2:] |
| return label.upper() |
|
|
|
|
| def _sanitize_tokenizer_dir(tokenizer_path: Path) -> str: |
| tokenizer_cfg_path = tokenizer_path / 'tokenizer_config.json' |
| if not tokenizer_cfg_path.exists(): |
| return str(tokenizer_path) |
| data = json.loads(tokenizer_cfg_path.read_text(encoding='utf-8')) |
| if 'fix_mistral_regex' not in data: |
| return str(tokenizer_path) |
| tmpdir = Path(tempfile.mkdtemp(prefix='openmed_gp_tokenizer_')) |
| keep = set(TOKENIZER_FILES) |
| for child in tokenizer_path.iterdir(): |
| if child.is_file() and child.name in keep: |
| (tmpdir / child.name).write_bytes(child.read_bytes()) |
| data.pop('fix_mistral_regex', None) |
| (tmpdir / 'tokenizer_config.json').write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding='utf-8') |
| return str(tmpdir) |
|
|
|
|
| def safe_auto_tokenizer(tokenizer_ref: str): |
| tokenizer_path = Path(tokenizer_ref) |
| if tokenizer_path.exists(): |
| tokenizer_ref = _sanitize_tokenizer_dir(tokenizer_path) |
| else: |
| api = HfApi() |
| files = set(api.list_repo_files(repo_id=tokenizer_ref, repo_type='model')) |
| tmpdir = Path(tempfile.mkdtemp(prefix='openmed_gp_remote_tokenizer_')) |
| copied = False |
| for name in TOKENIZER_FILES: |
| if name not in files: |
| continue |
| src = hf_hub_download(repo_id=tokenizer_ref, filename=name, repo_type='model') |
| (tmpdir / Path(name).name).write_bytes(Path(src).read_bytes()) |
| copied = True |
| if copied: |
| tokenizer_ref = _sanitize_tokenizer_dir(tmpdir) |
| try: |
| return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True) |
| except Exception: |
| pass |
| try: |
| return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False) |
| except TypeError: |
| pass |
| try: |
| return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True) |
| except Exception: |
| return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=False) |
|
|
|
|
| def label_names_from_config(config) -> list[str]: |
| names = list(getattr(config, 'span_label_names', [])) |
| if not names: |
| raise ValueError('Missing span_label_names in config') |
| return [normalize_entity_name(name) for name in names] |
|
|
|
|
| def label_max_span_tokens_from_config(config) -> dict[str, int]: |
| raw = getattr(config, 'span_label_max_span_tokens', None) or {} |
| out = {normalize_entity_name(key): int(value) for key, value in raw.items()} |
| for label, value in DEFAULT_LABEL_MAX_SPAN_TOKENS.items(): |
| out.setdefault(label, value) |
| for label in label_names_from_config(config): |
| out.setdefault(label, 8) |
| return out |
|
|
|
|
| def label_min_nonspace_chars_from_config(config) -> dict[str, int]: |
| raw = getattr(config, 'span_label_min_nonspace_chars', None) or {} |
| out = {normalize_entity_name(key): int(value) for key, value in raw.items()} |
| for label, value in DEFAULT_LABEL_MIN_NONSPACE_CHARS.items(): |
| out.setdefault(label, value) |
| for label in label_names_from_config(config): |
| out.setdefault(label, 1) |
| return out |
|
|
|
|
| def overlaps(a: dict, b: dict) -> bool: |
| return not (a['end'] <= b['start'] or b['end'] <= a['start']) |
|
|
|
|
| def dedupe_spans(spans: list[dict]) -> list[dict]: |
| ordered = sorted( |
| spans, |
| key=lambda item: (-float(item.get('score', 0.0)), item['start'], item['end'], OUTPUT_PRIORITY.get(item['label'], 99)), |
| ) |
| kept = [] |
| for span in ordered: |
| if any(overlaps(span, other) for other in kept): |
| continue |
| kept.append(span) |
| kept.sort(key=lambda item: (item['start'], item['end'], OUTPUT_PRIORITY.get(item['label'], 99))) |
| return kept |
|
|
|
|
| def load_onnx_session(model_ref: str, onnx_file: str = 'model_quantized.onnx', onnx_subfolder: str = 'onnx'): |
| import onnxruntime as ort |
|
|
| model_path = Path(model_ref) |
| if model_path.exists(): |
| candidates = [] |
| if onnx_subfolder: |
| candidates.append(model_path / onnx_subfolder / onnx_file) |
| candidates.append(model_path / onnx_file) |
| onnx_path = next((path for path in candidates if path.exists()), candidates[0]) |
| config = AutoConfig.from_pretrained(model_ref) |
| tokenizer = safe_auto_tokenizer(model_ref) |
| else: |
| remote_name = f"{onnx_subfolder}/{onnx_file}" if onnx_subfolder else onnx_file |
| onnx_path = Path(hf_hub_download(repo_id=model_ref, filename=remote_name, repo_type='model')) |
| config = AutoConfig.from_pretrained(model_ref) |
| tokenizer = safe_auto_tokenizer(model_ref) |
| session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider']) |
| return session, tokenizer, config |
|
|
|
|
| def run_onnx_span(session, encoded: dict[str, Any]) -> np.ndarray: |
| feed = {} |
| input_names = {item.name for item in session.get_inputs()} |
| for key, value in encoded.items(): |
| if key == 'offset_mapping': |
| continue |
| if key in input_names: |
| feed[key] = value |
| outputs = session.run(None, feed) |
| if not outputs: |
| raise ValueError('ONNX session returned no outputs') |
| return outputs[0] |
|
|
|
|
| def label_thresholds_from_config(config, default_threshold: float) -> dict[str, float]: |
| raw = getattr(config, "span_label_thresholds", None) or {} |
| out = {normalize_entity_name(key): float(value) for key, value in raw.items()} |
| for label in label_names_from_config(config): |
| out.setdefault(label, float(default_threshold)) |
| return out |
|
|
|
|
| def valid_offset(offset: tuple[int, int]) -> bool: |
| return bool(offset) and int(offset[1]) > int(offset[0]) |
|
|
|
|
| def nonspace_length(text: str, start: int, end: int) -> int: |
| return sum(0 if ch.isspace() else 1 for ch in text[int(start) : int(end)]) |
|
|
|
|
| def alnum_upper(text: str) -> str: |
| return "".join(ch for ch in text.upper() if ch.isalnum()) |
|
|
|
|
| def normalize_surface(text: str) -> str: |
| value = unicodedata.normalize("NFKD", text) |
| value = "".join(ch for ch in value if not unicodedata.combining(ch)) |
| value = value.replace("\u00A0", " ").replace("\u202F", " ") |
| value = re.sub(r"\s+", " ", value.strip().lower()) |
| return value |
|
|
|
|
| IRISH_CITY_FORMS = ( |
| "Dublin", |
| "Baile Átha Cliath", |
| "Galway", |
| "Gaillimh", |
| "Cork", |
| "Cork City", |
| "Corcaigh", |
| "Limerick", |
| "Luimneach", |
| "Waterford", |
| "Port Láirge", |
| "Kilkenny", |
| "Cill Chainnigh", |
| "Carlow", |
| "Ceatharlach", |
| ) |
| IRISH_CITY_SURFACES = {normalize_surface(value) for value in IRISH_CITY_FORMS} |
|
|
| IRISH_COUNTY_SURFACES = { |
| normalize_surface(value) |
| for value in { |
| "Co. Dublin", |
| "Co. Bhaile Átha Cliath", |
| "Co. Galway", |
| "Co. na Gaillimhe", |
| "Co. Cork", |
| "Co. Chorcaí", |
| "Co. Limerick", |
| "Co. Luimnigh", |
| "Co. Waterford", |
| "Co. Phort Láirge", |
| "Co. Kilkenny", |
| "Co. Chill Chainnigh", |
| "Co. Carlow", |
| "Co. Cheatharlach", |
| } |
| } |
|
|
| STREET_SUFFIX_RE = re.compile( |
| r"(?i)\b(street|road|avenue|lane|park|view|square|terrace|drive|close|way|place|bóthar|bothar|sráid|sraid|lána|lana)\b" |
| ) |
| PHONE_SURFACE_RE = re.compile(r"^[+()\d][+()\d \-\u00A0\u202F]*\d$") |
| ACCOUNT_DIGIT_SURFACE_RE = re.compile(r"^[\d \-\u00A0\u202F]+$") |
| DATE_OF_BIRTH_RE = re.compile( |
| r"(?i)^(?:\d{2}/\d{2}/\d{4}|\d{4}-\d{2}-\d{2}|\d{1,2}(?:st|nd|rd|th)?\s+[A-Za-zÁÉÍÓÚáéíóú]+(?:\s+[A-Za-zÁÉÍÓÚáéíóú]+)?\s+\d{4})$" |
| ) |
| DATE_OF_BIRTH_VALUE_RE = re.compile( |
| r"(?<![A-Za-z0-9])(\d{2}/\d{2}/\d{4}|\d{4}-\d{2}-\d{2}|\d{1,2}(?:st|nd|rd|th)?\s+[A-Za-zÁÉÍÓÚáéíóú]+(?:\s+[A-Za-zÁÉÍÓÚáéíóú]+)?\s+\d{4})(?![A-Za-z0-9])" |
| ) |
| AGE_CONTEXT_RE = re.compile(r"(?i)\b(age|years?\s+old|year\s+old|aois|bliana\s+d['’]aois)\b") |
| DOB_CONTEXT_RE = re.compile( |
| r"(?i)\b(dob|date\s+of\s+birth|born\s+on|data\s+breithe|dáta\s+breithe|dhata\s+breithe|dháta\s+breithe|rugadh)\b" |
| ) |
| PPSN_CUE_RE = re.compile( |
| r"(?i)\b(ppsn|upsp|personal public service(?:\s+number)?|uimhir\s+(?:mo\s+)?upsp|uimhir\s+(?:mo\s+)?ppsn)\b" |
| ) |
| NAME_STOP_SURFACES = { |
| normalize_surface(value) |
| for value in { |
| "Address", |
| "Name", |
| "Phone", |
| "Email", |
| "Seoladh", |
| "Ainm", |
| "Teagmháil", |
| "Teagmhail", |
| "Ríomhphost", |
| "Riomhphost", |
| "PPSN", |
| "UPSP", |
| "Call", |
| "Glao", |
| "Glaoigh", |
| "Rugadh", |
| "Ionad", |
| "Intreo", |
| "Cill", |
| "Sampla", |
| "Leithdháilte", |
| "Leithdhailte", |
| "Leithdháil", |
| "Leithdhail", |
| "Leithdh", |
| "Fón", |
| "Fon", |
| "January", |
| "February", |
| "March", |
| "April", |
| "May", |
| "June", |
| "July", |
| "August", |
| "September", |
| "October", |
| "November", |
| "December", |
| "Monday", |
| "Tuesday", |
| "Wednesday", |
| "Thursday", |
| "Friday", |
| "Saturday", |
| "Sunday", |
| "Eanáir", |
| "Feabhra", |
| "Márta", |
| "Aibreán", |
| "Aibrean", |
| "Bealtaine", |
| "Meitheamh", |
| "Iúil", |
| "Iuil", |
| "Lúnasa", |
| "Lunasa", |
| "Meán Fómhair", |
| "Mean Fomhair", |
| "Deireadh Fómhair", |
| "Deireadh Fomhair", |
| "Samhain", |
| "Nollaig", |
| "Luan", |
| "Máirt", |
| "Mairt", |
| "Céadaoin", |
| "Ceadaoin", |
| "Déardaoin", |
| "Deardaoin", |
| "Aoine", |
| "Satharn", |
| "Domhnach", |
| } |
| } |
| NAME_PARTICLE_SURFACES = { |
| normalize_surface(value) |
| for value in {"Ó", "O", "Ní", "Ni", "Nic", "Mac", "Mc", "de", "van", "von"} |
| } |
| STREET_TRAILING_BLOCK_SURFACES = { |
| normalize_surface(value) |
| for value in { |
| "are", |
| "public", |
| "contact", |
| "details", |
| "website", |
| "open", |
| "before", |
| "visiting", |
| "roimh", |
| "chuairt", |
| "agus", |
| "and", |
| "the", |
| "is", |
| "ta", |
| } |
| } |
|
|
|
|
| def is_plausible_last_name_sequence(value: str) -> bool: |
| tokens = [token for token in re.split(r"\s+", value.strip()) if token] |
| if not tokens: |
| return False |
| for token in tokens: |
| if not any(ch.isalpha() for ch in token): |
| return False |
| if not all(is_name_token_char(ch) for ch in token): |
| return False |
| alpha_chars = [ch for ch in token if ch.isalpha()] |
| first_alpha = alpha_chars[0] if alpha_chars else "" |
| if first_alpha.isupper(): |
| continue |
| if len(alpha_chars) >= 2 and alpha_chars[0].islower() and alpha_chars[1].isupper(): |
| continue |
| if normalize_surface(token) in NAME_PARTICLE_SURFACES: |
| continue |
| return False |
| return True |
|
|
|
|
| def is_reasonable_span_text(label: str, text: str, start: int, end: int) -> bool: |
| value = text[int(start) : int(end)].strip() |
| if not value: |
| return False |
| upper = alnum_upper(value) |
|
|
| if label in {"FIRST_NAME", "LAST_NAME"}: |
| if not any(ch.isalpha() for ch in value): |
| return False |
| if any(ch.isdigit() for ch in value): |
| return False |
| if normalize_surface(value) in NAME_STOP_SURFACES: |
| return False |
| if label == "FIRST_NAME" and any(ch.isspace() for ch in value): |
| return False |
| if any(ch in ".,;:/@()" for ch in value): |
| return False |
| if label == "FIRST_NAME": |
| first_alpha = next((ch for ch in value if ch.isalpha()), "") |
| if not first_alpha or not first_alpha.isupper(): |
| return False |
| if label == "LAST_NAME" and not is_plausible_last_name_sequence(value): |
| return False |
| if start > 0 and text[int(start) - 1].isdigit(): |
| return False |
| return True |
|
|
| if label == "EMAIL": |
| if "@" not in value: |
| return False |
| local, _, domain = value.partition("@") |
| return bool(local) and "." in domain |
|
|
| if label == "PHONE_NUMBER": |
| normalized = value.replace("\u00A0", " ").replace("\u202F", " ").strip() |
| if any(ch.isalpha() for ch in normalized): |
| return False |
| if any(ch in "/@" for ch in normalized): |
| return False |
| if int(start) > 0 and text[int(start) - 1].isalnum(): |
| return False |
| if int(end) < len(text) and text[int(end)].isalnum(): |
| return False |
| if not PHONE_SURFACE_RE.match(normalized): |
| return False |
| digits = "".join(ch for ch in value if ch.isdigit()) |
| if normalized.startswith("+353"): |
| return 11 <= len(digits) <= 12 |
| if not digits.startswith("0"): |
| return False |
| if digits.startswith("0818") or digits.startswith("1800"): |
| return len(digits) == 10 |
| if digits.startswith("08"): |
| return len(digits) == 10 |
| if digits.startswith("01"): |
| return len(digits) == 9 |
| return 9 <= len(digits) <= 10 |
|
|
| if label == "PPSN": |
| return bool(len(upper) in {8, 9} and upper[:7].isdigit() and upper[7:].isalpha()) |
|
|
| if label == "POSTCODE": |
| compact = value.replace(" ", "").replace("\u00A0", "").replace("\u202F", "") |
| if any(not (ch.isalnum() or ch.isspace()) for ch in value): |
| return False |
| if len(compact) != 7: |
| return False |
| routing = compact[:3] |
| unique = compact[3:] |
| routing_ok = bool( |
| (routing[0].isalpha() and routing[1:].isdigit()) |
| or routing == "D6W" |
| ) |
| unique_ok = bool( |
| len(unique) == 4 |
| and unique[0].isalpha() |
| and unique[1:].isalnum() |
| ) |
| return routing_ok and unique_ok |
|
|
| if label == "PASSPORT_NUMBER": |
| return bool(re.fullmatch(r"[A-Z]{1,2}\s?\d{7}", value.strip())) |
|
|
| if label == "BANK_ROUTING_NUMBER": |
| digits = "".join(ch for ch in value if ch.isdigit()) |
| if len(digits) != 6: |
| return False |
| context = text[max(0, int(start) - 32) : min(len(text), int(end) + 24)] |
| return bool(BANK_ROUTING_CONTEXT_RE.search(context)) |
|
|
| if label == "SWIFT_BIC": |
| return len(upper) in {8, 11} and upper.isalnum() |
|
|
| if label == "CREDIT_DEBIT_CARD": |
| digits = "".join(ch for ch in value if ch.isdigit()) |
| return 12 <= len(digits) <= 19 |
|
|
| if label == "ACCOUNT_NUMBER": |
| if upper.startswith("IE"): |
| return bool(re.fullmatch(r"IE\d{2}[A-Z0-9]{18}", upper)) |
| if not ACCOUNT_DIGIT_SURFACE_RE.fullmatch(value.strip()): |
| return False |
| digits = "".join(ch for ch in value if ch.isdigit()) |
| return 6 <= len(digits) <= 34 |
|
|
| if label == "AGE": |
| digits = "".join(ch for ch in value if ch.isdigit()) |
| if digits != value.strip(): |
| return False |
| if not digits: |
| return False |
| if int(start) > 0 and text[int(start) - 1].isalnum(): |
| return False |
| if int(end) < len(text) and text[int(end)].isalnum(): |
| return False |
| if int(start) > 0 and text[int(start) - 1] in "/-": |
| return False |
| if int(end) < len(text) and text[int(end)] in "/-": |
| return False |
| age = int(digits) |
| if not (0 < age <= 120): |
| return False |
| context = text[max(0, int(start) - 24) : min(len(text), int(end) + 24)] |
| return bool(AGE_CONTEXT_RE.search(context)) |
|
|
| if label == "DATE_OF_BIRTH": |
| if not any(ch.isdigit() for ch in value): |
| return False |
| if not DATE_OF_BIRTH_RE.match(value.strip()): |
| return False |
| context = text[max(0, int(start) - 32) : min(len(text), int(end) + 32)] |
| return bool(DOB_CONTEXT_RE.search(context)) |
|
|
| if label == "CITY": |
| if any(ch.isdigit() for ch in value): |
| return False |
| return normalize_surface(value) in IRISH_CITY_SURFACES |
|
|
| if label == "COUNTY": |
| if any(ch.isdigit() for ch in value): |
| return False |
| return normalize_surface(value) in IRISH_COUNTY_SURFACES |
|
|
| if label == "STREET_ADDRESS": |
| cleaned = value.strip() |
| suffix_match = STREET_SUFFIX_RE.search(cleaned) |
| if not suffix_match: |
| return False |
| if any(ch in "@,:;" for ch in cleaned): |
| return False |
| trailing = cleaned[int(suffix_match.end()) :].strip() |
| trailing_tokens = [token for token in re.split(r"\s+", trailing) if token] |
| if len(trailing_tokens) > 3: |
| return False |
| if any(normalize_surface(token) in STREET_TRAILING_BLOCK_SURFACES for token in trailing_tokens): |
| return False |
| has_digit = any(ch.isdigit() for ch in cleaned) |
| if has_digit and not re.match(r"^\s*\d{1,4}\b", cleaned): |
| return False |
| title_tokens = [token for token in re.split(r"\s+", cleaned) if token] |
| return has_digit or len(title_tokens) >= 2 |
|
|
| return True |
|
|
|
|
| def spans_overlap(a: dict, b: dict) -> bool: |
| return int(a["start"]) < int(b["end"]) and int(b["start"]) < int(a["end"]) |
|
|
|
|
| def is_name_token_char(ch: str) -> bool: |
| return ch.isalpha() or ch in {"-", "'", "’"} |
|
|
|
|
| def is_plausible_first_name(value: str) -> bool: |
| if not value: |
| return False |
| if any(ch.isspace() for ch in value): |
| return False |
| if any(ch.isdigit() for ch in value): |
| return False |
| if any(ch in ",;:/@()" for ch in value): |
| return False |
| if not any(ch.isalpha() for ch in value): |
| return False |
| first_alpha = next((ch for ch in value if ch.isalpha()), "") |
| if not first_alpha or not first_alpha.isupper(): |
| return False |
| return all(is_name_token_char(ch) for ch in value) |
|
|
|
|
| def repair_first_name_from_last_name(text: str, spans: list[dict]) -> list[dict]: |
| repaired = list(spans) |
| for last_name in [span for span in repaired if span["label"] == "LAST_NAME"]: |
| if any( |
| span["label"] == "FIRST_NAME" |
| and int(span["end"]) <= int(last_name["start"]) |
| and int(last_name["start"]) - int(span["end"]) <= 2 |
| for span in repaired |
| ): |
| continue |
|
|
| cursor = int(last_name["start"]) - 1 |
| if cursor < 0 or not text[cursor].isspace(): |
| continue |
| while cursor >= 0 and text[cursor].isspace(): |
| cursor -= 1 |
| token_end = cursor + 1 |
| while cursor >= 0 and is_name_token_char(text[cursor]): |
| cursor -= 1 |
| token_start = cursor + 1 |
| if token_end <= token_start: |
| continue |
| candidate = text[token_start:token_end] |
| if not is_plausible_first_name(candidate): |
| continue |
| candidate_span = { |
| "start": token_start, |
| "end": token_end, |
| "label": "FIRST_NAME", |
| "score": float(last_name.get("score", 0.5)) * 0.6, |
| "text": candidate, |
| } |
| if any(spans_overlap(candidate_span, other) for other in repaired if other["label"] == "FIRST_NAME"): |
| continue |
| repaired.append(candidate_span) |
| return repaired |
|
|
|
|
| PASSPORT_CUE_RE = re.compile( |
| r"(?i)(passport(?:\s+number)?|phas|uimhir\s+(?:mo\s+)?phas)" |
| ) |
| PASSPORT_VALUE_RE = re.compile(r"(?<![A-Za-z0-9])([A-Z]{1,2}\s?\d{7})(?![A-Za-z0-9])") |
| EMAIL_EXTRACT_RE = re.compile(r"([^\s@,;:()<>]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,})", re.UNICODE) |
| PHONE_CUE_RE = re.compile( |
| r"(?i)\b(phone|call|contact|reach\s+me|glaoigh\s+ar|teagmh[aá]il|uimhir|m['’]uimhir)\b" |
| ) |
| PUBLIC_CONTACT_DETAILS_RE = re.compile(r"(?i)\bpublic\s+contact\s+details\b") |
| CITY_CUE_RE = re.compile( |
| r"(?i)\b(address|seoladh|located|suite|centre|center|ionad|intreo|clinic|hospital|ospid[eé]al|hse|fss)\b" |
| ) |
| BANK_ROUTING_CONTEXT_RE = re.compile( |
| r"(?i)\b(sort\s+code|routing\s+number|bank\s+of\s+ireland|aib|cod\s+sort[aá]la|sort[aá]la)\b" |
| ) |
| PHONE_VALUE_RE = re.compile( |
| r"(?<![A-Za-z0-9])((?:\+353|0)\d(?:[\s\-]?\d){7,13})(?![A-Za-z0-9])" |
| ) |
| PPSN_VALUE_RE = re.compile(r"(?<![A-Za-z0-9])(\d{7}(?:[\s-]*[A-Za-z]){1,2})(?![A-Za-z0-9])") |
| POSTCODE_VALUE_RE = re.compile( |
| r"(?<![A-Za-z0-9])((?:[A-Za-z]\d{2}|D6W)[\s\u00A0\u202F]?[A-Za-z][A-Za-z0-9]{3})(?![A-Za-z0-9])" |
| ) |
|
|
|
|
| def repair_contextual_passport_numbers(text: str, spans: list[dict]) -> list[dict]: |
| repaired = list(spans) |
| for match in PASSPORT_VALUE_RE.finditer(text): |
| start, end = match.span(1) |
| candidate_span = { |
| "start": start, |
| "end": end, |
| "label": "PASSPORT_NUMBER", |
| "score": 0.67, |
| "text": text[start:end], |
| } |
| if any( |
| other["label"] == "PASSPORT_NUMBER" |
| and int(other["start"]) <= start |
| and int(other["end"]) >= end |
| for other in repaired |
| ): |
| continue |
| cue_window = text[max(0, start - 32) : start] |
| if not PASSPORT_CUE_RE.search(cue_window): |
| continue |
| conflicting_labels = {"PHONE_NUMBER", "PPSN", "ACCOUNT_NUMBER", "AGE", "PASSPORT_NUMBER"} |
| repaired = [ |
| other |
| for other in repaired |
| if not ( |
| spans_overlap(candidate_span, other) |
| and other["label"] in conflicting_labels |
| ) |
| ] |
| repaired.append(candidate_span) |
| return repaired |
|
|
|
|
| def repair_ppsn_variants(text: str, spans: list[dict]) -> list[dict]: |
| repaired = list(spans) |
| for match in PPSN_VALUE_RE.finditer(text): |
| start, end = match.span(1) |
| value = text[start:end] |
| compact = alnum_upper(value) |
| if not (len(compact) in {8, 9} and compact[:7].isdigit() and compact[7:].isalpha()): |
| continue |
| cue_window = text[max(0, start - 32) : min(len(text), end + 24)] |
| has_cue = bool(PPSN_CUE_RE.search(cue_window)) |
| candidate_span = { |
| "start": start, |
| "end": end, |
| "label": "PPSN", |
| "score": 0.72 if has_cue else 0.58, |
| "text": value, |
| } |
| conflicting_labels = {"PHONE_NUMBER", "PASSPORT_NUMBER", "ACCOUNT_NUMBER", "AGE", "FIRST_NAME", "LAST_NAME"} |
| repaired = [ |
| other |
| for other in repaired |
| if not ( |
| spans_overlap(candidate_span, other) |
| and other["label"] in conflicting_labels.union({"PPSN"}) |
| ) |
| ] |
| repaired.append(candidate_span) |
| return repaired |
|
|
|
|
| def repair_contextual_date_of_birth(text: str, spans: list[dict]) -> list[dict]: |
| repaired = list(spans) |
| for match in DATE_OF_BIRTH_VALUE_RE.finditer(text): |
| start, end = match.span(1) |
| cue_window = text[max(0, start - 40) : min(len(text), end + 24)] |
| if not DOB_CONTEXT_RE.search(cue_window): |
| continue |
| candidate_span = { |
| "start": start, |
| "end": end, |
| "label": "DATE_OF_BIRTH", |
| "score": 0.66, |
| "text": text[start:end], |
| } |
| conflicting_labels = {"DATE_OF_BIRTH", "PHONE_NUMBER", "AGE", "FIRST_NAME", "LAST_NAME", "ACCOUNT_NUMBER"} |
| repaired = [ |
| other |
| for other in repaired |
| if not ( |
| spans_overlap(candidate_span, other) |
| and other["label"] in conflicting_labels |
| ) |
| ] |
| repaired.append(candidate_span) |
| return repaired |
|
|
|
|
| ACCOUNT_CUE_RE = re.compile( |
| r"(?i)(account\s+number|bank\s+account|uimhir\s+chuntais|cuntas\s+bainc)" |
| ) |
| ACCOUNT_VALUE_RE = re.compile(r"(?<![A-Za-z0-9])(\d{6,12})(?![A-Za-z0-9])") |
|
|
|
|
| def repair_contextual_account_numbers(text: str, spans: list[dict]) -> list[dict]: |
| repaired = list(spans) |
| for match in ACCOUNT_VALUE_RE.finditer(text): |
| start, end = match.span(1) |
| candidate_span = { |
| "start": start, |
| "end": end, |
| "label": "ACCOUNT_NUMBER", |
| "score": 0.51, |
| "text": text[start:end], |
| } |
| if any( |
| other["label"] == "ACCOUNT_NUMBER" |
| and int(other["start"]) <= start |
| and int(other["end"]) >= end |
| for other in repaired |
| ): |
| continue |
| cue_window = text[max(0, start - 40) : start] |
| if not ACCOUNT_CUE_RE.search(cue_window): |
| continue |
| if any( |
| spans_overlap(candidate_span, other) |
| and other["label"] in {"PHONE_NUMBER", "BANK_ROUTING_NUMBER", "PPSN", "POSTCODE", "PASSPORT_NUMBER"} |
| for other in repaired |
| ): |
| continue |
| repaired.append(candidate_span) |
| return repaired |
|
|
|
|
| def repair_emails(text: str, spans: list[dict]) -> list[dict]: |
| repaired = list(spans) |
| for match in EMAIL_EXTRACT_RE.finditer(text): |
| start, end = match.span(1) |
| candidate_span = { |
| "start": start, |
| "end": end, |
| "label": "EMAIL", |
| "score": 0.74, |
| "text": text[start:end], |
| } |
| conflicting_labels = {"EMAIL", "FIRST_NAME", "LAST_NAME"} |
| repaired = [ |
| other |
| for other in repaired |
| if not ( |
| spans_overlap(candidate_span, other) |
| and other["label"] in conflicting_labels |
| ) |
| ] |
| repaired.append(candidate_span) |
| return repaired |
|
|
|
|
| def repair_phone_numbers(text: str, spans: list[dict]) -> list[dict]: |
| repaired = list(spans) |
| for match in PHONE_VALUE_RE.finditer(text): |
| start, end = match.span(1) |
| candidate_span = { |
| "start": start, |
| "end": end, |
| "label": "PHONE_NUMBER", |
| "score": 0.69, |
| "text": text[start:end], |
| } |
| cue_window = text[max(0, start - 32) : min(len(text), end + 16)] |
| has_cue = bool(PHONE_CUE_RE.search(cue_window)) |
| has_overlap = any(spans_overlap(candidate_span, other) and other["label"] == "PHONE_NUMBER" for other in repaired) |
| if not (has_cue or has_overlap): |
| continue |
| if not is_reasonable_span_text("PHONE_NUMBER", text, start, end): |
| continue |
| conflicting_labels = {"PHONE_NUMBER", "PPSN", "ACCOUNT_NUMBER", "BANK_ROUTING_NUMBER"} |
| repaired = [ |
| other |
| for other in repaired |
| if not ( |
| spans_overlap(candidate_span, other) |
| and other["label"] in conflicting_labels |
| ) |
| ] |
| repaired.append(candidate_span) |
| return repaired |
|
|
|
|
| def repair_postcodes(text: str, spans: list[dict]) -> list[dict]: |
| repaired = list(spans) |
| for match in POSTCODE_VALUE_RE.finditer(text): |
| start, end = match.span(1) |
| candidate_span = { |
| "start": start, |
| "end": end, |
| "label": "POSTCODE", |
| "score": 0.71, |
| "text": text[start:end], |
| } |
| conflicting_labels = {"POSTCODE", "PHONE_NUMBER", "ACCOUNT_NUMBER", "FIRST_NAME", "LAST_NAME"} |
| repaired = [ |
| other |
| for other in repaired |
| if not ( |
| spans_overlap(candidate_span, other) |
| and other["label"] in conflicting_labels |
| ) |
| ] |
| repaired.append(candidate_span) |
| return repaired |
|
|
|
|
| def repair_city_spans(text: str, spans: list[dict]) -> list[dict]: |
| repaired = list(spans) |
| seen: set[tuple[int, int]] = set() |
| ordered_forms = sorted(IRISH_CITY_FORMS, key=len, reverse=True) |
| for form in ordered_forms: |
| for match in re.finditer(re.escape(form), text, flags=re.IGNORECASE): |
| start, end = match.span() |
| key = (start, end) |
| if key in seen: |
| continue |
| seen.add(key) |
| candidate_span = { |
| "start": start, |
| "end": end, |
| "label": "CITY", |
| "score": 0.64, |
| "text": text[start:end], |
| } |
| has_context = False |
| for other in repaired: |
| other_start = int(other["start"]) |
| other_end = int(other["end"]) |
| if other["label"] == "STREET_ADDRESS" and 0 <= start - other_end <= 4: |
| has_context = True |
| break |
| if other["label"] in {"COUNTY", "POSTCODE"} and 0 <= other_start - end <= 6: |
| has_context = True |
| break |
| if not has_context and re.match(r"^\s*,\s*(?:Co\.\s+|[A-Z]\d{2}|D6W)", text[end:]): |
| has_context = True |
| if not has_context: |
| cue_window = text[max(0, start - 40) : min(len(text), end + 32)] |
| has_context = bool(CITY_CUE_RE.search(cue_window)) |
| if not has_context: |
| continue |
| conflicting_labels = {"CITY", "FIRST_NAME", "LAST_NAME"} |
| repaired = [ |
| other |
| for other in repaired |
| if not ( |
| spans_overlap(candidate_span, other) |
| and other["label"] in conflicting_labels |
| ) |
| ] |
| repaired.append(candidate_span) |
| return repaired |
|
|
|
|
| def drop_public_contact_detail_spans(text: str, spans: list[dict]) -> list[dict]: |
| if not PUBLIC_CONTACT_DETAILS_RE.search(text): |
| return spans |
| return [ |
| span |
| for span in spans |
| if not (span["label"] in {"STREET_ADDRESS", "CITY", "COUNTY"} and PUBLIC_CONTACT_DETAILS_RE.search(text)) |
| ] |
|
|
|
|
| def drop_city_org_prefix_spans(text: str, spans: list[dict]) -> list[dict]: |
| keep: list[dict] = [] |
| for span in spans: |
| if span["label"] != "CITY": |
| keep.append(span) |
| continue |
| tail = text[int(span["end"]) : min(len(text), int(span["end"]) + 24)] |
| if re.match(r"^\s+Intreo\s+Centre\b", tail): |
| continue |
| keep.append(span) |
| return keep |
|
|
|
|
| def canonicalize_email_spans(text: str, spans: list[dict]) -> list[dict]: |
| repaired: list[dict] = [] |
| for span in spans: |
| if span["label"] != "EMAIL": |
| repaired.append(span) |
| continue |
| segment = text[int(span["start"]) : int(span["end"])] |
| match = EMAIL_EXTRACT_RE.search(segment) |
| if not match: |
| repaired.append(span) |
| continue |
| start = int(span["start"]) + int(match.start(1)) |
| end = int(span["start"]) + int(match.end(1)) |
| repaired.append( |
| { |
| **span, |
| "start": start, |
| "end": end, |
| "text": text[start:end], |
| } |
| ) |
| return repaired |
|
|
|
|
| def drop_stacked_first_names(spans: list[dict]) -> list[dict]: |
| if not spans: |
| return spans |
| first_names = [span for span in spans if span["label"] == "FIRST_NAME"] |
| last_names = [span for span in spans if span["label"] == "LAST_NAME"] |
| if not first_names or not last_names: |
| return spans |
| keep: list[dict] = [] |
| for span in spans: |
| if span["label"] != "FIRST_NAME": |
| keep.append(span) |
| continue |
| shadowed = False |
| for other in first_names: |
| if other is span: |
| continue |
| if int(other["start"]) <= int(span["start"]): |
| continue |
| if int(other["start"]) - int(span["end"]) > 2: |
| continue |
| if not any( |
| int(last["start"]) >= int(other["end"]) and int(last["start"]) - int(other["end"]) <= 2 |
| for last in last_names |
| ): |
| continue |
| shadowed = True |
| break |
| if not shadowed: |
| keep.append(span) |
| return keep |
|
|
|
|
| def decode_span_matrix( |
| text: str, |
| offsets: list[tuple[int, int]], |
| span_scores: np.ndarray, |
| config, |
| min_score: float, |
| ) -> list[dict]: |
| label_names = label_names_from_config(config) |
| thresholds = label_thresholds_from_config(config, min_score) |
| max_span_tokens = label_max_span_tokens_from_config(config) |
| min_nonspace_chars = label_min_nonspace_chars_from_config(config) |
|
|
| if span_scores.ndim != 3: |
| raise ValueError(f"Expected [num_labels, seq_len, seq_len] span scores, got shape {span_scores.shape}") |
|
|
| num_labels, seq_len, _ = span_scores.shape |
| spans: list[dict] = [] |
| for label_index in range(min(num_labels, len(label_names))): |
| label = label_names[label_index] |
| threshold = thresholds.get(label, min_score) |
| max_width = max(1, int(max_span_tokens.get(label, 8))) |
| min_chars = max(1, int(min_nonspace_chars.get(label, 1))) |
|
|
| for start_idx in range(seq_len): |
| start_offset = offsets[start_idx] |
| if not valid_offset(start_offset): |
| continue |
| max_end = min(seq_len, start_idx + max_width) |
| for end_idx in range(start_idx, max_end): |
| end_offset = offsets[end_idx] |
| if not valid_offset(end_offset): |
| continue |
| score = float(span_scores[label_index, start_idx, end_idx]) |
| if score < threshold: |
| continue |
| start_char = int(start_offset[0]) |
| end_char = int(end_offset[1]) |
| if end_char <= start_char: |
| continue |
| if nonspace_length(text, start_char, end_char) < min_chars: |
| continue |
| if not is_reasonable_span_text(label, text, start_char, end_char): |
| continue |
| spans.append( |
| { |
| "start": start_char, |
| "end": end_char, |
| "label": label, |
| "score": score, |
| "text": text[start_char:end_char], |
| } |
| ) |
| spans = prefer_long_name_spans(spans, thresholds) |
| spans = prefer_long_structured_spans(spans, thresholds) |
| spans = repair_first_name_from_last_name(text, spans) |
| spans = repair_emails(text, spans) |
| spans = repair_phone_numbers(text, spans) |
| spans = repair_ppsn_variants(text, spans) |
| spans = repair_postcodes(text, spans) |
| spans = repair_city_spans(text, spans) |
| spans = repair_contextual_date_of_birth(text, spans) |
| spans = repair_contextual_passport_numbers(text, spans) |
| spans = repair_contextual_account_numbers(text, spans) |
| spans = drop_public_contact_detail_spans(text, spans) |
| spans = drop_city_org_prefix_spans(text, spans) |
| spans = drop_stacked_first_names(spans) |
| spans = canonicalize_email_spans(text, spans) |
| return dedupe_spans(spans) |
|
|
|
|
| def prefer_long_name_spans(spans: list[dict], thresholds: dict[str, float]) -> list[dict]: |
| if not spans: |
| return spans |
| preferred: list[dict] = [] |
| consumed: set[int] = set() |
| for index, span in enumerate(spans): |
| if index in consumed: |
| continue |
| label = span["label"] |
| if label not in {"FIRST_NAME", "LAST_NAME"}: |
| preferred.append(span) |
| continue |
| same_start = [ |
| (other_index, other) |
| for other_index, other in enumerate(spans) |
| if other_index not in consumed and other["label"] == label and other["start"] == span["start"] |
| ] |
| if len(same_start) == 1: |
| preferred.append(span) |
| continue |
| for other_index, _ in same_start: |
| consumed.add(other_index) |
| best_by_score = max(same_start, key=lambda item: float(item[1].get("score", 0.0)))[1] |
| longest = max(same_start, key=lambda item: (item[1]["end"] - item[1]["start"], float(item[1].get("score", 0.0))))[1] |
| threshold = float(thresholds.get(label, 0.5)) |
| if float(longest.get("score", 0.0)) >= max(threshold + 0.15, float(best_by_score.get("score", 0.0)) * 0.7): |
| preferred.append(longest) |
| else: |
| preferred.append(best_by_score) |
| return prefer_same_end_extensions(preferred, thresholds) |
|
|
|
|
| def prefer_same_end_extensions(spans: list[dict], thresholds: dict[str, float]) -> list[dict]: |
| if not spans: |
| return spans |
| preferred: list[dict] = [] |
| consumed: set[int] = set() |
| for index, span in enumerate(spans): |
| if index in consumed: |
| continue |
| label = span["label"] |
| if label not in {"FIRST_NAME", "LAST_NAME", "EMAIL"}: |
| preferred.append(span) |
| continue |
| same_end = [ |
| (other_index, other) |
| for other_index, other in enumerate(spans) |
| if other_index not in consumed and other["label"] == label and other["end"] == span["end"] |
| ] |
| if len(same_end) == 1: |
| preferred.append(span) |
| continue |
| for other_index, _ in same_end: |
| consumed.add(other_index) |
| best_by_score = max(same_end, key=lambda item: float(item[1].get("score", 0.0)))[1] |
| longest = max(same_end, key=lambda item: (item[1]["end"] - item[1]["start"], float(item[1].get("score", 0.0))))[1] |
| longest_score = float(longest.get("score", 0.0)) |
| best_score = float(best_by_score.get("score", 0.0)) |
| if label == "EMAIL": |
| if "@" in longest.get("text", "") or longest["end"] - longest["start"] > best_by_score["end"] - best_by_score["start"]: |
| if longest_score >= best_score - 0.02: |
| preferred.append(longest) |
| continue |
| else: |
| longest_text = longest.get("text", "") |
| if " " not in longest_text.strip() and longest_score >= max(float(thresholds.get(label, 0.5)) * 0.8, best_score * 0.55): |
| preferred.append(longest) |
| continue |
| preferred.append(best_by_score) |
| return preferred |
|
|
|
|
| def prefer_long_structured_spans(spans: list[dict], thresholds: dict[str, float]) -> list[dict]: |
| if not spans: |
| return spans |
| preferred: list[dict] = [] |
| consumed: set[int] = set() |
| target_labels = {"STREET_ADDRESS", "DATE_OF_BIRTH"} |
| for index, span in enumerate(spans): |
| if index in consumed: |
| continue |
| label = span["label"] |
| if label not in target_labels: |
| preferred.append(span) |
| continue |
| overlapping = [ |
| (other_index, other) |
| for other_index, other in enumerate(spans) |
| if other_index not in consumed and other["label"] == label and spans_overlap(span, other) |
| ] |
| if len(overlapping) == 1: |
| preferred.append(span) |
| continue |
| for other_index, _ in overlapping: |
| consumed.add(other_index) |
| best_by_score = max(overlapping, key=lambda item: float(item[1].get("score", 0.0)))[1] |
| longest = max( |
| overlapping, |
| key=lambda item: (item[1]["end"] - item[1]["start"], float(item[1].get("score", 0.0))), |
| )[1] |
| longest_score = float(longest.get("score", 0.0)) |
| best_score = float(best_by_score.get("score", 0.0)) |
| threshold = float(thresholds.get(label, 0.5)) |
| if longest_score >= max(threshold, best_score * 0.75): |
| preferred.append(longest) |
| else: |
| preferred.append(best_by_score) |
| return preferred |
|
|
|
|
| def sigmoid_np(values: np.ndarray) -> np.ndarray: |
| clipped = np.clip(values, -60.0, 60.0) |
| return 1.0 / (1.0 + np.exp(-clipped)) |
|
|
|
|
| def run_onnx_span(session, encoded: dict[str, Any]) -> np.ndarray: |
| feed = {} |
| input_names = {item.name for item in session.get_inputs()} |
| for key, value in encoded.items(): |
| if key == "offset_mapping": |
| continue |
| if key in input_names: |
| feed[key] = value |
| outputs = session.run(None, feed) |
| if not outputs: |
| raise ValueError("ONNX session returned no outputs") |
| return outputs[0] |
|
|