DerivedFunction1 commited on
Commit
ed7fe38
·
1 Parent(s): 8a63f11
Files changed (1) hide show
  1. adverse_mix.py +45 -59
adverse_mix.py CHANGED
@@ -35,69 +35,54 @@ def _normalize_text_key(text: str) -> str:
35
  return " ".join(str(text or "").split()).casefold().strip()
36
 
37
 
38
- def _source_frame_to_records(frame: pd.DataFrame, *, source: str) -> list[dict[str, Any]]:
 
 
 
 
 
 
39
  if frame.empty:
40
- return []
41
 
42
- records: list[dict[str, Any]] = []
43
  if source == "fleurs":
44
- for _, row in frame.iterrows():
45
- lang_iso2 = str(row.get("model_lang", "")).strip()
46
- text = str(row.get("text", "")).strip()
47
- if not lang_iso2 or not text:
48
- continue
49
- records.append(
50
- {
51
- "source": source,
52
- "source_lang": str(row.get("source_lang", "")).strip(),
53
- "lang_iso2": lang_iso2,
54
- "lang_iso3": str(row.get("lang_iso3", "")).strip() or LANG_ISO2_TO_ISO3.get(lang_iso2, ""),
55
- "text": text,
56
- "split": str(row.get("split", "")).strip(),
57
- "example_id": int(row.get("id", -1)) if str(row.get("id", "-1")).strip().lstrip("-").isdigit() else -1,
58
- }
59
- )
60
- return records
61
 
62
  if source == "tatoeba":
63
- for _, row in frame.iterrows():
64
- lang_iso2 = str(row.get("source_lang", "")).strip()
65
- text = str(row.get("text", "")).strip()
66
- if not lang_iso2 or not text:
67
- continue
68
- records.append(
69
- {
70
- "source": source,
71
- "source_lang": lang_iso2,
72
- "lang_iso2": lang_iso2,
73
- "lang_iso3": str(row.get("lang_iso3", "")).strip() or LANG_ISO2_TO_ISO3.get(lang_iso2, ""),
74
- "text": text,
75
- "split": "",
76
- "example_id": int(row.get("id", -1)) if str(row.get("id", "-1")).strip().lstrip("-").isdigit() else -1,
77
- }
78
- )
79
- return records
80
 
81
  if source == "sib200":
82
- for _, row in frame.iterrows():
83
- lang_iso2 = str(row.get("lang_iso2", "")).strip()
84
- text = str(row.get("text", "")).strip()
85
- if not lang_iso2 or not text:
86
- continue
87
- records.append(
88
- {
89
- "source": source,
90
- "source_lang": str(row.get("source_lang", "")).strip(),
91
- "lang_iso2": lang_iso2,
92
- "lang_iso3": str(row.get("lang_iso3", "")).strip() or LANG_ISO2_TO_ISO3.get(lang_iso2, ""),
93
- "text": text,
94
- "split": str(row.get("split", "")).strip(),
95
- "example_id": int(row.get("index_id", -1)) if str(row.get("index_id", "-1")).strip().lstrip("-").isdigit() else -1,
96
- "topic": str(row.get("topic", "")).strip(),
97
- "label": int(row.get("label", -1)) if str(row.get("label", "-1")).strip().lstrip("-").isdigit() else -1,
98
- }
99
- )
100
- return records
101
 
102
  raise RuntimeError(f"Unsupported source: {source}")
103
 
@@ -115,14 +100,15 @@ def _load_adverse_pool() -> pd.DataFrame:
115
  source_frame = loader()
116
  except FileNotFoundError:
117
  continue
118
- records = _source_frame_to_records(source_frame, source=source)
119
- if records:
120
- frames.append(pd.DataFrame.from_records(records))
121
 
122
  if not frames:
123
  raise RuntimeError("No cached sources were available for adverse mixes.")
124
 
125
  combined = pd.concat(frames, ignore_index=True)
 
126
  combined = combined[combined["lang_iso2"].isin(ALL_LANGS)]
