File size: 8,159 Bytes
8a63f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd22687
 
 
 
574ad54
8a63f11
 
 
 
 
 
1f56845
 
8a63f11
 
 
 
 
 
 
ed7fe38
 
 
 
 
 
 
8a63f11
ed7fe38
8a63f11
 
ed7fe38
 
 
 
 
 
 
 
 
 
 
8a63f11
 
ed7fe38
 
 
 
 
 
 
 
 
 
 
8a63f11
 
ed7fe38
 
 
 
 
 
 
 
 
 
 
8a63f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed7fe38
 
 
8a63f11
 
 
 
 
ed7fe38
8a63f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from __future__ import annotations

import random
from functools import lru_cache
from typing import Any

import pandas as pd

from fleurs_cache import load_fleurs_table
from language import ALL_LANGS, LANG_ISO2_TO_ISO3
from sentence_sampling import sample_multi_group_bundle
from sib200_cache import load_sib200_table
from tatoeba import load_tatoeba_table


ADVERSE_FAMILY_LANGS: dict[str, tuple[str, ...]] = {
    "NordicLatin": ("sv", "no", "da", "is", "fi"),
    "BalticLatin": ("lv", "lt", "et"),
    "IberianRomance": ("es", "pt", "ca", "gl", "oc"),
    "ItaloRomance": ("it", "ro", "mt", "rm"),
    "FrenchLatin": ("fr", "oc"),
    "GermanicLatin": ("de", "nl", "af", "lb"),
    "PeripheralLatin": ("eu", "ku", "uz", "so", "su", "jv"),
    # Real confusion clusters from validation.
    "SouthSlavicLatin": ("bs", "hr", "sr"),
    "Malayic": ("id", "ms"),
    "Nguni": ("xh", "zu"),
    "CyrillicSlavic": ("ru", "uk", "be", "bg", "mk"),
    "CyrillicCentralAsian": ("kk", "mn", "tt", "ky", "tg", "ba", "ce"),
    "Arabic": ("ar", "fa", "ps", "sd", "ug", "ur", "ckb"),
    "IndicDevanagari": ("hi", "mr", "ne"),
    "IndicBengali": ("bn", "as"),
    "IndicSouth": ("ta", "te", "gu", "kn", "ml", "pa", "or"),
    "CJK": ("zh", "ja"),
    "Hebrew": ("he", "yi"),
    "English": ("en", "sco"),
}


def _normalize_text_key(text: str) -> str:
    return " ".join(str(text or "").split()).casefold().strip()


def _column_or_default(frame: pd.DataFrame, column: str, default: Any) -> pd.Series:
    if column in frame.columns:
        return frame[column]
    return pd.Series([default] * len(frame), index=frame.index)


def _standardize_frame(frame: pd.DataFrame, *, source: str) -> pd.DataFrame:
    if frame.empty:
        return pd.DataFrame()

    if source == "fleurs":
        standardized = frame.copy()
        standardized["source"] = source
        standardized["lang_iso2"] = _column_or_default(standardized, "model_lang", "").astype(str).str.strip()
        standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
        standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
        standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
        standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip()
        standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int)
        standardized["topic"] = ""
        standardized["label"] = -1
        return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()

    if source == "tatoeba":
        standardized = frame.copy()
        standardized["source"] = source
        standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
        standardized["lang_iso2"] = standardized["source_lang"]
        standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
        standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
        standardized["split"] = ""
        standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int)
        standardized["topic"] = ""
        standardized["label"] = -1
        return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()

    if source == "sib200":
        standardized = frame.copy()
        standardized["source"] = source
        standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
        standardized["lang_iso2"] = _column_or_default(standardized, "lang_iso2", "").astype(str).str.strip()
        standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
        standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
        standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip()
        standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "index_id", -1), errors="coerce").fillna(-1).astype(int)
        standardized["topic"] = _column_or_default(standardized, "topic", "").astype(str).str.strip()
        standardized["label"] = pd.to_numeric(_column_or_default(standardized, "label", -1), errors="coerce").fillna(-1).astype(int)
        return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()

    raise RuntimeError(f"Unsupported source: {source}")


@lru_cache(maxsize=1)
def _load_adverse_pool() -> pd.DataFrame:
    frames: list[pd.DataFrame] = []
    source_loaders = (
        ("fleurs", load_fleurs_table),
        ("tatoeba", load_tatoeba_table),
        ("sib200", load_sib200_table),
    )
    for source, loader in source_loaders:
        try:
            source_frame = loader()
        except FileNotFoundError:
            continue
        standardized = _standardize_frame(source_frame, source=source)
        if not standardized.empty:
            frames.append(standardized)

    if not frames:
        raise RuntimeError("No cached sources were available for adverse mixes.")

    combined = pd.concat(frames, ignore_index=True)
    combined = combined[combined["text"].astype(str).str.strip().ne("")]
    combined = combined[combined["lang_iso2"].isin(ALL_LANGS)]
    combined["text_key"] = combined["text"].astype(str).map(_normalize_text_key)
    combined = combined[combined["text_key"].ne("")].drop_duplicates(subset=["lang_iso2", "text_key"], keep="first")
    return combined.reset_index(drop=True)


def _family_candidates(frame: pd.DataFrame) -> dict[str, tuple[str, ...]]:
    available_langs = set(frame["lang_iso2"].dropna().astype(str).tolist())
    families: dict[str, tuple[str, ...]] = {}
    for family, langs in ADVERSE_FAMILY_LANGS.items():
        candidates = tuple(lang for lang in langs if lang in available_langs)
        if len(candidates) >= 2:
            families[family] = candidates
    return families


def _bundle_row(row: pd.Series) -> dict[str, Any]:
    return {
        "text": str(row.get("text", "")).strip(),
        "raw_text": str(row.get("text", "")).strip(),
        "source": str(row.get("source", "adverse-mix")).strip(),
        "source_lang": str(row.get("source_lang", "")).strip(),
        "lang_iso2": str(row.get("lang_iso2", "")).strip(),
        "lang_iso3": str(row.get("lang_iso3", "")).strip(),
        "language": str(row.get("source_lang", row.get("lang_iso2", ""))).strip(),
        "split": str(row.get("split", "")).strip(),
        "example_id": int(row.get("example_id", -1)) if str(row.get("example_id", "-1")).strip().lstrip("-").isdigit() else -1,
        "topic": str(row.get("topic", "")).strip(),
        "label": int(row.get("label", -1)) if str(row.get("label", "-1")).strip().lstrip("-").isdigit() else -1,
    }


def fetch_random_adverse_mix() -> dict[str, Any]:
    frame = _load_adverse_pool()
    families = _family_candidates(frame)
    if not families:
        raise RuntimeError("No adverse language families had at least two available languages.")

    family, langs = random.choice(list(families.items()))
    bundle = sample_multi_group_bundle(
        frame,
        group_column="lang_iso2",
        row_to_sentence=_bundle_row,
        min_groups=2,
        max_groups=min(3, len(langs)),
        min_sentences_per_group=1,
        max_sentences_per_group=1,
        allowed_groups=set(langs),
    )
    return {
        **bundle,
        "source": "adverse-mix",
        "family": family,
        "family_langs": list(langs),
    }