File size: 5,746 Bytes
4b112ae
 
 
2e6da1f
4b112ae
 
 
 
2e6da1f
4b112ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d906888
 
 
4b112ae
 
 
d906888
 
 
4b112ae
d906888
 
 
 
 
 
 
 
 
 
 
 
4b112ae
 
 
 
 
 
 
 
 
d906888
 
 
4b112ae
 
d906888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b112ae
 
 
 
 
 
d906888
 
 
4b112ae
d906888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b112ae
d906888
 
 
 
 
4b112ae
 
 
 
d906888
4b112ae
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# metrics/core.py
"""

Orchestrates batch computation of selected metrics FOR UPLOAD CSV TAB

Now adds precision/recall columns for ROUGE-L and BERTScore.

"""
import pandas as pd
from .bleu import compute_bleu_single, section_bleu, full_bleu, compute_bleu_single
from .bleurt import get_hf_bleurt, compute_bleurt_single
from .rouge import get_hf_rouge, compute_rouge_single, rougeL_score, rougeL_prec_rec_f1
from .bertscore import compute_batch_bertscore
from utils.file_utils import extract_sections, has_sections


def compute_all_metrics_batch(

    df: pd.DataFrame,

    selected_metrics: list = None,

    bert_models: list | None = None

) -> pd.DataFrame:
    if selected_metrics is None:
        selected_metrics = ["BLEU"]

    df = df.dropna(
        subset=["dsc_reference_free_text", "dsc_generated_clinical_report"]
    ).copy()

    if "code_audio_transcription" not in df.columns:
        df["code_audio_transcription"] = list(range(len(df)))

    df["has_sections"] = df.apply(
        lambda r: has_sections(r["dsc_reference_free_text"])
                  and has_sections(r["dsc_generated_clinical_report"]),
        axis=1
    )

    # only_one_metric = len(selected_metrics) == 1
    # only_bertscore_alone = only_one_metric and selected_metrics == ["BERTSCORE"]

    out_cols = ["code_audio_transcription"]
    tags = ["S", "O", "A", "P"]

    # -------------------------
    # BLEU (GLOBAL ONLY)
    # -------------------------
    if "BLEU" in selected_metrics:
        # OLD per-section logic (now disabled):
        # if only_one_metric and "BLEU" in selected_metrics:
        #     for tag in tags:
        #         def _sec_bleu(row, tag=tag):
        #             gen = extract_sections(row["dsc_generated_clinical_report"])[tag]
        #             ref = extract_sections(row["dsc_reference_free_text"])[tag]
        #             if row["has_sections"] and gen and ref:
        #                 return section_bleu(gen, ref) / 100.0
        #             return None
        #         df[f"bleu_{tag}"] = df.apply(_sec_bleu, axis=1)
        #         out_cols.append(f"bleu_{tag}")

        df["bleu_global"] = df.apply(
            lambda r: full_bleu(
                r["dsc_generated_clinical_report"],
                r["dsc_reference_free_text"]
            ) / 100.0,
            axis=1
        )
        out_cols.append("bleu_global")

    # -------------------------
    # BLEURT (GLOBAL ONLY)
    # -------------------------
    if "BLEURT" in selected_metrics:
        bleurt = get_hf_bleurt()

        # OLD per-section logic (now disabled):
        # if only_one_metric and "BLEURT" in selected_metrics:
        #     for tag in tags:
        #         idxs, gens, refs = [], [], []
        #         for i, row in df.iterrows():
        #             gen = extract_sections(row["dsc_generated_clinical_report"])[tag]
        #             ref = extract_sections(row["dsc_reference_free_text"])[tag]
        #             if row["has_sections"] and gen and ref:
        #                 idxs.append(i); gens.append(gen); refs.append(ref)
        #         scores = (
        #             bleurt.compute(predictions=gens, references=refs)["scores"]
        #             if gens else []
        #         )
        #         col = [None] * len(df)
        #         for i, sc in zip(idxs, scores):
        #             col[i] = sc
        #         df[f"bleurt_{tag}"] = col
        #         out_cols.append(f"bleurt_{tag}")

        df["bleurt_global"] = bleurt.compute(
            predictions=df["dsc_generated_clinical_report"].tolist(),
            references=df["dsc_reference_free_text"].tolist()
        )["scores"]
        out_cols.append("bleurt_global")

    # -------------------------
    # ROUGE-L (GLOBAL ONLY, P/R/F1)
    # -------------------------
    if "ROUGE" in selected_metrics:
        # OLD per-section logic (now disabled):
        # if only_one_metric and "ROUGE" in selected_metrics:
        #     for tag in tags:
        #         df[f"rougeL_{tag}_f1"] = df.apply(
        #             lambda row: rougeL_score(
        #                 extract_sections(row["dsc_generated_clinical_report"])[tag],
        #                 extract_sections(row["dsc_reference_free_text"])[tag]
        #             ) if row["has_sections"] else None,
        #             axis=1
        #         )
        #         out_cols.append(f"rougeL_{tag}_f1")

        # Global with P/R/F1
        df[["rougeL_global_p", "rougeL_global_r", "rougeL_global_f1"]] = df.apply(
            lambda row: pd.Series(
                rougeL_prec_rec_f1(
                    row["dsc_generated_clinical_report"],
                    row["dsc_reference_free_text"]
                )
            ),
            axis=1
        )
        out_cols.extend(["rougeL_global_p", "rougeL_global_r", "rougeL_global_f1"])

    # -------------------------
    # BERTScore (GLOBAL ONLY)
    # -------------------------
    if "BERTSCORE" in (selected_metrics or []) and bert_models:
        # OLD per-section option (now disabled):
        # per_section_bertscore = only_bertscore_alone and bert_models and len(bert_models) == 1
        # bert_df = compute_batch_bertscore(df, bert_models, per_section=per_section_bertscore)

        bert_df = compute_batch_bertscore(df, bert_models, per_section=False)  # force global only
        for col in bert_df.columns:
            df[col] = bert_df[col]
            out_cols.append(col)

    # clip BLEU
    for c in df.columns:
        if c.startswith("bleu_"):
            df[c] = df[c].clip(0.0, 1.0)

    return df[out_cols]