# Import Libraries import matplotlib.pyplot as plt import streamlit as st import src.prompt_config as prompt_params # Models import xgboost from sklearn.model_selection import train_test_split from alepython import ale_plot # XAI (Explainability) import shap # Global Variables to Store Model & Data global_model = None X_train, X_test, y_train, y_test = None, None, None, None def train_model(): """ Train the XGBoost model only once and store it globally. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: # Load Data from SHAP library X, y = shap.datasets.adult() # Split data X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) # Train XGBoost model global_model = xgboost.XGBClassifier() global_model.fit(X_train, y_train) print("XGBoost Model training completed!") def explain_example(): """ Explain a given sample without retraining the model. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: train_model() # fig, ax = plt.subplots(figsize=(10, 5)) # st.write("1D Main Effect ALE Plot") # ale_plot( # global_model, # X_test, # "Age", # bins=5, # monte_carlo=True, # monte_carlo_rep=30, # monte_carlo_ratio=0.5, # ) # # st.pyplot(fig) fig1, ax1 = plt.subplots(figsize=(10, 5)) st.write("2D Second-Order ALE Plot") ale_plot(global_model, X_test, X_train.columns[:2], bins=10) st.pyplot(fig1) def main(): global global_model # Ensure the model is trained only once if global_model is None: train_model() st.title("ALE (Accumulated Local Effects)") st.write(prompt_params.ALE_INTRODUCTION) st.write("now has bug, waiting for fix") # Explain the selected sample if st.button("Explain Sample"): explain_example() if __name__ == '__main__': main()