127
  combined["text_key"] = combined["text"].astype(str).map(_normalize_text_key)
128
  combined = combined[combined["text_key"].ne("")].drop_duplicates(subset=["lang_iso2", "text_key"], keep="first")
 
35
  return " ".join(str(text or "").split()).casefold().strip()
36
 
37
 
38
+ def _column_or_default(frame: pd.DataFrame, column: str, default: Any) -> pd.Series:
39
+ if column in frame.columns:
40
+ return frame[column]
41
+ return pd.Series([default] * len(frame), index=frame.index)
42
+
43
+
44
+ def _standardize_frame(frame: pd.DataFrame, *, source: str) -> pd.DataFrame:
45
  if frame.empty:
46
+ return pd.DataFrame()
47
 
 
48
  if source == "fleurs":
49
+ standardized = frame.copy()
50
+ standardized["source"] = source
51
+ standardized["lang_iso2"] = _column_or_default(standardized, "model_lang", "").astype(str).str.strip()
52
+ standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
53
+ standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
54
+ standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
55
+ standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip()
56
+ standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int)
57
+ standardized["topic"] = ""
58
+ standardized["label"] = -1
59
+ return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()
 
 
 
 
 
 
60
 
61
  if source == "tatoeba":
62
+ standardized = frame.copy()
63
+ standardized["source"] = source
64
+ standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
65
+ standardized["lang_iso2"] = standardized["source_lang"]
66
+ standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
67
+ standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
68
+ standardized["split"] = ""
69
+ standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "id", -1), errors="coerce").fillna(-1).astype(int)
70
+ standardized["topic"] = ""
71
+ standardized["label"] = -1
72
+ return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()
 
 
 
 
 
 
73
 
74
  if source == "sib200":
75
+ standardized = frame.copy()
76
+ standardized["source"] = source
77
+ standardized["source_lang"] = _column_or_default(standardized, "source_lang", "").astype(str).str.strip()
78
+ standardized["lang_iso2"] = _column_or_default(standardized, "lang_iso2", "").astype(str).str.strip()
79
+ standardized["lang_iso3"] = _column_or_default(standardized, "lang_iso3", "").astype(str).str.strip()
80
+ standardized["lang_iso3"] = standardized["lang_iso3"].where(standardized["lang_iso3"].ne(""), standardized["lang_iso2"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, "")))
81
+ standardized["split"] = _column_or_default(standardized, "split", "").astype(str).str.strip()
82
+ standardized["example_id"] = pd.to_numeric(_column_or_default(standardized, "index_id", -1), errors="coerce").fillna(-1).astype(int)
83
+ standardized["topic"] = _column_or_default(standardized, "topic", "").astype(str).str.strip()
84
+ standardized["label"] = pd.to_numeric(_column_or_default(standardized, "label", -1), errors="coerce").fillna(-1).astype(int)
85
+ return standardized.loc[:, ["source", "source_lang", "lang_iso2", "lang_iso3", "text", "split", "example_id", "topic", "label"]].copy()
 
 
 
 
 
 
 
 
86
 
87
  raise RuntimeError(f"Unsupported source: {source}")
88
 
 
100
  source_frame = loader()
101
  except FileNotFoundError:
102
  continue
103
+ standardized = _standardize_frame(source_frame, source=source)
104
+ if not standardized.empty:
105
+ frames.append(standardized)
106
 
107
  if not frames:
108
  raise RuntimeError("No cached sources were available for adverse mixes.")
109
 
110
  combined = pd.concat(frames, ignore_index=True)
111
+ combined = combined[combined["text"].astype(str).str.strip().ne("")]
112
  combined = combined[combined["lang_iso2"].isin(ALL_LANGS)]
113
  combined["text_key"] = combined["text"].astype(str).map(_normalize_text_key)
114
  combined = combined[combined["text_key"].ne("")].drop_duplicates(subset=["lang_iso2", "text_key"], keep="first")