File size: 10,370 Bytes
4b112ae
 
 
2e6da1f
4b112ae
8d4bef7
2e6da1f
8d4bef7
 
2e6da1f
8d4bef7
4b112ae
2e6da1f
 
 
 
 
 
 
 
 
4b112ae
 
 
 
8d4bef7
4b112ae
 
 
 
 
8d4bef7
2e6da1f
 
8d4bef7
 
2e6da1f
 
8d4bef7
2e6da1f
 
8d4bef7
 
4b112ae
 
8d4bef7
2e6da1f
4b112ae
 
 
 
 
 
 
 
 
 
 
 
2e6da1f
4b112ae
 
 
8d4bef7
4b112ae
 
 
2e6da1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d4bef7
 
4b112ae
 
 
 
2e6da1f
 
4b112ae
 
 
 
 
 
 
 
 
 
2e6da1f
4b112ae
2e6da1f
 
4b112ae
2e6da1f
4b112ae
2e6da1f
4b112ae
 
 
2e6da1f
 
 
 
 
 
 
 
 
 
 
 
4b112ae
 
2e6da1f
4b112ae
2e6da1f
 
 
 
 
4b112ae
2e6da1f
8d4bef7
4b112ae
 
2e6da1f
4b112ae
 
 
2e6da1f
 
 
 
 
 
 
 
 
 
 
 
 
8d4bef7
4b112ae
2e6da1f
 
4b112ae
 
2e6da1f
 
 
4b112ae
 
 
 
 
 
8d4bef7
4b112ae
 
2e6da1f
 
 
 
4b112ae
8d4bef7
4b112ae
 
 
 
 
 
 
 
 
 
 
 
 
 
2e6da1f
 
4b112ae
2e6da1f
 
 
 
4b112ae
 
 
 
2e6da1f
 
4b112ae
2e6da1f
 
4b112ae
 
 
 
2e6da1f
4b112ae
 
2e6da1f
 
 
4b112ae
2e6da1f
4b112ae
2e6da1f
4b112ae
2e6da1f
 
 
 
 
4b112ae
2e6da1f
 
 
 
4b112ae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# metrics/bertscore.py
"""

BERTScore helpers: scorer init, single and batch computation.

Adds precision/recall alongside F1 (UI shows F1; CSV export includes P/R too).

"""
from bert_score import BERTScorer
from functools import lru_cache
from transformers import AutoTokenizer, AutoConfig
from utils.file_utils import extract_sections, has_sections
import pandas as pd


# manual layer mapping (fallback; we also cap by config.num_hidden_layers if available)
_MANUAL_BERT_LAYERS = {
    "neuralmind/bert-base-portuguese-cased": 12,
    "pucpr/biobertpt-clin": 12,
    "xlm-roberta-large": 24,
    "medicalai/ClinicalBERT": 12,
}

# friendly label ↔ model id mapping
BERT_FRIENDLY_TO_MODEL = {
    "Portuguese (Br) Bert": "neuralmind/bert-base-portuguese-cased",
    "Portuguese (Br) Clinical BioBert": "pucpr/biobertpt-clin",
    "Multilingual Bert ( RoBerta)": "xlm-roberta-large",
    "ClinicalBERT (medicalai)": "medicalai/ClinicalBERT",
}
BERT_MODEL_TO_FRIENDLY = {v: k for k, v in BERT_FRIENDLY_TO_MODEL.items()}

_USE_RESCALE_BASELINE = False


def _safe_num_layers(model_type: str) -> int | None:
    # Try to read from HF config; fallback to manual
    try:
        cfg = AutoConfig.from_pretrained(model_type)
        if hasattr(cfg, "num_hidden_layers") and isinstance(cfg.num_hidden_layers, int):
            return cfg.num_hidden_layers
    except Exception:
        pass
    return _MANUAL_BERT_LAYERS.get(model_type)


