ManasSharma07 commited on
Commit
ea2c499
·
verified ·
1 Parent(s): 8afc436

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +196 -0
src/streamlit_app.py CHANGED
@@ -616,6 +616,202 @@ SAMPLE_STRUCTURES = {
616
  "hBN Monolayer (4x4)": "hBN_monolayer_4x4_supercell.extxyz",
617
  }
618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400):
620
  xyz_str = ""
621
  xyz_str += f"{len(atoms_obj)}\n"
 
616
  "hBN Monolayer (4x4)": "hBN_monolayer_4x4_supercell.extxyz",
617
  }
618
 
619
+ def get_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400,
620
+ show_path=True, path_color='red', path_radius=0.02):
621
+ """
622
+ Visualize optimization trajectory with multiple frames
623
+
624
+ Args:
625
+ trajectory: List of ASE atoms objects representing the optimization steps
626
+ style: Visualization style ('stick', 'ball', 'ball-stick')
627
+ show_unit_cell: Whether to show unit cell
628
+ show_path: Whether to show trajectory paths for each atom
629
+ path_color: Color of trajectory paths
630
+ path_radius: Radius of trajectory path cylinders
631
+ """
632
+ if not trajectory:
633
+ return None
634
+
635
+ view = py3Dmol.view(width=width, height=height)
636
+
637
+ # Add all frames to the viewer
638
+ for frame_idx, atoms_obj in enumerate(trajectory):
639
+ xyz_str = ""
640
+ xyz_str += f"{len(atoms_obj)}\n"
641
+ xyz_str += f"Frame {frame_idx}\n"
642
+ for atom in atoms_obj:
643
+ xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
644
+
645
+ view.addModel(xyz_str, "xyz")
646
+
647
+ # Set style for all models
648
+ if style.lower() == 'ball-stick':
649
+ view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
650
+ elif style.lower() == 'stick':
651
+ view.setStyle({'stick': {}})
652
+ elif style.lower() == 'ball':
653
+ view.setStyle({'sphere': {'scale': 0.4}})
654
+ else:
655
+ view.setStyle({'stick': {'radius': 0.15}})
656
+
657
+ # Add trajectory paths
658
+ if show_path and len(trajectory) > 1:
659
+ for atom_idx in range(len(trajectory[0])):
660
+ for frame_idx in range(len(trajectory) - 1):
661
+ start_pos = trajectory[frame_idx][atom_idx].position
662
+ end_pos = trajectory[frame_idx + 1][atom_idx].position
663
+
664
+ view.addCylinder({
665
+ 'start': {'x': start_pos[0], 'y': start_pos[1], 'z': start_pos[2]},
666
+ 'end': {'x': end_pos[0], 'y': end_pos[1], 'z': end_pos[2]},
667
+ 'radius': path_radius,
668
+ 'color': path_color,
669
+ 'alpha': 0.5
670
+ })
671
+
672
+ # Add unit cell for the last frame
673
+ if show_unit_cell and trajectory[-1].pbc.any():
674
+ cell = trajectory[-1].get_cell()
675
+ origin = np.array([0.0, 0.0, 0.0])
676
+ if cell is not None and cell.any():
677
+ edges = [
678
+ (origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
679
+ (cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
680
+ (cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]),
681
+ (origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
682
+ (cell[0] + cell[1], cell[0] + cell[1] + cell[2])
683
+ ]
684
+ for start, end in edges:
685
+ view.addCylinder({
686
+ 'start': {'x': start[0], 'y': start[1], 'z': start[2]},
687
+ 'end': {'x': end[0], 'y': end[1], 'z': end[2]},
688
+ 'radius': 0.05, 'color': 'black', 'alpha': 0.7
689
+ })
690
+
691
+ view.zoomTo()
692
+ view.setBackgroundColor('white')
693
+ return view
694
+
695
+
696
+ def get_animated_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400):
697
+ """
698
+ Create an animated trajectory visualization
699
+ """
700
+ if not trajectory:
701
+ return None
702
+
703
+ view = py3Dmol.view(width=width, height=height)
704
+
705
+ # Add all frames
706
+ for frame_idx, atoms_obj in enumerate(trajectory):
707
+ xyz_str = ""
708
+ xyz_str += f"{len(atoms_obj)}\n"
709
+ xyz_str += f"Frame {frame_idx}\n"
710
+ for atom in atoms_obj:
711
+ xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
712
+
713
+ view.addModel(xyz_str, "xyz")
714
+
715
+ # Set style
716
+ if style.lower() == 'ball-stick':
717
+ view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
718
+ elif style.lower() == 'stick':
719
+ view.setStyle({'stick': {}})
720
+ elif style.lower() == 'ball':
721
+ view.setStyle({'sphere': {'scale': 0.4}})
722
+ else:
723
+ view.setStyle({'stick': {'radius': 0.15}})
724
+
725
+ # Add unit cell for last frame
726
+ if show_unit_cell and trajectory[-1].pbc.any():
727
+ cell = trajectory[-1].get_cell()
728
+ origin = np.array([0.0, 0.0, 0.0])
729
+ if cell is not None and cell.any():
730
+ edges = [
731
+ (origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
732
+ (origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
733
+ (cell[0] + cell[1], cell[0] + cell[1] + cell[2]),
734
+ (cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
735
+ (cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1])
736
+ ]
737
+ for start, end in edges:
738
+ view.addCylinder({
739
+ 'start': {'x': start[0], 'y': start[1], 'z': start[2]},
740
+ 'end': {'x': end[0], 'y': end[1], 'z': end[2]},
741
+ 'radius': 0.05, 'color': 'black', 'alpha': 0.7
742
+ })
743
+
744
+ view.zoomTo()
745
+ view.setBackgroundColor('white')
746
+
747
+ # Enable animation
748
+ view.animate({'loop': 'forward', 'reps': 0, 'interval': 500})
749
+
750
+ return view
751
+
752
+
753
+ # Streamlit implementation example
754
+ def display_optimization_trajectory(trajectory, viz_style='stick'):
755
+ """
756
+ Display optimization trajectory in Streamlit with controls
757
+ """
758
+ if not trajectory:
759
+ st.error("No trajectory data available")
760
+ return
761
+
762
+ st.subheader(f"Optimization Trajectory ({len(trajectory)} steps)")
763
+
764
+ # Trajectory options
765
+ col1, col2 = st.columns(2)
766
+
767
+ with col1:
768
+ viz_mode = st.selectbox(
769
+ "Visualization Mode",
770
+ ["Static with paths", "Animation", "Step-by-step"],
771
+ key="viz_mode"
772
+ )
773
+
774
+ with col2:
775
+ if viz_mode == "Static with paths":
776
+ show_paths = st.checkbox("Show trajectory paths", value=True)
777
+ path_color = st.selectbox("Path color", ["red", "blue", "green", "orange"], index=0)
778
+ elif viz_mode == "Step-by-step":
779
+ frame_idx = st.slider("Frame", 0, len(trajectory)-1, 0, key="frame_slider")
780
+
781
+ # Display visualization based on mode
782
+ if viz_mode == "Static with paths":
783
+ opt_view = get_trajectory_viz(
784
+ trajectory,
785
+ style=viz_style,
786
+ show_unit_cell=True,
787
+ width=400,
788
+ height=400,
789
+ show_path=show_paths,
790
+ path_color=path_color
791
+ )
792
+ st.components.v1.html(opt_view._make_html(), width=400, height=400)
793
+
794
+ elif viz_mode == "Animation":
795
+ opt_view = get_animated_trajectory_viz(
796
+ trajectory,
797
+ style=viz_style,
798
+ show_unit_cell=True,
799
+ width=400,
800
+ height=400
801
+ )
802
+ st.components.v1.html(opt_view._make_html(), width=400, height=400)
803
+
804
+ elif viz_mode == "Step-by-step":
805
+ opt_view = get_structure_viz2(
806
+ trajectory[frame_idx],
807
+ style=viz_style,
808
+ show_unit_cell=True,
809
+ width=400,
810
+ height=400
811
+ )
812
+ st.components.v1.html(opt_view._make_html(), width=400, height=400)
813
+ st.write(f"Step {frame_idx + 1} of {len(trajectory)}")
814
+
815
  def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400):
816
  xyz_str = ""
817
  xyz_str += f"{len(atoms_obj)}\n"