from pathlib import Path import ast import json import traceback import pandas as pd import numpy as np import gradio as gr from datasets import Features, Value, load_dataset from datetime import datetime import os from about import ( PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo, COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS, METRIC_GROUP_COLORS, COLUMN_TO_GROUP, TRAINING_DATASETS ) RESULT_FEATURES = Features({ "run_name": Value("string"), "timestamp": Value("string"), "n_structures": Value("int64"), "overall_valid_count": Value("int64"), "charge_neutral_count": Value("int64"), "distance_valid_count": Value("int64"), "plausibility_valid_count": Value("int64"), "unique_count": Value("int64"), "novel_count": Value("int64"), "mean_formation_energy": Value("float64"), "formation_energy_std": Value("float64"), "stability_mean_above_hull": Value("float64"), "stability_std_e_above_hull": Value("float64"), "stability_mean_ensemble_std": Value("float64"), "mean_relaxation_RMSD": Value("float64"), "relaxation_RMSE_std": Value("float64"), "stable_count": Value("int64"), "unique_in_stable_count": Value("int64"), "sun_count": Value("int64"), "metastable_count": Value("int64"), "unique_in_metastable_count": Value("int64"), "msun_count": Value("int64"), "JSDistance": Value("float64"), "MMD": Value("float64"), "FrechetDistance": Value("float64"), "element_diversity": Value("float64"), "space_group_diversity": Value("float64"), "site_diversity": Value("float64"), "physical_size_diversity": Value("float64"), "hhi_production_mean": Value("float64"), "hhi_reserve_mean": Value("float64"), "hhi_combined_mean": Value("float64"), "model_name": Value("string"), "relaxed": Value("bool"), "training_set": Value("string"), "paper_link": Value("string"), "notes": Value("string"), }) def get_leaderboard(): ds = load_dataset( results_repo, data_files="*.csv", split="train", download_mode="force_redownload", features=RESULT_FEATURES, ) full_df = pd.DataFrame(ds) if len(full_df) == 0: return pd.DataFrame(columns=list(RESULT_FEATURES.keys())) if "msun_count" in full_df.columns and "sun_count" in full_df.columns: full_df["msun_plus_sun"] = full_df["msun_count"] + full_df["sun_count"] if "msun_plus_sun" in full_df.columns: full_df = full_df.sort_values(by="msun_plus_sun", ascending=False) return full_df def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True): """Format the dataframe with proper column names and optional percentages.""" if len(df) == 0: return df selected_cols = ["model_name"] if compact_view: from about import COMPACT_VIEW_COLUMNS selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns] else: if "training_set" in df.columns: selected_cols.append("training_set") if "n_structures" in df.columns: selected_cols.append("n_structures") if not selected_groups: selected_groups = list(METRIC_GROUPS.keys()) for group in selected_groups: if group in METRIC_GROUPS: for col in METRIC_GROUPS[group]: if col in df.columns and col not in selected_cols: selected_cols.append(col) display_df = df[selected_cols].copy() if "model_name" in display_df.columns: model_links = { "CrystaLLM-pi": "https://huggingface.co/c-bone/CrystaLLM-pi_base", "OMatG": "https://huggingface.co/OMatG/MP-20-DNG/tree/main/EncDec-ODE-Gamma", } def add_model_symbols(row): name = row["model_name"] symbols = [] if "paper_link" in df.columns: paper_val = row.get("paper_link", None) if paper_val and isinstance(paper_val, str) and paper_val.strip(): symbols.append(f'📄') if "relaxed" in df.columns and row.get("relaxed", False): symbols.append("⚡") if name in ["Alexandria", "OQMD"]: symbols.append("★") elif name == "AFLOW": symbols.append("◆") elif name in ["CrystaLLM-pi", "OMatG", "Zatom-1-WD"]: symbols.append("✅") symbol_str = f" {' '.join(symbols)}" if symbols else "" if name in model_links: return f'{name}{symbol_str}' return f"{name}{symbol_str}" display_df["model_name"] = df.apply(add_model_symbols, axis=1) if "training_set" in display_df.columns: def format_training_set(val): if val is None or (isinstance(val, float) and np.isnan(val)): return "" val = str(val).strip() if val in ("[]", "", "nan", "None"): return "" val = val.strip("[]") val = val.replace("'", "").replace('"', "") return val display_df["training_set"] = display_df["training_set"].apply(format_training_set) if show_percentage and "n_structures" in df.columns: n_structures = df["n_structures"] for col in COUNT_BASED_METRICS: if col in display_df.columns: display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + "%" for col in display_df.columns: if display_df[col].dtype in ["float64", "float32"]: display_df[col] = display_df[col].round(4) baseline_indices = set() if "notes" in df.columns: is_baseline = df["notes"].fillna("").str.contains("baseline", case=False, na=False) non_baseline_df = display_df[~is_baseline] baseline_df = display_df[is_baseline] display_df = pd.concat([non_baseline_df, baseline_df]).reset_index(drop=True) baseline_indices = set(range(len(non_baseline_df), len(display_df))) display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES) return apply_color_styling(display_df, selected_cols, baseline_indices) def apply_color_styling(display_df, original_cols, baseline_indices=None): """Apply background colors to dataframe based on metric groups using pandas Styler.""" if baseline_indices is None: baseline_indices = set() def style_by_group(x): styles = pd.DataFrame("", index=x.index, columns=x.columns) for i, display_col in enumerate(x.columns): if i < len(original_cols): original_col = original_cols[i] if original_col in COLUMN_TO_GROUP: group = COLUMN_TO_GROUP[original_col] color = METRIC_GROUP_COLORS.get(group, "") if color: styles[display_col] = f"background-color: {color}" if baseline_indices: first_baseline_idx = min(baseline_indices) for col in x.columns: current = styles.at[first_baseline_idx, col] separator_style = "border-top: 3px solid #555" styles.at[first_baseline_idx, col] = ( f"{current}; {separator_style}" if current else separator_style ) return styles return display_df.style.apply(style_by_group, axis=None) def parse_training_set(val): """Parse a training_set value stored as a string like "['MP-20']" into a list.""" try: return ast.literal_eval(str(val)) except (ValueError, SyntaxError): return [] def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction, training_set_filter): """Update the leaderboard based on user selections.""" df_to_format = cached_df.copy() ALWAYS_SHOW_MODELS = {"AFLOW", "Alexandria", "OQMD"} if training_set_filter and training_set_filter != "All" and "training_set" in df_to_format.columns: mask = ( df_to_format["training_set"].apply(lambda x: training_set_filter in parse_training_set(x)) | df_to_format["model_name"].isin(ALWAYS_SHOW_MODELS) ) df_to_format = df_to_format[mask] if sort_by and sort_by != "None": display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()} raw_column_name = display_to_raw.get(sort_by, sort_by) if raw_column_name in df_to_format.columns: ascending = sort_direction == "Ascending" df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending) return format_dataframe(df_to_format, show_percentage, selected_groups, compact_view) def show_output_box(message): return gr.update(value=message, visible=True) def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_settings, training_datasets, training_dataset_other, paper_link, hf_model_link, email, profile: gr.OAuthProfile | None): """Submit structures to the leaderboard.""" from huggingface_hub import upload_file if not model_name or not model_name.strip(): return "Error: Please provide a model name.", None if not problem_type: return "Error: Please select a problem type.", None if not cif_files: return "Error: Please upload a file.", None if not profile: return "Error: Please log in to submit.", None if not email or not email.strip(): return "Error: Please provide an email address.", None try: username = profile.username timestamp = datetime.now().isoformat() submission_data = { "username": username, "model_name": model_name.strip(), "problem_type": problem_type, "relaxed": relaxed, "relaxation_settings": relaxation_settings.strip() if relaxed and relaxation_settings else None, "training_datasets": training_datasets or [], "training_dataset_other": training_dataset_other.strip() if training_dataset_other else None, "paper_link": paper_link.strip() if paper_link else None, "hf_model_link": hf_model_link.strip() if hf_model_link else None, "email": email.strip(), "timestamp": timestamp, "file_name": Path(cif_files).name, } submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}" file_path = Path(cif_files) uploaded_file_path = f"submissions/{submission_id}/{file_path.name}" upload_file( path_or_fileobj=str(file_path), path_in_repo=uploaded_file_path, repo_id=submissions_repo, token=TOKEN, repo_type="dataset", ) metadata_path = f"submissions/{submission_id}/metadata.json" import tempfile with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(submission_data, f, indent=2) temp_metadata_path = f.name upload_file( path_or_fileobj=temp_metadata_path, path_in_repo=metadata_path, repo_id=submissions_repo, token=TOKEN, repo_type="dataset", ) os.unlink(temp_metadata_path) return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id except Exception as e: return f"Error during submission: {str(e)}", None def generate_metric_legend_html(): """Generate HTML table with color-coded metric group legend.""" metric_details = { "Validity ↑": ("Valid, Charge Neutral, Distance Valid, Plausibility Valid", "↑ Higher is better"), "Uniqueness & Novelty ↑": ("Unique, Novel", "↑ Higher is better"), "Energy Metrics ↓": ("E Above Hull, Formation Energy, Relaxation RMSD (with std)", "↓ Lower is better"), "Stability ↑": ("Stable, Unique in Stable, SUN", "↑ Higher is better"), "Metastability ↑": ("Metastable, Unique in Metastable, MSUN", "↑ Higher is better"), "Distribution ↓": ("JS Distance, MMD, FID", "↓ Lower is better"), "Diversity ↑": ("Element, Space Group, Atomic Site, Crystal Size", "↑ Higher is better"), "HHI ↓": ("HHI Production, HHI Reserve", "↓ Lower is better"), } html = '
| Color | ' html += 'Group | ' html += 'Metrics | ' html += 'Direction | ' html += "
|---|---|---|---|
| ' html += f' | {group_name} | ' html += f'{metrics} | ' html += f'{direction} | ' html += "