@lru_cache(maxsize=6)
def get_bertscore_scorer(model_type: str):
    lang = "pt" if any(model_type.startswith(p) for p in ("neuralmind", "pucpr")) else ""
    num_layers = _safe_num_layers(model_type)
    kwargs = {"lang": lang, "rescale_with_baseline": _USE_RESCALE_BASELINE}
    if num_layers is not None:
        kwargs["num_layers"] = num_layers
    return BERTScorer(model_type=model_type, **kwargs)


def chunk_text_with_stride(text: str, tokenizer, max_len: int = 512, stride: int = 50):
    ids = tokenizer.encode(text, add_special_tokens=True)
    if len(ids) <= max_len:
        return [tokenizer.decode(ids, skip_special_tokens=True)]
    chunks, step = [], max_len - stride
    for i in range(0, len(ids), step):
        subset = ids[i : i + max_len]
        if not subset:
            break
        chunks.append(tokenizer.decode(subset, skip_special_tokens=True))
        if i + max_len >= len(ids):
            break
    return chunks

def bertscore_prec_rec_f1(reference: str, prediction: str, model_type: str):
    """

    Return (precision, recall, f1) for a single reference/prediction pair.

    Handles long texts by chunking and averaging the per-chunk scores.

    On error, returns (None, None, None).

    """
    if not reference or not prediction:
        return (None, None, None)
    try:
        scorer = get_bertscore_scorer(model_type)
        tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=True)

        gen_chunks = chunk_text_with_stride(prediction, tokenizer)
        ref_chunks = chunk_text_with_stride(reference, tokenizer)
        paired = list(zip(gen_chunks, ref_chunks))
        if not paired:
            return (0.0, 0.0, 0.0)

        p_vals, r_vals, f_vals = [], [], []
        for gc, rc in paired:
            P, R, F1 = scorer.score([gc], [rc])
            p_vals.append(float(P[0]))
            r_vals.append(float(R[0]))
            f_vals.append(float(F1[0]))

        n = float(len(p_vals))
        return (sum(p_vals) / n, sum(r_vals) / n, sum(f_vals) / n)
    except Exception:
        return (None, None, None)


def compute_bertscore_single(reference: str, prediction: str, model_type: str, per_section: bool = False):
    """

    If per_section=False: returns float global F1 (0..1) or None on error.

    If per_section=True: returns dict with keys:

        - bertscore_global_{p,r,f1}

        - bertscore_{S,O,A,P}_{p,r,f1} (when sections exist; else None)

    """
    if not reference or not prediction:
        return None if not per_section else {}

    try:
        scorer = get_bertscore_scorer(model_type)
        tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=True)

        def score_pair(pred_text, ref_text):
            if not pred_text or not ref_text:
                return None, None, None
            try:
                P, R, F1 = scorer.score([pred_text], [ref_text])
                return float(P[0]), float(R[0]), float(F1[0])
            except Exception:
                return None, None, None

        # global (average over chunk pairs)
        pred_chunks = chunk_text_with_stride(prediction, tokenizer)
        ref_chunks = chunk_text_with_stride(reference, tokenizer)
        paired = list(zip(pred_chunks, ref_chunks))
        ps, rs, f1s = [], [], []
        for pc, rc in paired:
            p, r, f1 = score_pair(pc, rc)
            if p is not None:
                ps.append(p)
            if r is not None:
                rs.append(r)
            if f1 is not None:
                f1s.append(f1)
        global_p = sum(ps) / len(ps) if ps else 0.0
        global_r = sum(rs) / len(rs) if rs else 0.0
        global_f1 = sum(f1s) / len(f1s) if f1s else 0.0

        if not per_section:
            return global_f1

        out = {
            "bertscore_global_p": global_p,
            "bertscore_global_r": global_r,
            "bertscore_global_f1": global_f1,
        }

        # per-section only if both texts have sections
        if has_sections(reference) and has_sections(prediction):
            sections_ref = extract_sections(reference)
            sections_pred = extract_sections(prediction)
            for tag in ["S", "O", "A", "P"]:
                pred_sec = sections_pred.get(tag, "")
                ref_sec = sections_ref.get(tag, "")
                if pred_sec and ref_sec:
                    ps, rs, f1s = [], [], []
                    pred_chunks = chunk_text_with_stride(pred_sec, tokenizer)
                    ref_chunks = chunk_text_with_stride(ref_sec, tokenizer)
                    for pc, rc in zip(pred_chunks, ref_chunks):
                        p, r, f1 = score_pair(pc, rc)
                        if p is not None:
                            ps.append(p)
                        if r is not None:
                            rs.append(r)
                        if f1 is not None:
                            f1s.append(f1)
                    out[f"bertscore_{tag}_p"] = sum(ps) / len(ps) if ps else 0.0
                    out[f"bertscore_{tag}_r"] = sum(rs) / len(rs) if rs else 0.0
                    out[f"bertscore_{tag}_f1"] = sum(f1s) / len(f1s) if f1s else 0.0
                else:
                    out[f"bertscore_{tag}_p"] = None
                    out[f"bertscore_{tag}_r"] = None
                    out[f"bertscore_{tag}_f1"] = None
        else:
            for tag in ["S", "O", "A", "P"]:
                out[f"bertscore_{tag}_p"] = None
                out[f"bertscore_{tag}_r"] = None
                out[f"bertscore_{tag}_f1"] = None

        return out
    except Exception:
        return None if not per_section else {}


