Spaces:
Running
Running
bug-fix-label-mapping-align-with-correct-idx (#80)
Browse files- fix the label mapping order; fix out of scope error (bc6f52dabdda8f4e5aa9c9d980ffd3d2c8a55c49)
Co-authored-by: zcy <[email protected]>
- app.py +1 -3
- app_leaderboard.py +6 -1
- text_classification_ui_helpers.py +17 -9
app.py
CHANGED
|
@@ -12,12 +12,10 @@ try:
|
|
| 12 |
with gr.Tab("Text Classification"):
|
| 13 |
get_demo_text_classification()
|
| 14 |
with gr.Tab("Leaderboard") as leaderboard_tab:
|
| 15 |
-
get_demo_leaderboard()
|
| 16 |
with gr.Tab("Logs(Debug)"):
|
| 17 |
get_demo_debug()
|
| 18 |
|
| 19 |
-
leaderboard_tab.select(fn=get_demo_leaderboard)
|
| 20 |
-
|
| 21 |
start_process_run_job()
|
| 22 |
|
| 23 |
demo.queue(max_size=1000)
|
|
|
|
| 12 |
with gr.Tab("Text Classification"):
|
| 13 |
get_demo_text_classification()
|
| 14 |
with gr.Tab("Leaderboard") as leaderboard_tab:
|
| 15 |
+
get_demo_leaderboard(leaderboard_tab)
|
| 16 |
with gr.Tab("Logs(Debug)"):
|
| 17 |
get_demo_debug()
|
| 18 |
|
|
|
|
|
|
|
| 19 |
start_process_run_job()
|
| 20 |
|
| 21 |
demo.queue(max_size=1000)
|
app_leaderboard.py
CHANGED
|
@@ -73,8 +73,11 @@ def get_display_df(df):
|
|
| 73 |
)
|
| 74 |
return display_df
|
| 75 |
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
def get_demo():
|
| 78 |
logger.info("Loading leaderboard records")
|
| 79 |
leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
|
| 80 |
records = leaderboard.records
|
|
@@ -116,6 +119,8 @@ def get_demo():
|
|
| 116 |
with gr.Row():
|
| 117 |
leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
|
| 118 |
|
|
|
|
|
|
|
| 119 |
@gr.on(
|
| 120 |
triggers=[
|
| 121 |
model_select.change,
|
|
|
|
| 73 |
)
|
| 74 |
return display_df
|
| 75 |
|
| 76 |
+
def update_leaderboard_records():
|
| 77 |
+
logger.info("Updating leaderboard records")
|
| 78 |
+
leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
|
| 79 |
|
| 80 |
+
def get_demo(leaderboard_tab):
|
| 81 |
logger.info("Loading leaderboard records")
|
| 82 |
leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
|
| 83 |
records = leaderboard.records
|
|
|
|
| 119 |
with gr.Row():
|
| 120 |
leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
|
| 121 |
|
| 122 |
+
leaderboard_tab.select(fn=update_leaderboard_records)
|
| 123 |
+
|
| 124 |
@gr.on(
|
| 125 |
triggers=[
|
| 126 |
model_select.change,
|
text_classification_ui_helpers.py
CHANGED
|
@@ -30,7 +30,6 @@ MAX_FEATURES = 20
|
|
| 30 |
ds_dict = None
|
| 31 |
ds_config = None
|
| 32 |
|
| 33 |
-
|
| 34 |
def get_related_datasets_from_leaderboard(model_id):
|
| 35 |
records = leaderboard.records
|
| 36 |
model_records = records[records["model_id"] == model_id]
|
|
@@ -100,7 +99,7 @@ def export_mappings(all_mappings, key, subkeys, values):
|
|
| 100 |
if subkeys is None:
|
| 101 |
subkeys = list(all_mappings[key].keys())
|
| 102 |
|
| 103 |
-
if not subkeys:
|
| 104 |
logging.debug(f"subkeys is empty for {key}")
|
| 105 |
return all_mappings
|
| 106 |
|
|
@@ -121,6 +120,8 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels,
|
|
| 121 |
ds_labels = ds_labels[:MAX_LABELS]
|
| 122 |
gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
|
| 123 |
|
|
|
|
|
|
|
| 124 |
ds_labels.sort()
|
| 125 |
model_labels.sort()
|
| 126 |
|
|
@@ -293,17 +294,20 @@ def check_column_mapping_keys_validity(all_mappings):
|
|
| 293 |
return (gr.update(interactive=True), gr.update(visible=False))
|
| 294 |
|
| 295 |
|
| 296 |
-
def construct_label_and_feature_mapping(all_mappings):
|
| 297 |
label_mapping = {}
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
label_mapping.update({str(i): all_mappings["labels"][label]})
|
| 303 |
|
| 304 |
if "features" not in all_mappings.keys():
|
| 305 |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
| 306 |
-
return (gr.update(interactive=True), gr.update(visible=False))
|
| 307 |
feature_mapping = all_mappings["features"]
|
| 308 |
return label_mapping, feature_mapping
|
| 309 |
|
|
@@ -311,7 +315,11 @@ def construct_label_and_feature_mapping(all_mappings):
|
|
| 311 |
def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
|
| 312 |
all_mappings = read_column_mapping(uid)
|
| 313 |
check_column_mapping_keys_validity(all_mappings)
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
| 317 |
save_job_to_pipe(
|
|
|
|
| 30 |
ds_dict = None
|
| 31 |
ds_config = None
|
| 32 |
|
|
|
|
| 33 |
def get_related_datasets_from_leaderboard(model_id):
|
| 34 |
records = leaderboard.records
|
| 35 |
model_records = records[records["model_id"] == model_id]
|
|
|
|
| 99 |
if subkeys is None:
|
| 100 |
subkeys = list(all_mappings[key].keys())
|
| 101 |
|
| 102 |
+
if not subkeys:
|
| 103 |
logging.debug(f"subkeys is empty for {key}")
|
| 104 |
return all_mappings
|
| 105 |
|
|
|
|
| 120 |
ds_labels = ds_labels[:MAX_LABELS]
|
| 121 |
gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
|
| 122 |
|
| 123 |
+
# sort labels to make sure the order is consistent
|
| 124 |
+
# prediction gives the order based on probability
|
| 125 |
ds_labels.sort()
|
| 126 |
model_labels.sort()
|
| 127 |
|
|
|
|
| 294 |
return (gr.update(interactive=True), gr.update(visible=False))
|
| 295 |
|
| 296 |
|
| 297 |
+
def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
|
| 298 |
label_mapping = {}
|
| 299 |
+
if len(all_mappings["labels"].keys()) != len(ds_labels):
|
| 300 |
+
gr.Warning("Label mapping corrupted: " + CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
| 301 |
+
|
| 302 |
+
if len(all_mappings["features"].keys()) != len(ds_features):
|
| 303 |
+
gr.Warning("Feature mapping corrupted: " + CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
| 304 |
+
|
| 305 |
+
for i, label in zip(range(len(ds_labels)), ds_labels):
|
| 306 |
+
# align the saved labels with dataset labels order
|
| 307 |
label_mapping.update({str(i): all_mappings["labels"][label]})
|
| 308 |
|
| 309 |
if "features" not in all_mappings.keys():
|
| 310 |
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
|
|
|
|
| 311 |
feature_mapping = all_mappings["features"]
|
| 312 |
return label_mapping, feature_mapping
|
| 313 |
|
|
|
|
| 315 |
def try_submit(m_id, d_id, config, split, inference, inference_token, uid):
|
| 316 |
all_mappings = read_column_mapping(uid)
|
| 317 |
check_column_mapping_keys_validity(all_mappings)
|
| 318 |
+
|
| 319 |
+
# get ds labels and features again for alignment
|
| 320 |
+
ds = datasets.load_dataset(d_id, config)[split]
|
| 321 |
+
ds_labels, ds_features = get_labels_and_features_from_dataset(ds)
|
| 322 |
+
label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features)
|
| 323 |
|
| 324 |
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
|
| 325 |
save_job_to_pipe(
|