language-extractor-demo / adverse_mix.py
DerivedFunction1's picture
add
1f56845
from __future__ import annotations
import random
from functools import lru_cache
from typing import Any
import pandas as pd
from fleurs_cache import load_fleurs_table
from language import ALL_LANGS, LANG_ISO2_TO_ISO3
from sentence_sampling import sample_multi_group_bundle
from sib200_cache import load_sib200_table
from tatoeba import load_tatoeba_table
ADVERSE_FAMILY_LANGS: dict[str, tuple[str, ...]] = {
"NordicLatin": ("sv", "no", "da", "is", "fi"),
"BalticLatin": ("lv", "lt", "et"),
"IberianRomance": ("es", "pt", "ca", "gl", "oc"),
"ItaloRomance": ("it", "ro", "mt", "rm"),
"FrenchLatin": ("fr", "oc"),
"GermanicLatin": ("de", "nl", "af", "lb"),
"PeripheralLatin": ("eu", "ku", "uz", "so", "su", "jv"),
# Real confusion clusters from validation.
"SouthSlavicLatin": ("bs", "hr", "sr"),
"Malayic": ("id", "ms"),
"Nguni": ("xh", "zu"),
"CyrillicSlavic": ("ru", "uk", "be", "bg", "mk"),
"CyrillicCentralAsian": ("kk", "mn", "tt", "ky", "tg", "ba", "ce"),
"Arabic": ("ar", "fa", "ps", "sd", "ug", "ur", "ckb"),
"IndicDevanagari": ("hi", "mr", "ne"),
"IndicBengali": ("bn", "as"),
"IndicSouth": ("ta", "te", "gu", "kn", "ml", "pa", "or"),
"CJK": ("zh", "ja"),
"Hebrew": ("he", "yi"),
"English": ("en", "sco"),
}
def _normalize_text_key(text: str) -> str:
return " ".join(str(text or "").split()).casefold().strip()
def _column_or_default(frame: pd.DataFrame, column: str, default: Any) -> pd.Series:
if column in frame.columns:
return frame[column]
return pd.Series([default] * len(frame), index=frame.index)
def _standardize_frame(frame: pd.DataFrame, *, source: str) -> pd.DataFrame:
if frame.empty:
return pd.DataFrame()
if source == "fleurs":
standardized = frame.copy()
standardized["source"] = source
standardized["lang_iso2"] = _column_or_default(standardized, "model_lang", "").astype(str).str.strip()
standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip()
standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int)
standardized["topic"] = ""
standardized["label"] = -1
return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()
if source == "tatoeba":
standardized = frame.copy()
standardized["source"] = source
standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
standardized["lang_iso2"] = standardized["source_lang"]
standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
standardized["split"] = ""
standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int)
standardized["topic"] = ""
standardized["label"] = -1
return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()
if source == "sib200":
standardized = frame.copy()
standardized["source"] = source
standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
standardized["lang_iso2"] = _column_or_default(standardized, "lang_iso2", "").astype(str).str.strip()
standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip()
standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "index_id", -1), errors="coerce").fillna(-1).astype(int)
standardized["topic"] = _column_or_default(standardized, "topic", "").astype(str).str.strip()
standardized["label"] = pd.to_numeric(_column_or_default(standardized, "label", -1), errors="coerce").fillna(-1).astype(int)
return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()
raise RuntimeError(f"Unsupported source: {source}")
@lru_cache(maxsize=1)
def _load_adverse_pool() -> pd.DataFrame:
frames: list[pd.DataFrame] = []
source_loaders = (
("fleurs", load_fleurs_table),
("tatoeba", load_tatoeba_table),
("sib200", load_sib200_table),
)
for source, loader in source_loaders:
try:
source_frame = loader()
except FileNotFoundError:
continue
standardized = _standardize_frame(source_frame, source=source)
if not standardized.empty:
frames.append(standardized)
if not frames:
raise RuntimeError("No cached sources were available for adverse mixes.")
combined = pd.concat(frames, ignore_index=True)
combined = combined[combined["text"].astype(str).str.strip().ne("")]
combined = combined[combined["lang_iso2"].isin(ALL_LANGS)]
combined["text_key"] = combined["text"].astype(str).map(_normalize_text_key)
combined = combined[combined["text_key"].ne("")].drop_duplicates(subset=["lang_iso2", "text_key"], keep="first")
return combined.reset_index(drop=True)
def _family_candidates(frame: pd.DataFrame) -> dict[str, tuple[str, ...]]:
available_langs = set(frame["lang_iso2"].dropna().astype(str).tolist())
families: dict[str, tuple[str, ...]] = {}
for family, langs in ADVERSE_FAMILY_LANGS.items():
candidates = tuple(lang for lang in langs if lang in available_langs)
if len(candidates) >= 2:
families[family] = candidates
return families
def _bundle_row(row: pd.Series) -> dict[str, Any]:
return {
"text": str(row.get("text", "")).strip(),
"raw_text": str(row.get("text", "")).strip(),
"source": str(row.get("source", "adverse-mix")).strip(),
"source_lang": str(row.get("source_lang", "")).strip(),
"lang_iso2": str(row.get("lang_iso2", "")).strip(),
"lang_iso3": str(row.get("lang_iso3", "")).strip(),
"language": str(row.get("source_lang", row.get("lang_iso2", ""))).strip(),
"split": str(row.get("split", "")).strip(),
"example_id": int(row.get("example_id", -1)) if str(row.get("example_id", "-1")).strip().lstrip("-").isdigit() else -1,
"topic": str(row.get("topic", "")).strip(),
"label": int(row.get("label", -1)) if str(row.get("label", "-1")).strip().lstrip("-").isdigit() else -1,
}
def fetch_random_adverse_mix() -> dict[str, Any]:
frame = _load_adverse_pool()
families = _family_candidates(frame)
if not families:
raise RuntimeError("No adverse language families had at least two available languages.")
family, langs = random.choice(list(families.items()))
bundle = sample_multi_group_bundle(
frame,
group_column="lang_iso2",
row_to_sentence=_bundle_row,
min_groups=2,
max_groups=min(3, len(langs)),
min_sentences_per_group=1,
max_sentences_per_group=1,
allowed_groups=set(langs),
)
return {
**bundle,
"source": "adverse-mix",
"family": family,
"family_langs": list(langs),
}