def compute_batch_bertscore(df: pd.DataFrame, bert_models: list, per_section: bool = False) -> pd.DataFrame:
    """

    If per_section=True and single model:

        returns per-section + global BERTScore columns for {p,r,f1}.

    Otherwise:

        per-model global {p,r,f1} columns: bertscore_{modelshort}_{p,r,f1}

    """
    if not bert_models:
        return pd.DataFrame(index=df.index)

    preds = df["dsc_generated_clinical_report"].astype(str).tolist()
    refs = df["dsc_reference_free_text"].astype(str).tolist()

    add = {}
    single_model = len(bert_models) == 1

    for friendly in bert_models:
        model_id = BERT_FRIENDLY_TO_MODEL.get(friendly, friendly)
        short = model_id.split("/")[-1].replace("-", "_")

        if per_section and single_model:
            col_data = {
                "bertscore_global_p": [],
                "bertscore_global_r": [],
                "bertscore_global_f1": [],
                "bertscore_S_p": [], "bertscore_S_r": [], "bertscore_S_f1": [],
                "bertscore_O_p": [], "bertscore_O_r": [], "bertscore_O_f1": [],
                "bertscore_A_p": [], "bertscore_A_r": [], "bertscore_A_f1": [],
                "bertscore_P_p": [], "bertscore_P_r": [], "bertscore_P_f1": [],
            }
            for pred, ref in zip(preds, refs):
                scores = compute_bertscore_single(ref, pred, model_id, per_section=True)
                if not scores:
                    for k in col_data:
                        col_data[k].append(None)
                else:
                    for k in col_data:
                        col_data[k].append(scores.get(k))
            add.update(col_data)
        else:
            scorer = get_bertscore_scorer(model_id)
            tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
            p_list, r_list, f1_list = [], [], []
            for pred, ref in zip(preds, refs):
                try:
                    pred_chunks = chunk_text_with_stride(pred, tokenizer)
                    ref_chunks = chunk_text_with_stride(ref, tokenizer)
                    paired = list(zip(pred_chunks, ref_chunks))
                    if not paired:
                        p_list.append(None); r_list.append(None); f1_list.append(None)
                        continue
                    Ps, Rs, F1s = [], [], []
                    for pc, rc in paired:
                        P, R, F1 = scorer.score([pc], [rc])
                        Ps.append(float(P[0])); Rs.append(float(R[0])); F1s.append(float(F1[0]))
                    p_list.append(sum(Ps)/len(Ps) if Ps else None)
                    r_list.append(sum(Rs)/len(Rs) if Rs else None)
                    f1_list.append(sum(F1s)/len(F1s) if F1s else None)
                except Exception:
                    p_list.append(None); r_list.append(None); f1_list.append(None)
            add[f"bertscore_{short}_p"] = p_list
            add[f"bertscore_{short}_r"] = r_list
            add[f"bertscore_{short}_f1"] = f1_list

    return pd.DataFrame(add, index=df.index)