mlip-playground / src /streamlit_app.py
ManasSharma07's picture
Update src/streamlit_app.py
104d956 verified
import streamlit as st
import os
import io
import tempfile
import torch
# FOR CPU only mode
# torch._dynamo.config.suppress_errors = True
# Or disable compilation entirely
# torch.backends.cudnn.enabled = False
import numpy as np
from ase import Atoms
from ase.io import read, write
from ase.optimize import BFGS, LBFGS, FIRE
from ase.constraints import FixAtoms
from ase.filters import FrechetCellFilter
from ase.visualize import view
import py3Dmol
from mace.calculators import mace_mp
from fairchem.core import pretrained_mlip, FAIRChemCalculator
from orb_models.forcefield import pretrained
from orb_models.forcefield.calculator import ORBCalculator
from sevenn.calculator import SevenNetCalculator
import pandas as pd
import yaml # Added for FairChem reference energies
import subprocess
import sys
import pkg_resources
from ase.vibrations import Vibrations
import matplotlib.pyplot as plt
mattersim_available = False
if mattersim_available:
from mattersim.forcefield import MatterSimCalculator
# try:
# subprocess.check_call([sys.executable, "-m", "pip", "install", "mattersim"])
# except Exception as e:
# print(f"Error during installation of mattersim: {e}")
# try:
# from mattersim.forcefield import MatterSimCalculator
# mattersim_available = True
# print("\n\n\n\n\n\n\nSuccessfully imported MatterSimCalculator.\n\n\n\n\n\n\n\n\n\n")
# except ImportError as e:
# print(f"Failed to import MatterSimCalculator: {e} \n\n\n\n\n\n\n\n")
# mattersim_available = False
# # Define version threshold
# required_version = "2.0.0"
# try:
# installed_version = pkg_resources.get_distribution("numpy").version
# if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version(required_version):
# print(f"numpy version {installed_version} >= {required_version}. Installing numpy<2.0.0...")
# subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0"])
# else:
# print(f"numpy version {installed_version} is already < {required_version}. No action needed.")
# except pkg_resources.DistributionNotFound:
# print("numpy is not installed. Installing numpy<2.0.0...")
# subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0"])
from huggingface_hub import login
# try:
# hf_token = st.secrets["HF_TOKEN"]["token"]
# os.environ["HF_TOKEN"] = hf_token
# login(token=hf_token)
# except Exception as e:
# print("streamlit hf secret not defined/assigned")
try:
hf_token = os.getenv("YOUR SECRET KEY") # Replace with your actual Hugging Face token or manage secrets appropriately
if hf_token:
login(token = hf_token)
else:
print("Hugging Face token not found. Some models might not be accessible.")
except Exception as e:
print(f"hf login error: {e}")
os.environ["STREAMLIT_WATCHER_TYPE"] = "none"
# YAML data for FairChem reference energies
ELEMENT_REF_ENERGIES_YAML = """
oc20_elem_refs:
- 0.0
- -0.16141512
- 0.03262098
- -0.04787699
- -0.06299825
- -0.14979306
- -0.11657468
- -0.10862579
- -0.10298174
- -0.03420248
- 0.02673997
- -0.03729558
- 0.00515243
- -0.07535697
- -0.13663351
- -0.12922852
- -0.11796547
- -0.07802946
- -0.00672682
- -0.04089589
- -0.00024177
- -1.74545186
- -1.54220241
- -1.0934019
- -1.16168372
- -1.23073475
- -0.78852824
- -0.71851599
- -0.52465053
- -0.02692092
- -0.00317922
- -0.06266862
- -0.10835274
- -0.12394474
- -0.11351727
- -0.07455817
- -0.00258354
- -0.04111325
- -0.02090265
- -1.89306078
- -1.30591887
- -0.63320009
- -0.26230344
- -0.2633669
- -0.5160055
- -0.95950798
- -1.45589361
- -0.0429969
- -0.00026949
- -0.05925609
- -0.09734631
- -0.12406852
- -0.11427538
- -0.07021442
- 0.01091345
- -0.05305289
- -0.02427209
- -0.19975668
- -1.71692859
- -1.53677781
- -3.89987009
- -10.70940462
- -6.71693816
- -0.28102249
- -8.86944824
- -7.95762687
- -7.13041437
- -6.64620014
- -5.11482482
- -4.42548227
- 0.00848295
- -0.06956227
- -2.6748853
- -2.21153293
- -1.67367741
- -1.07636151
- -0.79009981
- -0.16387243
- -0.18164401
- -0.04122529
- -0.00041833
- -0.05259382
- -0.0934314
- -0.11023834
- -0.10039175
- -0.06069209
- 0.01790437
- -0.04694024
- 0.00334084
- -0.06030621
- -0.58793619
- -1.27821808
- -4.97483577
- -5.66985655
- -8.43154622
- -11.15001317
- -12.95770812
- 0.0
- -14.47602729
- 0.0
odac_elem_refs:
- 0.0
- -1.11737936
- -0.00011835
- -0.2941727
- -0.03868426
- -0.34862832
- -1.31552566
- -3.12457285
- -1.6052078
- -0.49653389
- -0.01137327
- -0.21957281
- -0.0008343
- -0.2750172
- -0.88417265
- -1.887378
- -0.94903558
- -0.31628167
- -0.02014536
- -0.15901053
- -0.00731884
- -1.96521355
- -1.89045209
- -2.53057428
- -5.43600675
- -5.09739336
- -3.03088746
- -1.23786562
- -0.40650749
- -0.2416017
- -0.01139188
- -0.26282496
- -0.82446455
- -1.70237206
- -0.84245376
- -0.28544892
- -0.02239991
- -0.14115912
- -0.02840799
- -2.09540994
- -1.85863996
- -1.12257399
- -4.32965355
- -3.30670045
- -1.19460755
- -1.26257601
- -1.46832888
- -0.19779414
- -0.0144274
- -0.23668767
- -0.70836953
- -1.43186113
- -0.71701186
- -0.24883129
- -0.01118184
- -0.13173447
- -0.0318395
- -0.41195547
- -1.23134873
- -2.03082996
- 0.1375954
- -5.45866275
- -7.59139905
- -5.99965965
- -8.43495767
- -2.6578407
- -7.77349787
- -5.30762201
- -5.15109657
- -4.41466995
- -0.02995219
- -0.2544495
- -3.23821202
- -3.45887214
- -4.53635003
- -4.60979468
- -2.90707964
- -1.28286153
- -0.57716664
- -0.18337108
- -0.01135944
- -0.22045398
- -0.66150479
- -1.32506342
- -0.66500178
- -0.22643927
- -0.00728197
- -0.11208472
- -0.00757856
- -0.21798637
- -0.91078787
- -1.78187161
- -3.89912261
- -3.94192659
- -7.59026042
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
omat_elem_refs:
- 0.0
- -1.11700253
- 0.00079886
- -0.29731164
- -0.04129868
- -0.29106192
- -1.27751531
- -3.12342715
- -1.54797136
- -0.43969356
- -0.01250908
- -0.22855413
- -0.00943179
- -0.21707638
- -0.82619133
- -1.88667434
- -0.89093583
- -0.25816211
- -0.02414768
- -0.17662425
- -0.02568319
- -2.13001165
- -2.38688845
- -3.55934233
- -5.44700879
- -5.14749562
- -3.30662847
- -1.42167737
- -0.63181379
- -0.23449167
- -0.01146636
- -0.21291259
- -0.77939897
- -1.70148487
- -0.78386705
- -0.22690657
- -0.02245409
- -0.16092396
- -0.02798717
- -2.25685695
- -2.23690495
- -2.15347771
- -4.60251809
- -3.36416792
- -2.23062607
- -1.15550917
- -1.47553527
- -0.19918102
- -0.01475888
- -0.19767692
- -0.68005773
- -1.43073368
- -0.65790462
- -0.18915279
- -0.01179476
- -0.13507902
- -0.03056979
- -0.36017439
- -0.86279246
- -0.20573327
- -0.2734463
- -0.20046965
- -0.25444338
- -8.37972664
- -9.58424928
- -0.19466184
- -0.24860115
- -0.19531288
- -0.15401392
- -0.14577898
- -0.19655747
- -0.15645898
- -3.49380556
- -3.5317097
- -4.57108006
- -4.63425205
- -2.88247063
- -1.45679675
- -0.50290184
- -0.18521704
- -0.01123956
- -0.17483649
- -0.63132037
- -1.3248562
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- -0.24135757
- -1.04601971
- -2.04574044
- -3.84544799
- -7.28626119
- -7.3136314
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
omol_elem_refs:
- 0.0
- -13.44558
- -78.82027
- -203.32564
- -398.94742
- -670.75275
- -1029.85403
- -1485.54188
- -2042.97832
- -2714.24015
- -3508.74317
- -4415.24203
- -5443.89712
- -6594.61834
- -7873.6878
- -9285.6593
- -10832.62132
- -12520.66852
- -14354.278
- -16323.54671
- -18436.47845
- -20696.18244
- -23110.5386
- -25682.99429
- -28418.37804
- -31317.92317
- -34383.42519
- -37623.46835
- -41039.92413
- -44637.38634
- -48417.14864
- -52373.87849
- -56512.76952
- -60836.14871
- -65344.28833
- -70041.24251
- -74929.56277
- -653.64777
- -833.31922
- -1038.0281
- -1273.96788
- -1542.45481
- -1850.74158
- -2193.91654
- -2577.18734
- -3004.13604
- -3477.52796
- -3997.31825
- -4563.75804
- -5171.82293
- -5828.85334
- -6535.61529
- -7291.54792
- -8099.87914
- -8962.17916
- -546.03214
- -690.6089
- -854.11237
- -12923.04096
- -14064.26124
- -15272.68689
- -16550.20551
- -17900.36515
- -19323.23406
- -20829.08848
- -22428.73258
- -24078.68008
- -25794.42097
- -27616.6819
- -29523.5526
- -31526.68012
- -33615.37779
- -1300.17791
- -1544.40924
- -1818.62298
- -2123.14417
- -2461.76028
- -2833.76287
- -3242.79895
- -3690.363
- -4174.99772
- -4691.75674
- -5245.36013
- -5838.12005
- -6469.07296
- -7140.86455
- -7854.60638
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
omc_elem_refs:
- 0.0
- -0.02831808
- 4.512e-05
- -0.03227157
- -0.03842519
- -0.05829283
- -0.0845041
- -0.08806738
- -0.09021346
- -0.06669846
- -0.01218631
- -0.03650269
- -0.00059093
- -0.05787736
- -0.08730952
- -0.0975534
- -0.09264199
- -0.07124762
- -0.02374602
- -0.05299112
- -0.02631476
- -1.7772147
- -1.25083444
- -0.79579447
- -0.49099317
- -0.31414986
- -0.20292182
- -0.14011632
- -0.09929659
- -0.03771207
- -0.01117902
- -0.06168715
- -0.08873364
- -0.09512942
- -0.09035978
- -0.06910849
- -0.02244872
- -0.05303651
- -0.02871903
- -1.94805417
- -1.33379896
- -0.69169331
- -0.26184306
- -0.20631599
- -0.48251608
- -0.96911893
- -1.47569462
- -0.03845194
- -0.0142445
- -0.07118991
- -0.09940292
- -0.09235056
- -0.08755943
- -0.06544925
- -0.01246646
- -0.04692937
- -0.03225123
- -0.26086039
- -27.20024339
- -0.08412926
- -0.08225924
- -0.07799715
- -0.07806185
- 0.00043759
- -0.07459766
- 0.0
- -0.06842841
- -0.07758266
- -0.07025152
- -0.08055003
- -0.07118177
- -0.07159568
- -2.69202862
- -2.21926765
- -1.679756
- -1.06135075
- -0.4554231
- -0.14488432
- -0.18377098
- -0.03603118
- -0.01076585
- -0.06381411
- -0.0905623
- -0.10095787
- -0.09501217
- -0.0574478
- -0.00599173
- -0.04134751
- -0.0082683
- -0.08704692
- -0.49656425
- -5.24233138
- -2.32542606
- -4.3376616
- -5.96430676
- 0.0
- 0.0
- -0.03842519
- 0.0
- 0.0
"""
try:
ELEMENT_REF_ENERGIES = yaml.safe_load(ELEMENT_REF_ENERGIES_YAML)
except yaml.YAMLError as e:
# st.error(f"Error parsing YAML reference energies: {e}") # st objects can only be used in main script flow
print(f"Error parsing YAML reference energies: {e}")
ELEMENT_REF_ENERGIES = {} # Fallback
# Check if running on Streamlit Cloud vs locally
is_streamlit_cloud = os.environ.get('STREAMLIT_RUNTIME_ENV') == 'cloud'
MAX_ATOMS_CLOUD = 500 # Maximum atoms allowed on Streamlit Cloud
MAX_ATOMS_CLOUD_UMA = 500
# Set page configuration
st.set_page_config(
page_title="MLIP Playground - Run, Test and Benchmark MLIPs",
page_icon="🧪",
layout="wide"
)
# Title and description
st.markdown('## MLIP Playground', unsafe_allow_html=True)
st.write('#### Run, test and compare 22 state-of-the-art universal machine learning interatomic potentials (MLIPs) for atomistic simulations of molecules and materials')
st.markdown('Upload molecular structure files or select from predefined examples, then compute energies and forces using foundation models such as those from MACE or FairChem (Meta).', unsafe_allow_html=True)
# Create a directory for sample structures if it doesn't exist
SAMPLE_DIR = "sample_structures"
os.makedirs(SAMPLE_DIR, exist_ok=True)
# Dictionary of sample structures
SAMPLE_STRUCTURES = {
"Water": "H2O.xyz",
"Methane": "CH4.xyz",
"Benzene": "C6H6.xyz",
"Ethane": "C2H6.xyz",
"Caffeine": "caffeine.xyz",
"Ibuprofen": "ibuprofen.xyz",
"Silicon": "Si.cif",
"hBN Monolayer (4x4)": "hBN_monolayer_4x4_supercell.extxyz",
}
def get_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400,
show_path=True, path_color='red', path_radius=0.02):
"""
Visualize optimization trajectory with multiple frames
Args:
trajectory: List of ASE atoms objects representing the optimization steps
style: Visualization style ('stick', 'ball', 'ball-stick')
show_unit_cell: Whether to show unit cell
show_path: Whether to show trajectory paths for each atom
path_color: Color of trajectory paths
path_radius: Radius of trajectory path cylinders
"""
if not trajectory:
return None
view = py3Dmol.view(width=width, height=height)
# Add all frames to the viewer
for frame_idx, atoms_obj in enumerate(trajectory):
xyz_str = ""
xyz_str += f"{len(atoms_obj)}\n"
xyz_str += f"Frame {frame_idx}\n"
for atom in atoms_obj:
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
view.addModel(xyz_str, "xyz")
# Set style for all models
if style.lower() == 'ball-stick':
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
elif style.lower() == 'stick':
view.setStyle({'stick': {}})
elif style.lower() == 'ball':
view.setStyle({'sphere': {'scale': 0.4}})
else:
view.setStyle({'stick': {'radius': 0.15}})
# Add trajectory paths
if show_path and len(trajectory) > 1:
for atom_idx in range(len(trajectory[0])):
for frame_idx in range(len(trajectory) - 1):
start_pos = trajectory[frame_idx][atom_idx].position
end_pos = trajectory[frame_idx + 1][atom_idx].position
view.addCylinder({
'start': {'x': start_pos[0], 'y': start_pos[1], 'z': start_pos[2]},
'end': {'x': end_pos[0], 'y': end_pos[1], 'z': end_pos[2]},
'radius': path_radius,
'color': path_color,
'alpha': 0.5
})
# Add unit cell for the last frame
if show_unit_cell and trajectory[-1].pbc.any():
cell = trajectory[-1].get_cell()
origin = np.array([0.0, 0.0, 0.0])
if cell is not None and cell.any():
edges = [
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]),
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
(cell[0] + cell[1], cell[0] + cell[1] + cell[2])
]
for start, end in edges:
view.addCylinder({
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
'radius': 0.05, 'color': 'black', 'alpha': 0.7
})
view.zoomTo()
view.setBackgroundColor('white')
return view
def get_animated_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400):
"""
Create an animated trajectory visualization
"""
if not trajectory:
return None
view = py3Dmol.view(width=width, height=height)
# Add all frames
for frame_idx, atoms_obj in enumerate(trajectory):
xyz_str = ""
xyz_str += f"{len(atoms_obj)}\n"
xyz_str += f"Frame {frame_idx}\n"
for atom in atoms_obj:
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
view.addModel(xyz_str, "xyz")
# Set style
if style.lower() == 'ball-stick':
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
elif style.lower() == 'stick':
view.setStyle({'stick': {}})
elif style.lower() == 'ball':
view.setStyle({'sphere': {'scale': 0.4}})
else:
view.setStyle({'stick': {'radius': 0.15}})
# Add unit cell for last frame
if show_unit_cell and trajectory[-1].pbc.any():
cell = trajectory[-1].get_cell()
origin = np.array([0.0, 0.0, 0.0])
if cell is not None and cell.any():
edges = [
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
(cell[0] + cell[1], cell[0] + cell[1] + cell[2]),
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1])
]
for start, end in edges:
view.addCylinder({
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
'radius': 0.05, 'color': 'black', 'alpha': 0.7
})
view.zoomTo()
view.setBackgroundColor('white')
# Enable animation
view.animate({'loop': 'forward', 'reps': 0, 'interval': 500})
return view
# Streamlit implementation example
def display_optimization_trajectory(trajectory, viz_style='ball-stick'):
"""
Display optimization trajectory in Streamlit with controls
"""
if not trajectory:
st.error("No trajectory data available")
return
st.subheader(f"Optimization Trajectory ({len(trajectory)} steps)")
# Trajectory options
col1, col2 = st.columns(2)
with col1:
viz_mode = st.selectbox(
"Visualization Mode",
["Animation", "Static with paths", "Step-by-step"],
key="viz_mode"
)
with col2:
if viz_mode == "Static with paths":
show_paths = st.checkbox("Show trajectory paths", value=True)
path_color = st.selectbox("Path color", ["red", "blue", "green", "orange"], index=0)
elif viz_mode == "Step-by-step":
frame_idx = st.slider("Frame", 0, len(trajectory)-1, 0, key="frame_slider")
# Display visualization based on mode
if viz_mode == "Static with paths":
opt_view = get_trajectory_viz(
trajectory,
style=viz_style,
show_unit_cell=True,
width=400,
height=400,
show_path=show_paths,
path_color=path_color
)
st.components.v1.html(opt_view._make_html(), width=400, height=400)
elif viz_mode == "Animation":
opt_view = get_animated_trajectory_viz(
trajectory,
style=viz_style,
show_unit_cell=True,
width=400,
height=400
)
st.components.v1.html(opt_view._make_html(), width=400, height=400)
elif viz_mode == "Step-by-step":
opt_view = get_structure_viz2(
trajectory[frame_idx],
style=viz_style,
show_unit_cell=True,
width=400,
height=400
)
st.components.v1.html(opt_view._make_html(), width=400, height=400)
st.write(f"Step {frame_idx + 1} of {len(trajectory)}")
def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400):
xyz_str = ""
xyz_str += f"{len(atoms_obj)}\n"
xyz_str += "Structure\n"
for atom in atoms_obj:
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
view = py3Dmol.view(width=width, height=height)
view.addModel(xyz_str, "xyz")
if style.lower() == 'ball-stick':
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
elif style.lower() == 'stick':
view.setStyle({'stick': {}})
elif style.lower() == 'ball':
view.setStyle({'sphere': {'scale': 0.4}})
else:
view.setStyle({'stick': {'radius': 0.15}})
if show_unit_cell and atoms_obj.pbc.any(): # Check pbc.any()
cell = atoms_obj.get_cell()
origin = np.array([0.0, 0.0, 0.0])
if cell is not None and cell.any(): # Ensure cell is not None and not all zeros
edges = [
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]),
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
(cell[0] + cell[1], cell[0] + cell[1] + cell[2])
]
for start, end in edges:
view.addCylinder({
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
'radius': 0.05, 'color': 'black', 'alpha': 0.7
})
view.zoomTo()
view.setBackgroundColor('white')
return view
opt_log = [] # Define globally or pass around if necessary
table_placeholder = st.empty() # Define globally if updated from callback
def streamlit_log(opt):
global opt_log, table_placeholder
try:
energy = opt.atoms.get_potential_energy()
forces = opt.atoms.get_forces()
fmax_step = np.max(np.linalg.norm(forces, axis=1)) if forces.shape[0] > 0 else 0.0
opt_log.append({
"Step": opt.nsteps,
"Energy (eV)": round(energy, 6),
"Fmax (eV/Å)": round(fmax_step, 6)
})
df = pd.DataFrame(opt_log)
table_placeholder.dataframe(df)
except Exception as e:
st.warning(f"Error in optimization logger: {e}")
def check_atom_limit(atoms_obj, selected_model):
if atoms_obj is None:
return True
num_atoms = len(atoms_obj)
limit = MAX_ATOMS_CLOUD_UMA if ('UMA' in selected_model or 'ESEN MD' in selected_model) else MAX_ATOMS_CLOUD
if num_atoms > limit:
st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, exceeding the {limit} atom limit for this model on Streamlit Cloud. Please run locally for larger systems.")
return False
return True
MACE_MODELS = {
"MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
"MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
"MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
"MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
"MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"MACE MP 0a Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"MACE MP 0a Large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model",
"MACE MP 0b Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_small.model",
"MACE MP 0b Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_medium.model",
"MACE MP 0b2 Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", # Corrected name from original code
"MACE MP 0b2 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
"MACE MP 0b2 Large": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
"MACE MP 0b3 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model",
}
FAIRCHEM_MODELS = {
"UMA Small": "uma-s-1",
"ESEN MD Direct All OMOL": "esen-md-direct-all-omol",
"ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol",
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
}
# Define the available ORB models
ORB_MODELS = {
"V3 OMAT Conservative (inf)": "orb-v3-conservative-inf-omat",
"V3 OMAT Conservative (20)": "orb-v3-conservative-20-omat",
"V3 OMAT Direct (inf)": "orb-v3-direct-inf-omat",
"V3 OMAT Direct (20)": "orb-v3-direct-20-omat",
"V3 MPA Conservative (inf)": "orb-v3-conservative-inf-mpa",
"V3 MPA Conservative (20)": "orb-v3-conservative-20-mpa",
"V3 MPA Direct (inf)": "orb-v3-direct-inf-mpa",
"V3 MPA Direct (20)": "orb-v3-direct-20-mpa",
}
# Define the available MatterSim models
MATTERSIM_MODELS = {
"V1 SMALL": "MatterSim-v1.0.0-1M.pth",
"V1 LARGE": "MatterSim-v1.0.0-5M.pth"
}
SEVEN_NET_MODELS = {
"7net-0": "7net-0",
"7net-l3i5": "7net-l3i5",
"7net-omat": "7net-omat",
"7net-mf-ompa": "7net-mf-ompa"
}
@st.cache_resource
def get_mace_model(model_path, device, selected_default_dtype):
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
@st.cache_resource
def get_fairchem_model(selected_model_name, model_path_or_name, device, selected_task_type_fc): # Renamed args to avoid conflict
predictor = pretrained_mlip.get_predict_unit(model_path_or_name, device=device)
if selected_model_name == "UMA Small":
calc = FAIRChemCalculator(predictor, task_name=selected_task_type_fc)
else:
calc = FAIRChemCalculator(predictor, task_name="omol")
return calc
st.sidebar.markdown("## Input Options")
input_method = st.sidebar.radio("Choose Input Method:", ["Select Example", "Upload File", "Paste Content"])
atoms = None
if input_method == "Upload File":
uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
if uploaded_file:
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_filepath = tmp_file.name
try:
atoms = read(tmp_filepath)
st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!")
except Exception as e:
st.sidebar.error(f"Error loading file: {str(e)}")
finally:
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
os.unlink(tmp_filepath)
elif input_method == "Select Example":
example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
if example_name:
file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
try:
atoms = read(file_path)
st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!")
except Exception as e:
st.sidebar.error(f"Error loading example: {str(e)}")
elif input_method == "Paste Content":
file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
content = st.sidebar.text_area("Paste file content here:", height=200)
if content:
try:
suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
suffix = suffix_map.get(file_format, ".xyz")
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_file.write(content.encode())
tmp_filepath = tmp_file.name
atoms = read(tmp_filepath)
st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
except Exception as e:
st.sidebar.error(f"Error parsing content: {str(e)}")
finally:
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
os.unlink(tmp_filepath)
if atoms is not None:
if not hasattr(atoms, 'info'):
atoms.info = {}
atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
atoms.info["spin"] = atoms.info.get("spin", 0) # Default spin (usually 2S for ASE, model might want 2S+1)
st.sidebar.markdown("## Model Selection")
if mattersim_available:
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "SEVEN_NET", "MatterSim"])
else:
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "SEVEN_NET"])
selected_task_type = None # For FairChem UMA
if model_type == "MACE":
selected_model = st.sidebar.selectbox("Select MACE Model:", list(MACE_MODELS.keys()))
model_path = MACE_MODELS[selected_model]
if selected_model == "MACE OMAT Medium":
st.sidebar.warning("Using model under Academic Software License (ASL).")
# selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32', 'float64'])
selected_default_dtype = 'float64'
if model_type == "FairChem":
selected_model = st.sidebar.selectbox("Select FairChem Model:", list(FAIRCHEM_MODELS.keys()))
model_path = FAIRCHEM_MODELS[selected_model]
if selected_model == "UMA Small":
st.sidebar.warning("Meta FAIR Acceptable Use Policy applies.")
selected_task_type = st.sidebar.selectbox("Select UMA Model Task Type:", ["omol", "omat", "omc", "odac", "oc20"])
if selected_task_type == "omol" and atoms is not None:
charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=atoms.info.get("charge",0))
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
atoms.info["charge"] = charge
atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
if model_type == "ORB":
selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys()))
model_path = ORB_MODELS[selected_model]
if "omat" in selected_model:
st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
# selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest'])
if model_type == "MatterSim":
selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys()))
model_path = MATTERSIM_MODELS[selected_model]
if model_type == "SEVEN_NET":
selected_model = st.sidebar.selectbox("Select SEVENNET Model:", list(SEVEN_NET_MODELS.keys()))
if selected_model == '7net-mf-ompa':
selected_modal_7net = st.sidebar.selectbox("Select Modal (multi fidelity model):", ['omat24', 'mpa'])
model_path = SEVEN_NET_MODELS[selected_model]
if atoms is not None:
if not check_atom_limit(atoms, selected_model):
st.stop() # Stop execution if limit exceeded
device = st.sidebar.radio("Computation Device:", ["CPU", "CUDA (GPU)"], index=0 if not torch.cuda.is_available() else 1)
device = "cuda" if device == "CUDA (GPU)" and torch.cuda.is_available() else "cpu"
if device == "cpu" and torch.cuda.is_available():
st.sidebar.info("GPU is available but CPU was selected.")
elif device == "cpu" and not torch.cuda.is_available():
st.sidebar.info("No GPU detected. Using CPU.")
st.sidebar.markdown("## Task Selection")
task = st.sidebar.selectbox("Select Calculation Task:",
["Energy Calculation",
"Energy + Forces Calculation",
"Atomization/Cohesive Energy", # New Task Added
"Geometry Optimization",
"Cell + Geometry Optimization",
"Vibrational Mode Analysis"])
if "Optimization" in task:
st.sidebar.markdown("### Optimization Parameters")
max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps
fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax
optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type
if atoms is not None:
col1, col2 = st.columns(2)
with col1:
st.markdown('### Structure Visualization', unsafe_allow_html=True)
viz_style = st.selectbox("Select Visualization Style:",
["ball-stick",
"stick",
"ball"])
view_3d = get_structure_viz2(atoms, style=viz_style, show_unit_cell=True, width=400, height=400)
st.components.v1.html(view_3d._make_html(), width=400, height=400)
st.markdown("### Structure Information")
atoms_info = {
"Number of Atoms": len(atoms),
"Chemical Formula": atoms.get_chemical_formula(),
"Periodic Boundary Conditions (PBC)": atoms.pbc.tolist(),
"Cell Dimensions": np.round(atoms.cell.cellpar(),3).tolist() if atoms.pbc.any() and atoms.cell is not None and atoms.cell.any() else "No cell / Non-periodic",
"Atom Types": ", ".join(sorted(list(set(atoms.get_chemical_symbols()))))
}
for key, value in atoms_info.items():
st.write(f"**{key}:** {value}")
with col2:
st.markdown('## Calculation Setup', unsafe_allow_html=True)
st.markdown("### Selected Model")
st.write(f"**Model Type:** {model_type}")
st.write(f"**Model:** {selected_model}")
if model_type == "FairChem" and selected_model == "UMA Small":
st.write(f"**UMA Task Type:** {selected_task_type}")
st.write(f"**Device:** {device}")
st.markdown("### Selected Task")
st.write(f"**Task:** {task}")
if "Optimization" in task:
st.write(f"**Max Steps:** {max_steps}")
st.write(f"**Convergence Threshold:** {fmax} eV/Å")
st.write(f"**Optimizer:** {optimizer_type}")
run_calculation = st.button("Run Calculation", type="primary")
if run_calculation:
# Delete all the items in Session state
for key in st.session_state.keys():
del st.session_state[key]
results = {}
#global table_placeholder # Ensure they are accessible
opt_log = [] # Reset log for each run
if "Optimization" in task:
table_placeholder = st.empty() # Recreate placeholder for table
try:
torch.set_default_dtype(torch.float32)
with st.spinner("Running calculation... Please wait."):
calc_atoms = atoms.copy()
if model_type == "MACE":
# st.write("Setting up MACE calculator...")
calc = get_mace_model(model_path, device, 'float32')
elif model_type == "FairChem": # FairChem
# st.write("Setting up FairChem calculator...")
# Workaround for potential dtype issues when switching models
# if device == "cpu": # Ensure torch default dtype matches if needed
# torch.set_default_dtype(torch.float32)
# _ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
elif model_type == "ORB":
# st.write("Setting up ORB calculator...")
orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
calc = ORBCalculator(orbff, device=device)
elif model_type == "MatterSim":
# st.write("Setting up MatterSim calculator...")
# NOTE: Running mattersim on windows requires changing source code file
# https://github.com/microsoft/mattersim/issues/112
# mattersim/datasets/utils/convertor.py: 117
# to pbc_ = np.array(structure.pbc, dtype=np.int64)
calc = MatterSimCalculator(load_path=model_path, device=device)
elif model_type == "SEVEN_NET":
# st.write("Setting up SEVENNET calculator...")
if model_path=='7net-mf-ompa':
calc = SevenNetCalculator(model=model_path, modal=selected_modal_7net, device=device)
else:
calc = SevenNetCalculator(model=model_path, device=device)
calc_atoms.calc = calc
if task == "Energy Calculation":
energy = calc_atoms.get_potential_energy()
results["Energy"] = f"{energy:.6f} eV"
elif task == "Energy + Forces Calculation":
energy = calc_atoms.get_potential_energy()
forces = calc_atoms.get_forces()
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0
results["Energy"] = f"{energy:.6f} eV"
results["Maximum Force"] = f"{max_force:.6f} eV/Å"
elif task == "Atomization/Cohesive Energy":
st.write("Calculating system energy...")
E_system = calc_atoms.get_potential_energy()
num_atoms = len(calc_atoms)
if num_atoms == 0:
st.error("Cannot calculate atomization/cohesive energy for a system with zero atoms.")
results["Error"] = "System has no atoms."
else:
atomic_numbers = calc_atoms.get_atomic_numbers()
E_isolated_atoms_total = 0.0
calculation_possible = True
if model_type == "FairChem":
st.write("Fetching FairChem reference energies for isolated atoms...")
ref_key_suffix = "_elem_refs"
chosen_ref_list_name = None
if selected_model == "UMA Small":
if selected_task_type:
chosen_ref_list_name = selected_task_type + ref_key_suffix
elif "ESEN" in selected_model:
chosen_ref_list_name = "omol" + ref_key_suffix
if chosen_ref_list_name and chosen_ref_list_name in ELEMENT_REF_ENERGIES:
ref_energies = ELEMENT_REF_ENERGIES[chosen_ref_list_name]
missing_Z_refs = []
for Z_val in atomic_numbers:
if Z_val > 0 and Z_val < len(ref_energies):
E_isolated_atoms_total += ref_energies[Z_val]
else:
if Z_val not in missing_Z_refs: missing_Z_refs.append(Z_val)
if missing_Z_refs:
st.warning(f"Reference energy for atomic number(s) {sorted(list(set(missing_Z_refs)))} "
f"not found in '{chosen_ref_list_name}' list (max Z defined: {len(ref_energies)-1}). "
"These atoms are treated as having 0 reference energy.")
else:
st.error(f"Could not find or determine reference energy list for FairChem model: '{selected_model}' "
f"and UMA task type: '{selected_task_type}'. Cannot calculate atomization/cohesive energy.")
results["Error"] = "Missing FairChem reference energies."
calculation_possible = False
else:# == "MACE":
st.write("Calculating isolated atom energies with MACE...")
unique_atomic_numbers = sorted(list(set(atomic_numbers)))
atom_counts = {Z_unique: np.count_nonzero(atomic_numbers == Z_unique) for Z_unique in unique_atomic_numbers}
progress_text = "Calculating isolated atom energies: 0% complete"
mace_progress_bar = st.progress(0, text=progress_text)
for i, Z_unique in enumerate(unique_atomic_numbers):
isolated_atom = Atoms(numbers=[Z_unique], cell=[20, 20, 20], pbc=False)
if not hasattr(isolated_atom, 'info'): isolated_atom.info = {}
isolated_atom.info["charge"] = 0
isolated_atom.info["spin"] = 0
isolated_atom.calc = calc # Use the same MACE calculator
E_isolated_atom_type = isolated_atom.get_potential_energy()
E_isolated_atoms_total += E_isolated_atom_type * atom_counts[Z_unique]
progress_val = (i + 1) / len(unique_atomic_numbers)
mace_progress_bar.progress(progress_val, text=f"Calculating isolated atom energies for Z={Z_unique}: {int(progress_val*100)}% complete")
mace_progress_bar.empty()
if calculation_possible:
is_periodic = any(calc_atoms.pbc)
if is_periodic:
cohesive_E = (E_isolated_atoms_total - E_system) / num_atoms
results["Cohesive Energy"] = f"{cohesive_E:.6f} eV/atom"
else:
atomization_E = E_isolated_atoms_total - E_system
results["Atomization Energy"] = f"{atomization_E:.6f} eV"
results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV"
results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV"
elif "Optimization" in task: # Handles both Geometry and Cell+Geometry Opt
is_periodic = any(calc_atoms.pbc)
opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms
# Create temporary trajectory file
traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
if optimizer_type == "BFGS":
opt = BFGS(opt_atoms_obj, trajectory=traj_filename)
elif optimizer_type == "LBFGS":
opt = LBFGS(opt_atoms_obj, trajectory=traj_filename)
else: # FIRE
opt = FIRE(opt_atoms_obj, trajectory=traj_filename)
# opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly
opt.attach(lambda: streamlit_log(opt), interval=1)
st.write(f"Running {task.lower()}...")
opt.run(fmax=fmax, steps=max_steps)
energy = calc_atoms.get_potential_energy()
forces = calc_atoms.get_forces()
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0
results["Final Energy"] = f"{energy:.6f} eV"
results["Final Maximum Force"] = f"{max_force:.6f} eV/Å"
results["Steps Taken"] = opt.get_number_of_steps()
results["Converged"] = "Yes" if opt.converged() else "No"
if task == "Cell + Geometry Optimization":
results["Final Cell Parameters"] = np.round(calc_atoms.cell.cellpar(), 4).tolist()
st.success("Calculation completed successfully!")
st.markdown("### Results")
for key, value in results.items():
st.write(f"**{key}:** {value}")
if "Optimization" in task and "Final Energy" in results: # Check if opt was successful
st.markdown("### Optimized Structure")
opt_view = get_structure_viz2(opt_atoms_obj, style=viz_style, show_unit_cell=True, width=400, height=400)
st.components.v1.html(opt_view._make_html(), width=400, height=400)
with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz", mode="w+") as tmp_file_opt:
if is_periodic:
write(tmp_file_opt.name, calc_atoms, format="extxyz")
else:
write(tmp_file_opt.name, calc_atoms, format="xyz")
tmp_filepath_opt = tmp_file_opt.name
with open(tmp_filepath_opt, 'r') as file_opt:
xyz_content_opt = file_opt.read()
@st.fragment
def show_optimized_structure_download_button():
# st.button("Release the balloons", help="Fragment rerun")
# st.balloons()
st.download_button(
label="Download Optimized Structure (XYZ)",
data=xyz_content_opt,
file_name="optimized_structure.xyz",
mime="chemical/x-xyz"
)
show_optimized_structure_download_button()
os.unlink(tmp_filepath_opt)
@st.fragment
def show_trajectory_and_controls():
from ase.io import read
import py3Dmol
if "traj_frames" not in st.session_state:
if os.path.exists(traj_filename):
try:
trajectory = read(traj_filename, index=":")
st.session_state.traj_frames = trajectory
st.session_state.traj_index = 0
except Exception as e:
st.error(f"Error reading trajectory: {e}")
return
# finally:
# os.unlink(traj_filename)
else:
st.warning("Trajectory file not found.")
return
trajectory = st.session_state.traj_frames
index = st.session_state.traj_index
st.markdown("### Optimization Trajectory")
st.write(f"Captured {len(trajectory)} optimization steps")
# Navigation Buttons
col1, col2, col3, col4 = st.columns(4)
with col1:
if st.button("⏮ First"):
st.session_state.traj_index = 0
with col2:
if st.button("◀ Previous") and index > 0:
st.session_state.traj_index -= 1
with col3:
if st.button("Next ▶") and index < len(trajectory) - 1:
st.session_state.traj_index += 1
with col4:
if st.button("Last ⏭"):
st.session_state.traj_index = len(trajectory) - 1
# Show current frame
current_atoms = trajectory[st.session_state.traj_index]
st.write(f"Frame {st.session_state.traj_index + 1}/{len(trajectory)}")
def atoms_to_xyz_string(atoms, step_idx=None):
xyz_str = f"{len(atoms)}\n"
if step_idx is not None:
xyz_str += f"Step {step_idx}, Energy = {atoms.get_potential_energy():.6f} eV\n"
else:
xyz_str += f"Energy = {atoms.get_potential_energy():.6f} eV\n"
for atom in atoms:
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
return xyz_str
traj_view = get_structure_viz2(current_atoms, style=viz_style, show_unit_cell=True, width=400, height=400)
st.components.v1.html(traj_view._make_html(), width=400, height=400)
# Download button for entire trajectory
trajectory_xyz = ""
for i, atoms in enumerate(trajectory):
trajectory_xyz += atoms_to_xyz_string(atoms, i)
st.download_button(
label="Download Optimization Trajectory (XYZ)",
data=trajectory_xyz,
file_name="optimization_trajectory.xyz",
mime="chemical/x-xyz"
)
show_trajectory_and_controls()
elif task == "Vibrational Mode Analysis":
st.write("Running vibrational mode analysis using finite differences...")
natoms = len(calc_atoms)
is_linear = False # Set manually or auto-detect
nmodes_expected = 3 * natoms - (5 if is_linear else 6)
# Create temporary directory to store .vib files
with tempfile.TemporaryDirectory() as tmpdir:
vib = Vibrations(calc_atoms, name=os.path.join(tmpdir, 'vib'))
with st.spinner("Calculating vibrational modes... This may take a few minutes."):
vib.run()
freqs = vib.get_frequencies()
# Convert frequencies to cm⁻¹
freqs_cm = freqs #/ cm
# Classify frequencies
mode_data = []
for i, freq in enumerate(freqs_cm):
if freq < 0:
label = "Imaginary"
elif abs(freq) < 500:
label = "Low"
else:
label = "Physical"
mode_data.append({
"Mode": i + 1,
"Frequency (cm⁻¹)": round(freq, 2),
"Type": label
})
df_modes = pd.DataFrame(mode_data)
# Display summary and mode count
st.success("Vibrational analysis completed.")
st.write(f"Number of atoms: {natoms}")
st.write(f"Expected vibrational modes: {nmodes_expected}")
st.write(f"Found {len(freqs_cm)} modes (including translational/rotational modes).")
# Show table of modes
st.write("### Vibrational Mode Summary")
st.dataframe(df_modes, use_container_width=True)
# Store in results dictionary
results["Vibrational Modes"] = df_modes.to_dict(orient="records")
# Histogram plot of vibrational frequencies
st.write("### Frequency Distribution Histogram")
fig, ax = plt.subplots()
ax.hist(freqs_cm, bins=30, color='skyblue', edgecolor='black')
ax.set_xlabel("Frequency (cm⁻¹)")
ax.set_ylabel("Number of Modes")
ax.set_title("Distribution of Vibrational Frequencies")
st.pyplot(fig)
# CSV download
csv_buffer = io.StringIO()
df_modes.to_csv(csv_buffer, index=False)
st.download_button(
label="Download Vibrational Frequencies (CSV)",
data=csv_buffer.getvalue(),
file_name="vibrational_modes.csv",
mime="text/csv"
)
except Exception as e:
st.error(f"🔴 Calculation error: {str(e)}")
st.error("Please check the structure, model compatibility, and parameters. For FairChem UMA, ensure the task type (omol, omat etc.) is appropriate for your system (e.g. omol for molecules, omat for materials).")
import traceback
st.error(f"Traceback: {traceback.format_exc()}")
else:
st.info("👋 Welcome! Please select or upload a structure using the sidebar options to begin.")
st.markdown("---")
with st.expander('ℹ️ About This App & Foundational MLIPs'):
st.write("""
**Test, compare, and benchmark universal machine learning interatomic potentials (MLIPs).**
This application allows you to perform atomistic simulations using pre-trained foundational MLIPs
from the MACE and FairChem (by Meta AI) libraries.
**Features:**
- Upload structure files (XYZ, CIF, POSCAR, etc.) or use built-in examples.
- Select from various MACE and FairChem models.
- Calculate energies, forces, and perform geometry/cell optimizations.
- **New**: Calculate atomization energy (for molecules) or cohesive energy (for periodic systems).
- Visualize atomic structures in 3D and download results.
**Quick Start:**
1. **Input**: Choose an input method in the sidebar (e.g., "Select Example").
2. **Model**: Pick a model type (MACE/FairChem) and specific model. For FairChem UMA, select the appropriate task type (e.g., `omol` for molecules, `omat` for materials).
3. **Task**: Select a calculation task (e.g., "Energy Calculation", "Atomization/Cohesive Energy", "Geometry Optimization").
4. **Run**: Click "Run Calculation" and view the results.
**Atomization/Cohesive Energy Notes:**
- **Atomization Energy** ($E_{\text{atomization}} = \sum E_{\text{isolated atoms}} - E_{\text{molecule}}$) is typically for non-periodic systems (molecules).
- **Cohesive Energy** ($E_{\text{cohesive}} = (\sum E_{\text{isolated atoms}} - E_{\text{bulk system}}) / N_{\text{atoms}}$) is for periodic systems.
- For **MACE models**, isolated atom energies are computed on-the-fly.
- For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references.
""")
st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem, SevenNet, ORB and ❤️")
st.markdown("Developed by [Manas Sharma](https://manas.bragitoff.com/) in the groups of [Prof. Ananth Govind Rajan Group](https://www.agrgroup.org/) and [Prof. Sudeep Punnathanam](https://chemeng.iisc.ac.in/sudeep/) at [IISc Bangalore](https://iisc.ac.in/)")