| |
| """Gradio demo for the multilingual token-classification language ID model.""" |
|
|
| from __future__ import annotations |
|
|
| from collections import Counter, defaultdict |
| from functools import lru_cache |
| import json |
| import random |
| import os |
| import re |
| from typing import Any |
|
|
| import pandas as pd |
| import gradio as gr |
| import pycountry |
| import fasttext |
| import numpy as np |
| import torch |
| from huggingface_hub import hf_hub_download |
| from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoTokenizer, pipeline |
|
|
| from adverse_mix import fetch_random_adverse_mix |
| from fleurs_cache import fetch_random_fleurs_sentence, fetch_random_fleurs_sentence_mix |
| from language import ALL_LANGS, LANG_ALIASES, LANG_ISO2_TO_ISO3, canonical_lang |
| from sib200_cache import fetch_random_sib200_sentence, fetch_random_sib200_sentence_mix |
| from tatoeba import fetch_random_tatoeba_sentence, fetch_random_tatoeba_sentence_mix |
|
|
|
|
| MODEL_CHECKPOINT = "polyglot-tagger/language-identification" |
| MULTI_MODEL = "polyglot-tagger/multilabel-language-identification" |
| FASTTEXT_MODEL_REPO = "facebook/fasttext-language-identification" |
| FASTTEXT_MODEL_FILENAME = "model.bin" |
| FASTTEXT_MIN_CONFIDENCE = 0.15 |
| MULTI_LABEL_TOP_N = 6 |
| MULTI_LABEL_MIN_SCORE = 0.15 |
| MIN_ARTIFACT_SPAN_CHARS = 4 |
| MIN_ARTIFACT_CONFIDENCE = 0.5 |
| ARTIFACT_SPAN_WEIGHT = 0.35 |
|
|
| RANDOM_SENTENCE_SAMPLERS = ( |
| fetch_random_fleurs_sentence, |
| fetch_random_tatoeba_sentence, |
| fetch_random_sib200_sentence, |
| ) |
| RANDOM_MIX_SAMPLERS = ( |
| fetch_random_fleurs_sentence_mix, |
| fetch_random_tatoeba_sentence_mix, |
| fetch_random_sib200_sentence_mix, |
| ) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_tokenizer(): |
| return AutoTokenizer.from_pretrained("xlm-roberta-base") |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_pipeline(): |
| model = AutoModelForTokenClassification.from_pretrained(MODEL_CHECKPOINT) |
| model.eval() |
| return pipeline( |
| "token-classification", |
| model=model, |
| tokenizer=get_tokenizer(), |
| aggregation_strategy="simple", |
| ) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_multilabel_model(): |
| model = AutoModelForSequenceClassification.from_pretrained(MULTI_MODEL) |
| model.eval() |
| return model |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_fasttext_model(): |
| """Load the reference fastText language ID model once.""" |
| model_path = hf_hub_download(repo_id=FASTTEXT_MODEL_REPO, filename=FASTTEXT_MODEL_FILENAME) |
| return fasttext.load_model(model_path) |
|
|
|
|
| def normalize_label(label: str) -> str: |
| if label.startswith(("B-", "I-")): |
| label = label[2:] |
| return canonical_lang(label.lower()) |
|
|
|
|
| def build_lang_stats( |
| entities: list[dict[str, Any]], |
| ) -> tuple[dict[str, dict[str, float | int]], int]: |
| """Aggregate merged entity spans into per-language coverage stats.""" |
| char_coverage: defaultdict[str, float] = defaultdict(float) |
| conf_weighted: defaultdict[str, float] = defaultdict(float) |
| entity_counts: defaultdict[str, int] = defaultdict(int) |
|
|
| total_tagged_chars = 0.0 |
| ignored_artifacts = 0 |
|
|
| for entity in entities: |
| label = normalize_label(entity.get("entity_group", entity.get("entity", "O"))) |
| if label == "o": |
| continue |
|
|
| start = entity.get("start") |
| end = entity.get("end") |
| if start is None or end is None: |
| continue |
|
|
| span_len = max(int(end) - int(start), 1) |
| score = float(entity.get("score", 0.0)) |
| span_weight = ARTIFACT_SPAN_WEIGHT if is_artifact_span(span_len, score) else 1.0 |
| if span_weight < 1.0: |
| ignored_artifacts += 1 |
|
|
| effective_span_len = span_len * span_weight |
| char_coverage[label] += effective_span_len |
| conf_weighted[label] += effective_span_len * score |
| entity_counts[label] += 1 |
| total_tagged_chars += effective_span_len |
|
|
| if total_tagged_chars == 0: |
| return {}, ignored_artifacts |
|
|
| stats: dict[str, dict[str, float | int]] = {} |
| for lang, coverage in char_coverage.items(): |
| avg_confidence = conf_weighted[lang] / coverage if coverage else 0.0 |
| coverage_pct = coverage / total_tagged_chars |
| stats[lang] = { |
| "char_coverage": coverage, |
| "coverage_pct": coverage_pct, |
| "avg_confidence": avg_confidence, |
| "entity_count": entity_counts[lang], |
| "rank_score": coverage_pct * avg_confidence, |
| } |
|
|
| return stats, ignored_artifacts |
|
|
|
|
| def to_classifier_scores(lang_stats: dict[str, dict[str, float | int]]) -> dict[str, float]: |
| """Normalize coverage-confidence weights into a classifier-like distribution.""" |
| raw = { |
| lang: float(stat["coverage_pct"]) * float(stat["avg_confidence"]) |
| for lang, stat in lang_stats.items() |
| } |
| total = sum(raw.values()) |
| if total == 0: |
| return {} |
| return dict(sorted(((lang, weight / total) for lang, weight in raw.items()), key=lambda item: item[1], reverse=True)) |
|
|
|
|
| def make_lang_chip_label(lang: str, stat: dict[str, float | int], score: float) -> str: |
| """Render a compact label for a clickable language chip.""" |
| return f"{lang.upper()} {score:.0%}" |
|
|
|
|
| def build_chip_button_updates( |
| ranked: list[tuple[str, dict[str, float | int]]], |
| token_scores: dict[str, float], |
| multi_scores: dict[str, float] | None = None, |
| reference_scores: dict[str, float] | None = None, |
| max_chips: int = 6, |
| ) -> list[dict[str, Any]]: |
| """Return button updates for the top-ranked languages.""" |
| multi_scores = multi_scores or {} |
| reference_scores = reference_scores or {} |
| multi_ranked = sorted(multi_scores.items(), key=lambda item: item[1], reverse=True) |
| reference_ranked = sorted(reference_scores.items(), key=lambda item: item[1], reverse=True) |
| multi_rank = {lang: idx for idx, (lang, _) in enumerate(multi_ranked)} |
| reference_rank = {lang: idx for idx, (lang, _) in enumerate(reference_ranked)} |
| model_avg_confidence = {lang: float(stat["avg_confidence"]) for lang, stat in ranked} |
| model_ranked = [lang for lang, _ in ranked] |
| extra_langs = [lang for lang in list(multi_scores.keys()) + list(reference_scores.keys()) if lang not in set(model_ranked)] |
| ordered_langs = model_ranked + [lang for lang in extra_langs if lang not in model_ranked] |
| updates: list[dict[str, Any]] = [] |
| for idx in range(max_chips): |
| if idx < len(ordered_langs): |
| lang = ordered_langs[idx] |
| token_score = token_scores.get(lang, 0.0) |
| token_label_score = model_avg_confidence.get(lang, token_score) |
| multi_score = multi_scores.get(lang, 0.0) |
| reference_score = reference_scores.get(lang, 0.0) |
| in_token = token_score > 0.0 |
| in_multi = lang in multi_scores |
| in_reference = lang in reference_scores |
| if in_token and in_multi and in_reference: |
| variant = "primary" |
| elif in_multi or in_reference: |
| variant = "secondary" |
| else: |
| variant = "stop" |
| multi_rank_text = f" M#{multi_rank.get(lang) + 1}" if lang in multi_rank else "" |
| reference_rank_text = f" FT#{reference_rank.get(lang) + 1}" if lang in reference_rank else "" |
| updates.append( |
| gr.update( |
| value=f"{lang.upper()} T {token_label_score:.0%} | M {multi_score:.0%}{multi_rank_text} | FT {reference_score:.0%}{reference_rank_text}", |
| visible=True, |
| variant=variant, |
| ) |
| ) |
| else: |
| updates.append(gr.update(value="", visible=False)) |
| return updates |
|
|
|
|
| def build_ui_state( |
| *, |
| text: str, |
| lang_stats: dict[str, dict[str, float | int]], |
| classifier_scores: dict[str, float], |
| fasttext_result: dict[str, Any] | None, |
| model_label: str, |
| dominant_lang: str, |
| overall_confidence: float, |
| ignored_artifacts: int, |
| ) -> dict[str, Any]: |
| """Package the bits the interactive chips need to redraw the card.""" |
| fasttext_scores = {} |
| if fasttext_result: |
| fasttext_scores = {item["lang"]: float(item["score"]) for item in fasttext_result.get("predictions", [])} |
| model_ranked = [lang for lang, _ in sorted(lang_stats.items(), key=lambda item: classifier_scores.get(item[0], 0.0), reverse=True)] |
| fasttext_ranked = sorted(fasttext_scores, key=lambda lang: fasttext_scores.get(lang, 0.0), reverse=True) |
| chip_langs = [] |
| seen = set() |
| for lang in model_ranked: |
| chip_langs.append({"lang": lang, "source": "model"}) |
| seen.add(lang) |
| for lang in fasttext_ranked: |
| if lang not in seen: |
| chip_langs.append({"lang": lang, "source": "fasttext"}) |
| seen.add(lang) |
| return { |
| "text": text, |
| "lang_stats": lang_stats, |
| "classifier_scores": classifier_scores, |
| "fasttext": fasttext_result, |
| "reference": fasttext_result, |
| "model_label": model_label, |
| "dominant_lang": dominant_lang, |
| "selected_lang": dominant_lang, |
| "overall_confidence": overall_confidence, |
| "ignored_artifacts": ignored_artifacts, |
| "ranked_langs": sorted(lang_stats.keys(), key=lambda lang: classifier_scores.get(lang, 0.0), reverse=True), |
| "chip_langs": chip_langs, |
| } |
|
|
|
|
| def _scores_to_validation_metrics( |
| classifier_scores: dict[str, float], |
| expected_set: set[str], |
| ) -> dict[str, Any]: |
| """Compute validation metrics for a single score map.""" |
| ranked_predictions = sorted(classifier_scores.items(), key=lambda item: item[1], reverse=True) |
| top_lang = ranked_predictions[0][0] if ranked_predictions else None |
| top_score = float(ranked_predictions[0][1]) if ranked_predictions else 0.0 |
| predicted_langs = [lang for lang, score in ranked_predictions if score > 0.0] |
| if not predicted_langs and ranked_predictions: |
| predicted_langs = [lang for lang, _ in ranked_predictions[:1]] |
| predicted_set = set(predicted_langs) |
| true_positive = len(expected_set & predicted_set) |
| false_positive = len(predicted_set - expected_set) |
| false_negative = len(expected_set - predicted_set) |
| precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0 |
| recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0 |
| validation_score = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0 |
| return { |
| "predicted_langs": predicted_langs, |
| "top_lang": top_lang, |
| "top_score": top_score, |
| "true_positive": true_positive, |
| "false_positive": false_positive, |
| "false_negative": false_negative, |
| "expected_count": len(expected_set), |
| "predicted_count": len(predicted_set), |
| "precision": precision, |
| "recall": recall, |
| "top_match": false_positive == 0 and false_negative == 0, |
| "validation_score": validation_score, |
| } |
|
|
|
|
| def build_example_validation( |
| classifier_scores: dict[str, float], |
| multi_scores: dict[str, float] | None, |
| reference_scores: dict[str, float] | None, |
| expected_langs: list[str], |
| ) -> dict[str, Any]: |
| """Compare the token, multi-label, and reference scores against known source languages.""" |
| expected_langs = [canonical_lang(lang) for lang in expected_langs if lang] |
| expected_set = set(expected_langs) |
| token_validation = _scores_to_validation_metrics(classifier_scores, expected_set) |
| multi_validation = _scores_to_validation_metrics(multi_scores or {}, expected_set) |
| reference_validation = _scores_to_validation_metrics(reference_scores or {}, expected_set) |
| reference_predicted = reference_validation.get("predicted_langs", []) |
| reference_score = float(reference_validation.get("validation_score", 0.0)) |
|
|
| return { |
| "expected_langs": expected_langs, |
| "token": token_validation, |
| "multi": multi_validation, |
| "reference": reference_validation, |
| "reference_langs": reference_predicted, |
| "reference_score": reference_score, |
| "active_model": "token", |
| } |
|
|
|
|
| def render_validation_html(validation: dict[str, Any], *, source_label: str) -> str: |
| """Render a compact validation card for a labeled example source.""" |
| if not validation: |
| return "" |
|
|
| expected_langs = ", ".join(lang.upper() for lang in validation.get("expected_langs", [])) or "n/a" |
| model_data = { |
| "token": validation.get("token", {}), |
| "multi": validation.get("multi", {}), |
| "reference": validation.get("reference", {}), |
| } |
| active_model = validation.get("active_model", "token") |
| active_data = model_data.get(active_model, model_data["token"]) |
| active_label = { |
| "token": "Token classifier", |
| "multi": "Multi-label", |
| "reference": "fastText", |
| }.get(active_model, "Token classifier") |
| validation_score = float(active_data.get("validation_score", 0.0)) |
| top_match = bool(active_data.get("top_match")) |
| status_label = "Match" if top_match else "Mismatch" |
| status_class = "validation-pass" if top_match else "validation-warn" |
| model_theme_class = { |
| "token": "validation-model-token", |
| "multi": "validation-model-multi", |
| "reference": "validation-model-reference", |
| }.get(active_model, "validation-model-token") |
| if validation_score >= 0.8: |
| tier_class = "validation-high" |
| elif validation_score >= 0.45: |
| tier_class = "validation-mid" |
| else: |
| tier_class = "validation-low" |
|
|
| def _metric_button(model_key: str, label: str) -> str: |
| data = model_data.get(model_key, {}) |
| value = float(data.get("validation_score", 0.0)) |
| active_class = " is-active" if model_key == active_model else "" |
| return f""" |
| <button type="button" class="validation-metric-button{active_class}" data-validation-toggle="{model_key}"> |
| <span class="validation-metric-label">{label}</span> |
| <span class="validation-metric-value">{value:.1%}</span> |
| </button> |
| """ |
|
|
| def _detail_panel(model_key: str, label: str) -> str: |
| data = model_data.get(model_key, {}) |
| predicted_langs = ", ".join(lang.upper() for lang in data.get("predicted_langs", [])) or "n/a" |
| top_lang = data.get("top_lang") or "n/a" |
| top_score = float(data.get("top_score", 0.0)) |
| expected_count = int(data.get("expected_count", 0)) |
| predicted_count = int(data.get("predicted_count", 0)) |
| true_positive = int(data.get("true_positive", 0)) |
| false_positive = int(data.get("false_positive", 0)) |
| false_negative = int(data.get("false_negative", 0)) |
| precision = float(data.get("precision", 0.0)) |
| recall = float(data.get("recall", 0.0)) |
| validation_score_local = float(data.get("validation_score", 0.0)) |
| top_match_local = bool(data.get("top_match")) |
| return f""" |
| <div class="validation-detail {'is-active' if model_key == active_model else ''}" data-validation-panel="{model_key}"> |
| <div class="validation-detail-head"> |
| <div class="validation-detail-title">{label}</div> |
| <div class="validation-detail-score">{validation_score_local:.1%}</div> |
| </div> |
| <div class="validation-detail-row"><span>Predicted</span><strong>{predicted_langs}</strong></div> |
| <div class="validation-detail-row"><span>Top</span><strong>{top_lang.upper()} <span class="validation-chip-muted">{top_score:.1%}</span></strong></div> |
| <div class="validation-detail-row"><span>Counts</span><strong>TP {true_positive} / FP {false_positive} / FN {false_negative}</strong></div> |
| <div class="validation-detail-row"><span>Precision</span><strong>{precision:.1%} <span class="validation-chip-muted">({predicted_count} predicted)</span></strong></div> |
| <div class="validation-detail-row"><span>Recall</span><strong>{recall:.1%} <span class="validation-chip-muted">({expected_count} expected)</span></strong></div> |
| <div class="validation-detail-row"><span>Status</span><strong>{'Match' if top_match_local else 'Mismatch'}</strong></div> |
| </div> |
| """ |
|
|
| return f""" |
| <div class="validation-strip {tier_class} {model_theme_class}" data-validation-card data-active-model="{active_model}"> |
| <div class="validation-kicker">{source_label} validation</div> |
| <div class="validation-main"> |
| {_metric_button("token", "Token")} |
| {_metric_button("multi", "Multi-label")} |
| {_metric_button("reference", "fastText")} |
| </div> |
| <div class="validation-vs">Tap a percent to inspect that model's state.</div> |
| <div class="validation-status {status_class}">{status_label}</div> |
| <div class="validation-meta">Expected: {expected_langs}</div> |
| <div class="validation-panels"> |
| {_detail_panel("token", "Token classifier")} |
| {_detail_panel("multi", "Multi-label")} |
| {_detail_panel("reference", "fastText")} |
| </div> |
| </div> |
| """ |
|
|
|
|
| def build_tatoeba_validation( |
| classifier_scores: dict[str, float], |
| expected_langs: list[str], |
| ) -> dict[str, Any]: |
| """Backward-compatible wrapper for existing Tatoeba callers.""" |
| return build_example_validation(classifier_scores, {}, {}, expected_langs) |
|
|
|
|
| def render_tatoeba_validation_html(validation: dict[str, Any]) -> str: |
| """Backward-compatible wrapper for existing Tatoeba callers.""" |
| return render_validation_html(validation, source_label="Tatoeba") |
|
|
|
|
| def _source_key(source: str) -> str: |
| return (source or "").strip().split("-", 1)[0].lower() |
|
|
|
|
| def _source_label(source: str) -> str: |
| key = _source_key(source) |
| if key == "fleurs": |
| return "FLEURS" |
| if key == "tatoeba": |
| return "Tatoeba" |
| if key == "sib200": |
| return "SIB-200" |
| if key == "adverse": |
| return "Adverse mix" |
| return key.upper() or "Example" |
|
|
|
|
| def _validation_key(source: str) -> str: |
| key = _source_key(source) or "example" |
| return f"{key}_validation" |
|
|
|
|
| def _sentence_id_keys(sentence: dict[str, Any]) -> list[str]: |
| keys = [] |
| for candidate in ("fleurs_id", "sentence_id", "sib200_id", "id"): |
| value = sentence.get(candidate) |
| if value is not None: |
| keys.append(value) |
| return keys |
|
|
|
|
| def _language_name(lang_code: str) -> str: |
| """Best-effort human readable language name for a code.""" |
| code = (lang_code or "").strip() |
| if not code: |
| return "Unknown" |
|
|
| language = pycountry.languages.get(alpha_2=code) |
| if language is None: |
| language = pycountry.languages.get(alpha_3=LANG_ISO2_TO_ISO3.get(code, "")) |
|
|
| if language is None: |
| return code.upper() |
|
|
| name = getattr(language, "name", None) |
| if not name: |
| return code.upper() |
| return name |
|
|
|
|
| def render_language_reference_html() -> str: |
| """Render a clickable footer that expands to code-to-name mappings.""" |
| rows = [] |
| for code in sorted(LANG_ISO2_TO_ISO3.keys()): |
| name = _language_name(code) |
| rows.append(f"<li><span class='lang-code'>{code}</span><span class='lang-name'>{name}</span></li>") |
|
|
| rows_html = "".join(rows) |
| return f""" |
| <details class="footer-note footer-langs"> |
| <summary>Supported model languages: {len(LANG_ISO2_TO_ISO3)}. Click to view code-to-name mapping.</summary> |
| <div class="footer-langs-body"> |
| <ul class="footer-lang-list">{rows_html}</ul> |
| </div> |
| </details> |
| """ |
|
|
|
|
| def _split_sentences_for_fasttext(text: str) -> list[str]: |
| blocks = re.split(r"\n\s*\n+", text) |
| sentences: list[str] = [] |
| for block in blocks: |
| block = block.strip() |
| if not block: |
| continue |
| chunks = re.split(r"(?<=[.!?。!?])\s+|\n+", block) |
| sentences.extend(chunk.strip() for chunk in chunks if chunk and chunk.strip()) |
| return sentences |
|
|
|
|
| def predict_fasttext(text: str, k: int = 5, mode: str = "full") -> dict[str, Any]: |
| """Return fastText language predictions for comparison.""" |
| model = get_fasttext_model() |
| original_array = np.array |
|
|
| def _array_compat(obj, *args, **kwargs): |
| if kwargs.get("copy") is False: |
| kwargs = {**kwargs} |
| kwargs.pop("copy", None) |
| return original_array(obj, *args, **kwargs) |
|
|
| def _predict_one(sample: str) -> tuple[list[str], list[float]]: |
| np.array = _array_compat |
| try: |
| labels, scores = model.predict(sample, k=k) |
| finally: |
| np.array = original_array |
| return list(labels), [float(score) for score in scores] |
|
|
| def _normalize_predictions(labels: list[str], scores: list[float], *, keep_best: bool = False) -> list[dict[str, Any]]: |
| predictions = [ |
| { |
| "raw_label": label.removeprefix("__label__"), |
| "lang": fasttext_label_to_iso2(label.removeprefix("__label__")), |
| "score": float(score), |
| } |
| for label, score in zip(labels, scores) |
| ] |
| if keep_best and predictions: |
| return [predictions[0]] |
| return [item for item in predictions if item["score"] >= FASTTEXT_MIN_CONFIDENCE] |
|
|
| if mode == "sentences": |
| sentence_predictions: list[dict[str, Any]] = [] |
| for sentence in _split_sentences_for_fasttext(text): |
| labels, scores = _predict_one(sentence) |
| if not labels: |
| continue |
| sentence_predictions.append( |
| { |
| "sentence": sentence, |
| "top_raw_label": labels[0].removeprefix("__label__"), |
| "top_family": fasttext_label_to_iso2(labels[0].removeprefix("__label__")), |
| "top_score": float(scores[0]) if scores else 0.0, |
| "predictions": _normalize_predictions(labels, scores, keep_best=True), |
| } |
| ) |
|
|
| votes: defaultdict[str, float] = defaultdict(float) |
| for item in sentence_predictions: |
| top_lang = item["top_family"] |
| top_score = float(item["top_score"]) |
| votes[top_lang] += top_score |
|
|
| predictions = [ |
| {"lang": lang, "score": score / max(len(sentence_predictions), 1)} |
| for lang, score in sorted(votes.items(), key=lambda item: item[1], reverse=True) |
| if score > 0.0 |
| ] |
| top_raw_label = sentence_predictions[0]["top_raw_label"] if sentence_predictions else None |
| top_family = sentence_predictions[0]["top_family"] if sentence_predictions else None |
| variant_warning = any(fasttext_label_is_proxy(item["top_raw_label"]) for item in sentence_predictions) |
| return { |
| "model": FASTTEXT_MODEL_REPO, |
| "mode": mode, |
| "sentences": sentence_predictions, |
| "sentence_count": len(sentence_predictions), |
| "predictions": predictions, |
| "top_lang": predictions[0]["lang"] if predictions else None, |
| "top_score": predictions[0]["score"] if predictions else 0.0, |
| "top_raw_label": top_raw_label, |
| "top_family": top_family, |
| "variant_warning": variant_warning, |
| } |
|
|
| line = " ".join(part.strip() for part in text.splitlines() if part.strip()) |
| labels, scores = _predict_one(line) |
| predictions = _normalize_predictions(labels, scores) |
| top_raw_label = labels[0].removeprefix("__label__") if labels else None |
| top_family = fasttext_label_to_iso2(top_raw_label) if top_raw_label else None |
| variant_warning = any(fasttext_label_is_proxy(item["raw_label"]) for item in predictions) |
| return { |
| "model": FASTTEXT_MODEL_REPO, |
| "mode": mode, |
| "predictions": predictions, |
| "sentence_count": 1, |
| "top_lang": predictions[0]["lang"] if predictions else None, |
| "top_score": predictions[0]["score"] if predictions else 0.0, |
| "top_raw_label": top_raw_label, |
| "top_family": top_family, |
| "variant_warning": variant_warning, |
| } |
|
|
|
|
| def fasttext_label_to_iso2(label: str) -> str: |
| """Convert fastText labels like `bos_Latn` or `eng` into our ISO-2 space.""" |
| base = label.split("_", 1)[0].lower() |
| base = canonical_lang(LANG_ALIASES.get(base, base)) |
| if base in LANG_ALIASES: |
| base = canonical_lang(LANG_ALIASES[base]) |
| if len(base) == 2: |
| return base |
|
|
| language = pycountry.languages.get(alpha_3=base) |
| if language is None: |
| return base |
|
|
| alpha_2 = getattr(language, "alpha_2", None) |
| if alpha_2: |
| return canonical_lang(alpha_2.lower()) |
| return base |
|
|
|
|
| def fasttext_label_is_proxy(label: str) -> bool: |
| """Return True when a fastText label maps through an explicit alias/proxy.""" |
| base = label.split("_", 1)[0].lower() |
| return base in LANG_ALIASES and LANG_ALIASES[base] != base |
|
|
|
|
| def fasttext_alias_hint_for_lang(fasttext_result: dict[str, Any] | None, lang: str) -> str | None: |
| """Return the raw fastText label when the selected language was reached via an explicit alias.""" |
| if not fasttext_result or not lang: |
| return None |
|
|
| for item in fasttext_result.get("predictions", []): |
| if item.get("lang") == lang and fasttext_label_is_proxy(str(item.get("raw_label", ""))): |
| return str(item.get("raw_label")) |
|
|
| for sentence in fasttext_result.get("sentences", []): |
| if sentence.get("top_family") == lang and fasttext_label_is_proxy(str(sentence.get("top_raw_label", ""))): |
| return str(sentence.get("top_raw_label")) |
| for item in sentence.get("predictions", []): |
| if item.get("lang") == lang and fasttext_label_is_proxy(str(item.get("raw_label", ""))): |
| return str(item.get("raw_label")) |
|
|
| return None |
|
|
|
|
| def predict_multilabel_model(text: str) -> tuple[list[tuple[str, dict[str, float | int]]], dict[str, float], float]: |
| """Return ranked multi-label language scores for one text.""" |
| tokenizer = get_tokenizer() |
| model = get_multilabel_model() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
|
|
| encoded = tokenizer( |
| text, |
| truncation=True, |
| padding=True, |
| max_length=512, |
| return_tensors="pt", |
| ) |
| encoded = {key: value.to(device) for key, value in encoded.items()} |
|
|
| with torch.inference_mode(): |
| logits = model(**encoded).logits[0] |
| probs = torch.sigmoid(logits).detach().cpu().tolist() |
|
|
| id2label = { |
| int(key): str(value) |
| for key, value in getattr(model.config, "id2label", {}).items() |
| } |
| ranked_pairs = sorted( |
| ((idx, float(score)) for idx, score in enumerate(probs)), |
| key=lambda item: item[1], |
| reverse=True, |
| ) |
| kept_pairs = [(idx, score) for idx, score in ranked_pairs[:MULTI_LABEL_TOP_N] if score >= MULTI_LABEL_MIN_SCORE] |
| if not kept_pairs: |
| kept_pairs = ranked_pairs[:1] |
|
|
| total_score = sum(score for _, score in kept_pairs) or 1.0 |
| lang_stats: dict[str, dict[str, float | int]] = {} |
| classifier_scores: dict[str, float] = {} |
| for idx, score in kept_pairs: |
| label = id2label.get(idx, f"label_{idx}") |
| lang = canonical_lang(label.replace("B-", "").replace("I-", "").lower()) |
| if not lang or lang == "o" or lang in lang_stats: |
| continue |
| normalized_share = score / total_score |
| lang_stats[lang] = { |
| "char_coverage": normalized_share * len(text), |
| "coverage_pct": normalized_share, |
| "avg_confidence": score, |
| "entity_count": 1, |
| "rank_score": normalized_share * score, |
| } |
| classifier_scores[lang] = score |
|
|
| if not lang_stats: |
| return [], {}, 0.0 |
|
|
| ranked = sorted(lang_stats.items(), key=lambda item: item[1]["rank_score"], reverse=True) |
| top_score = float(classifier_scores.get(ranked[0][0], 0.0)) |
| return ranked, classifier_scores, top_score |
|
|
|
|
| def fetch_random_cached_sentence() -> dict[str, Any]: |
| """Randomly sample a sentence from either cached source.""" |
| samplers = list(RANDOM_SENTENCE_SAMPLERS) |
| random.shuffle(samplers) |
| last_error: FileNotFoundError | None = None |
| for sampler in samplers: |
| try: |
| return sampler() |
| except FileNotFoundError as exc: |
| last_error = exc |
| if last_error is not None: |
| raise last_error |
| raise RuntimeError("No cached sentence samplers are registered.") |
|
|
|
|
| def fetch_random_cached_sentence_mix() -> dict[str, Any]: |
| """Randomly sample a mixed-language example from either cached source.""" |
| samplers = list(RANDOM_MIX_SAMPLERS) |
| random.shuffle(samplers) |
| last_error: FileNotFoundError | None = None |
| for sampler in samplers: |
| try: |
| return sampler() |
| except FileNotFoundError as exc: |
| last_error = exc |
| if last_error is not None: |
| raise last_error |
| raise RuntimeError("No cached mix samplers are registered.") |
|
|
|
|
| def render_prediction_body( |
| *, |
| text: str, |
| selected_lang: str, |
| dominant_lang: str, |
| lang_stats: dict[str, dict[str, float | int]], |
| classifier_scores: dict[str, float], |
| fasttext_result: dict[str, Any] | None, |
| overall_confidence: float, |
| ignored_artifacts: int, |
| model_label: str = "Token classifier", |
| ) -> str: |
| """Render the prediction card for a selected language.""" |
| stat = lang_stats[selected_lang] |
| iso3 = LANG_ISO2_TO_ISO3.get(selected_lang, "n/a") |
| selected_score = classifier_scores.get(selected_lang, 0.0) |
| tagged_chars = sum(float(s["char_coverage"]) for s in lang_stats.values()) |
| label = "Dominant" if selected_lang == dominant_lang else "Selected" |
|
|
| warnings = [] |
| if overall_confidence < 0.75: |
| warnings.append(f"Low confidence overall: {overall_confidence:.2f}") |
| if selected_lang != dominant_lang: |
| warnings.append(f"Top prediction: {dominant_lang.upper()}") |
| warning_html = "".join(f"<div class='ambiguity-warning'>{note}</div>" for note in warnings) |
|
|
| return f""" |
| <div class="summary-kicker" data-summary-kicker>Prediction</div> |
| <div class="summary-main" data-summary-main>{selected_lang.upper()}</div> |
| <div class="summary-note" data-summary-note>{model_label} · {label.lower()} view · derived score {selected_score:.1%}</div> |
| {warning_html} |
| <div class="metric-grid"> |
| <div class="metric"> |
| <span class="metric-label">ISO-3</span> |
| <span class="metric-value" data-metric-iso3>{iso3}</span> |
| </div> |
| <div class="metric"> |
| <span class="metric-label">Derived score</span> |
| <span class="metric-value" data-metric-derived-score><strong>{selected_score:.1%}</strong></span> |
| </div> |
| <div class="metric"> |
| <span class="metric-label">Coverage</span> |
| <span class="metric-value" data-metric-coverage>{float(stat['coverage_pct']):.1%}</span> |
| </div> |
| <div class="metric"> |
| <span class="metric-label">Avg confidence</span> |
| <span class="metric-value" data-metric-avg-confidence>{float(stat['avg_confidence']):.3f}</span> |
| </div> |
| <div class="metric" style="grid-column: 1 / -1; display: flex; justify-content: space-between; align-items: baseline; gap: 12px;"> |
| <span class="metric-label" style="margin: 0;">Tagged chars</span> |
| <span class="metric-value" data-metric-tagged-chars>{tagged_chars:.0f} / {len(text)}</span> |
| </div> |
| </div> |
| <div class="meter"> |
| <div class="meter-head"> |
| <span class="metric-label">Overall confidence</span> |
| <span class="meter-value" data-metric-overall-confidence>{overall_confidence:.3f}</span> |
| </div> |
| <div class="meter-track"> |
| <div class="meter-fill" style="width: {max(0.0, min(100.0, overall_confidence * 100.0)):.1f}%"></div> |
| </div> |
| </div> |
| <div class="summary-note" data-metric-ignored-artifacts>Ignored artifacts: {ignored_artifacts}</div> |
| """ |
|
|
|
|
| def render_comparison_strip( |
| *, |
| token_summary: dict[str, Any], |
| multi_summary: dict[str, Any], |
| reference_summary: dict[str, Any], |
| ) -> str: |
| token_label = token_summary.get("selected_lang", "") or "n/a" |
| token_score = float(token_summary.get("selected_score", 0.0)) |
| multi_label = multi_summary.get("selected_lang", "") or "n/a" |
| multi_score = float(multi_summary.get("selected_score", 0.0)) |
| reference_label = reference_summary.get("selected_lang", "") or "n/a" |
| reference_score = float(reference_summary.get("selected_score", 0.0)) |
| return f""" |
| <div class="comparison-strip"> |
| <div class="comparison-kicker">Three-way comparison</div> |
| <div class="comparison-row"> |
| <div class="comparison-pill comparison-token" title="Token classifier"> |
| <span class="comparison-label">Token classifier</span> |
| <span class="comparison-value">{token_label.upper()} {token_score:.1%}</span> |
| </div> |
| <div class="comparison-pill comparison-multi" title="Multi-label model"> |
| <span class="comparison-label">Multi-label</span> |
| <span class="comparison-value">{multi_label.upper()} {multi_score:.1%}</span> |
| </div> |
| <div class="comparison-pill comparison-ref" title="Reference baseline"> |
| <span class="comparison-label">fastText</span> |
| <span class="comparison-value">{reference_label.upper()} {reference_score:.1%}</span> |
| </div> |
| </div> |
| </div> |
| """ |
|
|
|
|
| def fasttext_mode_from_choice(choice: str | None) -> str: |
| choice = (choice or "").strip().lower() |
| return "sentences" if choice in {"sentences", "sentence by sentence"} else "full" |
|
|
|
|
| def render_selected_language_summary(ui_state: dict[str, Any], selected_lang: str) -> str: |
| """Redraw the summary card for a clicked language chip.""" |
| if not ui_state: |
| return """ |
| <div class="summary-card"> |
| <div class="summary-kicker">Prediction</div> |
| <div class="summary-main">No language selected</div> |
| </div> |
| """ |
|
|
| lang_stats = ui_state.get("lang_stats", {}) |
| if selected_lang not in lang_stats: |
| selected_lang = ui_state.get("selected_lang") or ui_state.get("dominant_lang") or "" |
| if not selected_lang: |
| return """ |
| <div class="summary-card"> |
| <div class="summary-kicker">Prediction</div> |
| <div class="summary-main">No language selected</div> |
| </div> |
| """ |
|
|
| new_state = {**ui_state, "selected_lang": selected_lang} |
| return render_prediction_body( |
| text=new_state.get("text", ""), |
| selected_lang=selected_lang, |
| dominant_lang=new_state.get("dominant_lang", selected_lang), |
| lang_stats=lang_stats, |
| classifier_scores=new_state.get("classifier_scores", {}), |
| fasttext_result=new_state.get("reference") or new_state.get("fasttext"), |
| overall_confidence=float(new_state.get("overall_confidence", 0.0)), |
| ignored_artifacts=int(new_state.get("ignored_artifacts", 0)), |
| model_label=str(new_state.get("model_label", "Token classifier")), |
| ) |
|
|
|
|
| def make_model_summary( |
| *, |
| text: str, |
| selected_lang: str, |
| lang_stats: dict[str, dict[str, float | int]], |
| classifier_scores: dict[str, float], |
| overall_confidence: float, |
| model_label: str, |
| fasttext_result: dict[str, Any] | None, |
| ignored_artifacts: int, |
| ) -> dict[str, Any]: |
| selected_score = classifier_scores.get(selected_lang, 0.0) |
| return { |
| "selected_lang": selected_lang, |
| "selected_score": selected_score, |
| "html": render_prediction_body( |
| text=text, |
| selected_lang=selected_lang, |
| dominant_lang=selected_lang, |
| lang_stats=lang_stats, |
| classifier_scores=classifier_scores, |
| fasttext_result=fasttext_result, |
| overall_confidence=overall_confidence, |
| ignored_artifacts=ignored_artifacts, |
| model_label=model_label, |
| ), |
| } |
|
|
|
|
| def render_fasttext_summary(ui_state: dict[str, Any], selected_lang: str) -> str: |
| """Render a summary card for the fastText view.""" |
| fasttext_result = ui_state.get("reference") or ui_state.get("fasttext") or {} |
| predictions = fasttext_result.get("predictions", []) |
| fasttext_scores = {item["lang"]: float(item["score"]) for item in predictions} |
| score = fasttext_scores.get(selected_lang, 0.0) |
| top_lang = fasttext_result.get("top_lang") or selected_lang |
| top_score = float(fasttext_result.get("top_score", 0.0)) |
| mode = fasttext_result.get("mode", "full") |
| mode_label = "sentence" if mode == "sentences" else "full text" |
| top_raw_label = fasttext_result.get("top_raw_label") or "n/a" |
| alias_hint = fasttext_alias_hint_for_lang(fasttext_result, selected_lang) |
|
|
| warnings = [] |
| if alias_hint: |
| warnings.append(f"Proxy hint: {alias_hint} -> {selected_lang.upper()}") |
| warning_html = "".join(f"<div class='ambiguity-warning'>{note}</div>" for note in warnings) |
|
|
| return f""" |
| <div class="summary-kicker" data-summary-kicker>fastText</div> |
| <div class="summary-main" data-summary-main>{selected_lang.upper()}</div> |
| <div class="summary-note" data-summary-note>Mode: {mode_label} · top prediction {top_lang.upper()} {top_score:.1%}</div> |
| {warning_html} |
| <div class="metric-grid"> |
| <div class="metric"> |
| <span class="metric-label">fastText score</span> |
| <span class="metric-value" data-metric-fasttext-score><strong>{score:.1%}</strong></span> |
| </div> |
| <div class="metric"> |
| <span class="metric-label">Top score</span> |
| <span class="metric-value" data-metric-top-score>{top_score:.1%}</span> |
| </div> |
| </div> |
| """ |
|
|
|
|
| def render_multilabel_summary(ui_state: dict[str, Any], selected_lang: str) -> str: |
| """Render a compact summary card for the multi-label model.""" |
| lang_stats = ui_state.get("lang_stats", {}) |
| classifier_scores = ui_state.get("classifier_scores", {}) |
| selected_lang = selected_lang if selected_lang in lang_stats else ui_state.get("selected_lang") or ui_state.get("dominant_lang") or "" |
| if not selected_lang: |
| return """ |
| <div class="summary-kicker">Multi-label</div> |
| <div class="summary-main">No multi-label prediction</div> |
| """ |
|
|
| stat = lang_stats[selected_lang] |
| iso3 = LANG_ISO2_TO_ISO3.get(selected_lang, "n/a") |
| selected_score = classifier_scores.get(selected_lang, 0.0) |
| top_lang = ui_state.get("dominant_lang") or selected_lang |
| top_score = float(classifier_scores.get(top_lang, selected_score)) |
| sequence_score = float(ui_state.get("overall_confidence", top_score)) |
| return f""" |
| <div class="summary-kicker" data-summary-kicker>Multi-label</div> |
| <div class="summary-main" data-summary-main>{selected_lang.upper()}</div> |
| <div class="summary-note" data-summary-note>Multi-label · top prediction {top_lang.upper()} {top_score:.1%}</div> |
| <div class="metric-grid"> |
| <div class="metric"> |
| <span class="metric-label">Top score</span> |
| <span class="metric-value" data-metric-top-score><strong>{selected_score:.1%}</strong></span> |
| </div> |
| <div class="metric"> |
| <span class="metric-label">Seq score</span> |
| <span class="metric-value" data-metric-seq-score>{sequence_score:.1%}</span> |
| </div> |
| <div class="metric"> |
| <span class="metric-label">ISO-3</span> |
| <span class="metric-value" data-metric-iso3>{iso3}</span> |
| </div> |
| <div class="metric"> |
| <span class="metric-label">Coverage</span> |
| <span class="metric-value" data-metric-coverage>{float(stat['coverage_pct']):.1%}</span> |
| </div> |
| </div> |
| """ |
|
|
|
|
| def render_prediction_bundle( |
| *, |
| token_ui_state: dict[str, Any], |
| multi_ui_state: dict[str, Any], |
| reference_ui_state: dict[str, Any], |
| active_model: str = "token", |
| ) -> str: |
| """Render the main interactive prediction card with in-card model switching.""" |
| models = { |
| "token": { |
| "label": "Token classifier", |
| "short": "T", |
| "ui_state": token_ui_state, |
| "summary": render_selected_language_summary(token_ui_state, token_ui_state.get("selected_lang", "")) if token_ui_state.get("selected_lang") else "<div class='summary-subtitle'>No token prediction</div>", |
| }, |
| "multi": { |
| "label": "Multi-label", |
| "short": "M", |
| "ui_state": multi_ui_state, |
| "summary": render_multilabel_summary(multi_ui_state, multi_ui_state.get("selected_lang", "")), |
| }, |
| "reference": { |
| "label": "FastText", |
| "short": "FT", |
| "ui_state": reference_ui_state, |
| "summary": render_fasttext_summary(reference_ui_state, reference_ui_state.get("selected_lang", "")), |
| }, |
| } |
| active_model = active_model if active_model in models else "token" |
| active_state = models[active_model]["ui_state"] |
| active_lang = active_state.get("selected_lang", "") or "n/a" |
| active_label = models[active_model]["label"] |
| active_score = float(active_state.get("classifier_scores", {}).get(active_state.get("selected_lang", ""), 0.0)) |
| top_langs = active_state.get("chip_langs", []) |
| active_scores = { |
| "token": token_ui_state.get("classifier_scores", {}), |
| "multi": multi_ui_state.get("classifier_scores", {}), |
| "reference": reference_ui_state.get("classifier_scores", {}), |
| } |
| token_lang_stats = token_ui_state.get("lang_stats", {}) |
| multi_lang_stats = multi_ui_state.get("lang_stats", {}) |
| reference_lang_stats = reference_ui_state.get("lang_stats", {}) |
| token_stats_json = json.dumps(token_lang_stats, ensure_ascii=True) |
| multi_stats_json = json.dumps(multi_lang_stats, ensure_ascii=True) |
| reference_stats_json = json.dumps(reference_lang_stats, ensure_ascii=True) |
| chips_html = "".join( |
| f""" |
| <button type="button" class="prediction-chip{' is-active' if idx == 0 else ''}" data-lang-toggle="{idx}" data-lang-code="{chip['lang']}" data-token-score="{float(token_ui_state.get('classifier_scores', {}).get(chip['lang'], 0.0)):.4f}" data-multi-score="{float(multi_ui_state.get('classifier_scores', {}).get(chip['lang'], 0.0)):.4f}" data-reference-score="{float(reference_ui_state.get('classifier_scores', {}).get(chip['lang'], 0.0)):.4f}"> |
| <span class="prediction-chip-dot"></span> |
| <span class="prediction-chip-lang">{chip['lang'].upper()}</span> |
| <span class="prediction-chip-score">{float(active_scores[active_model].get(chip['lang'], 0.0)):.1%}</span> |
| </button> |
| """ |
| for idx, chip in enumerate(top_langs[:3]) |
| ) or "<div class='prediction-chip-empty'>No ranked languages</div>" |
|
|
| def _model_tab(key: str, label: str, short: str) -> str: |
| active_class = " is-active" if key == active_model else "" |
| return f""" |
| <button type="button" class="prediction-model-tab{active_class}" data-model-toggle="{key}"> |
| <span class="prediction-model-short">{short}</span> |
| <span class="prediction-model-label">{label}</span> |
| </button> |
| """ |
|
|
| return f""" |
| <div class="prediction-card" data-prediction-card> |
| <div class="prediction-tabs"> |
| {_model_tab("token", "Token classifier", "T")} |
| {_model_tab("multi", "Multi-label", "M")} |
| {_model_tab("reference", "FastText", "FT")} |
| </div> |
| <div class="prediction-chip-row"> |
| <div class="prediction-chip-title">Language</div> |
| <div class="prediction-chip-list">{chips_html}</div> |
| </div> |
| <div class="prediction-body"> |
| <div class="prediction-panel is-active" data-model-panel="token" data-model-label="Token classifier" data-stats='{token_stats_json}' data-selected-lang="{token_ui_state.get('selected_lang', '')}" data-dominant-lang="{token_ui_state.get('dominant_lang', '')}" data-iso3="{LANG_ISO2_TO_ISO3.get(token_ui_state.get('selected_lang', ''), 'n/a')}" data-derived-score="{float(token_ui_state.get('classifier_scores', {}).get(token_ui_state.get('selected_lang', ''), 0.0)):.4f}" data-coverage="{float(token_ui_state.get('lang_stats', {}).get(token_ui_state.get('selected_lang', ''), {}).get('coverage_pct', 0.0)):.4f}" data-avg-confidence="{float(token_ui_state.get('lang_stats', {}).get(token_ui_state.get('selected_lang', ''), {}).get('avg_confidence', 0.0)):.4f}" data-tagged-chars="{sum(float(s['char_coverage']) for s in token_ui_state.get('lang_stats', {}).values()):.0f}" data-total-chars="{len(token_ui_state.get('text', ''))}" data-overall-confidence="{float(token_ui_state.get('overall_confidence', 0.0)):.4f}" data-ignored-artifacts="{int(token_ui_state.get('ignored_artifacts', 0))}">{models["token"]["summary"]}</div> |
| <div class="prediction-panel" data-model-panel="multi" data-model-label="Multi-label" data-stats='{multi_stats_json}' data-selected-lang="{multi_ui_state.get('selected_lang', '')}" data-dominant-lang="{multi_ui_state.get('dominant_lang', '')}" data-iso3="{LANG_ISO2_TO_ISO3.get(multi_ui_state.get('selected_lang', ''), 'n/a')}" data-top-score="{float(multi_ui_state.get('classifier_scores', {}).get(multi_ui_state.get('selected_lang', ''), 0.0)):.4f}" data-seq-score="{float(multi_ui_state.get('overall_confidence', 0.0)):.4f}" data-coverage="{float(multi_ui_state.get('lang_stats', {}).get(multi_ui_state.get('selected_lang', ''), {}).get('coverage_pct', 0.0)):.4f}">{models["multi"]["summary"]}</div> |
| <div class="prediction-panel" data-model-panel="reference" data-model-label="FastText" data-stats='{reference_stats_json}' data-selected-lang="{reference_ui_state.get('selected_lang', '')}" data-dominant-lang="{reference_ui_state.get('dominant_lang', '')}" data-top-lang="{reference_ui_state.get('selected_lang', '')}" data-top-score="{float(reference_ui_state.get('classifier_scores', {}).get(reference_ui_state.get('selected_lang', ''), 0.0)):.4f}" data-fasttext-score="{float(reference_ui_state.get('classifier_scores', {}).get(reference_ui_state.get('selected_lang', ''), 0.0)):.4f}" data-mode="{reference_ui_state.get('fasttext', {}).get('mode', 'full')}">{models["reference"]["summary"]}</div> |
| </div> |
| </div> |
| """ |
|
|
|
|
| def select_language_from_chip(chip_index: int, ui_state: dict[str, Any]) -> tuple[str, dict[str, Any]]: |
| """Pick a ranked language chip and redraw the summary for that language.""" |
| if not ui_state: |
| return ( |
| "<div class='empty-state'>Run a prediction first.</div>", |
| {}, |
| ) |
|
|
| chip_langs = ui_state.get("chip_langs", []) |
| if not chip_langs: |
| return ( |
| "<div class='empty-state'>No ranked languages available.</div>", |
| ui_state, |
| ) |
|
|
| chip_index = max(0, min(int(chip_index), len(chip_langs) - 1)) |
| chip = chip_langs[chip_index] |
| selected_lang = chip["lang"] |
| selected_source = chip.get("source", "model") |
| if selected_source == "fasttext": |
| summary = render_fasttext_summary(ui_state, selected_lang) |
| else: |
| summary = render_selected_language_summary(ui_state, selected_lang) |
| return summary, {**ui_state, "selected_lang": selected_lang, "selected_source": selected_source} |
|
|
|
|
| def is_artifact_span(span_len: int, score: float) -> bool: |
| """Identify tiny, low-confidence spans that are likely trailing noise.""" |
| return span_len < MIN_ARTIFACT_SPAN_CHARS and score < MIN_ARTIFACT_CONFIDENCE |
|
|
|
|
| def predict( |
| text: str, |
| fasttext_mode: str = "full", |
| validation_expected_langs: list[str] | None = None, |
| validation_source_label: str = "Example", |
| ) -> tuple[str, pd.DataFrame, dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any]]: |
| text = (text or "").strip() |
| if not text: |
| empty = pd.DataFrame(columns=["token", "language", "score", "start", "end"]) |
| hidden = {} |
| hidden_buttons = [gr.update(value="", visible=False) for _ in range(6)] |
| return ( |
| "<div class='empty-state'>Paste some text to see the model's language signal.</div>", |
| empty, |
| {}, |
| hidden, |
| "", |
| *hidden_buttons, |
| ) |
|
|
| nlp = get_pipeline() |
| entities = nlp(text) |
| fasttext_result = predict_fasttext(text, mode=fasttext_mode_from_choice(fasttext_mode)) |
| fasttext_scores = {item["lang"]: item["score"] for item in fasttext_result.get("predictions", [])} |
| rows: list[dict[str, Any]] = [] |
| token_counts: Counter[str] = Counter() |
|
|
| for entity in entities: |
| label = normalize_label(entity.get("entity_group", entity.get("entity", "O"))) |
| if label == "o": |
| continue |
| token_counts[label] += 1 |
| rows.append( |
| { |
| "token": entity.get("word", ""), |
| "language": label, |
| "score": round(float(entity.get("score", 0.0)), 4), |
| "start": entity.get("start", None), |
| "end": entity.get("end", None), |
| } |
| ) |
|
|
| spans = pd.DataFrame(rows, columns=["token", "language", "score", "start", "end"]) |
| spans = spans.sort_values(by=["start", "end"], na_position="last") if not spans.empty else spans |
| token_lang_stats, token_ignored_artifacts = build_lang_stats(entities) |
| token_ranked = sorted(token_lang_stats.items(), key=lambda item: item[1]["rank_score"], reverse=True) |
| token_classifier_scores = to_classifier_scores(token_lang_stats) |
| token_dominant_lang = token_ranked[0][0] if token_ranked else "" |
| token_overall_confidence = ( |
| sum(float(stat["char_coverage"]) * float(stat["avg_confidence"]) for stat in token_lang_stats.values()) / sum(float(stat["char_coverage"]) for stat in token_lang_stats.values()) |
| if token_lang_stats |
| else 0.0 |
| ) |
| if token_lang_stats: |
| token_ui_state = build_ui_state( |
| text=text, |
| lang_stats=token_lang_stats, |
| classifier_scores=token_classifier_scores, |
| fasttext_result=fasttext_result, |
| model_label="Token classifier", |
| dominant_lang=token_dominant_lang, |
| overall_confidence=token_overall_confidence, |
| ignored_artifacts=token_ignored_artifacts, |
| ) |
| else: |
| summary = """ |
| <div class="summary-card"> |
| <div class="summary-kicker">Prediction</div> |
| <div class="summary-main">No language spans detected</div> |
| <div class="summary-subtitle">Try a longer sample or a cleaner single-language paragraph.</div> |
| </div> |
| """ |
| token_ui_state = { |
| "text": text, |
| "lang_stats": {}, |
| "classifier_scores": {}, |
| "fasttext": fasttext_result, |
| "reference": fasttext_result, |
| "model_label": "Token classifier", |
| "dominant_lang": "", |
| "selected_lang": "", |
| "overall_confidence": 0.0, |
| "ignored_artifacts": 0, |
| "ranked_langs": [], |
| "chip_langs": [], |
| } |
|
|
| multi_ranked, multi_classifier_scores, multi_top_score = predict_multilabel_model(text) |
| if multi_ranked: |
| multi_lang_stats = {lang: stat for lang, stat in multi_ranked} |
| multi_dominant_lang = multi_ranked[0][0] |
| multi_ui_state = build_ui_state( |
| text=text, |
| lang_stats=multi_lang_stats, |
| classifier_scores=multi_classifier_scores, |
| fasttext_result=fasttext_result, |
| model_label="Multi-label", |
| dominant_lang=multi_dominant_lang, |
| overall_confidence=multi_top_score, |
| ignored_artifacts=0, |
| ) |
| else: |
| multi_lang_stats = {} |
| multi_classifier_scores = {} |
| multi_dominant_lang = "" |
| multi_summary = """ |
| <div class="summary-card"> |
| <div class="summary-kicker">Prediction</div> |
| <div class="summary-main">No multi-label predictions</div> |
| <div class="summary-subtitle">Try a longer sample or a cleaner multilingual paragraph.</div> |
| </div> |
| """ |
| multi_ui_state = { |
| "text": text, |
| "lang_stats": {}, |
| "classifier_scores": {}, |
| "fasttext": fasttext_result, |
| "reference": fasttext_result, |
| "model_label": "Multi-label", |
| "dominant_lang": "", |
| "selected_lang": "", |
| "overall_confidence": 0.0, |
| "ignored_artifacts": 0, |
| "ranked_langs": [], |
| "chip_langs": [], |
| } |
|
|
| reference_lang_stats = {lang: {"coverage_pct": score, "avg_confidence": score, "char_coverage": score, "entity_count": 1, "rank_score": score} for lang, score in fasttext_scores.items()} |
| reference_ranked = sorted(reference_lang_stats.items(), key=lambda item: item[1]["rank_score"], reverse=True) |
| reference_selected_lang = reference_ranked[0][0] if reference_ranked else "" |
| reference_ui_state = { |
| "text": text, |
| "lang_stats": reference_lang_stats, |
| "classifier_scores": fasttext_scores, |
| "fasttext": fasttext_result, |
| "reference": fasttext_result, |
| "model_label": "Reference", |
| "dominant_lang": reference_selected_lang, |
| "selected_lang": reference_selected_lang, |
| "overall_confidence": float(fasttext_result.get("top_score", 0.0)), |
| "ignored_artifacts": 0, |
| "ranked_langs": [lang for lang, _ in reference_ranked], |
| "chip_langs": [{"lang": lang, "source": "reference"} for lang, _ in reference_ranked], |
| } |
| summary = render_prediction_bundle( |
| token_ui_state=token_ui_state, |
| multi_ui_state=multi_ui_state, |
| reference_ui_state=reference_ui_state, |
| active_model="token", |
| ) |
| comparison_html = render_comparison_strip( |
| token_summary={"selected_lang": token_dominant_lang, "selected_score": token_classifier_scores.get(token_dominant_lang, 0.0)}, |
| multi_summary={"selected_lang": multi_dominant_lang, "selected_score": multi_classifier_scores.get(multi_dominant_lang, 0.0)}, |
| reference_summary={"selected_lang": reference_selected_lang, "selected_score": float(fasttext_scores.get(reference_selected_lang, 0.0))}, |
| ) |
|
|
| raw = { |
| "models": { |
| "token_classifier": { |
| "name": MODEL_CHECKPOINT, |
| "kind": "token-classification", |
| "top_predictions": token_counts.most_common(10), |
| "classifier_scores": token_classifier_scores if token_lang_stats else {}, |
| "overall_confidence": f"{token_overall_confidence:.3f}" if token_lang_stats else "0.000", |
| "ignored_artifacts": token_ignored_artifacts, |
| "lang_stats": { |
| lang: { |
| **stat, |
| "coverage_pct": f"{float(stat['coverage_pct']):.3f}", |
| "avg_confidence": f"{float(stat['avg_confidence']):.3f}", |
| "rank_score": f"{float(stat['rank_score']):.3f}", |
| } |
| for lang, stat in token_ranked |
| }, |
| "selected_lang": token_dominant_lang, |
| "ranked_langs": [lang for lang, _ in token_ranked], |
| }, |
| "multi_label": { |
| "name": MULTI_MODEL, |
| "kind": "multi-label-classification", |
| "top_predictions": [(lang, round(float(multi_classifier_scores.get(lang, 0.0)), 4)) for lang, _ in multi_ranked[:10]], |
| "classifier_scores": multi_classifier_scores, |
| "overall_confidence": f"{multi_top_score:.3f}" if multi_ranked else "0.000", |
| "ignored_artifacts": 0, |
| "lang_stats": { |
| lang: { |
| **stat, |
| "coverage_pct": f"{float(stat['coverage_pct']):.3f}", |
| "avg_confidence": f"{float(stat['avg_confidence']):.3f}", |
| "rank_score": f"{float(stat['rank_score']):.3f}", |
| } |
| for lang, stat in multi_ranked |
| }, |
| "selected_lang": multi_dominant_lang, |
| "ranked_langs": [lang for lang, _ in multi_ranked], |
| }, |
| "reference": { |
| "name": FASTTEXT_MODEL_REPO, |
| "kind": "fasttext-language-id", |
| "top_predictions": fasttext_result.get("predictions", []), |
| "classifier_scores": fasttext_scores, |
| "overall_confidence": f"{float(fasttext_result.get('top_score', 0.0)):.3f}", |
| "selected_lang": reference_selected_lang, |
| "ranked_langs": [lang for lang, _ in reference_ranked], |
| }, |
| }, |
| "text": text, |
| } |
| chip_updates = build_chip_button_updates(token_ranked, token_classifier_scores, multi_classifier_scores, fasttext_scores) if token_lang_stats else [gr.update(value="", visible=False) for _ in range(6)] |
|
|
| validation_html = "" |
| if validation_expected_langs: |
| validation = build_example_validation( |
| token_classifier_scores, |
| multi_classifier_scores, |
| fasttext_scores, |
| validation_expected_langs, |
| ) |
| validation_html = render_validation_html(validation, source_label=validation_source_label) |
|
|
| return summary, spans, raw, token_ui_state, comparison_html + validation_html, *chip_updates |
|
|
|
|
| def load_random_tatoeba_example( |
| fasttext_mode: str = "full", |
| ) -> tuple[str, str, pd.DataFrame, dict[str, Any], dict[str, Any], str]: |
| sentence = fetch_random_tatoeba_sentence() |
| text = sentence["text"] |
| summary, spans, raw, ui_state, _, *chip_updates = predict(text, fasttext_mode=fasttext_mode) |
| token_model = raw.get("models", {}).get("token_classifier", {}) |
| sentence_rows = sentence.get("sentences") or [sentence] |
| sentence_ids = _sentence_id_keys(sentence) |
| sentence_langs = [item.get("lang_iso2", "") for item in sentence_rows] |
| sentence_lang_iso3s = [item.get("lang_iso3", "") for item in sentence_rows] |
| validation = build_example_validation( |
| token_model.get("classifier_scores", {}), |
| raw.get("models", {}).get("multi_label", {}).get("classifier_scores", {}), |
| raw.get("models", {}).get("reference", {}).get("classifier_scores", {}), |
| sentence_langs, |
| ) |
| raw = { |
| **raw, |
| "source": "tatoeba", |
| "sentence_id": sentence_ids[0] if sentence_ids else sentence.get("sentence_id", sentence.get("id")), |
| "sentence_ids": sentence_ids, |
| "lang_count": sentence.get("lang_count", len(sentence_rows)), |
| "sentence_langs": sentence_langs, |
| "sentence_lang_iso3s": sentence_lang_iso3s, |
| "sentences": sentence_rows, |
| "sentence_lang": sentence.get("source_lang", sentence.get("lang")), |
| "sentence_lang_iso2": sentence.get("lang_iso2", sentence.get("source_lang")), |
| "sentence_lang_iso3": sentence.get("lang_iso3", ""), |
| _validation_key(sentence.get("source", "tatoeba")): validation, |
| } |
| validation_html = render_validation_html(validation, source_label=_source_label(sentence.get("source", "tatoeba"))) |
| return text, summary, spans, raw, ui_state, validation_html, *chip_updates |
|
|
|
|
| def load_random_tatoeba_mix_example( |
| fasttext_mode: str = "full", |
| ) -> tuple[str, str, pd.DataFrame, dict[str, Any], dict[str, Any], str]: |
| mix = fetch_random_tatoeba_sentence_mix() |
| text = mix["text"] |
| summary, spans, raw, ui_state, _, *chip_updates = predict(text, fasttext_mode=fasttext_mode) |
| token_model = raw.get("models", {}).get("token_classifier", {}) |
| validation = build_example_validation( |
| token_model.get("classifier_scores", {}), |
| raw.get("models", {}).get("multi_label", {}).get("classifier_scores", {}), |
| raw.get("models", {}).get("reference", {}).get("classifier_scores", {}), |
| mix.get("langs", []), |
| ) |
| raw = { |
| **raw, |
| "source": "tatoeba-mix", |
| "lang_count": mix["lang_count"], |
| "sentence_langs": mix["langs"], |
| "sentence_lang_iso3s": mix["lang_iso3s"], |
| "sentences": mix["sentences"], |
| _validation_key(mix.get("source", "tatoeba-mix")): validation, |
| } |
| validation_html = render_validation_html(validation, source_label=_source_label(mix.get("source", "tatoeba-mix"))) |
| return text, summary, spans, raw, ui_state, validation_html, *chip_updates |
|
|
|
|
| def load_random_fleurs_example( |
| fasttext_mode: str = "full", |
| ) -> tuple[str, str, pd.DataFrame, dict[str, Any], dict[str, Any], str]: |
| try: |
| sentence = fetch_random_cached_sentence() |
| except FileNotFoundError as exc: |
| empty = pd.DataFrame(columns=["token", "language", "score", "start", "end"]) |
| message = ( |
| "<div class='empty-state'>" |
| f"{exc}" |
| "</div>" |
| ) |
| return "", message, empty, {}, {}, "", *[gr.update(value="", visible=False) for _ in range(6)] |
| text = sentence["text"] |
| summary, spans, raw, ui_state, _, *chip_updates = predict(text, fasttext_mode=fasttext_mode) |
| token_model = raw.get("models", {}).get("token_classifier", {}) |
| sentence_rows = sentence.get("sentences") or [sentence] |
| sentence_id_values = _sentence_id_keys(sentence) |
| sentence_langs = [item.get("lang_iso2", "") for item in sentence_rows] |
| sentence_lang_iso3s = [item.get("lang_iso3", "") for item in sentence_rows] |
| validation = build_example_validation( |
| token_model.get("classifier_scores", {}), |
| raw.get("models", {}).get("multi_label", {}).get("classifier_scores", {}), |
| raw.get("models", {}).get("reference", {}).get("classifier_scores", {}), |
| sentence_langs, |
| ) |
| raw = { |
| **raw, |
| "source": sentence.get("source", "fleurs"), |
| "cached_sentence_id": sentence_id_values[0] if sentence_id_values else None, |
| "cached_sentence_ids": [_sentence_id_keys(item)[0] if _sentence_id_keys(item) else None for item in sentence_rows], |
| "lang_count": sentence.get("lang_count", len(sentence_rows)), |
| "cached_split": sentence.get("split"), |
| "cached_source_lang": sentence.get("source_lang"), |
| "cached_model_lang": sentence.get("model_lang", sentence.get("lang_iso2")), |
| "cached_language": sentence.get("language"), |
| "sentence_langs": sentence_langs, |
| "sentence_lang_iso3s": sentence_lang_iso3s, |
| "sentences": sentence_rows, |
| _validation_key(sentence.get("source", "fleurs")): validation, |
| } |
| source_label = _source_label(sentence.get("source", "fleurs")) |
| validation_html = render_validation_html(validation, source_label=source_label) |
| return text, summary, spans, raw, ui_state, validation_html, *chip_updates |
|
|
|
|
| def load_random_mix_example( |
| fasttext_mode: str = "full", |
| adverse_mix: bool = False, |
| ) -> tuple[str, str, pd.DataFrame, dict[str, Any], dict[str, Any], str]: |
| try: |
| mix = fetch_random_adverse_mix() if adverse_mix else fetch_random_cached_sentence_mix() |
| except FileNotFoundError as exc: |
| empty = pd.DataFrame(columns=["token", "language", "score", "start", "end"]) |
| message = ( |
| "<div class='empty-state'>" |
| f"{exc}" |
| "</div>" |
| ) |
| return "", message, empty, {}, {}, "", *[gr.update(value="", visible=False) for _ in range(6)] |
| text = mix["text"] |
| summary, spans, raw, ui_state, _, *chip_updates = predict(text, fasttext_mode=fasttext_mode) |
| token_model = raw.get("models", {}).get("token_classifier", {}) |
| validation = build_example_validation( |
| token_model.get("classifier_scores", {}), |
| raw.get("models", {}).get("multi_label", {}).get("classifier_scores", {}), |
| raw.get("models", {}).get("reference", {}).get("classifier_scores", {}), |
| mix.get("langs", []), |
| ) |
| raw = { |
| **raw, |
| "source": mix.get("source", "fleurs-mix"), |
| "lang_count": mix["lang_count"], |
| "sentence_langs": mix["langs"], |
| "sentence_lang_iso3s": mix["lang_iso3s"], |
| "sentences": mix["sentences"], |
| "adverse_mix": adverse_mix, |
| "adverse_family": mix.get("family", ""), |
| "adverse_family_langs": mix.get("family_langs", []), |
| _validation_key(mix.get("source", "fleurs-mix")): validation, |
| } |
| source_label = _source_label(mix.get("source", "fleurs-mix")) |
| validation_html = render_validation_html(validation, source_label=source_label) |
| return text, summary, spans, raw, ui_state, validation_html, *chip_updates |
|
|
|
|
| CSS = """ |
| :root { |
| --bg-1: #06111f; |
| --bg-2: #0b1f33; |
| --card: rgba(10, 20, 33, 0.72); |
| --card-border: rgba(255, 255, 255, 0.12); |
| --text: #f4f7fb; |
| --muted: #b7c3d6; |
| --accent: #7dd3fc; |
| --accent-2: #f59e0b; |
| } |
| html { |
| color-scheme: dark; |
| } |
| body { |
| background: |
| radial-gradient(circle at top left, rgba(125, 211, 252, 0.22), transparent 28%), |
| radial-gradient(circle at top right, rgba(245, 158, 11, 0.16), transparent 24%), |
| linear-gradient(135deg, var(--bg-1), var(--bg-2)); |
| color: var(--text); |
| } |
| .gradio-container { |
| color: var(--text); |
| } |
| .gradio-container *, |
| .gradio-container :is(p, span, div, label, strong, em, li, summary) { |
| color: inherit; |
| } |
| .gradio-container input, |
| .gradio-container textarea, |
| .gradio-container select { |
| color: var(--text) !important; |
| } |
| .wrap { |
| max-width: 1180px; |
| margin: 0 auto; |
| } |
| .hero { |
| padding: 28px 28px 22px; |
| border: 1px solid var(--card-border); |
| border-radius: 24px; |
| background: linear-gradient(180deg, rgba(255,255,255,0.08), rgba(255,255,255,0.03)); |
| box-shadow: 0 24px 80px rgba(0, 0, 0, 0.28); |
| backdrop-filter: blur(14px); |
| } |
| .eyebrow { |
| text-transform: uppercase; |
| letter-spacing: 0.22em; |
| color: var(--accent); |
| font-size: 12px; |
| font-weight: 700; |
| margin-bottom: 8px; |
| } |
| .title { |
| font-size: clamp(32px, 5vw, 56px); |
| line-height: 1.02; |
| margin: 0; |
| color: var(--text); |
| font-weight: 800; |
| } |
| .subtitle { |
| margin-top: 12px; |
| color: var(--muted); |
| font-size: 16px; |
| max-width: 820px; |
| } |
| .summary-card { |
| border: 1px solid var(--card-border); |
| border-radius: 22px; |
| padding: 22px; |
| background: rgba(7, 13, 24, 0.7); |
| color: var(--text); |
| min-height: 240px; |
| display: flex; |
| flex-direction: column; |
| gap: 10px; |
| } |
| .prediction-card { |
| border: 1px solid var(--card-border); |
| border-radius: 22px; |
| background: rgba(7, 13, 24, 0.82); |
| overflow: hidden; |
| } |
| .prediction-tabs { |
| display: grid; |
| grid-template-columns: repeat(3, minmax(0, 1fr)); |
| } |
| .prediction-model-tab { |
| border: 0; |
| border-right: 1px solid rgba(255, 255, 255, 0.10); |
| background: rgba(255, 255, 255, 0.02); |
| color: var(--muted); |
| padding: 12px 14px; |
| display: grid; |
| justify-items: center; |
| gap: 2px; |
| cursor: pointer; |
| } |
| .prediction-model-tab:last-child { |
| border-right: 0; |
| } |
| .prediction-model-tab.is-active { |
| color: white; |
| box-shadow: inset 0 -2px 0 var(--accent); |
| } |
| .prediction-model-short { |
| font-size: 18px; |
| font-weight: 900; |
| line-height: 1; |
| } |
| .prediction-model-label { |
| font-size: 14px; |
| font-weight: 700; |
| } |
| .prediction-chip-row { |
| display: grid; |
| grid-template-columns: auto 1fr; |
| gap: 12px; |
| align-items: center; |
| padding: 10px 14px; |
| border-top: 1px solid rgba(255, 255, 255, 0.08); |
| border-bottom: 1px solid rgba(255, 255, 255, 0.08); |
| } |
| .prediction-chip-title { |
| color: var(--muted); |
| text-transform: uppercase; |
| letter-spacing: 0.12em; |
| font-weight: 800; |
| } |
| .prediction-chip-list { |
| display: flex; |
| gap: 10px; |
| flex-wrap: wrap; |
| } |
| .prediction-chip { |
| border: 1px solid rgba(255, 255, 255, 0.10); |
| background: rgba(255, 255, 255, 0.03); |
| color: white; |
| border-radius: 14px; |
| padding: 10px 12px; |
| cursor: pointer; |
| display: inline-flex; |
| align-items: center; |
| gap: 8px; |
| } |
| .prediction-chip.is-active { |
| border-color: rgba(125, 211, 252, 0.72); |
| background: rgba(125, 211, 252, 0.16); |
| } |
| .prediction-chip-dot { |
| width: 8px; |
| height: 8px; |
| border-radius: 999px; |
| background: var(--accent); |
| } |
| .prediction-chip-lang { |
| font-size: 15px; |
| font-weight: 900; |
| } |
| .prediction-chip-score { |
| font-size: 14px; |
| color: var(--muted); |
| font-weight: 700; |
| } |
| .prediction-body { |
| padding: 14px; |
| } |
| .prediction-panel { |
| display: none; |
| } |
| .prediction-panel.is-active { |
| display: block; |
| } |
| .prediction-topline { |
| color: var(--muted); |
| padding: 0 14px 14px; |
| } |
| .summary-kicker { |
| color: var(--accent); |
| text-transform: uppercase; |
| letter-spacing: 0.18em; |
| font-size: 11px; |
| font-weight: 700; |
| } |
| .summary-main { |
| font-size: 48px; |
| font-weight: 900; |
| margin-top: 8px; |
| color: white; |
| line-height: 0.95; |
| letter-spacing: -0.03em; |
| } |
| .summary-note { |
| color: var(--muted); |
| margin-top: 2px; |
| line-height: 1.45; |
| } |
| .metric-grid { |
| display: grid; |
| grid-template-columns: repeat(2, minmax(0, 1fr)); |
| gap: 8px; |
| } |
| .metric { |
| border: 1px solid rgba(255, 255, 255, 0.10); |
| background: rgba(255, 255, 255, 0.03); |
| border-radius: 16px; |
| padding: 10px 12px; |
| } |
| .metric-label { |
| display: block; |
| color: var(--muted); |
| font-size: 11px; |
| text-transform: uppercase; |
| letter-spacing: 0.14em; |
| margin-bottom: 6px; |
| } |
| .metric-value { |
| display: block; |
| color: var(--text); |
| font-size: 16px; |
| font-weight: 700; |
| } |
| .metric-value strong { |
| color: white; |
| } |
| .summary-row { |
| display: flex; |
| flex-wrap: wrap; |
| gap: 10px; |
| align-items: flex-start; |
| } |
| .meter { |
| margin-top: 2px; |
| } |
| .meter-head { |
| display: flex; |
| justify-content: space-between; |
| align-items: baseline; |
| gap: 10px; |
| margin-bottom: 8px; |
| } |
| .meter-value { |
| color: var(--text); |
| font-size: 16px; |
| font-weight: 800; |
| } |
| .meter-track { |
| width: 100%; |
| height: 10px; |
| border-radius: 999px; |
| overflow: hidden; |
| background: rgba(255, 255, 255, 0.08); |
| } |
| .meter-fill { |
| height: 100%; |
| border-radius: 999px; |
| background: linear-gradient(90deg, var(--accent), #60a5fa); |
| } |
| .chip-row { |
| display: flex; |
| flex-wrap: wrap; |
| gap: 8px; |
| margin-top: 2px; |
| } |
| .chip { |
| border: 1px solid rgba(125, 211, 252, 0.25); |
| background: rgba(125, 211, 252, 0.08); |
| color: var(--text); |
| padding: 8px 10px; |
| border-radius: 999px; |
| font-size: 12px; |
| min-width: 140px; |
| white-space: nowrap; |
| } |
| .chip strong { |
| margin-left: 4px; |
| color: white; |
| } |
| .chip-conf { |
| display: block; |
| color: var(--muted); |
| font-size: 11px; |
| margin-top: 2px; |
| } |
| .ambiguity-warning { |
| margin-top: 10px; |
| padding: 10px 12px; |
| border-radius: 14px; |
| border: 1px solid rgba(245, 158, 11, 0.35); |
| background: rgba(245, 158, 11, 0.12); |
| color: #fbd38d; |
| font-size: 13px; |
| font-weight: 600; |
| } |
| .empty-state { |
| padding: 18px 20px; |
| border-radius: 18px; |
| border: 1px dashed rgba(255,255,255,0.16); |
| color: var(--muted); |
| background: rgba(255,255,255,0.03); |
| } |
| .gradio-container .gr-textbox textarea { |
| font-size: 15px !important; |
| } |
| .footer-note { |
| color: var(--muted); |
| font-size: 13px; |
| margin-top: 8px; |
| } |
| .footer-langs { |
| cursor: pointer; |
| padding: 14px 16px; |
| border-radius: 14px; |
| border: 1px solid var(--card-border); |
| background: rgba(255, 255, 255, 0.04); |
| } |
| .footer-langs summary { |
| list-style: none; |
| font-weight: 700; |
| } |
| .footer-langs summary::-webkit-details-marker { |
| display: none; |
| } |
| .footer-langs-body { |
| margin-top: 10px; |
| max-height: 240px; |
| overflow: auto; |
| padding-right: 4px; |
| } |
| .footer-lang-list { |
| margin: 0; |
| padding: 0; |
| list-style: none; |
| display: grid; |
| grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); |
| gap: 8px 12px; |
| } |
| .footer-lang-list li { |
| display: flex; |
| gap: 8px; |
| align-items: baseline; |
| } |
| .lang-code { |
| font-family: monospace; |
| color: var(--accent); |
| min-width: 2.2rem; |
| } |
| .lang-name { |
| color: var(--text); |
| } |
| .validation-strip { |
| border-radius: 18px; |
| padding: 12px 14px; |
| margin-top: 10px; |
| } |
| .validation-strip.validation-model-token { |
| border: 1px solid rgba(34, 197, 94, 0.34); |
| background: rgba(34, 197, 94, 0.12); |
| } |
| .validation-strip.validation-model-multi { |
| border: 1px solid rgba(245, 158, 11, 0.34); |
| background: rgba(245, 158, 11, 0.12); |
| } |
| .validation-strip.validation-model-reference { |
| border: 1px solid rgba(96, 165, 250, 0.34); |
| background: rgba(96, 165, 250, 0.12); |
| } |
| .validation-high { |
| border: 1px solid rgba(34, 197, 94, 0.30); |
| background: rgba(34, 197, 94, 0.10); |
| } |
| .validation-mid { |
| border: 1px solid rgba(245, 158, 11, 0.35); |
| background: rgba(245, 158, 11, 0.10); |
| } |
| .validation-low { |
| border: 1px solid rgba(239, 68, 68, 0.35); |
| background: rgba(239, 68, 68, 0.10); |
| } |
| .validation-kicker { |
| color: #86efac; |
| text-transform: uppercase; |
| letter-spacing: 0.18em; |
| font-size: 11px; |
| font-weight: 700; |
| } |
| .validation-main { |
| display: flex; |
| flex-wrap: wrap; |
| gap: 10px; |
| margin-top: 6px; |
| } |
| .validation-metric-button { |
| appearance: none; |
| border: 1px solid rgba(255, 255, 255, 0.12); |
| background: rgba(255, 255, 255, 0.04); |
| color: white; |
| border-radius: 14px; |
| padding: 10px 12px; |
| min-width: 110px; |
| text-align: left; |
| cursor: pointer; |
| display: grid; |
| gap: 2px; |
| pointer-events: auto; |
| position: relative; |
| z-index: 1; |
| } |
| .validation-metric-button.is-active { |
| border-color: rgba(134, 239, 172, 0.55); |
| background: rgba(134, 239, 172, 0.12); |
| box-shadow: 0 0 0 1px rgba(134, 239, 172, 0.15) inset; |
| } |
| .validation-metric-label { |
| font-size: 11px; |
| text-transform: uppercase; |
| letter-spacing: 0.14em; |
| color: #a7f3d0; |
| font-weight: 800; |
| } |
| .validation-metric-value { |
| font-size: 28px; |
| line-height: 1; |
| font-weight: 900; |
| } |
| .validation-vs { |
| color: var(--muted); |
| font-size: 16px; |
| font-weight: 700; |
| margin-top: 8px; |
| } |
| .validation-status { |
| margin-top: 4px; |
| font-size: 13px; |
| font-weight: 700; |
| } |
| .validation-meta { |
| margin-top: 8px; |
| color: var(--muted); |
| font-size: 13px; |
| } |
| .validation-panels { |
| margin-top: 10px; |
| } |
| .validation-detail { |
| display: none; |
| gap: 8px; |
| padding: 12px; |
| border-radius: 14px; |
| background: rgba(255, 255, 255, 0.05); |
| border: 1px solid rgba(255, 255, 255, 0.08); |
| } |
| .validation-detail.is-active { |
| display: grid; |
| } |
| .validation-detail-head { |
| display: flex; |
| justify-content: space-between; |
| gap: 12px; |
| align-items: baseline; |
| } |
| .validation-detail-title { |
| font-size: 13px; |
| font-weight: 800; |
| color: white; |
| } |
| .validation-detail-score { |
| font-size: 22px; |
| font-weight: 900; |
| color: white; |
| } |
| .validation-detail-row { |
| display: flex; |
| justify-content: space-between; |
| gap: 12px; |
| color: var(--muted); |
| font-size: 13px; |
| } |
| .validation-detail-row strong { |
| color: white; |
| text-align: right; |
| } |
| .validation-pass { |
| color: #86efac; |
| } |
| .validation-warn { |
| color: #fbbf24; |
| } |
| .validation-subtitle { |
| color: var(--muted); |
| margin-top: 6px; |
| font-size: 13px; |
| line-height: 1.4; |
| } |
| .comparison-strip { |
| border-radius: 18px; |
| padding: 12px 14px; |
| margin-top: 10px; |
| border: 1px solid rgba(125, 211, 252, 0.22); |
| background: linear-gradient(180deg, rgba(125, 211, 252, 0.08), rgba(255, 255, 255, 0.03)); |
| } |
| .comparison-kicker { |
| color: var(--accent); |
| text-transform: uppercase; |
| letter-spacing: 0.18em; |
| font-size: 11px; |
| font-weight: 700; |
| } |
| .comparison-row { |
| display: grid; |
| grid-template-columns: repeat(3, minmax(0, 1fr)); |
| gap: 10px; |
| margin-top: 10px; |
| } |
| .comparison-pill { |
| border-radius: 14px; |
| padding: 10px 12px; |
| border: 1px solid rgba(255, 255, 255, 0.10); |
| display: flex; |
| flex-direction: column; |
| gap: 4px; |
| } |
| .comparison-label { |
| font-size: 10px; |
| letter-spacing: 0.14em; |
| text-transform: uppercase; |
| color: var(--muted); |
| font-weight: 700; |
| } |
| .comparison-value { |
| font-size: 15px; |
| font-weight: 900; |
| color: white; |
| } |
| .comparison-token { |
| background: rgba(34, 197, 94, 0.12); |
| border-color: rgba(34, 197, 94, 0.25); |
| } |
| .comparison-multi { |
| background: rgba(245, 158, 11, 0.12); |
| border-color: rgba(245, 158, 11, 0.25); |
| } |
| .comparison-ref { |
| background: rgba(96, 165, 250, 0.12); |
| border-color: rgba(96, 165, 250, 0.25); |
| } |
| .validation-grid { |
| display: grid; |
| grid-template-columns: repeat(2, minmax(0, 1fr)); |
| gap: 8px; |
| margin-top: 10px; |
| } |
| .validation-chip { |
| border: 1px solid rgba(255, 255, 255, 0.10); |
| background: rgba(255, 255, 255, 0.04); |
| border-radius: 14px; |
| padding: 10px 12px; |
| display: flex; |
| flex-direction: column; |
| gap: 4px; |
| } |
| .validation-chip-label { |
| color: var(--muted); |
| text-transform: uppercase; |
| letter-spacing: 0.14em; |
| font-size: 10px; |
| font-weight: 700; |
| } |
| .validation-chip-value { |
| color: white; |
| font-size: 14px; |
| font-weight: 700; |
| line-height: 1.35; |
| } |
| .validation-chip-muted { |
| color: var(--muted); |
| font-size: 12px; |
| font-weight: 600; |
| } |
| .validation-grid .validation-chip:first-child, |
| .validation-grid .validation-chip:nth-child(2) { |
| grid-column: span 1; |
| } |
| .validation-grid .validation-chip:nth-child(4) { |
| grid-column: span 2; |
| } |
| .validation-grid .validation-chip:nth-child(7), |
| .validation-grid .validation-chip:nth-child(8) { |
| grid-column: span 1; |
| } |
| .chip-strip { |
| display: none !important; |
| } |
| .chip-btn { |
| display: none !important; |
| } |
| .results-shell { |
| gap: 18px; |
| align-items: start; |
| } |
| .results-grid { |
| gap: 14px; |
| align-items: stretch; |
| } |
| .results-panel { |
| min-width: 0 !important; |
| } |
| .results-panel .gr-panel { |
| height: 100%; |
| } |
| .results-panel .gr-dataframe, |
| .results-panel .gr-json { |
| min-height: 280px; |
| max-height: 420px; |
| overflow-y: auto; |
| } |
| .gradio-container .gr-dataframe table { |
| table-layout: fixed !important; |
| width: 100% !important; |
| } |
| .gradio-container .gr-dataframe th:nth-child(1), |
| .gradio-container .gr-dataframe td:nth-child(1) { |
| width: 42% !important; |
| } |
| .gradio-container .gr-dataframe th:nth-child(2), |
| .gradio-container .gr-dataframe td:nth-child(2) { |
| width: 12% !important; |
| } |
| .gradio-container .gr-dataframe th:nth-child(3), |
| .gradio-container .gr-dataframe td:nth-child(3) { |
| width: 14% !important; |
| } |
| .gradio-container .gr-dataframe th:nth-child(4), |
| .gradio-container .gr-dataframe td:nth-child(4) { |
| width: 16% !important; |
| } |
| .gradio-container .gr-dataframe th:nth-child(5), |
| .gradio-container .gr-dataframe td:nth-child(5) { |
| width: 16% !important; |
| } |
| .gradio-container .gr-dataframe td:nth-child(1) { |
| overflow: hidden; |
| text-overflow: ellipsis; |
| white-space: nowrap; |
| } |
| @media (max-width: 900px) { |
| .chip-strip { display: none !important; } |
| } |
| @media (max-width: 640px) { |
| .chip-strip { display: none !important; } |
| .chip-btn { display: none !important; } |
| } |
| .action-btn { |
| width: 100%; |
| } |
| .action-btn button { |
| width: 100%; |
| min-height: 56px; |
| padding: 0 16px; |
| white-space: normal; |
| line-height: 1.15; |
| display: flex; |
| align-items: center; |
| justify-content: center; |
| text-align: center; |
| } |
| .action-primary button { |
| min-height: 58px; |
| font-weight: 800; |
| } |
| .action-secondary button { |
| min-height: 58px; |
| } |
| .action-clear button { |
| min-height: 48px; |
| opacity: 0.9; |
| } |
| .action-strip { |
| gap: 12px; |
| } |
| .action-strip > .gr-column { |
| min-width: 0 !important; |
| } |
| .action-stack { |
| margin-top: 10px; |
| gap: 10px; |
| } |
| """ |
|
|
|
|
| APP_JS = """ |
| (() => { |
| const setText = (root, selector, value) => { |
| const node = root.querySelector(selector); |
| if (node) node.textContent = value; |
| }; |
| |
| const selectedChip = (card, langIndex) => { |
| const chips = Array.from(card.querySelectorAll('[data-lang-toggle]')); |
| return chips.find((chip) => Number(chip.dataset.langToggle) === Number(langIndex)) || chips[0] || null; |
| }; |
| |
| const syncTokenPanel = (panel, chip) => { |
| const stats = JSON.parse(panel.dataset.stats || '{}'); |
| const lang = chip ? chip.dataset.langCode : panel.dataset.selectedLang; |
| const stat = stats[lang] || {}; |
| const score = Number(chip ? chip.dataset.tokenScore : panel.dataset.derivedScore || 0); |
| panel.dataset.selectedLang = lang || ''; |
| setText(panel, '[data-summary-kicker]', 'Prediction'); |
| setText(panel, '[data-summary-main]', (lang || 'n/a').toUpperCase()); |
| setText(panel, '[data-summary-note]', `Token classifier · dominant view · derived score ${(score * 100).toFixed(1)}%`); |
| setText(panel, '[data-metric-iso3]', stat.iso3 || panel.dataset.iso3 || 'n/a'); |
| setText(panel, '[data-metric-derived-score]', `${(score * 100).toFixed(1)}%`); |
| setText(panel, '[data-metric-coverage]', `${(Number(stat.coverage_pct || 0) * 100).toFixed(1)}%`); |
| setText(panel, '[data-metric-avg-confidence]', Number(stat.avg_confidence || 0).toFixed(3)); |
| setText(panel, '[data-metric-tagged-chars]', `${Number(stat.char_coverage || 0).toFixed(0)} / ${Number(panel.dataset.totalChars || 0).toFixed(0)}`); |
| setText(panel, '[data-metric-overall-confidence]', Number(panel.dataset.overallConfidence || 0).toFixed(3)); |
| }; |
| |
| const syncMultiPanel = (panel, chip) => { |
| const stats = JSON.parse(panel.dataset.stats || '{}'); |
| const lang = chip ? chip.dataset.langCode : panel.dataset.selectedLang; |
| const stat = stats[lang] || {}; |
| const score = Number(chip ? chip.dataset.multiScore : panel.dataset.topScore || 0); |
| panel.dataset.selectedLang = lang || ''; |
| setText(panel, '[data-summary-kicker]', 'Multi-label'); |
| setText(panel, '[data-summary-main]', (lang || 'n/a').toUpperCase()); |
| setText(panel, '[data-summary-note]', `Multi-label · top prediction ${(lang || 'n/a').toUpperCase()} ${(score * 100).toFixed(1)}%`); |
| setText(panel, '[data-metric-top-score]', `${(score * 100).toFixed(1)}%`); |
| setText(panel, '[data-metric-seq-score]', `${(Number(panel.dataset.seqScore || score) * 100).toFixed(1)}%`); |
| setText(panel, '[data-metric-iso3]', stat.iso3 || panel.dataset.iso3 || 'n/a'); |
| setText(panel, '[data-metric-coverage]', `${(Number(stat.coverage_pct || 0) * 100).toFixed(1)}%`); |
| }; |
| |
| const syncFastTextPanel = (panel, chip) => { |
| const stats = JSON.parse(panel.dataset.stats || '{}'); |
| const lang = chip ? chip.dataset.langCode : panel.dataset.selectedLang; |
| const stat = stats[lang] || {}; |
| const score = Number(chip ? chip.dataset.referenceScore : panel.dataset.fasttextScore || 0); |
| const modeLabel = panel.dataset.mode === 'sentences' ? 'sentence' : 'full text'; |
| panel.dataset.selectedLang = lang || ''; |
| setText(panel, '[data-summary-kicker]', 'fastText'); |
| setText(panel, '[data-summary-main]', (lang || 'n/a').toUpperCase()); |
| setText(panel, '[data-summary-note]', `Mode: ${modeLabel} · top prediction ${(lang || 'n/a').toUpperCase()} ${(score * 100).toFixed(1)}%`); |
| setText(panel, '[data-metric-fasttext-score]', `${(score * 100).toFixed(1)}%`); |
| setText(panel, '[data-metric-top-score]', `${(score * 100).toFixed(1)}%`); |
| }; |
| |
| const syncPredictionCard = (card, modelKey, langIndex) => { |
| const panels = card.querySelectorAll('[data-model-panel]'); |
| const tabs = card.querySelectorAll('[data-model-toggle]'); |
| tabs.forEach((tab) => { |
| const active = tab.dataset.modelToggle === modelKey; |
| tab.classList.toggle('is-active', active); |
| tab.setAttribute('aria-pressed', active ? 'true' : 'false'); |
| }); |
| panels.forEach((panel) => { |
| panel.classList.toggle('is-active', panel.dataset.modelPanel === modelKey); |
| }); |
| const activePanel = card.querySelector(`[data-model-panel="${modelKey}"]`); |
| const chip = selectedChip(card, langIndex); |
| const chips = card.querySelectorAll('[data-lang-toggle]'); |
| chips.forEach((chipNode) => { |
| const active = Number(chipNode.dataset.langToggle) === Number(langIndex); |
| chipNode.classList.toggle('is-active', active); |
| const scoreKey = modelKey === 'reference' ? 'referenceScore' : modelKey === 'multi' ? 'multiScore' : 'tokenScore'; |
| const score = chipNode.dataset[scoreKey] || chipNode.dataset.tokenScore || '0'; |
| const scoreNode = chipNode.querySelector('.prediction-chip-score'); |
| if (scoreNode) { |
| scoreNode.textContent = `${(Number(score) * 100).toFixed(1)}%`; |
| } |
| }); |
| if (activePanel) { |
| if (modelKey === 'token') { |
| syncTokenPanel(activePanel, chip); |
| } else if (modelKey === 'multi') { |
| syncMultiPanel(activePanel, chip); |
| } else { |
| syncFastTextPanel(activePanel, chip); |
| } |
| } |
| }; |
| |
| document.addEventListener('click', (event) => { |
| const modelTab = event.target.closest('[data-model-toggle]'); |
| if (modelTab) { |
| const card = modelTab.closest('[data-prediction-card]'); |
| if (!card) return; |
| const activeChip = card.querySelector('[data-lang-toggle].is-active') || card.querySelector('[data-lang-toggle]'); |
| syncPredictionCard(card, modelTab.dataset.modelToggle, activeChip ? activeChip.dataset.langToggle : 0); |
| return; |
| } |
| const langChip = event.target.closest('[data-lang-toggle]'); |
| if (langChip) { |
| const card = langChip.closest('[data-prediction-card]'); |
| if (!card) return; |
| const activeTab = card.querySelector('[data-model-toggle].is-active') || card.querySelector('[data-model-toggle]'); |
| syncPredictionCard(card, activeTab ? activeTab.dataset.modelToggle : 'token', langChip.dataset.langToggle); |
| } |
| }); |
| |
| const syncCard = (card, key) => { |
| const buttons = card.querySelectorAll('[data-validation-toggle]'); |
| const panels = card.querySelectorAll('[data-validation-panel]'); |
| card.classList.remove('validation-model-token', 'validation-model-multi', 'validation-model-reference'); |
| card.classList.add(key === 'multi' ? 'validation-model-multi' : key === 'reference' ? 'validation-model-reference' : 'validation-model-token'); |
| card.dataset.activeModel = key; |
| buttons.forEach((button) => { |
| const active = button.dataset.validationToggle === key; |
| button.classList.toggle('is-active', active); |
| button.setAttribute('aria-pressed', active ? 'true' : 'false'); |
| }); |
| panels.forEach((panel) => { |
| panel.classList.toggle('is-active', panel.dataset.validationPanel === key); |
| }); |
| }; |
| |
| const activateCard = (card) => { |
| const defaultButton = card.querySelector('[data-validation-toggle].is-active') || card.querySelector('[data-validation-toggle]'); |
| if (defaultButton) { |
| syncCard(card, defaultButton.dataset.validationToggle); |
| } |
| }; |
| |
| document.addEventListener('click', (event) => { |
| const button = event.target.closest('[data-validation-toggle]'); |
| if (!button) return; |
| const card = button.closest('[data-validation-card]'); |
| if (!card) return; |
| syncCard(card, button.dataset.validationToggle); |
| }); |
| |
| const observer = new MutationObserver(() => { |
| document.querySelectorAll('[data-validation-card]').forEach(activateCard); |
| }); |
| |
| const boot = () => { |
| document.querySelectorAll('[data-prediction-card]').forEach((card) => { |
| const activeTab = card.querySelector('[data-model-toggle].is-active') || card.querySelector('[data-model-toggle]'); |
| const activeChip = card.querySelector('[data-lang-toggle].is-active') || card.querySelector('[data-lang-toggle]'); |
| if (activeTab) { |
| syncPredictionCard(card, activeTab.dataset.modelToggle, activeChip ? activeChip.dataset.langToggle : 0); |
| } |
| }); |
| document.querySelectorAll('[data-validation-card]').forEach(activateCard); |
| observer.observe(document.body, { childList: true, subtree: true }); |
| }; |
| |
| if (document.readyState === 'loading') { |
| document.addEventListener('DOMContentLoaded', boot, { once: true }); |
| } else { |
| boot(); |
| } |
| })(); |
| """ |
|
|
| APP_THEME = gr.themes.Soft().set( |
| body_background_fill="#06111f", |
| body_background_fill_dark="#06111f", |
| body_text_color="#f4f7fb", |
| body_text_color_dark="#f4f7fb", |
| body_text_color_subdued="#b7c3d6", |
| body_text_color_subdued_dark="#b7c3d6", |
| background_fill_primary="#0b1220", |
| background_fill_primary_dark="#0b1220", |
| background_fill_secondary="#10192b", |
| background_fill_secondary_dark="#10192b", |
| panel_background_fill="#0b1220", |
| panel_background_fill_dark="#0b1220", |
| panel_border_color="rgba(255, 255, 255, 0.12)", |
| panel_border_color_dark="rgba(255, 255, 255, 0.12)", |
| block_background_fill="#0b1220", |
| block_background_fill_dark="#0b1220", |
| block_border_color="rgba(255, 255, 255, 0.12)", |
| block_border_color_dark="rgba(255, 255, 255, 0.12)", |
| input_background_fill="#111827", |
| input_background_fill_dark="#111827", |
| input_border_color="rgba(255, 255, 255, 0.12)", |
| input_border_color_dark="rgba(255, 255, 255, 0.12)", |
| input_placeholder_color="#94a3b8", |
| input_placeholder_color_dark="#94a3b8", |
| button_secondary_background_fill="#202938", |
| button_secondary_background_fill_dark="#202938", |
| button_secondary_background_fill_hover="#283244", |
| button_secondary_background_fill_hover_dark="#283244", |
| button_secondary_border_color="rgba(255, 255, 255, 0.12)", |
| button_secondary_border_color_dark="rgba(255, 255, 255, 0.12)", |
| button_secondary_text_color="#f4f7fb", |
| button_secondary_text_color_dark="#f4f7fb", |
| button_primary_background_fill="#2563eb", |
| button_primary_background_fill_dark="#2563eb", |
| button_primary_text_color="#ffffff", |
| button_primary_text_color_dark="#ffffff", |
| checkbox_label_text_color="#f4f7fb", |
| checkbox_label_text_color_dark="#f4f7fb", |
| radio_circle="#7dd3fc", |
| checkbox_check="#7dd3fc", |
| ) |
|
|
|
|
| with gr.Blocks( |
| title="Polyglot Tagger Studio", |
| theme=APP_THEME, |
| css=CSS, |
| ) as demo: |
| gr.HTML( |
| """ |
| <div class="wrap hero"> |
| <div class="eyebrow">Multilingual Language ID</div> |
| <h1 class="title">Polyglot Tagger Studio</h1> |
| <div class="subtitle"> |
| A Gradio demo for comparing the token-classification model, the multi-label model, and a fastText reference baseline. |
| Paste a sentence or paragraph, and the app will surface the dominant language signal, token-level spans when available, and raw predictions. |
| Note that this is experimental and does not replace a text classifier: be prepared for unexpected results. |
| </div> |
| </div> |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=5): |
| input_text = gr.Textbox( |
| label="Text", |
| lines=12, |
| placeholder="Paste a sentence or a short paragraph here...", |
| value="", |
| ) |
| fasttext_mode = gr.Radio( |
| choices=["Full text", "Sentence by sentence"], |
| value="Full text", |
| label="fastText mode", |
| info="Choose whether fastText sees the whole input at once or one sentence at a time.", |
| ) |
| adverse_mix = gr.Checkbox( |
| label="Adverse mix", |
| value=False, |
| info="When enabled, Random mix prefers same-script-family languages like Nordic Latin or CJK pairs.", |
| ) |
| gr.Markdown("Sentence-by-sentence mode splits on double newlines first, then sentence punctuation inside each block.") |
| validation_strip = gr.HTML() |
| gr.Markdown( |
| "Use the buttons for fresh examples, or paste your own text." |
| ) |
| with gr.Row(elem_classes=["action-strip"]): |
| with gr.Column(scale=1, min_width=0): |
| analyze_btn = gr.Button("Analyze", variant="primary", elem_classes=["action-btn", "action-primary"]) |
| with gr.Column(scale=1, min_width=0): |
| clear_btn = gr.Button("Clear", elem_classes=["action-btn", "action-clear"]) |
| with gr.Row(elem_classes=["action-strip", "action-stack"]): |
| with gr.Column(scale=1, min_width=0): |
| random_btn = gr.Button("Random sentence", elem_classes=["action-btn", "action-secondary"]) |
| with gr.Column(scale=1, min_width=0): |
| random_mix_btn = gr.Button("Random mix", elem_classes=["action-btn", "action-secondary"]) |
| with gr.Column(scale=7): |
| summary = gr.HTML() |
| prediction_state = gr.State({}) |
| with gr.Row(elem_classes=["chip-strip"]): |
| chip_0 = gr.Button("", visible=True, elem_classes=["chip-btn"]) |
| chip_1 = gr.Button("", visible=True, elem_classes=["chip-btn"]) |
| chip_2 = gr.Button("", visible=True, elem_classes=["chip-btn"]) |
| chip_3 = gr.Button("", visible=True, elem_classes=["chip-btn"]) |
| chip_4 = gr.Button("", visible=True, elem_classes=["chip-btn"]) |
| chip_5 = gr.Button("", visible=True, elem_classes=["chip-btn"]) |
|
|
| with gr.Row(elem_classes=["results-shell"]): |
| with gr.Column(scale=7, min_width=0, elem_classes=["results-panel"]): |
| spans = gr.Dataframe( |
| headers=["token", "language", "score", "start", "end"], |
| datatype=["str", "str", "number", "number", "number"], |
| label="Token-level spans", |
| interactive=False, |
| wrap=True, |
| ) |
| with gr.Column(scale=5, min_width=0, elem_classes=["results-panel"]): |
| raw = gr.JSON(label="Raw output") |
|
|
| analyze_btn.click( |
| fn=predict, |
| inputs=[input_text, fasttext_mode], |
| outputs=[summary, spans, raw, prediction_state, validation_strip, chip_0, chip_1, chip_2, chip_3, chip_4, chip_5], |
| api_name="analyze", |
| ) |
| random_btn.click( |
| fn=load_random_fleurs_example, |
| inputs=[fasttext_mode], |
| outputs=[input_text, summary, spans, raw, prediction_state, validation_strip, chip_0, chip_1, chip_2, chip_3, chip_4, chip_5], |
| api_name="random_fleurs_sentence", |
| ) |
| random_mix_btn.click( |
| fn=load_random_mix_example, |
| inputs=[fasttext_mode, adverse_mix], |
| outputs=[input_text, summary, spans, raw, prediction_state, validation_strip, chip_0, chip_1, chip_2, chip_3, chip_4, chip_5], |
| api_name="random_mix", |
| ) |
| input_text.submit( |
| fn=predict, |
| inputs=[input_text, fasttext_mode], |
| outputs=[summary, spans, raw, prediction_state, validation_strip, chip_0, chip_1, chip_2, chip_3, chip_4, chip_5], |
| api_name="analyze_text", |
| ) |
| clear_btn.click( |
| fn=lambda: ( |
| "", |
| pd.DataFrame(columns=["token", "language", "score", "start", "end"]), |
| {}, |
| {}, |
| "", |
| *[gr.update(value="", visible=False) for _ in range(6)], |
| ), |
| inputs=None, |
| outputs=[summary, spans, raw, prediction_state, validation_strip, chip_0, chip_1, chip_2, chip_3, chip_4, chip_5], |
| api_name="clear", |
| ) |
|
|
| chip_0.click(fn=lambda state: select_language_from_chip(0, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_0") |
| chip_1.click(fn=lambda state: select_language_from_chip(1, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_1") |
| chip_2.click(fn=lambda state: select_language_from_chip(2, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_2") |
| chip_3.click(fn=lambda state: select_language_from_chip(3, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_3") |
| chip_4.click(fn=lambda state: select_language_from_chip(4, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_4") |
| chip_5.click(fn=lambda state: select_language_from_chip(5, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_5") |
|
|
| gr.HTML(render_language_reference_html()) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue() |
| demo.launch(js=APP_JS, share=os.getenv("GRADIO_SHARE", "1") != "0") |
|
|