from __future__ import annotations import random from functools import lru_cache from typing import Any import pandas as pd from fleurs_cache import load_fleurs_table from language import ALL_LANGS, LANG_ISO2_TO_ISO3 from sentence_sampling import sample_multi_group_bundle from sib200_cache import load_sib200_table from tatoeba import load_tatoeba_table ADVERSE_FAMILY_LANGS: dict[str, tuple[str, ...]] = { "NordicLatin": ("sv", "no", "da", "is", "fi"), "BalticLatin": ("lv", "lt", "et"), "IberianRomance": ("es", "pt", "ca", "gl", "oc"), "ItaloRomance": ("it", "ro", "mt", "rm"), "FrenchLatin": ("fr", "oc"), "GermanicLatin": ("de", "nl", "af", "lb"), "PeripheralLatin": ("eu", "ku", "uz", "so", "su", "jv"), # Real confusion clusters from validation. "SouthSlavicLatin": ("bs", "hr", "sr"), "Malayic": ("id", "ms"), "Nguni": ("xh", "zu"), "CyrillicSlavic": ("ru", "uk", "be", "bg", "mk"), "CyrillicCentralAsian": ("kk", "mn", "tt", "ky", "tg", "ba", "ce"), "Arabic": ("ar", "fa", "ps", "sd", "ug", "ur", "ckb"), "IndicDevanagari": ("hi", "mr", "ne"), "IndicBengali": ("bn", "as"), "IndicSouth": ("ta", "te", "gu", "kn", "ml", "pa", "or"), "CJK": ("zh", "ja"), "Hebrew": ("he", "yi"), "English": ("en", "sco"), } def _normalize_text_key(text: str) -> str: return " ".join(str(text or "").split()).casefold().strip() def _column_or_default(frame: pd.DataFrame, column: str, default: Any) -> pd.Series: if column in frame.columns: return frame[column] return pd.Series([default] * len(frame), index=frame.index) def _standardize_frame(frame: pd.DataFrame, *, source: str) -> pd.DataFrame: if frame.empty: return pd.DataFrame() if source == "fleurs": standardized = frame.copy() standardized["source"] = source standardized["lang_iso2"] = _column_or_default(standardized, "model_lang", "").astype(str).str.strip() standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip() standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip() standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, ""))) standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip() standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int) standardized["topic"] = "" standardized["label"] = -1 return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy() if source == "tatoeba": standardized = frame.copy() standardized["source"] = source standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip() standardized["lang_iso2"] = standardized["source_lang"] standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip() standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, ""))) standardized["split"] = "" standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int) standardized["topic"] = "" standardized["label"] = -1 return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy() if source == "sib200": standardized = frame.copy() standardized["source"] = source standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip() standardized["lang_iso2"] = _column_or_default(standardized, "lang_iso2", "").astype(str).str.strip() standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip() standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, ""))) standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip() standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "index_id", -1), errors="coerce").fillna(-1).astype(int) standardized["topic"] = _column_or_default(standardized, "topic", "").astype(str).str.strip() standardized["label"] = pd.to_numeric(_column_or_default(standardized, "label", -1), errors="coerce").fillna(-1).astype(int) return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy() raise RuntimeError(f"Unsupported source: {source}") @lru_cache(maxsize=1) def _load_adverse_pool() -> pd.DataFrame: frames: list[pd.DataFrame] = [] source_loaders = ( ("fleurs", load_fleurs_table), ("tatoeba", load_tatoeba_table), ("sib200", load_sib200_table), ) for source, loader in source_loaders: try: source_frame = loader() except FileNotFoundError: continue standardized = _standardize_frame(source_frame, source=source) if not standardized.empty: frames.append(standardized) if not frames: raise RuntimeError("No cached sources were available for adverse mixes.") combined = pd.concat(frames, ignore_index=True) combined = combined[combined["text"].astype(str).str.strip().ne("")] combined = combined[combined["lang_iso2"].isin(ALL_LANGS)] combined["text_key"] = combined["text"].astype(str).map(_normalize_text_key) combined = combined[combined["text_key"].ne("")].drop_duplicates(subset=["lang_iso2", "text_key"], keep="first") return combined.reset_index(drop=True) def _family_candidates(frame: pd.DataFrame) -> dict[str, tuple[str, ...]]: available_langs = set(frame["lang_iso2"].dropna().astype(str).tolist()) families: dict[str, tuple[str, ...]] = {} for family, langs in ADVERSE_FAMILY_LANGS.items(): candidates = tuple(lang for lang in langs if lang in available_langs) if len(candidates) >= 2: families[family] = candidates return families def _bundle_row(row: pd.Series) -> dict[str, Any]: return { "text": str(row.get("text", "")).strip(), "raw_text": str(row.get("text", "")).strip(), "source": str(row.get("source", "adverse-mix")).strip(), "source_lang": str(row.get("source_lang", "")).strip(), "lang_iso2": str(row.get("lang_iso2", "")).strip(), "lang_iso3": str(row.get("lang_iso3", "")).strip(), "language": str(row.get("source_lang", row.get("lang_iso2", ""))).strip(), "split": str(row.get("split", "")).strip(), "example_id": int(row.get("example_id", -1)) if str(row.get("example_id", "-1")).strip().lstrip("-").isdigit() else -1, "topic": str(row.get("topic", "")).strip(), "label": int(row.get("label", -1)) if str(row.get("label", "-1")).strip().lstrip("-").isdigit() else -1, } def fetch_random_adverse_mix() -> dict[str, Any]: frame = _load_adverse_pool() families = _family_candidates(frame) if not families: raise RuntimeError("No adverse language families had at least two available languages.") family, langs = random.choice(list(families.items())) bundle = sample_multi_group_bundle( frame, group_column="lang_iso2", row_to_sentence=_bundle_row, min_groups=2, max_groups=min(3, len(langs)), min_sentences_per_group=1, max_sentences_per_group=1, allowed_groups=set(langs), ) return { **bundle, "source": "adverse-mix", "family": family, "family_langs": list(langs), }