File size: 8,159 Bytes
8a63f11 cd22687 574ad54 8a63f11 1f56845 8a63f11 ed7fe38 8a63f11 ed7fe38 8a63f11 ed7fe38 8a63f11 ed7fe38 8a63f11 ed7fe38 8a63f11 ed7fe38 8a63f11 ed7fe38 8a63f11 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | from __future__ import annotations
import random
from functools import lru_cache
from typing import Any
import pandas as pd
from fleurs_cache import load_fleurs_table
from language import ALL_LANGS, LANG_ISO2_TO_ISO3
from sentence_sampling import sample_multi_group_bundle
from sib200_cache import load_sib200_table
from tatoeba import load_tatoeba_table
ADVERSE_FAMILY_LANGS: dict[str, tuple[str, ...]] = {
"NordicLatin": ("sv", "no", "da", "is", "fi"),
"BalticLatin": ("lv", "lt", "et"),
"IberianRomance": ("es", "pt", "ca", "gl", "oc"),
"ItaloRomance": ("it", "ro", "mt", "rm"),
"FrenchLatin": ("fr", "oc"),
"GermanicLatin": ("de", "nl", "af", "lb"),
"PeripheralLatin": ("eu", "ku", "uz", "so", "su", "jv"),
# Real confusion clusters from validation.
"SouthSlavicLatin": ("bs", "hr", "sr"),
"Malayic": ("id", "ms"),
"Nguni": ("xh", "zu"),
"CyrillicSlavic": ("ru", "uk", "be", "bg", "mk"),
"CyrillicCentralAsian": ("kk", "mn", "tt", "ky", "tg", "ba", "ce"),
"Arabic": ("ar", "fa", "ps", "sd", "ug", "ur", "ckb"),
"IndicDevanagari": ("hi", "mr", "ne"),
"IndicBengali": ("bn", "as"),
"IndicSouth": ("ta", "te", "gu", "kn", "ml", "pa", "or"),
"CJK": ("zh", "ja"),
"Hebrew": ("he", "yi"),
"English": ("en", "sco"),
}
def _normalize_text_key(text: str) -> str:
return " ".join(str(text or "").split()).casefold().strip()
def _column_or_default(frame: pd.DataFrame, column: str, default: Any) -> pd.Series:
if column in frame.columns:
return frame[column]
return pd.Series([default] * len(frame), index=frame.index)
def _standardize_frame(frame: pd.DataFrame, *, source: str) -> pd.DataFrame:
if frame.empty:
return pd.DataFrame()
if source == "fleurs":
standardized = frame.copy()
standardized["source"] = source
standardized["lang_iso2"] = _column_or_default(standardized, "model_lang", "").astype(str).str.strip()
standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip()
standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int)
standardized["topic"] = ""
standardized["label"] = -1
return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()
if source == "tatoeba":
standardized = frame.copy()
standardized["source"] = source
standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
standardized["lang_iso2"] = standardized["source_lang"]
standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
standardized["split"] = ""
standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int)
standardized["topic"] = ""
standardized["label"] = -1
return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()
if source == "sib200":
standardized = frame.copy()
standardized["source"] = source
standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
standardized["lang_iso2"] = _column_or_default(standardized, "lang_iso2", "").astype(str).str.strip()
standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip()
standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "index_id", -1), errors="coerce").fillna(-1).astype(int)
standardized["topic"] = _column_or_default(standardized, "topic", "").astype(str).str.strip()
standardized["label"] = pd.to_numeric(_column_or_default(standardized, "label", -1), errors="coerce").fillna(-1).astype(int)
return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()
raise RuntimeError(f"Unsupported source: {source}")
@lru_cache(maxsize=1)
def _load_adverse_pool() -> pd.DataFrame:
frames: list[pd.DataFrame] = []
source_loaders = (
("fleurs", load_fleurs_table),
("tatoeba", load_tatoeba_table),
("sib200", load_sib200_table),
)
for source, loader in source_loaders:
try:
source_frame = loader()
except FileNotFoundError:
continue
standardized = _standardize_frame(source_frame, source=source)
if not standardized.empty:
frames.append(standardized)
if not frames:
raise RuntimeError("No cached sources were available for adverse mixes.")
combined = pd.concat(frames, ignore_index=True)
combined = combined[combined["text"].astype(str).str.strip().ne("")]
combined = combined[combined["lang_iso2"].isin(ALL_LANGS)]
combined["text_key"] = combined["text"].astype(str).map(_normalize_text_key)
combined = combined[combined["text_key"].ne("")].drop_duplicates(subset=["lang_iso2", "text_key"], keep="first")
return combined.reset_index(drop=True)
def _family_candidates(frame: pd.DataFrame) -> dict[str, tuple[str, ...]]:
available_langs = set(frame["lang_iso2"].dropna().astype(str).tolist())
families: dict[str, tuple[str, ...]] = {}
for family, langs in ADVERSE_FAMILY_LANGS.items():
candidates = tuple(lang for lang in langs if lang in available_langs)
if len(candidates) >= 2:
families[family] = candidates
return families
def _bundle_row(row: pd.Series) -> dict[str, Any]:
return {
"text": str(row.get("text", "")).strip(),
"raw_text": str(row.get("text", "")).strip(),
"source": str(row.get("source", "adverse-mix")).strip(),
"source_lang": str(row.get("source_lang", "")).strip(),
"lang_iso2": str(row.get("lang_iso2", "")).strip(),
"lang_iso3": str(row.get("lang_iso3", "")).strip(),
"language": str(row.get("source_lang", row.get("lang_iso2", ""))).strip(),
"split": str(row.get("split", "")).strip(),
"example_id": int(row.get("example_id", -1)) if str(row.get("example_id", "-1")).strip().lstrip("-").isdigit() else -1,
"topic": str(row.get("topic", "")).strip(),
"label": int(row.get("label", -1)) if str(row.get("label", "-1")).strip().lstrip("-").isdigit() else -1,
}
def fetch_random_adverse_mix() -> dict[str, Any]:
frame = _load_adverse_pool()
families = _family_candidates(frame)
if not families:
raise RuntimeError("No adverse language families had at least two available languages.")
family, langs = random.choice(list(families.items()))
bundle = sample_multi_group_bundle(
frame,
group_column="lang_iso2",
row_to_sentence=_bundle_row,
min_groups=2,
max_groups=min(3, len(langs)),
min_sentences_per_group=1,
max_sentences_per_group=1,
allowed_groups=set(langs),
)
return {
**bundle,
"source": "adverse-mix",
"family": family,
"family_langs": list(langs),
}
|