Spaces:
Running
Running
# metrics/core.py | |
""" | |
Orchestrates batch computation of selected metrics FOR UPLOAD CSV TAB | |
Now adds precision/recall columns for ROUGE-L and BERTScore. | |
""" | |
import pandas as pd | |
from .bleu import compute_bleu_single, section_bleu, full_bleu, compute_bleu_single | |
from .bleurt import get_hf_bleurt, compute_bleurt_single | |
from .rouge import get_hf_rouge, compute_rouge_single, rougeL_score, rougeL_prec_rec_f1 | |
from .bertscore import compute_batch_bertscore | |
from utils.file_utils import extract_sections, has_sections | |
def compute_all_metrics_batch( | |
df: pd.DataFrame, | |
selected_metrics: list = None, | |
bert_models: list | None = None | |
) -> pd.DataFrame: | |
if selected_metrics is None: | |
selected_metrics = ["BLEU"] | |
df = df.dropna( | |
subset=["dsc_reference_free_text", "dsc_generated_clinical_report"] | |
).copy() | |
if "code_audio_transcription" not in df.columns: | |
df["code_audio_transcription"] = list(range(len(df))) | |
df["has_sections"] = df.apply( | |
lambda r: has_sections(r["dsc_reference_free_text"]) | |
and has_sections(r["dsc_generated_clinical_report"]), | |
axis=1 | |
) | |
# only_one_metric = len(selected_metrics) == 1 | |
# only_bertscore_alone = only_one_metric and selected_metrics == ["BERTSCORE"] | |
out_cols = ["code_audio_transcription"] | |
tags = ["S", "O", "A", "P"] | |
# ------------------------- | |
# BLEU (GLOBAL ONLY) | |
# ------------------------- | |
if "BLEU" in selected_metrics: | |
# OLD per-section logic (now disabled): | |
# if only_one_metric and "BLEU" in selected_metrics: | |
# for tag in tags: | |
# def _sec_bleu(row, tag=tag): | |
# gen = extract_sections(row["dsc_generated_clinical_report"])[tag] | |
# ref = extract_sections(row["dsc_reference_free_text"])[tag] | |
# if row["has_sections"] and gen and ref: | |
# return section_bleu(gen, ref) / 100.0 | |
# return None | |
# df[f"bleu_{tag}"] = df.apply(_sec_bleu, axis=1) | |
# out_cols.append(f"bleu_{tag}") | |
df["bleu_global"] = df.apply( | |
lambda r: full_bleu( | |
r["dsc_generated_clinical_report"], | |
r["dsc_reference_free_text"] | |
) / 100.0, | |
axis=1 | |
) | |
out_cols.append("bleu_global") | |
# ------------------------- | |
# BLEURT (GLOBAL ONLY) | |
# ------------------------- | |
if "BLEURT" in selected_metrics: | |
bleurt = get_hf_bleurt() | |
# OLD per-section logic (now disabled): | |
# if only_one_metric and "BLEURT" in selected_metrics: | |
# for tag in tags: | |
# idxs, gens, refs = [], [], [] | |
# for i, row in df.iterrows(): | |
# gen = extract_sections(row["dsc_generated_clinical_report"])[tag] | |
# ref = extract_sections(row["dsc_reference_free_text"])[tag] | |
# if row["has_sections"] and gen and ref: | |
# idxs.append(i); gens.append(gen); refs.append(ref) | |
# scores = ( | |
# bleurt.compute(predictions=gens, references=refs)["scores"] | |
# if gens else [] | |
# ) | |
# col = [None] * len(df) | |
# for i, sc in zip(idxs, scores): | |
# col[i] = sc | |
# df[f"bleurt_{tag}"] = col | |
# out_cols.append(f"bleurt_{tag}") | |
df["bleurt_global"] = bleurt.compute( | |
predictions=df["dsc_generated_clinical_report"].tolist(), | |
references=df["dsc_reference_free_text"].tolist() | |
)["scores"] | |
out_cols.append("bleurt_global") | |
# ------------------------- | |
# ROUGE-L (GLOBAL ONLY, P/R/F1) | |
# ------------------------- | |
if "ROUGE" in selected_metrics: | |
# OLD per-section logic (now disabled): | |
# if only_one_metric and "ROUGE" in selected_metrics: | |
# for tag in tags: | |
# df[f"rougeL_{tag}_f1"] = df.apply( | |
# lambda row: rougeL_score( | |
# extract_sections(row["dsc_generated_clinical_report"])[tag], | |
# extract_sections(row["dsc_reference_free_text"])[tag] | |
# ) if row["has_sections"] else None, | |
# axis=1 | |
# ) | |
# out_cols.append(f"rougeL_{tag}_f1") | |
# Global with P/R/F1 | |
df[["rougeL_global_p", "rougeL_global_r", "rougeL_global_f1"]] = df.apply( | |
lambda row: pd.Series( | |
rougeL_prec_rec_f1( | |
row["dsc_generated_clinical_report"], | |
row["dsc_reference_free_text"] | |
) | |
), | |
axis=1 | |
) | |
out_cols.extend(["rougeL_global_p", "rougeL_global_r", "rougeL_global_f1"]) | |
# ------------------------- | |
# BERTScore (GLOBAL ONLY) | |
# ------------------------- | |
if "BERTSCORE" in (selected_metrics or []) and bert_models: | |
# OLD per-section option (now disabled): | |
# per_section_bertscore = only_bertscore_alone and bert_models and len(bert_models) == 1 | |
# bert_df = compute_batch_bertscore(df, bert_models, per_section=per_section_bertscore) | |
bert_df = compute_batch_bertscore(df, bert_models, per_section=False) # force global only | |
for col in bert_df.columns: | |
df[col] = bert_df[col] | |
out_cols.append(col) | |
# clip BLEU | |
for c in df.columns: | |
if c.startswith("bleu_"): | |
df[c] = df[c].clip(0.0, 1.0) | |
return df[out_cols] | |