#!/usr/bin/env python3 import json import os import tempfile from pathlib import Path os.environ.setdefault("TRANSFORMERS_NO_TF", "1") os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1") os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1") os.environ["USE_TF"] = "0" os.environ["USE_FLAX"] = "0" os.environ["USE_TORCH"] = "1" import numpy as np import regex as re from huggingface_hub import HfApi, hf_hub_download from transformers import AutoConfig, AutoTokenizer TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE) DEFAULT_ONNX_FILES = ["onnx/model_quantized.onnx", "model_quantized.onnx"] EIRCODE_RE = re.compile(r"^(?:[ACDEFHKNPRTVWXY]\d{2}|D6W)\s?[0-9ACDEFHKNPRTVWXY]{4}$", re.IGNORECASE) def tokenize_with_spans(text: str): return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)] def _load_tokenizer(tokenizer_ref: str): tokenizer_path = Path(tokenizer_ref) if tokenizer_path.exists(): tokenizer_cfg_path = tokenizer_path / "tokenizer_config.json" if tokenizer_cfg_path.exists(): data = json.loads(tokenizer_cfg_path.read_text(encoding="utf-8")) if "fix_mistral_regex" in data: tmpdir = Path(tempfile.mkdtemp(prefix="openmed_onnx_tokenizer_")) keep = {"tokenizer_config.json", "tokenizer.json", "special_tokens_map.json", "vocab.txt"} for child in tokenizer_path.iterdir(): if child.is_file() and child.name in keep: (tmpdir / child.name).write_bytes(child.read_bytes()) data.pop("fix_mistral_regex", None) (tmpdir / "tokenizer_config.json").write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") tokenizer_ref = str(tmpdir) try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True) except Exception: try: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False) except TypeError: return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True) def load_onnx_token_classifier(model_ref: str, onnx_file: str | None = None): import onnxruntime as ort model_path = Path(model_ref) if model_path.exists(): candidates = [onnx_file] if onnx_file else [] candidates += DEFAULT_ONNX_FILES onnx_path = None for candidate in candidates: if not candidate: continue path = model_path / candidate if path.exists(): onnx_path = path break if onnx_path is None: raise FileNotFoundError("Missing ONNX artifact") config = AutoConfig.from_pretrained(model_ref) tokenizer = _load_tokenizer(model_ref) else: api = HfApi() files = set(api.list_repo_files(repo_id=model_ref, repo_type="model")) chosen = None candidates = [onnx_file] if onnx_file else [] candidates += DEFAULT_ONNX_FILES for candidate in candidates: if candidate and candidate in files: chosen = candidate break if chosen is None: raise FileNotFoundError("No ONNX artifact published") onnx_path = Path(hf_hub_download(repo_id=model_ref, filename=chosen, repo_type="model")) config = AutoConfig.from_pretrained(model_ref) tokenizer = _load_tokenizer(model_ref) session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) return session, tokenizer, config def _softmax(logits): shifted = logits - np.max(logits, axis=-1, keepdims=True) exp = np.exp(shifted) return exp / np.clip(np.sum(exp, axis=-1, keepdims=True), 1e-12, None) def _run(session, encoded): input_names = {item.name for item in session.get_inputs()} feed = {} for key, value in encoded.items(): if key == "offset_mapping": continue if key in input_names: feed[key] = value return session.run(None, feed)[0] def normalize_label(label: str) -> str: label = (label or "").strip() if label.startswith("B-") or label.startswith("I-"): label = label[2:] return label.upper() def looks_like_eircode(value: str) -> bool: return EIRCODE_RE.match(value.strip()) is not None def ppsn_label_ids_from_config(config): ids = [] for raw_id, raw_label in config.id2label.items(): label_id = int(raw_id) label = str(raw_label or "").strip() if label.endswith("PPSN"): ids.append(label_id) return sorted(ids) def simple_aggregate_spans_onnx(text, session, tokenizer, config, min_score=0.5): encoded = tokenizer(text, return_offsets_mapping=True, return_tensors="np", truncation=True) logits = _run(session, encoded)[0] probs = _softmax(logits) pred_ids = probs.argmax(axis=-1) id2label = {int(k): v for k, v in config.id2label.items()} offsets = encoded["offset_mapping"][0].tolist() attention = encoded["attention_mask"][0].tolist() spans = [] active = None for idx, ((start, end), keep) in enumerate(zip(offsets, attention)): if not keep or start == end: if active is not None: spans.append(active) active = None continue label = id2label[int(pred_ids[idx])] if label == "O": if active is not None: spans.append(active) active = None continue score = float(probs[idx, int(pred_ids[idx])]) if score < min_score: if active is not None: spans.append(active) active = None continue prefix = label[:1] if label.startswith(("B-", "I-")) else "B" entity = normalize_label(label) if active is None or prefix == "B" or entity != active["label"] or int(start) > int(active["end"]) + 1: if active is not None: spans.append(active) active = {"label": entity, "start": int(start), "end": int(end), "score": score} else: active["end"] = int(end) active["score"] = max(float(active["score"]), score) if active is not None: spans.append(active) for span in spans: span["text"] = text[span["start"]:span["end"]] return spans def word_aligned_ppsn_spans_onnx(text, session, tokenizer, config, threshold=0.4): pieces = tokenize_with_spans(text) if not pieces: return [] words = [word for word, _, _ in pieces] encoded = tokenizer(words, is_split_into_words=True, return_tensors="np", truncation=True) word_ids = encoded.word_ids(batch_index=0) logits = _run(session, encoded)[0] probs = _softmax(logits) label_ids = ppsn_label_ids_from_config(config) word_scores = [] for word_index in range(len(pieces)): score = 0.0 for token_index, wid in enumerate(word_ids): if wid != word_index: continue for label_id in label_ids: score = max(score, float(probs[token_index, label_id])) word_scores.append(score) spans = [] active = None for (_, start, end), score in zip(pieces, word_scores): if score >= threshold: if active is None: active = {"label": "PPSN", "start": start, "end": end, "score": score} else: active["end"] = end active["score"] = max(float(active["score"]), score) elif active is not None: spans.append(active) active = None if active is not None: spans.append(active) for span in spans: span["text"] = text[span["start"]:span["end"]] return spans