Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from wordings import CONFIRM_MAPPING_DETAILS_FAIL_RAW | |
| import json | |
| import os | |
| import logging | |
| import threading | |
| from io_utils import read_column_mapping, write_column_mapping, save_job_to_pipe, write_log_to_user_file | |
| import datasets | |
| import collections | |
| from text_classification import get_labels_and_features_from_dataset, check_model, get_example_prediction | |
| from transformers.pipelines import TextClassificationPipeline | |
| MAX_LABELS = 20 | |
| MAX_FEATURES = 20 | |
| HF_REPO_ID = 'HF_REPO_ID' | |
| HF_SPACE_ID = 'SPACE_ID' | |
| HF_WRITE_TOKEN = 'HF_WRITE_TOKEN' | |
| CONFIG_PATH = "./config.yaml" | |
| def check_dataset_and_get_config(dataset_id): | |
| try: | |
| write_column_mapping(None) | |
| configs = datasets.get_dataset_config_names(dataset_id) | |
| return gr.Dropdown(configs, value=configs[0], visible=True) | |
| except Exception: | |
| # Dataset may not exist | |
| pass | |
| def check_dataset_and_get_split(dataset_id, dataset_config): | |
| try: | |
| splits = list(datasets.load_dataset(dataset_id, dataset_config).keys()) | |
| return gr.Dropdown(splits, value=splits[0], visible=True) | |
| except Exception: | |
| # Dataset may not exist | |
| # gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}") | |
| pass | |
| def write_column_mapping_to_config(dataset_id, dataset_config, dataset_split, *labels): | |
| ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split) | |
| if labels is None: | |
| return | |
| labels = [*labels] | |
| all_mappings = read_column_mapping(CONFIG_PATH) | |
| if all_mappings is None: | |
| all_mappings = dict() | |
| if "labels" not in all_mappings.keys(): | |
| all_mappings["labels"] = dict() | |
| for i, label in enumerate(labels[:MAX_LABELS]): | |
| if label: | |
| all_mappings["labels"][label] = ds_labels[i] | |
| if "features" not in all_mappings.keys(): | |
| all_mappings["features"] = dict() | |
| for i, feat in enumerate(labels[MAX_LABELS:(MAX_LABELS + MAX_FEATURES)]): | |
| if feat: | |
| all_mappings["features"][feat] = ds_features[i] | |
| write_column_mapping(all_mappings) | |
| def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label): | |
| model_labels = list(model_id2label.values()) | |
| len_model_labels = len(model_labels) | |
| print(model_labels, model_id2label, 3%len_model_labels) | |
| lables = [gr.Dropdown(label=f"{label}", choices=model_labels, value=model_id2label[i%len_model_labels], interactive=True, visible=True) for i, label in enumerate(ds_labels[:MAX_LABELS])] | |
| lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))] | |
| # TODO: Substitute 'text' with more features for zero-shot | |
| features = [gr.Dropdown(label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True) for feature in ['text']] | |
| features += [gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))] | |
| return lables + features | |
| def check_model_and_show_prediction(model_id, dataset_id, dataset_config, dataset_split): | |
| ppl = check_model(model_id) | |
| if ppl is None or not isinstance(ppl, TextClassificationPipeline): | |
| gr.Warning("Please check your model.") | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)] | |
| ) | |
| dropdown_placement = [gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)] | |
| if ppl is None: # pipeline not found | |
| gr.Warning("Model not found") | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False, open=False), | |
| *dropdown_placement | |
| ) | |
| model_id2label = ppl.model.config.id2label | |
| ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split) | |
| # when dataset does not have labels or features | |
| if not isinstance(ds_labels, list) or not isinstance(ds_features, list): | |
| # gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False, open=False), | |
| *dropdown_placement | |
| ) | |
| column_mappings = list_labels_and_features_from_dataset( | |
| ds_labels, | |
| ds_features, | |
| model_id2label, | |
| ) | |
| # when labels or features are not aligned | |
| # show manually column mapping | |
| if collections.Counter(model_id2label.values()) != collections.Counter(ds_labels) or ds_features[0] != 'text': | |
| gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=True, open=True), | |
| *column_mappings | |
| ) | |
| prediction_input, prediction_output = get_example_prediction(ppl, dataset_id, dataset_config, dataset_split) | |
| return ( | |
| gr.update(value=prediction_input, visible=True), | |
| gr.update(value=prediction_output, visible=True), | |
| gr.update(visible=True, open=False), | |
| *column_mappings | |
| ) | |
| def try_submit(m_id, d_id, config, split, local, uid): | |
| all_mappings = read_column_mapping(CONFIG_PATH) | |
| if all_mappings is None: | |
| gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
| return (gr.update(interactive=True), gr.update(visible=False)) | |
| if "labels" not in all_mappings.keys(): | |
| gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
| return (gr.update(interactive=True), gr.update(visible=False)) | |
| label_mapping = all_mappings["labels"] | |
| if "features" not in all_mappings.keys(): | |
| gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
| return (gr.update(interactive=True), gr.update(visible=False)) | |
| feature_mapping = all_mappings["features"] | |
| # TODO: Set column mapping for some dataset such as `amazon_polarity` | |
| if local: | |
| command = [ | |
| "python", | |
| "cli.py", | |
| "--loader", "huggingface", | |
| "--model", m_id, | |
| "--dataset", d_id, | |
| "--dataset_config", config, | |
| "--dataset_split", split, | |
| "--hf_token", os.environ.get(HF_WRITE_TOKEN), | |
| "--discussion_repo", os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID), | |
| "--output_format", "markdown", | |
| "--output_portal", "huggingface", | |
| "--feature_mapping", json.dumps(feature_mapping), | |
| "--label_mapping", json.dumps(label_mapping), | |
| "--scan_config", "../config.yaml", | |
| ] | |
| eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" | |
| logging.info(f"Start local evaluation on {eval_str}") | |
| save_job_to_pipe(uid, command, threading.Lock()) | |
| write_log_to_user_file(uid, f"Start local evaluation on {eval_str}. Please wait for your job to start...\n") | |
| gr.Info(f"Start local evaluation on {eval_str}") | |
| return ( | |
| gr.update(interactive=False), | |
| gr.update(lines=5, visible=True, interactive=False)) | |
| else: | |
| gr.Info("TODO: Submit task to an endpoint") | |
| return (gr.update(interactive=True), # Submit button | |
| gr.update(visible=False)) | |