Spaces:
Running
Running
# ui/csv_tab.py | |
""" | |
Builds the CSV-upload tab (batch metrics). | |
- Summary table: **only global scores** (no S/O/A/P). Labels are short (e.g., "BLEU", not "BLEU GLOBAL"). | |
- Detailed table: shows only global F1 columns (colored) and, when available, dark badges for P/R. | |
- CSV export includes whatever columns the backend produced; UI renders only the globals. | |
- Upload "Status" is collapsed into the file input's label. | |
- Errors (missing CSV, columns not chosen, etc.) are displayed in the status textbox under "Run Evaluation". | |
""" | |
import os | |
import time | |
import tempfile | |
import gradio as gr | |
import pandas as pd | |
from metrics import compute_all_metrics_batch, BERT_FRIENDLY_TO_MODEL | |
from ui.widgets import MetricCheckboxGroup, BertCheckboxGroup | |
from utils.file_utils import smart_read_csv | |
from utils.colors_utils import get_metric_color | |
from utils.tokenizer_refgen import generate_diff_html | |
# ------------------- Summary HTML builder (GLOBAL ONLY) ------------------- | |
def build_summary_html(result_df: pd.DataFrame, selected_metrics: list, bert_models: list | None = None) -> str: | |
def safe_stats(col): | |
if col not in result_df.columns: | |
return None | |
s = result_df[col].dropna() | |
if s.empty: | |
return None | |
s = s.astype(float) | |
avg, mn, mx = s.mean(), s.min(), s.max() | |
def audio_id_for(v): | |
subset = result_df[result_df[col].astype(float) == v] | |
if not subset.empty and "code_audio_transcription" in subset.columns: | |
aid = subset.iloc[0]["code_audio_transcription"] | |
try: | |
return int(aid) | |
except Exception: | |
return aid | |
return "" | |
return {"avg": avg, "min": mn, "min_id": audio_id_for(mn), "max": mx, "max_id": audio_id_for(mx)} | |
rows = [] | |
# NOTE: We used to show per-section rows (S/O/A/P) when a single metric was selected. | |
# That logic has been **removed**; we now present **only global** rows for all metrics. | |
if "BLEU" in selected_metrics: | |
s = safe_stats("bleu_global") | |
if s: | |
rows.append(("bleu_global", s)) | |
if "BLEURT" in selected_metrics: | |
s = safe_stats("bleurt_global") | |
if s: | |
rows.append(("bleurt_global", s)) | |
if "ROUGE" in selected_metrics: | |
s = safe_stats("rougeL_global_f1") | |
if s: | |
rows.append(("rougeL_global_f1", s)) | |
# BERTScore (global only) | |
if "BERTSCORE" in selected_metrics and bert_models: | |
# NOTE: Previously, if only BERTScore with one model was selected, we added per-section rows. | |
# That behavior is **disabled**. We only show global columns: | |
# - bertscore_<short>_f1 (multi-model) | |
# - or bertscore_global_f1 (if that's what backend produced) | |
for friendly in bert_models: | |
mid = BERT_FRIENDLY_TO_MODEL.get(friendly) | |
if not mid: | |
continue | |
short = mid.split("/")[-1].replace("-", "_") | |
col = f"bertscore_{short}_f1" if f"bertscore_{short}_f1" in result_df.columns else "bertscore_global_f1" | |
s = safe_stats(col) | |
if s: | |
rows.append((col, s)) | |
if not rows: | |
return "<div style='padding:8px;background:#1f1f1f;color:#eee;border-radius:6px;'>No summary available.</div>" | |
# Build HTML table | |
html = """ | |
<div style="margin-bottom:12px;overflow-x:auto;"> | |
<div style="font-weight:600;margin-bottom:4px;color:#f5f5f5;font-size:16px;">Summary Statistics</div> | |
<table style="border-collapse:collapse;width:100%;font-family:system-ui,-apple-system,BlinkMacSystemFont,Segoe UI,Roboto,sans-serif;border-radius:8px;overflow:hidden;min-width:500px;"> | |
<thead><tr> | |
<th style="padding:8px 12px;background:#2d3748;color:#fff;text-align:left;font-weight:600;">Metric</th> | |
<th style="padding:8px 12px;background:#2d3748;color:#fff;text-align:center;font-weight:600;">Avg</th> | |
<th style="padding:8px 12px;background:#2d3748;color:#fff;text-align:center;font-weight:600;">Min (ID)</th> | |
<th style="padding:8px 12px;background:#2d3748;color:#fff;text-align:center;font-weight:600;">Max (ID)</th> | |
</tr></thead><tbody> | |
""" | |
for col, stat in rows: | |
# Pretty names (drop "GLOBAL") | |
if col == "bleu_global": | |
name = "BLEU" | |
elif col == "bleurt_global": | |
name = "BLEURT" | |
elif col == "rougeL_global_f1": | |
name = "ROUGE-L" | |
elif col.startswith("bertscore_"): | |
if col == "bertscore_global_f1": | |
name = "BERTSCORE" | |
else: | |
label = " ".join(col.split("_")[1:-1]).upper() | |
name = f"BERTSCORE {label}" if label else "BERTSCORE" | |
else: | |
name = col.replace("_", " ").upper() | |
avg = f"{stat['avg']:.4f}" | |
mn = f"{stat['min']:.4f} ({stat['min_id']})" if stat['min_id'] != "" else f"{stat['min']:.4f}" | |
mx = f"{stat['max']:.4f} ({stat['max_id']})" if stat['max_id'] != "" else f"{stat['max']:.4f}" | |
# Color scale by metric family (F1) | |
if col.startswith("bleu_"): | |
ca, cm, cx = get_metric_color(stat['avg'], "BLEU"), get_metric_color(stat['min'], "BLEU"), get_metric_color(stat['max'], "BLEU") | |
elif col.startswith("bleurt_"): | |
ca, cm, cx = get_metric_color(stat['avg'], "BLEURT"), get_metric_color(stat['min'], "BLEURT"), get_metric_color(stat['max'], "BLEURT") | |
elif col.startswith("rougeL_"): | |
ca, cm, cx = get_metric_color(stat['avg'], "ROUGE"), get_metric_color(stat['min'], "ROUGE"), get_metric_color(stat['max'], "ROUGE") | |
else: | |
ca, cm, cx = get_metric_color(stat['avg'], "BERTSCORE"), get_metric_color(stat['min'], "BERTSCORE"), get_metric_color(stat['max'], "BERTSCORE") | |
html += f""" | |
<tr style="background:#0f1218;"> | |
<td style="padding:8px 12px;border:1px solid #2f3240;color:#fff;white-space:nowrap;">{name}</td> | |
<td style="padding:8px 12px;border:1px solid #2f3240;background:{ca};color:#fff;text-align:center;white-space:nowrap;">{avg}</td> | |
<td style="padding:8px 12px;border:1px solid #2f3240;background:{cm};color:#fff;text-align:center;white-space:nowrap;">{mn}</td> | |
<td style="padding:8px 12px;border:1px solid #2f3240;background:{cx};color:#fff;text-align:center;white-space:nowrap;">{mx}</td> | |
</tr> | |
""" | |
html += "</tbody></table></div>" | |
return html | |
# ------------------- Detailed table (GLOBAL ONLY, F1 colored + dark P/R badges) ------------------- | |
def render_results_table_html(result_df: pd.DataFrame) -> str: | |
if result_df is None or result_df.empty: | |
return "<div style='padding:8px;background:#1f1f1f;color:#eee;border-radius:6px;'>No results.</div>" | |
# Keep only *global* F1 columns (skip *_p/_r and any S/O/A/P) | |
def is_global_f1(col: str) -> bool: | |
if col == "code_audio_transcription": | |
return False | |
if col.endswith("_p") or col.endswith("_r"): | |
return False | |
if col.startswith("bleu_"): | |
return col == "bleu_global" | |
if col.startswith("bleurt_"): | |
return col == "bleurt_global" | |
if col.startswith("rougeL_"): | |
return col == "rougeL_global_f1" | |
if col.startswith("bertscore_"): | |
parts = col.split("_") | |
# Exclude per-section: bertscore_S_f1, etc. | |
if len(parts) >= 2 and parts[1] in {"S", "O", "A", "P"}: | |
return False | |
# Allow model-specific or "bertscore_global_f1" | |
return parts[-1] == "f1" or col == "bertscore_global_f1" | |
return False | |
f1_cols = [c for c in result_df.columns if is_global_f1(c)] | |
# Sort for readability: BLEU, BLEURT, ROUGE-L, BERTSCORE (...) | |
def _grp_key(col): | |
if col.startswith("bleu_"): | |
g = 0 | |
elif col.startswith("bleurt_"): | |
g = 1 | |
elif col.startswith("rougeL_"): | |
g = 2 | |
elif col.startswith("bertscore_"): | |
g = 3 | |
else: | |
g = 9 | |
return (g, col) | |
f1_cols = sorted(f1_cols, key=_grp_key) | |
# HTML table | |
html = [ | |
"<div style='overflow-x:auto;'>", | |
"<div style='font-weight:600;margin:8px 0;color:#f5f5f5;font-size:16px;'>Individual Results</div>", | |
"<table style='border-collapse:collapse;width:100%;font-family:system-ui,-apple-system,BlinkMacSystemFont,Segoe UI,Roboto,sans-serif;border-radius:8px;overflow:hidden;'>", | |
"<thead><tr>", | |
"<th style='padding:8px 12px;background:#2d3748;color:#fff;text-align:left;font-weight:600;white-space:nowrap;'>ID</th>", | |
] | |
def pretty_header(col: str) -> str: | |
if col == "bleu_global": | |
return "BLEU" | |
if col == "bleurt_global": | |
return "BLEURT" | |
if col == "rougeL_global_f1": | |
return "ROUGE-L" | |
if col.startswith("bertscore_"): | |
if col == "bertscore_global_f1": | |
return "BERTSCORE" | |
label = " ".join(col.split("_")[1:-1]).upper() | |
return f"BERTSCORE {label}" if label else "BERTSCORE" | |
return col.replace("_", " ").upper() | |
for col in f1_cols: | |
html.append( | |
f"<th style='padding:8px 12px;background:#2d3748;color:#fff;text-align:center;font-weight:600;white-space:nowrap;'>{pretty_header(col)}</th>" | |
) | |
html.append("</tr></thead><tbody>") | |
for _, row in result_df.iterrows(): | |
rid = row.get("code_audio_transcription", "") | |
try: | |
rid = int(rid) | |
except Exception: | |
pass | |
html.append("<tr style='background:#0f1218;'>") | |
html.append(f"<td style='padding:8px 12px;border:1px solid #2f3240;color:#fff;white-space:nowrap;'>{rid}</td>") | |
for col in f1_cols: | |
val = row.get(col, None) | |
# figure metric family & pick P/R columns accordingly | |
metric_kind = "BERTSCORE" | |
p_text = r_text = "" | |
if col.startswith("bleu_"): | |
metric_kind = "BLEU" | |
# BLEU: no P/R | |
elif col.startswith("bleurt_"): | |
metric_kind = "BLEURT" | |
elif col.startswith("rougeL_"): | |
metric_kind = "ROUGE" | |
base = "rougeL_global" # global root | |
pcol, rcol = f"{base}_p", f"{base}_r" | |
p = row.get(pcol, None) | |
r = row.get(rcol, None) | |
p_text = f"P: {p:.4f}" if isinstance(p, (int, float)) else "" | |
r_text = f"R: {r:.4f}" if isinstance(r, (int, float)) else "" | |
elif col.startswith("bertscore_"): | |
metric_kind = "BERTSCORE" | |
# try model-specific first | |
base = col[:-3] if col.endswith("_f1") else col # strip trailing _f1 | |
pcol, rcol = f"{base}_p", f"{base}_r" | |
if pcol not in result_df.columns and rcol not in result_df.columns: | |
# fallback to "bertscore_global" naming | |
pcol, rcol = "bertscore_global_p", "bertscore_global_r" | |
p = row.get(pcol, None) | |
r = row.get(rcol, None) | |
p_text = f"P: {p:.4f}" if isinstance(p, (int, float)) else "" | |
r_text = f"R: {r:.4f}" if isinstance(r, (int, float)) else "" | |
if isinstance(val, (int, float)): | |
bg = get_metric_color(float(val), metric_kind) | |
val_text = f"{float(val):.4f}" | |
else: | |
bg = "transparent" | |
val_text = "—" | |
# Dark badges for P/R | |
pills = [] | |
if p_text: | |
pills.append("<span style='padding:1px 6px;border-radius:999px;background:rgba(0,0,0,.48);color:#fff;display:inline-block;'>" | |
f"{p_text}</span>") | |
if r_text: | |
pills.append("<span style='padding:1px 6px;border-radius:999px;background:rgba(0,0,0,.48);color:#fff;display:inline-block;margin-left:6px;'>" | |
f"{r_text}</span>") | |
badges = "" | |
if pills: | |
badges = "<div style='font-size:12px;margin-top:4px;line-height:1.2;'>" + "".join(pills) + "</div>" | |
html.append( | |
f"<td style='padding:8px 12px;border:1px solid #2f3240;background:{bg};color:#fff;text-align:center;white-space:nowrap;'>" | |
f"{val_text}{badges}</td>" | |
) | |
html.append("</tr>") | |
html.append("</tbody></table></div>") | |
return "".join(html) | |
# ------------------- Tab builder ------------------- | |
def build_csv_tab(): | |
with gr.Blocks() as tab: | |
state_df = gr.State() # original uploaded DataFrame | |
state_pairs = gr.State() # standardized pairs: id + reference + generated | |
state_result = gr.State() # metrics result DataFrame for export | |
gr.Markdown("# RUN AN EXPERIMENT VIA CSV UPLOAD") | |
gr.Markdown( | |
"Upload a CSV of reference/generated text pairs, map the columns, pick metrics, and run a batch evaluation. \n " | |
"F1 is highlighted in color; Precision/Recall appear as small dark badges." | |
) | |
gr.Markdown("## Experiment Configuration") | |
# 1) Upload CSV (status collapsed into the label) | |
gr.Markdown("### Upload CSV") | |
gr.Markdown("Provide a CSV file containing your data. It should include columns for the reference text, the generated text, and an identifier (e.g., audio ID).") | |
with gr.Row(): | |
file_input = gr.File(label="Upload CSV", file_types=[".csv"]) | |
# 2) Map Columns | |
gr.Markdown("### Map Columns") | |
gr.Markdown("Select which columns in your CSV correspond to the reference text, generated text, and audio/example ID.") | |
with gr.Row(visible=False) as mapping: | |
ref_col = gr.Dropdown(label="Reference Column", choices=[]) | |
gen_col = gr.Dropdown(label="Generated Column", choices=[]) | |
id_col = gr.Dropdown(label="Audio ID Column", choices=[]) | |
# 3) Select Metrics | |
gr.Markdown("### Select Metrics") | |
metric_selector = MetricCheckboxGroup() | |
bert_model_selector = BertCheckboxGroup() | |
# ---------- Divider before RESULTS ---------- | |
gr.HTML("""<div style="height:1px;margin:22px 0;background: | |
linear-gradient(90deg, rgba(0,0,0,0) 0%, #4a5568 35%, #4a5568 65%, rgba(0,0,0,0) 100%);"></div>""") | |
gr.Markdown("# RESULTS") | |
# Emphasize the run button | |
gr.HTML(""" | |
<style> | |
#run-eval-btn button { | |
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%) !important; | |
color: #fff !important; | |
border: none !important; | |
box-shadow: 0 6px 16px rgba(0,0,0,.25); | |
} | |
#run-eval-btn button:hover { filter: brightness(1.08); transform: translateY(-1px); } | |
</style> | |
""") | |
# 4) Run Evaluation (+ Export control) | |
with gr.Row(): | |
run_btn = gr.Button("🚀 Run Evaluation", variant="primary", elem_id="run-eval-btn") | |
download_btn = gr.DownloadButton(label="⬇️ Export full results (CSV)", visible=False) | |
# This Text box will display both success and error messages | |
output_status = gr.Text() | |
summary_output = gr.HTML() | |
table_output = gr.HTML() | |
# 5) Inspect example | |
gr.Markdown("### Inspect an Example") | |
gr.Markdown("Pick an example by its ID to view the reference vs generated text with token-level differences highlighted.") | |
with gr.Accordion("🔍 Show reference & generated text", open=False): | |
pick_id = gr.Dropdown(label="Pick an Audio ID", choices=[]) | |
ref_disp = gr.Textbox(label="Reference Text", lines=6, interactive=False) | |
gen_disp = gr.Textbox(label="Generated Text", lines=6, interactive=False) | |
diff_disp= gr.HTML() | |
# ---- Handlers ---- | |
def handle_upload(f): | |
if not f: | |
# reset label & hide mapping | |
return ( | |
None, | |
gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[]), | |
gr.update(visible=False), | |
gr.update(label="Upload CSV") | |
) | |
df = smart_read_csv(f.name) | |
cols = list(df.columns) | |
return ( | |
df, | |
gr.update(choices=cols, value=None), | |
gr.update(choices=cols, value=None), | |
gr.update(choices=cols, value=None), | |
gr.update(visible=True), | |
gr.update(label="Upload CSV — OK: selecione as colunas.") | |
) | |
def run_batch(df, r, g, i, mets, berts): | |
# Pre-flight validation: CSV uploaded? | |
if df is None: | |
return ( | |
"Erro: por favor faça upload de um CSV e selecione as colunas.", | |
"", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
) | |
# Columns chosen? | |
if not r or not g or not i: | |
return ( | |
"Erro: selecione as colunas de Reference, Generated e Audio ID.", | |
"", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
) | |
# Columns exist? | |
missing = [c for c in [i, r, g] if c not in df.columns] | |
if missing: | |
return ( | |
f"Erro: as colunas não existem no CSV: {missing}", | |
"", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
) | |
# Metrics chosen? | |
if not mets: | |
return ( | |
"Erro: selecione pelo menos uma métrica.", | |
"", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
) | |
# Rename into standard schema (this is what we'll use for "Inspect an Example") | |
try: | |
sub = df[[i, r, g]].rename( | |
columns={i: "code_audio_transcription", r: "dsc_reference_free_text", g: "dsc_generated_clinical_report"} | |
) | |
except Exception as e: | |
return ( | |
f"Erro ao preparar dados: {e}", | |
"", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
) | |
# Compute metrics | |
try: | |
result = compute_all_metrics_batch( | |
sub, | |
mets, | |
berts if "BERTSCORE" in (mets or []) else None | |
) | |
except Exception as e: | |
return ( | |
f"Erro ao calcular métricas: {e}", | |
"", "", gr.update(choices=[]), None, None, gr.update(visible=False) | |
) | |
# Normalize IDs for dropdown | |
try: | |
raw_ids = result["code_audio_transcription"].dropna().unique().tolist() | |
ids = [] | |
for x in raw_ids: | |
try: | |
ids.append(int(x)) | |
except Exception: | |
ids.append(x) | |
ids = sorted(ids, key=lambda z: (not isinstance(z, int), z)) | |
except Exception: | |
ids = [] | |
# Build HTML views | |
try: | |
summary = build_summary_html(result, mets, berts if "BERTSCORE" in (mets or []) else None) | |
table = render_results_table_html(result) | |
except Exception as e: | |
return ( | |
f"Erro ao renderizar resultados: {e}", | |
"", "", gr.update(choices=ids, value=None), None, None, gr.update(visible=False) | |
) | |
# Keep results for export & show download button | |
# Also keep standardized pairs (sub) for the "Inspect an Example" view | |
return ( | |
"Métricas calculadas com sucesso.", | |
summary, | |
table, | |
gr.update(choices=ids, value=None), | |
result, | |
sub, | |
gr.update(visible=True), | |
) | |
def show_example(pairs_df, audio_id): | |
# Use the standardized pairs dataframe (id + reference + generated) | |
if pairs_df is None or audio_id is None: | |
return "", "", "" | |
try: | |
row = pairs_df[pairs_df["code_audio_transcription"] == audio_id] | |
if row.empty: | |
# Try float cast fallback for IDs that come as strings | |
try: | |
audio_id2 = float(audio_id) | |
row = pairs_df[pairs_df["code_audio_transcription"] == audio_id2] | |
except Exception: | |
return "", "", "" | |
if row.empty: | |
return "", "", "" | |
row = row.iloc[0] | |
ref_txt = row["dsc_reference_free_text"] | |
gen_txt = row["dsc_generated_clinical_report"] | |
return ref_txt, gen_txt, generate_diff_html(ref_txt, gen_txt) | |
except Exception: | |
return "", "", "" | |
def _export_results_csv(df: pd.DataFrame | None) -> str: | |
# Always export with comma separator; include ALL columns that were computed | |
if df is None or df.empty: | |
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") | |
with open(tmp.name, "w", encoding="utf-8") as f: | |
f.write("no_data\n") | |
return tmp.name | |
ts = time.strftime("%Y%m%d_%H%M%S") | |
tmp_path = os.path.join(tempfile.gettempdir(), f"automatic_metrics_{ts}.csv") | |
df.to_csv(tmp_path, sep=",", index=False) | |
return tmp_path | |
# ---- Wiring ---- | |
file_input.change( | |
fn=handle_upload, | |
inputs=[file_input], | |
outputs=[state_df, ref_col, gen_col, id_col, mapping, file_input], # update label in place | |
) | |
metric_selector.change( | |
fn=lambda ms: gr.update(visible="BERTSCORE" in ms), | |
inputs=[metric_selector], | |
outputs=[bert_model_selector], | |
) | |
run_btn.click( | |
fn=run_batch, | |
inputs=[state_df, ref_col, gen_col, id_col, metric_selector, bert_model_selector], | |
outputs=[output_status, summary_output, table_output, pick_id, state_result, state_pairs, download_btn], | |
) | |
# Use standardized pairs DF for example view (fixes KeyError on original DF) | |
pick_id.change( | |
fn=show_example, | |
inputs=[state_pairs, pick_id], | |
outputs=[ref_disp, gen_disp, diff_disp], | |
) | |
download_btn.click( | |
fn=_export_results_csv, | |
inputs=[state_result], | |
outputs=download_btn, # path returned; Gradio serves it | |
) | |
return tab | |