Spaces:
Running
Running
# metrics/bertscore.py | |
""" | |
BERTScore helpers: scorer init, single and batch computation. | |
Adds precision/recall alongside F1 (UI shows F1; CSV export includes P/R too). | |
""" | |
from bert_score import BERTScorer | |
from functools import lru_cache | |
from transformers import AutoTokenizer, AutoConfig | |
from utils.file_utils import extract_sections, has_sections | |
import pandas as pd | |
# manual layer mapping (fallback; we also cap by config.num_hidden_layers if available) | |
_MANUAL_BERT_LAYERS = { | |
"neuralmind/bert-base-portuguese-cased": 12, | |
"pucpr/biobertpt-clin": 12, | |
"xlm-roberta-large": 24, | |
"medicalai/ClinicalBERT": 12, | |
} | |
# friendly label ↔ model id mapping | |
BERT_FRIENDLY_TO_MODEL = { | |
"Portuguese (Br) Bert": "neuralmind/bert-base-portuguese-cased", | |
"Portuguese (Br) Clinical BioBert": "pucpr/biobertpt-clin", | |
"Multilingual Bert ( RoBerta)": "xlm-roberta-large", | |
"ClinicalBERT (medicalai)": "medicalai/ClinicalBERT", | |
} | |
BERT_MODEL_TO_FRIENDLY = {v: k for k, v in BERT_FRIENDLY_TO_MODEL.items()} | |
_USE_RESCALE_BASELINE = False | |
def _safe_num_layers(model_type: str) -> int | None: | |
# Try to read from HF config; fallback to manual | |
try: | |
cfg = AutoConfig.from_pretrained(model_type) | |
if hasattr(cfg, "num_hidden_layers") and isinstance(cfg.num_hidden_layers, int): | |
return cfg.num_hidden_layers | |
except Exception: | |
pass | |
return _MANUAL_BERT_LAYERS.get(model_type) | |
def get_bertscore_scorer(model_type: str): | |
lang = "pt" if any(model_type.startswith(p) for p in ("neuralmind", "pucpr")) else "" | |
num_layers = _safe_num_layers(model_type) | |
kwargs = {"lang": lang, "rescale_with_baseline": _USE_RESCALE_BASELINE} | |
if num_layers is not None: | |
kwargs["num_layers"] = num_layers | |
return BERTScorer(model_type=model_type, **kwargs) | |
def chunk_text_with_stride(text: str, tokenizer, max_len: int = 512, stride: int = 50): | |
ids = tokenizer.encode(text, add_special_tokens=True) | |
if len(ids) <= max_len: | |
return [tokenizer.decode(ids, skip_special_tokens=True)] | |
chunks, step = [], max_len - stride | |
for i in range(0, len(ids), step): | |
subset = ids[i : i + max_len] | |
if not subset: | |
break | |
chunks.append(tokenizer.decode(subset, skip_special_tokens=True)) | |
if i + max_len >= len(ids): | |
break | |
return chunks | |
def bertscore_prec_rec_f1(reference: str, prediction: str, model_type: str): | |
""" | |
Return (precision, recall, f1) for a single reference/prediction pair. | |
Handles long texts by chunking and averaging the per-chunk scores. | |
On error, returns (None, None, None). | |
""" | |
if not reference or not prediction: | |
return (None, None, None) | |
try: | |
scorer = get_bertscore_scorer(model_type) | |
tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=True) | |
gen_chunks = chunk_text_with_stride(prediction, tokenizer) | |
ref_chunks = chunk_text_with_stride(reference, tokenizer) | |
paired = list(zip(gen_chunks, ref_chunks)) | |
if not paired: | |
return (0.0, 0.0, 0.0) | |
p_vals, r_vals, f_vals = [], [], [] | |
for gc, rc in paired: | |
P, R, F1 = scorer.score([gc], [rc]) | |
p_vals.append(float(P[0])) | |
r_vals.append(float(R[0])) | |
f_vals.append(float(F1[0])) | |
n = float(len(p_vals)) | |
return (sum(p_vals) / n, sum(r_vals) / n, sum(f_vals) / n) | |
except Exception: | |
return (None, None, None) | |
def compute_bertscore_single(reference: str, prediction: str, model_type: str, per_section: bool = False): | |
""" | |
If per_section=False: returns float global F1 (0..1) or None on error. | |
If per_section=True: returns dict with keys: | |
- bertscore_global_{p,r,f1} | |
- bertscore_{S,O,A,P}_{p,r,f1} (when sections exist; else None) | |
""" | |
if not reference or not prediction: | |
return None if not per_section else {} | |
try: | |
scorer = get_bertscore_scorer(model_type) | |
tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=True) | |
def score_pair(pred_text, ref_text): | |
if not pred_text or not ref_text: | |
return None, None, None | |
try: | |
P, R, F1 = scorer.score([pred_text], [ref_text]) | |
return float(P[0]), float(R[0]), float(F1[0]) | |
except Exception: | |
return None, None, None | |
# global (average over chunk pairs) | |
pred_chunks = chunk_text_with_stride(prediction, tokenizer) | |
ref_chunks = chunk_text_with_stride(reference, tokenizer) | |
paired = list(zip(pred_chunks, ref_chunks)) | |
ps, rs, f1s = [], [], [] | |
for pc, rc in paired: | |
p, r, f1 = score_pair(pc, rc) | |
if p is not None: | |
ps.append(p) | |
if r is not None: | |
rs.append(r) | |
if f1 is not None: | |
f1s.append(f1) | |
global_p = sum(ps) / len(ps) if ps else 0.0 | |
global_r = sum(rs) / len(rs) if rs else 0.0 | |
global_f1 = sum(f1s) / len(f1s) if f1s else 0.0 | |
if not per_section: | |
return global_f1 | |
out = { | |
"bertscore_global_p": global_p, | |
"bertscore_global_r": global_r, | |
"bertscore_global_f1": global_f1, | |
} | |
# per-section only if both texts have sections | |
if has_sections(reference) and has_sections(prediction): | |
sections_ref = extract_sections(reference) | |
sections_pred = extract_sections(prediction) | |
for tag in ["S", "O", "A", "P"]: | |
pred_sec = sections_pred.get(tag, "") | |
ref_sec = sections_ref.get(tag, "") | |
if pred_sec and ref_sec: | |
ps, rs, f1s = [], [], [] | |
pred_chunks = chunk_text_with_stride(pred_sec, tokenizer) | |
ref_chunks = chunk_text_with_stride(ref_sec, tokenizer) | |
for pc, rc in zip(pred_chunks, ref_chunks): | |
p, r, f1 = score_pair(pc, rc) | |
if p is not None: | |
ps.append(p) | |
if r is not None: | |
rs.append(r) | |
if f1 is not None: | |
f1s.append(f1) | |
out[f"bertscore_{tag}_p"] = sum(ps) / len(ps) if ps else 0.0 | |
out[f"bertscore_{tag}_r"] = sum(rs) / len(rs) if rs else 0.0 | |
out[f"bertscore_{tag}_f1"] = sum(f1s) / len(f1s) if f1s else 0.0 | |
else: | |
out[f"bertscore_{tag}_p"] = None | |
out[f"bertscore_{tag}_r"] = None | |
out[f"bertscore_{tag}_f1"] = None | |
else: | |
for tag in ["S", "O", "A", "P"]: | |
out[f"bertscore_{tag}_p"] = None | |
out[f"bertscore_{tag}_r"] = None | |
out[f"bertscore_{tag}_f1"] = None | |
return out | |
except Exception: | |
return None if not per_section else {} | |
def compute_batch_bertscore(df: pd.DataFrame, bert_models: list, per_section: bool = False) -> pd.DataFrame: | |
""" | |
If per_section=True and single model: | |
returns per-section + global BERTScore columns for {p,r,f1}. | |
Otherwise: | |
per-model global {p,r,f1} columns: bertscore_{modelshort}_{p,r,f1} | |
""" | |
if not bert_models: | |
return pd.DataFrame(index=df.index) | |
preds = df["dsc_generated_clinical_report"].astype(str).tolist() | |
refs = df["dsc_reference_free_text"].astype(str).tolist() | |
add = {} | |
single_model = len(bert_models) == 1 | |
for friendly in bert_models: | |
model_id = BERT_FRIENDLY_TO_MODEL.get(friendly, friendly) | |
short = model_id.split("/")[-1].replace("-", "_") | |
if per_section and single_model: | |
col_data = { | |
"bertscore_global_p": [], | |
"bertscore_global_r": [], | |
"bertscore_global_f1": [], | |
"bertscore_S_p": [], "bertscore_S_r": [], "bertscore_S_f1": [], | |
"bertscore_O_p": [], "bertscore_O_r": [], "bertscore_O_f1": [], | |
"bertscore_A_p": [], "bertscore_A_r": [], "bertscore_A_f1": [], | |
"bertscore_P_p": [], "bertscore_P_r": [], "bertscore_P_f1": [], | |
} | |
for pred, ref in zip(preds, refs): | |
scores = compute_bertscore_single(ref, pred, model_id, per_section=True) | |
if not scores: | |
for k in col_data: | |
col_data[k].append(None) | |
else: | |
for k in col_data: | |
col_data[k].append(scores.get(k)) | |
add.update(col_data) | |
else: | |
scorer = get_bertscore_scorer(model_id) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) | |
p_list, r_list, f1_list = [], [], [] | |
for pred, ref in zip(preds, refs): | |
try: | |
pred_chunks = chunk_text_with_stride(pred, tokenizer) | |
ref_chunks = chunk_text_with_stride(ref, tokenizer) | |
paired = list(zip(pred_chunks, ref_chunks)) | |
if not paired: | |
p_list.append(None); r_list.append(None); f1_list.append(None) | |
continue | |
Ps, Rs, F1s = [], [], [] | |
for pc, rc in paired: | |
P, R, F1 = scorer.score([pc], [rc]) | |
Ps.append(float(P[0])); Rs.append(float(R[0])); F1s.append(float(F1[0])) | |
p_list.append(sum(Ps)/len(Ps) if Ps else None) | |
r_list.append(sum(Rs)/len(Rs) if Rs else None) | |
f1_list.append(sum(F1s)/len(F1s) if F1s else None) | |
except Exception: | |
p_list.append(None); r_list.append(None); f1_list.append(None) | |
add[f"bertscore_{short}_p"] = p_list | |
add[f"bertscore_{short}_r"] = r_list | |
add[f"bertscore_{short}_f1"] = f1_list | |
return pd.DataFrame(add, index=df.index) | |