temsa's picture
Publish v2-rc4 release with bundled ONNX q8 artifact and cleaned benchmarks
2a38b32 verified
#!/usr/bin/env python3
import json
import os
import tempfile
from pathlib import Path
os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
os.environ["USE_TF"] = "0"
os.environ["USE_FLAX"] = "0"
os.environ["USE_TORCH"] = "1"
import numpy as np
import regex as re
from huggingface_hub import HfApi, hf_hub_download
from transformers import AutoConfig, AutoTokenizer
TOKEN_RE = re.compile(r"[A-Za-z0-9]+|[^\w\s]", re.UNICODE)
DEFAULT_ONNX_FILES = ["onnx/model_quantized.onnx", "model_quantized.onnx"]
EIRCODE_RE = re.compile(r"^(?:[ACDEFHKNPRTVWXY]\d{2}|D6W)\s?[0-9ACDEFHKNPRTVWXY]{4}$", re.IGNORECASE)
def tokenize_with_spans(text: str):
return [(m.group(0), m.start(), m.end()) for m in TOKEN_RE.finditer(text)]
def _load_tokenizer(tokenizer_ref: str):
tokenizer_path = Path(tokenizer_ref)
if tokenizer_path.exists():
tokenizer_cfg_path = tokenizer_path / "tokenizer_config.json"
if tokenizer_cfg_path.exists():
data = json.loads(tokenizer_cfg_path.read_text(encoding="utf-8"))
if "fix_mistral_regex" in data:
tmpdir = Path(tempfile.mkdtemp(prefix="openmed_onnx_tokenizer_"))
keep = {"tokenizer_config.json", "tokenizer.json", "special_tokens_map.json", "vocab.txt"}
for child in tokenizer_path.iterdir():
if child.is_file() and child.name in keep:
(tmpdir / child.name).write_bytes(child.read_bytes())
data.pop("fix_mistral_regex", None)
(tmpdir / "tokenizer_config.json").write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
tokenizer_ref = str(tmpdir)
try:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=True)
except Exception:
try:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True, fix_mistral_regex=False)
except TypeError:
return AutoTokenizer.from_pretrained(tokenizer_ref, use_fast=True)
def load_onnx_token_classifier(model_ref: str, onnx_file: str | None = None):
import onnxruntime as ort
model_path = Path(model_ref)
if model_path.exists():
candidates = [onnx_file] if onnx_file else []
candidates += DEFAULT_ONNX_FILES
onnx_path = None
for candidate in candidates:
if not candidate:
continue
path = model_path / candidate
if path.exists():
onnx_path = path
break
if onnx_path is None:
raise FileNotFoundError("Missing ONNX artifact")
config = AutoConfig.from_pretrained(model_ref)
tokenizer = _load_tokenizer(model_ref)
else:
api = HfApi()
files = set(api.list_repo_files(repo_id=model_ref, repo_type="model"))
chosen = None
candidates = [onnx_file] if onnx_file else []
candidates += DEFAULT_ONNX_FILES
for candidate in candidates:
if candidate and candidate in files:
chosen = candidate
break
if chosen is None:
raise FileNotFoundError("No ONNX artifact published")
onnx_path = Path(hf_hub_download(repo_id=model_ref, filename=chosen, repo_type="model"))
config = AutoConfig.from_pretrained(model_ref)
tokenizer = _load_tokenizer(model_ref)
session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
return session, tokenizer, config
def _softmax(logits):
shifted = logits - np.max(logits, axis=-1, keepdims=True)
exp = np.exp(shifted)
return exp / np.clip(np.sum(exp, axis=-1, keepdims=True), 1e-12, None)
def _run(session, encoded):
input_names = {item.name for item in session.get_inputs()}
feed = {}
for key, value in encoded.items():
if key == "offset_mapping":
continue
if key in input_names:
feed[key] = value
return session.run(None, feed)[0]
def normalize_label(label: str) -> str:
label = (label or "").strip()
if label.startswith("B-") or label.startswith("I-"):
label = label[2:]
return label.upper()
def looks_like_eircode(value: str) -> bool:
return EIRCODE_RE.match(value.strip()) is not None
def ppsn_label_ids_from_config(config):
ids = []
for raw_id, raw_label in config.id2label.items():
label_id = int(raw_id)
label = str(raw_label or "").strip()
if label.endswith("PPSN"):
ids.append(label_id)
return sorted(ids)
def simple_aggregate_spans_onnx(text, session, tokenizer, config, min_score=0.5):
encoded = tokenizer(text, return_offsets_mapping=True, return_tensors="np", truncation=True)
logits = _run(session, encoded)[0]
probs = _softmax(logits)
pred_ids = probs.argmax(axis=-1)
id2label = {int(k): v for k, v in config.id2label.items()}
offsets = encoded["offset_mapping"][0].tolist()
attention = encoded["attention_mask"][0].tolist()
spans = []
active = None
for idx, ((start, end), keep) in enumerate(zip(offsets, attention)):
if not keep or start == end:
if active is not None:
spans.append(active)
active = None
continue
label = id2label[int(pred_ids[idx])]
if label == "O":
if active is not None:
spans.append(active)
active = None
continue
score = float(probs[idx, int(pred_ids[idx])])
if score < min_score:
if active is not None:
spans.append(active)
active = None
continue
prefix = label[:1] if label.startswith(("B-", "I-")) else "B"
entity = normalize_label(label)
if active is None or prefix == "B" or entity != active["label"] or int(start) > int(active["end"]) + 1:
if active is not None:
spans.append(active)
active = {"label": entity, "start": int(start), "end": int(end), "score": score}
else:
active["end"] = int(end)
active["score"] = max(float(active["score"]), score)
if active is not None:
spans.append(active)
for span in spans:
span["text"] = text[span["start"]:span["end"]]
return spans
def word_aligned_ppsn_spans_onnx(text, session, tokenizer, config, threshold=0.4):
pieces = tokenize_with_spans(text)
if not pieces:
return []
words = [word for word, _, _ in pieces]
encoded = tokenizer(words, is_split_into_words=True, return_tensors="np", truncation=True)
word_ids = encoded.word_ids(batch_index=0)
logits = _run(session, encoded)[0]
probs = _softmax(logits)
label_ids = ppsn_label_ids_from_config(config)
word_scores = []
for word_index in range(len(pieces)):
score = 0.0
for token_index, wid in enumerate(word_ids):
if wid != word_index:
continue
for label_id in label_ids:
score = max(score, float(probs[token_index, label_id]))
word_scores.append(score)
spans = []
active = None
for (_, start, end), score in zip(pieces, word_scores):
if score >= threshold:
if active is None:
active = {"label": "PPSN", "start": start, "end": end, "score": score}
else:
active["end"] = end
active["score"] = max(float(active["score"]), score)
elif active is not None:
spans.append(active)
active = None
if active is not None:
spans.append(active)
for span in spans:
span["text"] = text[span["start"]:span["end"]]
return spans