from pathlib import Path import ast import json import traceback import pandas as pd import numpy as np import gradio as gr from datasets import Features, Value, load_dataset from datetime import datetime import os from about import ( PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo, COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS, METRIC_GROUP_COLORS, COLUMN_TO_GROUP, TRAINING_DATASETS ) RESULT_FEATURES = Features({ "run_name": Value("string"), "timestamp": Value("string"), "n_structures": Value("int64"), "overall_valid_count": Value("int64"), "charge_neutral_count": Value("int64"), "distance_valid_count": Value("int64"), "plausibility_valid_count": Value("int64"), "unique_count": Value("int64"), "novel_count": Value("int64"), "mean_formation_energy": Value("float64"), "formation_energy_std": Value("float64"), "stability_mean_above_hull": Value("float64"), "stability_std_e_above_hull": Value("float64"), "stability_mean_ensemble_std": Value("float64"), "mean_relaxation_RMSD": Value("float64"), "relaxation_RMSE_std": Value("float64"), "stable_count": Value("int64"), "unique_in_stable_count": Value("int64"), "sun_count": Value("int64"), "metastable_count": Value("int64"), "unique_in_metastable_count": Value("int64"), "msun_count": Value("int64"), "JSDistance": Value("float64"), "MMD": Value("float64"), "FrechetDistance": Value("float64"), "element_diversity": Value("float64"), "space_group_diversity": Value("float64"), "site_diversity": Value("float64"), "physical_size_diversity": Value("float64"), "hhi_production_mean": Value("float64"), "hhi_reserve_mean": Value("float64"), "hhi_combined_mean": Value("float64"), "model_name": Value("string"), "relaxed": Value("bool"), "training_set": Value("string"), "paper_link": Value("string"), "notes": Value("string"), }) def get_leaderboard(): ds = load_dataset( results_repo, data_files="*.csv", split="train", download_mode="force_redownload", features=RESULT_FEATURES, ) full_df = pd.DataFrame(ds) if len(full_df) == 0: return pd.DataFrame(columns=list(RESULT_FEATURES.keys())) if "msun_count" in full_df.columns and "sun_count" in full_df.columns: full_df["msun_plus_sun"] = full_df["msun_count"] + full_df["sun_count"] if "msun_plus_sun" in full_df.columns: full_df = full_df.sort_values(by="msun_plus_sun", ascending=False) return full_df def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True): """Format the dataframe with proper column names and optional percentages.""" if len(df) == 0: return df selected_cols = ["model_name"] if compact_view: from about import COMPACT_VIEW_COLUMNS selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns] else: if "training_set" in df.columns: selected_cols.append("training_set") if "n_structures" in df.columns: selected_cols.append("n_structures") if not selected_groups: selected_groups = list(METRIC_GROUPS.keys()) for group in selected_groups: if group in METRIC_GROUPS: for col in METRIC_GROUPS[group]: if col in df.columns and col not in selected_cols: selected_cols.append(col) display_df = df[selected_cols].copy() if "model_name" in display_df.columns: model_links = { "CrystaLLM-pi": "https://huggingface.co/c-bone/CrystaLLM-pi_base", "OMatG": "https://huggingface.co/OMatG/MP-20-DNG/tree/main/EncDec-ODE-Gamma", } def add_model_symbols(row): name = row["model_name"] symbols = [] if "paper_link" in df.columns: paper_val = row.get("paper_link", None) if paper_val and isinstance(paper_val, str) and paper_val.strip(): symbols.append(f'📄') if "relaxed" in df.columns and row.get("relaxed", False): symbols.append("⚡") if name in ["Alexandria", "OQMD"]: symbols.append("★") elif name == "AFLOW": symbols.append("◆") elif name in ["CrystaLLM-pi", "OMatG", "Zatom-1-WD"]: symbols.append("✅") symbol_str = f" {' '.join(symbols)}" if symbols else "" if name in model_links: return f'{name}{symbol_str}' return f"{name}{symbol_str}" display_df["model_name"] = df.apply(add_model_symbols, axis=1) if "training_set" in display_df.columns: def format_training_set(val): if val is None or (isinstance(val, float) and np.isnan(val)): return "" val = str(val).strip() if val in ("[]", "", "nan", "None"): return "" val = val.strip("[]") val = val.replace("'", "").replace('"', "") return val display_df["training_set"] = display_df["training_set"].apply(format_training_set) if show_percentage and "n_structures" in df.columns: n_structures = df["n_structures"] for col in COUNT_BASED_METRICS: if col in display_df.columns: display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + "%" for col in display_df.columns: if display_df[col].dtype in ["float64", "float32"]: display_df[col] = display_df[col].round(4) baseline_indices = set() if "notes" in df.columns: is_baseline = df["notes"].fillna("").str.contains("baseline", case=False, na=False) non_baseline_df = display_df[~is_baseline] baseline_df = display_df[is_baseline] display_df = pd.concat([non_baseline_df, baseline_df]).reset_index(drop=True) baseline_indices = set(range(len(non_baseline_df), len(display_df))) display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES) return apply_color_styling(display_df, selected_cols, baseline_indices) def apply_color_styling(display_df, original_cols, baseline_indices=None): """Apply background colors to dataframe based on metric groups using pandas Styler.""" if baseline_indices is None: baseline_indices = set() def style_by_group(x): styles = pd.DataFrame("", index=x.index, columns=x.columns) for i, display_col in enumerate(x.columns): if i < len(original_cols): original_col = original_cols[i] if original_col in COLUMN_TO_GROUP: group = COLUMN_TO_GROUP[original_col] color = METRIC_GROUP_COLORS.get(group, "") if color: styles[display_col] = f"background-color: {color}" if baseline_indices: first_baseline_idx = min(baseline_indices) for col in x.columns: current = styles.at[first_baseline_idx, col] separator_style = "border-top: 3px solid #555" styles.at[first_baseline_idx, col] = ( f"{current}; {separator_style}" if current else separator_style ) return styles return display_df.style.apply(style_by_group, axis=None) def parse_training_set(val): """Parse a training_set value stored as a string like "['MP-20']" into a list.""" try: return ast.literal_eval(str(val)) except (ValueError, SyntaxError): return [] def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction, training_set_filter): """Update the leaderboard based on user selections.""" df_to_format = cached_df.copy() ALWAYS_SHOW_MODELS = {"AFLOW", "Alexandria", "OQMD"} if training_set_filter and training_set_filter != "All" and "training_set" in df_to_format.columns: mask = ( df_to_format["training_set"].apply(lambda x: training_set_filter in parse_training_set(x)) | df_to_format["model_name"].isin(ALWAYS_SHOW_MODELS) ) df_to_format = df_to_format[mask] if sort_by and sort_by != "None": display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()} raw_column_name = display_to_raw.get(sort_by, sort_by) if raw_column_name in df_to_format.columns: ascending = sort_direction == "Ascending" df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending) return format_dataframe(df_to_format, show_percentage, selected_groups, compact_view) def show_output_box(message): return gr.update(value=message, visible=True) def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_settings, training_datasets, training_dataset_other, paper_link, hf_model_link, email, profile: gr.OAuthProfile | None): """Submit structures to the leaderboard.""" from huggingface_hub import upload_file if not model_name or not model_name.strip(): return "Error: Please provide a model name.", None if not problem_type: return "Error: Please select a problem type.", None if not cif_files: return "Error: Please upload a file.", None if not profile: return "Error: Please log in to submit.", None if not email or not email.strip(): return "Error: Please provide an email address.", None try: username = profile.username timestamp = datetime.now().isoformat() submission_data = { "username": username, "model_name": model_name.strip(), "problem_type": problem_type, "relaxed": relaxed, "relaxation_settings": relaxation_settings.strip() if relaxed and relaxation_settings else None, "training_datasets": training_datasets or [], "training_dataset_other": training_dataset_other.strip() if training_dataset_other else None, "paper_link": paper_link.strip() if paper_link else None, "hf_model_link": hf_model_link.strip() if hf_model_link else None, "email": email.strip(), "timestamp": timestamp, "file_name": Path(cif_files).name, } submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}" file_path = Path(cif_files) uploaded_file_path = f"submissions/{submission_id}/{file_path.name}" upload_file( path_or_fileobj=str(file_path), path_in_repo=uploaded_file_path, repo_id=submissions_repo, token=TOKEN, repo_type="dataset", ) metadata_path = f"submissions/{submission_id}/metadata.json" import tempfile with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: json.dump(submission_data, f, indent=2) temp_metadata_path = f.name upload_file( path_or_fileobj=temp_metadata_path, path_in_repo=metadata_path, repo_id=submissions_repo, token=TOKEN, repo_type="dataset", ) os.unlink(temp_metadata_path) return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id except Exception as e: return f"Error during submission: {str(e)}", None def generate_metric_legend_html(): """Generate HTML table with color-coded metric group legend.""" metric_details = { "Validity ↑": ("Valid, Charge Neutral, Distance Valid, Plausibility Valid", "↑ Higher is better"), "Uniqueness & Novelty ↑": ("Unique, Novel", "↑ Higher is better"), "Energy Metrics ↓": ("E Above Hull, Formation Energy, Relaxation RMSD (with std)", "↓ Lower is better"), "Stability ↑": ("Stable, Unique in Stable, SUN", "↑ Higher is better"), "Metastability ↑": ("Metastable, Unique in Metastable, MSUN", "↑ Higher is better"), "Distribution ↓": ("JS Distance, MMD, FID", "↓ Lower is better"), "Diversity ↑": ("Element, Space Group, Atomic Site, Crystal Size", "↑ Higher is better"), "HHI ↓": ("HHI Production, HHI Reserve", "↓ Lower is better"), } html = '' html += "" html += '' html += '' html += '' html += '' html += "" for group, color in METRIC_GROUP_COLORS.items(): metrics, direction = metric_details.get(group, ("", "")) group_name = group.replace("↑", "").replace("↓", "").strip() html += "" html += f'' html += f'' html += f'' html += f'' html += "" html += "
ColorGroupMetricsDirection
{group_name}{metrics}{direction}
" return html def gradio_interface() -> gr.Blocks: with gr.Blocks() as demo: gr.Markdown(""" # 🔬 LeMat-GenBench: A Unified Benchmark for Generative Models of Crystalline Materials Generative machine learning models hold great promise for accelerating materials discovery, particularly through the inverse design of inorganic crystals, enabling an unprecedented exploration of chemical space. Yet, the lack of standardized evaluation frameworks makes it difficult to evaluate, compare and further develop these ML models meaningfully. **LeMat-GenBench** introduces a unified benchmark for generative models of crystalline materials, with standardized evaluation metrics for meaningful model comparison, diverse tasks, and this leaderboard to encourage and track community progress. 📄 **Paper**: [arXiv](https://arxiv.org/abs/2512.04562) | 💻 **Code**: [GitHub](https://github.com/LeMaterial/lemat-genbench) | 📧 **Contact**: siddharth.betala [at] entalpic.ai, alexandre.duval [at] entalpic.ai """) with gr.Tabs(elem_classes="tab-buttons"): with gr.TabItem("🚀 Leaderboard", elem_id="boundary-benchmark-tab-table"): gr.Markdown("# LeMat-GenBench") with gr.Row(): with gr.Column(scale=1): compact_view = gr.Checkbox( value=True, label="Compact View", info="Show only key metrics", ) show_percentage = gr.Checkbox( value=True, label="Show as Percentages", info="Display count-based metrics as percentages of total structures", ) with gr.Column(scale=1): sort_choices = ["None"] + [ COLUMN_DISPLAY_NAMES.get(col, col) for col in COLUMN_DISPLAY_NAMES.keys() ] sort_by = gr.Dropdown( choices=sort_choices, value="None", label="Sort By", info="Select column to sort by (default: sorted by MSUN+SUN descending)", ) sort_direction = gr.Radio( choices=["Ascending", "Descending"], value="Descending", label="Sort Direction", ) with gr.Column(scale=1): training_set_filter = gr.Dropdown( choices=["All"] + TRAINING_DATASETS, value="MP-20", label="Filter by Training Set", info="Show only models trained on a specific dataset", ) with gr.Column(scale=2): selected_groups = gr.CheckboxGroup( choices=list(METRIC_GROUPS.keys()), value=list(METRIC_GROUPS.keys()), label="Metric Families (only active when Compact View is off)", info="Select which metric groups to display", ) with gr.Accordion("Metric Groups Legend", open=False): gr.HTML(generate_metric_legend_html()) try: initial_df = get_leaderboard() cached_df_state = gr.State(initial_df) ALWAYS_SHOW_MODELS = {"AFLOW", "Alexandria", "OQMD"} filtered_initial_df = initial_df[ initial_df["training_set"].apply(lambda x: "MP-20" in parse_training_set(x)) | initial_df["model_name"].isin(ALWAYS_SHOW_MODELS) ] formatted_df = format_dataframe( filtered_initial_df, show_percentage=True, selected_groups=list(METRIC_GROUPS.keys()), compact_view=True, ) formatted_columns = ( list(formatted_df.data.columns) if hasattr(formatted_df, "data") else list(formatted_df.columns) ) leaderboard_table = gr.Dataframe( label="GenBench Leaderboard", value=formatted_df, interactive=False, wrap=True, datatype=["html"] + [None] * (len(formatted_columns) - 1) if formatted_columns else None, column_widths=["180px"] + ["160px"] * (len(formatted_columns) - 1) if formatted_columns else None, show_fullscreen_button=True, ) inputs = [ show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter, ] show_percentage.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table) selected_groups.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table) compact_view.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table) sort_by.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table) sort_direction.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table) training_set_filter.change(fn=update_leaderboard, inputs=inputs, outputs=leaderboard_table) except Exception as e: traceback.print_exc() gr.Markdown(f"Leaderboard is empty or error loading: {type(e).__name__}: {str(e)}") gr.Markdown(""" **Symbol Legend:** - 📄 Paper available (click to view) - ✅ Model output verified - ⚡ Structures were already relaxed - ★ Contributes to LeMat-Bulk reference dataset (in-distribution) - ◆ Out-of-distribution relative to LeMat-Bulk reference dataset Verified submissions mean the results came from a model submission rather than a CIF submission. Models marked as baselines appear below the separator line at the bottom of the table. """) with gr.TabItem("✉️ Submit", elem_id="boundary-benchmark-tab-table"): gr.Markdown(""" # Materials Submission Upload a ZIP of CIFs with your structures. To ensure eligibility for the leaderboard, please provide exactly 2,500 representative structures. """) filename = gr.State(value=None) gr.LoginButton() with gr.Row(): with gr.Column(): model_name_input = gr.Textbox( label="Model Name", placeholder="Enter your model name", info="Provide a name for your model/method", ) email_input = gr.Textbox( label="Email Address", placeholder="Enter your email address", info="Contact email for correspondence about this submission", ) paper_link_input = gr.Textbox( label="Paper Link (optional)", placeholder="https://arxiv.org/abs/...", info="Link to the paper describing your model/method", ) hf_model_link_input = gr.Textbox( label="HuggingFace Model Link (optional)", placeholder="https://huggingface.co/...", info="Link to your model on HuggingFace", ) problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type") with gr.Column(): cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.") relaxed = gr.Checkbox( value=False, label="Structures are pre-relaxed", info="Check this box if your submitted structures have already been relaxed", ) relaxation_settings_input = gr.Textbox( label="Relaxation Settings", placeholder="e.g., VASP PBE, 520 eV cutoff, ...", info="Describe the relaxation settings used", visible=False, ) training_dataset_input = gr.Dropdown( choices=TRAINING_DATASETS, label="Training Dataset", info="Select all datasets used for training", multiselect=True, ) training_dataset_other_input = gr.Textbox( label="Other Training Dataset", placeholder="Specify your training dataset", info="Provide details if you selected 'Others (must specify)'", visible=False, ) relaxed.change( fn=lambda x: gr.update(visible=x), inputs=[relaxed], outputs=[relaxation_settings_input], ) training_dataset_input.change( fn=lambda x: gr.update(visible="Others (must specify)" in (x or [])), inputs=[training_dataset_input], outputs=[training_dataset_other_input], ) submit_btn = gr.Button("Submission") message = gr.Textbox(label="Status", lines=1, visible=False) gr.Markdown( "If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space." ) submit_btn.click( submit_cif_files, inputs=[ model_name_input, problem_type, cif_file, relaxed, relaxation_settings_input, training_dataset_input, training_dataset_other_input, paper_link_input, hf_model_link_input, email_input, ], outputs=[message, filename], ).then( fn=show_output_box, inputs=[message], outputs=[message], ) return demo if __name__ == "__main__": gradio_interface().launch(show_error=True)