| from __future__ import annotations |
|
|
| import random |
| from functools import lru_cache |
| from typing import Any |
|
|
| import pandas as pd |
|
|
| from fleurs_cache import load_fleurs_table |
| from language import ALL_LANGS, LANG_ISO2_TO_ISO3 |
| from sentence_sampling import sample_multi_group_bundle |
| from sib200_cache import load_sib200_table |
| from tatoeba import load_tatoeba_table |
|
|
|
|
| ADVERSE_FAMILY_LANGS: dict[str, tuple[str, ...]] = { |
| "NordicLatin": ("sv", "no", "da", "is", "fi"), |
| "BalticLatin": ("lv", "lt", "et"), |
| "IberianRomance": ("es", "pt", "ca", "gl", "oc"), |
| "ItaloRomance": ("it", "ro", "mt", "rm"), |
| "FrenchLatin": ("fr", "oc"), |
| "GermanicLatin": ("de", "nl", "af", "lb"), |
| "PeripheralLatin": ("eu", "ku", "uz", "so", "su", "jv"), |
| |
| "SouthSlavicLatin": ("bs", "hr", "sr"), |
| "Malayic": ("id", "ms"), |
| "Nguni": ("xh", "zu"), |
| "CyrillicSlavic": ("ru", "uk", "be", "bg", "mk"), |
| "CyrillicCentralAsian": ("kk", "mn", "tt", "ky", "tg", "ba", "ce"), |
| "Arabic": ("ar", "fa", "ps", "sd", "ug", "ur", "ckb"), |
| "IndicDevanagari": ("hi", "mr", "ne"), |
| "IndicBengali": ("bn", "as"), |
| "IndicSouth": ("ta", "te", "gu", "kn", "ml", "pa", "or"), |
| "CJK": ("zh", "ja"), |
| "Hebrew": ("he", "yi"), |
| "English": ("en", "sco"), |
| } |
|
|
|
|
| def _normalize_text_key(text: str) -> str: |
| return " ".join(str(text or "").split()).casefold().strip() |
|
|
|
|
| def _column_or_default(frame: pd.DataFrame, column: str, default: Any) -> pd.Series: |
| if column in frame.columns: |
| return frame[column] |
| return pd.Series([default] * len(frame), index=frame.index) |
|
|
|
|
| def _standardize_frame(frame: pd.DataFrame, *, source: str) -> pd.DataFrame: |
| if frame.empty: |
| return pd.DataFrame() |
|
|
| if source == "fleurs": |
| standardized = frame.copy() |
| standardized["source"] = source |
| standardized["lang_iso2"] = _column_or_default(standardized, "model_lang", "").astype(str).str.strip() |
| standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip() |
| standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip() |
| standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, ""))) |
| standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip() |
| standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int) |
| standardized["topic"] = "" |
| standardized["label"] = -1 |
| return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy() |
|
|
| if source == "tatoeba": |
| standardized = frame.copy() |
| standardized["source"] = source |
| standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip() |
| standardized["lang_iso2"] = standardized["source_lang"] |
| standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip() |
| standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, ""))) |
| standardized["split"] = "" |
| standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int) |
| standardized["topic"] = "" |
| standardized["label"] = -1 |
| return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy() |
|
|
| if source == "sib200": |
| standardized = frame.copy() |
| standardized["source"] = source |
| standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip() |
| standardized["lang_iso2"] = _column_or_default(standardized, "lang_iso2", "").astype(str).str.strip() |
| standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip() |
| standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, ""))) |
| standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip() |
| standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "index_id", -1), errors="coerce").fillna(-1).astype(int) |
| standardized["topic"] = _column_or_default(standardized, "topic", "").astype(str).str.strip() |
| standardized["label"] = pd.to_numeric(_column_or_default(standardized, "label", -1), errors="coerce").fillna(-1).astype(int) |
| return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy() |
|
|
| raise RuntimeError(f"Unsupported source: {source}") |
|
|
|
|
| @lru_cache(maxsize=1) |
| def _load_adverse_pool() -> pd.DataFrame: |
| frames: list[pd.DataFrame] = [] |
| source_loaders = ( |
| ("fleurs", load_fleurs_table), |
| ("tatoeba", load_tatoeba_table), |
| ("sib200", load_sib200_table), |
| ) |
| for source, loader in source_loaders: |
| try: |
| source_frame = loader() |
| except FileNotFoundError: |
| continue |
| standardized = _standardize_frame(source_frame, source=source) |
| if not standardized.empty: |
| frames.append(standardized) |
|
|
| if not frames: |
| raise RuntimeError("No cached sources were available for adverse mixes.") |
|
|
| combined = pd.concat(frames, ignore_index=True) |
| combined = combined[combined["text"].astype(str).str.strip().ne("")] |
| combined = combined[combined["lang_iso2"].isin(ALL_LANGS)] |
| combined["text_key"] = combined["text"].astype(str).map(_normalize_text_key) |
| combined = combined[combined["text_key"].ne("")].drop_duplicates(subset=["lang_iso2", "text_key"], keep="first") |
| return combined.reset_index(drop=True) |
|
|
|
|
| def _family_candidates(frame: pd.DataFrame) -> dict[str, tuple[str, ...]]: |
| available_langs = set(frame["lang_iso2"].dropna().astype(str).tolist()) |
| families: dict[str, tuple[str, ...]] = {} |
| for family, langs in ADVERSE_FAMILY_LANGS.items(): |
| candidates = tuple(lang for lang in langs if lang in available_langs) |
| if len(candidates) >= 2: |
| families[family] = candidates |
| return families |
|
|
|
|
| def _bundle_row(row: pd.Series) -> dict[str, Any]: |
| return { |
| "text": str(row.get("text", "")).strip(), |
| "raw_text": str(row.get("text", "")).strip(), |
| "source": str(row.get("source", "adverse-mix")).strip(), |
| "source_lang": str(row.get("source_lang", "")).strip(), |
| "lang_iso2": str(row.get("lang_iso2", "")).strip(), |
| "lang_iso3": str(row.get("lang_iso3", "")).strip(), |
| "language": str(row.get("source_lang", row.get("lang_iso2", ""))).strip(), |
| "split": str(row.get("split", "")).strip(), |
| "example_id": int(row.get("example_id", -1)) if str(row.get("example_id", "-1")).strip().lstrip("-").isdigit() else -1, |
| "topic": str(row.get("topic", "")).strip(), |
| "label": int(row.get("label", -1)) if str(row.get("label", "-1")).strip().lstrip("-").isdigit() else -1, |
| } |
|
|
|
|
| def fetch_random_adverse_mix() -> dict[str, Any]: |
| frame = _load_adverse_pool() |
| families = _family_candidates(frame) |
| if not families: |
| raise RuntimeError("No adverse language families had at least two available languages.") |
|
|
| family, langs = random.choice(list(families.items())) |
| bundle = sample_multi_group_bundle( |
| frame, |
| group_column="lang_iso2", |
| row_to_sentence=_bundle_row, |
| min_groups=2, |
| max_groups=min(3, len(langs)), |
| min_sentences_per_group=1, |
| max_sentences_per_group=1, |
| allowed_groups=set(langs), |
| ) |
| return { |
| **bundle, |
| "source": "adverse-mix", |
| "family": family, |
| "family_langs": list(langs), |
| } |
|
|