Spaces:
Running
Running
| from pathlib import Path | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from datetime import datetime | |
| import os | |
| from about import ( | |
| PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo, | |
| COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS, | |
| METRIC_GROUP_COLORS, COLUMN_TO_GROUP, TRAINING_DATASETS | |
| ) | |
| def get_leaderboard(): | |
| ds = load_dataset(results_repo, split='train', download_mode="force_redownload") | |
| full_df = pd.DataFrame(ds) | |
| print(full_df.columns) | |
| if len(full_df) == 0: | |
| return pd.DataFrame({'date':[], 'model':[], 'score':[], 'verified':[]}) | |
| # Add computed column for MSUN+SUN | |
| if 'msun_count' in full_df.columns and 'sun_count' in full_df.columns: | |
| full_df['msun_plus_sun'] = full_df['msun_count'] + full_df['sun_count'] | |
| # Sort by MSUN+SUN in descending order by default | |
| if 'msun_plus_sun' in full_df.columns: | |
| full_df = full_df.sort_values(by='msun_plus_sun', ascending=False) | |
| return full_df | |
| def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True): | |
| """Format the dataframe with proper column names and optional percentages.""" | |
| if len(df) == 0: | |
| return df | |
| # Build column list based on view mode | |
| selected_cols = ['model_name'] | |
| if compact_view: | |
| # Use predefined compact columns | |
| from about import COMPACT_VIEW_COLUMNS | |
| selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns] | |
| else: | |
| # Build from selected groups | |
| if 'training_set' in df.columns: | |
| selected_cols.append('training_set') | |
| if 'n_structures' in df.columns: | |
| selected_cols.append('n_structures') | |
| # If no groups selected, show all | |
| if not selected_groups: | |
| selected_groups = list(METRIC_GROUPS.keys()) | |
| # Add columns from selected groups | |
| for group in selected_groups: | |
| if group in METRIC_GROUPS: | |
| for col in METRIC_GROUPS[group]: | |
| if col in df.columns and col not in selected_cols: | |
| selected_cols.append(col) | |
| # Create a copy with selected columns | |
| display_df = df[selected_cols].copy() | |
| # Add symbols to model names based on various properties | |
| if 'model_name' in display_df.columns: | |
| # Model links mapping | |
| model_links = { | |
| 'CrystaLLM-pi': 'https://huggingface.co/c-bone/CrystaLLM-pi_base', | |
| } | |
| def add_model_symbols(row): | |
| name = row['model_name'] | |
| symbols = [] | |
| # Add paper link emoji | |
| if 'paper_link' in df.columns: | |
| paper_val = row.get('paper_link', None) | |
| if paper_val and isinstance(paper_val, str) and paper_val.strip(): | |
| symbols.append(f'<a href="{paper_val.strip()}" target="_blank">π</a>') | |
| # Add relaxed symbol | |
| if 'relaxed' in df.columns and row.get('relaxed', False): | |
| symbols.append('β‘') | |
| # Add reference dataset symbols | |
| # β for Alexandria and OQMD (in-distribution, part of reference dataset) | |
| if name in ['Alexandria', 'OQMD']: | |
| symbols.append('β ') | |
| # β for AFLOW (out-of-distribution relative to reference dataset) | |
| elif name == 'AFLOW': | |
| symbols.append('β') | |
| elif name == 'CrystaLLM-pi': | |
| symbols.append('β ') | |
| symbol_str = f" {' '.join(symbols)}" if symbols else "" | |
| # Add link if model has one | |
| if name in model_links: | |
| return f'<a href="{model_links[name]}" target="_blank">{name}</a>{symbol_str}' | |
| return f"{name}{symbol_str}" | |
| display_df['model_name'] = df.apply(add_model_symbols, axis=1) | |
| # Format training_set column for clean display | |
| if 'training_set' in display_df.columns: | |
| def format_training_set(val): | |
| if val is None or (isinstance(val, float) and np.isnan(val)): | |
| return '' | |
| val = str(val).strip() | |
| if val in ('[]', '', 'nan', 'None'): | |
| return '' | |
| # Strip brackets and quotes for list-like strings | |
| val = val.strip('[]') | |
| val = val.replace("'", "").replace('"', '') | |
| return val | |
| display_df['training_set'] = display_df['training_set'].apply(format_training_set) | |
| # Convert count-based metrics to percentages if requested | |
| if show_percentage and 'n_structures' in df.columns: | |
| n_structures = df['n_structures'] | |
| for col in COUNT_BASED_METRICS: | |
| if col in display_df.columns: | |
| # Calculate percentage and format as string with % | |
| display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + '%' | |
| # Round numeric columns for cleaner display | |
| for col in display_df.columns: | |
| if display_df[col].dtype in ['float64', 'float32']: | |
| display_df[col] = display_df[col].round(4) | |
| # Separate baseline models to the bottom | |
| baseline_indices = set() | |
| if 'notes' in df.columns: | |
| is_baseline = df['notes'].fillna('').str.contains('baseline', case=False, na=False) | |
| non_baseline_df = display_df[~is_baseline] | |
| baseline_df = display_df[is_baseline] | |
| display_df = pd.concat([non_baseline_df, baseline_df]).reset_index(drop=True) | |
| # Track baseline row indices in the new dataframe | |
| baseline_indices = set(range(len(non_baseline_df), len(display_df))) | |
| # Rename columns for display | |
| display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES) | |
| # Apply color coding based on metric groups | |
| styled_df = apply_color_styling(display_df, selected_cols, baseline_indices) | |
| return styled_df | |
| def apply_color_styling(display_df, original_cols, baseline_indices=None): | |
| """Apply background colors to dataframe based on metric groups using pandas Styler.""" | |
| if baseline_indices is None: | |
| baseline_indices = set() | |
| def style_by_group(x): | |
| # Create a DataFrame with the same shape filled with empty strings | |
| styles = pd.DataFrame('', index=x.index, columns=x.columns) | |
| # Map display column names back to original column names | |
| for i, display_col in enumerate(x.columns): | |
| if i < len(original_cols): | |
| original_col = original_cols[i] | |
| # Check if this column belongs to a metric group | |
| if original_col in COLUMN_TO_GROUP: | |
| group = COLUMN_TO_GROUP[original_col] | |
| color = METRIC_GROUP_COLORS.get(group, '') | |
| if color: | |
| styles[display_col] = f'background-color: {color}' | |
| # Add thick top border to the first baseline row as a separator | |
| if baseline_indices: | |
| first_baseline_idx = min(baseline_indices) | |
| for col in x.columns: | |
| current = styles.at[first_baseline_idx, col] | |
| separator_style = 'border-top: 3px solid #555' | |
| styles.at[first_baseline_idx, col] = f'{current}; {separator_style}' if current else separator_style | |
| return styles | |
| # Apply the styling function | |
| return display_df.style.apply(style_by_group, axis=None) | |
| def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction, training_set_filter): | |
| """Update the leaderboard based on user selections. | |
| Uses cached dataframe to avoid re-downloading data on every change. | |
| """ | |
| # Use cached dataframe instead of re-downloading | |
| df_to_format = cached_df.copy() | |
| # Apply training set filter | |
| if training_set_filter and training_set_filter != "All" and 'training_set' in df_to_format.columns: | |
| df_to_format = df_to_format[df_to_format['training_set'] == training_set_filter] | |
| # Convert display name back to raw column name for sorting | |
| if sort_by and sort_by != "None": | |
| # Create reverse mapping from display names to raw column names | |
| display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()} | |
| raw_column_name = display_to_raw.get(sort_by, sort_by) | |
| if raw_column_name in df_to_format.columns: | |
| ascending = (sort_direction == "Ascending") | |
| df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending) | |
| formatted_df = format_dataframe(df_to_format, show_percentage, selected_groups, compact_view) | |
| return formatted_df | |
| def show_output_box(message): | |
| return gr.update(value=message, visible=True) | |
| def submit_cif_files(model_name, problem_type, cif_files, relaxed, relaxation_settings, training_datasets, training_dataset_other, paper_link, hf_model_link, email, profile: gr.OAuthProfile | None): | |
| """Submit structures to the leaderboard.""" | |
| from huggingface_hub import upload_file | |
| # Validate inputs | |
| if not model_name or not model_name.strip(): | |
| return "Error: Please provide a model name.", None | |
| if not problem_type: | |
| return "Error: Please select a problem type.", None | |
| if not cif_files: | |
| return "Error: Please upload a file.", None | |
| if not profile: | |
| return "Error: Please log in to submit.", None | |
| if not email or not email.strip(): | |
| return "Error: Please provide an email address.", None | |
| try: | |
| username = profile.username | |
| timestamp = datetime.now().isoformat() | |
| # Create submission metadata | |
| submission_data = { | |
| "username": username, | |
| "model_name": model_name.strip(), | |
| "problem_type": problem_type, | |
| "relaxed": relaxed, | |
| "relaxation_settings": relaxation_settings.strip() if relaxed and relaxation_settings else None, | |
| "training_datasets": training_datasets or [], | |
| "training_dataset_other": training_dataset_other.strip() if training_dataset_other else None, | |
| "paper_link": paper_link.strip() if paper_link else None, | |
| "hf_model_link": hf_model_link.strip() if hf_model_link else None, | |
| "email": email.strip(), | |
| "timestamp": timestamp, | |
| "file_name": Path(cif_files).name | |
| } | |
| # Create a unique submission ID | |
| submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}" | |
| # Upload the submission file | |
| file_path = Path(cif_files) | |
| uploaded_file_path = f"submissions/{submission_id}/{file_path.name}" | |
| upload_file( | |
| path_or_fileobj=str(file_path), | |
| path_in_repo=uploaded_file_path, | |
| repo_id=submissions_repo, | |
| token=TOKEN, | |
| repo_type="dataset" | |
| ) | |
| # Upload metadata as JSON | |
| metadata_path = f"submissions/{submission_id}/metadata.json" | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: | |
| json.dump(submission_data, f, indent=2) | |
| temp_metadata_path = f.name | |
| upload_file( | |
| path_or_fileobj=temp_metadata_path, | |
| path_in_repo=metadata_path, | |
| repo_id=submissions_repo, | |
| token=TOKEN, | |
| repo_type="dataset" | |
| ) | |
| # Clean up temp file | |
| os.unlink(temp_metadata_path) | |
| return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id | |
| except Exception as e: | |
| return f"Error during submission: {str(e)}", None | |
| def generate_metric_legend_html(): | |
| """Generate HTML table with color-coded metric group legend.""" | |
| metric_details = { | |
| 'Validity β': ('Valid, Charge Neutral, Distance Valid, Plausibility Valid', 'β Higher is better'), | |
| 'Uniqueness & Novelty β': ('Unique, Novel', 'β Higher is better'), | |
| 'Energy Metrics β': ('E Above Hull, Formation Energy, Relaxation RMSD (with std)', 'β Lower is better'), | |
| 'Stability β': ('Stable, Unique in Stable, SUN', 'β Higher is better'), | |
| 'Metastability β': ('Metastable, Unique in Metastable, MSUN', 'β Higher is better'), | |
| 'Distribution β': ('JS Distance, MMD, FID', 'β Lower is better'), | |
| 'Diversity β': ('Element, Space Group, Atomic Site, Crystal Size', 'β Higher is better'), | |
| 'HHI β': ('HHI Production, HHI Reserve', 'β Lower is better'), | |
| } | |
| html = '<table style="width: 100%; border-collapse: collapse;">' | |
| html += '<thead><tr>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Color</th>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Group</th>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Metrics</th>' | |
| html += '<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">Direction</th>' | |
| html += '</tr></thead><tbody>' | |
| for group, color in METRIC_GROUP_COLORS.items(): | |
| metrics, direction = metric_details.get(group, ('', '')) | |
| group_name = group.replace('β', '').replace('β', '').strip() | |
| html += '<tr>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;"><div style="width: 30px; height: 20px; background-color: {color}; border: 1px solid #999;"></div></td>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;"><strong>{group_name}</strong></td>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;">{metrics}</td>' | |
| html += f'<td style="border: 1px solid #ddd; padding: 8px;">{direction}</td>' | |
| html += '</tr>' | |
| html += '</tbody></table>' | |
| return html | |
| def gradio_interface() -> gr.Blocks: | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # π¬ LeMat-GenBench: A Unified Benchmark for Generative Models of Crystalline Materials | |
| Generative machine learning models hold great promise for accelerating materials discovery, particularly through the inverse design of inorganic crystals, enabling an unprecedented exploration of chemical space. Yet, the lack of standardized evaluation frameworks makes it difficult to evaluate, compare and further develop these ML models meaningfully. | |
| **LeMat-GenBench** introduces a unified benchmark for generative models of crystalline materials, with standardized evaluation metrics for meaningful model comparison, diverse tasks, and this leaderboard to encourage and track community progress. | |
| π **Paper**: [arXiv](https://arxiv.org/abs/2512.04562) | π» **Code**: [GitHub](https://github.com/LeMaterial/lemat-genbench) | π§ **Contact**: siddharth.betala-ext [at] entalpic.ai, alexandre.duval [at] entalpic.ai | |
| """) | |
| with gr.Tabs(elem_classes="tab-buttons"): | |
| with gr.TabItem("π Leaderboard", elem_id="boundary-benchmark-tab-table"): | |
| gr.Markdown("# LeMat-GenBench") | |
| # Display options | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| compact_view = gr.Checkbox( | |
| value=True, | |
| label="Compact View", | |
| info="Show only key metrics" | |
| ) | |
| show_percentage = gr.Checkbox( | |
| value=True, | |
| label="Show as Percentages", | |
| info="Display count-based metrics as percentages of total structures" | |
| ) | |
| with gr.Column(scale=1): | |
| # Create choices with display names, but values are the raw column names | |
| sort_choices = ["None"] + [COLUMN_DISPLAY_NAMES.get(col, col) for col in COLUMN_DISPLAY_NAMES.keys()] | |
| # Note: The initial sort is already applied in get_leaderboard() by MSUN+SUN | |
| sort_by = gr.Dropdown( | |
| choices=sort_choices, | |
| value="None", | |
| label="Sort By", | |
| info="Select column to sort by (default: sorted by MSUN+SUN descending)" | |
| ) | |
| sort_direction = gr.Radio( | |
| choices=["Ascending", "Descending"], | |
| value="Descending", | |
| label="Sort Direction" | |
| ) | |
| with gr.Column(scale=1): | |
| training_set_filter = gr.Dropdown( | |
| choices=["All"] + TRAINING_DATASETS, | |
| value="All", | |
| label="Filter by Training Set", | |
| info="Show only models trained on a specific dataset" | |
| ) | |
| with gr.Column(scale=2): | |
| selected_groups = gr.CheckboxGroup( | |
| choices=list(METRIC_GROUPS.keys()), | |
| value=list(METRIC_GROUPS.keys()), | |
| label="Metric Families (only active when Compact View is off)", | |
| info="Select which metric groups to display" | |
| ) | |
| # Metric legend with color coding | |
| with gr.Accordion("Metric Groups Legend", open=False): | |
| gr.HTML(generate_metric_legend_html()) | |
| try: | |
| # Initial dataframe - load once and cache | |
| initial_df = get_leaderboard() | |
| cached_df_state = gr.State(initial_df) | |
| formatted_df = format_dataframe(initial_df, show_percentage=True, selected_groups=list(METRIC_GROUPS.keys()), compact_view=True) | |
| leaderboard_table = gr.Dataframe( | |
| label="GenBench Leaderboard", | |
| value=formatted_df, | |
| interactive=False, | |
| wrap=True, | |
| datatype=["html"] + [None] * (len(formatted_df.columns) - 1) if len(formatted_df.columns) > 0 else None, | |
| column_widths=["180px"] + ["160px"] * (len(formatted_df.columns) - 1) if len(formatted_df.columns) > 0 else None, | |
| show_fullscreen_button=True | |
| ) | |
| # Update dataframe when options change (using cached data) | |
| show_percentage.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter], | |
| outputs=leaderboard_table | |
| ) | |
| selected_groups.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter], | |
| outputs=leaderboard_table | |
| ) | |
| compact_view.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter], | |
| outputs=leaderboard_table | |
| ) | |
| sort_by.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter], | |
| outputs=leaderboard_table | |
| ) | |
| sort_direction.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter], | |
| outputs=leaderboard_table | |
| ) | |
| training_set_filter.change( | |
| fn=update_leaderboard, | |
| inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter], | |
| outputs=leaderboard_table | |
| ) | |
| except Exception as e: | |
| gr.Markdown(f"Leaderboard is empty or error loading: {str(e)}") | |
| gr.Markdown(""" | |
| **Symbol Legend:** | |
| - π Paper available (click to view) | |
| - β Model output verified | |
| - β‘ Structures were already relaxed | |
| - β Contributes to LeMat-Bulk reference dataset (in-distribution) | |
| - β Out-of-distribution relative to LeMat-Bulk reference dataset | |
| Verified submissions mean the results came from a model submission rather than a CIF submission. | |
| Models marked as baselines appear below the separator line at the bottom of the table. | |
| """) | |
| with gr.TabItem("βοΈ Submit", elem_id="boundary-benchmark-tab-table"): | |
| gr.Markdown( | |
| """ | |
| # Materials Submission | |
| Upload a CSV, pkl, or a ZIP of CIFs with your structures. | |
| """ | |
| ) | |
| filename = gr.State(value=None) | |
| gr.LoginButton() | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_name_input = gr.Textbox( | |
| label="Model Name", | |
| placeholder="Enter your model name", | |
| info="Provide a name for your model/method" | |
| ) | |
| email_input = gr.Textbox( | |
| label="Email Address", | |
| placeholder="Enter your email address", | |
| info="Contact email for correspondence about this submission" | |
| ) | |
| paper_link_input = gr.Textbox( | |
| label="Paper Link (optional)", | |
| placeholder="https://arxiv.org/abs/...", | |
| info="Link to the paper describing your model/method" | |
| ) | |
| hf_model_link_input = gr.Textbox( | |
| label="HuggingFace Model Link (optional)", | |
| placeholder="https://huggingface.co/...", | |
| info="Link to your model on HuggingFace" | |
| ) | |
| problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type") | |
| with gr.Column(): | |
| cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.") | |
| relaxed = gr.Checkbox( | |
| value=False, | |
| label="Structures are pre-relaxed", | |
| info="Check this box if your submitted structures have already been relaxed" | |
| ) | |
| relaxation_settings_input = gr.Textbox( | |
| label="Relaxation Settings", | |
| placeholder="e.g., VASP PBE, 520 eV cutoff, ...", | |
| info="Describe the relaxation settings used", | |
| visible=False | |
| ) | |
| training_dataset_input = gr.Dropdown( | |
| choices=TRAINING_DATASETS, | |
| label="Training Dataset", | |
| info="Select all datasets used for training", | |
| multiselect=True | |
| ) | |
| training_dataset_other_input = gr.Textbox( | |
| label="Other Training Dataset", | |
| placeholder="Specify your training dataset", | |
| info="Provide details if you selected 'Others (must specify)'", | |
| visible=False | |
| ) | |
| # Show/hide relaxation settings based on pre-relaxed checkbox | |
| relaxed.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=[relaxed], | |
| outputs=[relaxation_settings_input] | |
| ) | |
| # Show/hide other dataset text box based on dropdown selection | |
| training_dataset_input.change( | |
| fn=lambda x: gr.update(visible="Others (must specify)" in (x or [])), | |
| inputs=[training_dataset_input], | |
| outputs=[training_dataset_other_input] | |
| ) | |
| submit_btn = gr.Button("Submission") | |
| message = gr.Textbox(label="Status", lines=1, visible=False) | |
| # help message | |
| gr.Markdown("If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space.") | |
| submit_btn.click( | |
| submit_cif_files, | |
| inputs=[model_name_input, problem_type, cif_file, relaxed, relaxation_settings_input, training_dataset_input, training_dataset_other_input, paper_link_input, hf_model_link_input, email_input], | |
| outputs=[message, filename], | |
| ).then( | |
| fn=show_output_box, | |
| inputs=[message], | |
| outputs=[message], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| gradio_interface().launch() | |