# Import Libraries import matplotlib.pyplot as plt import streamlit as st import src.prompt_config as prompt_params # Models import xgboost from sklearn.model_selection import train_test_split from alepython import ale_plot # XAI (Explainability) import shap from sklearn.inspection import permutation_importance # Global Variables to Store Model & Data global_model = None X_train, X_test, y_train, y_test = None, None, None, None def train_model(): """ Train the XGBoost model only once and store it globally. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: # Load Data from SHAP library X, y = shap.datasets.adult() # Split data X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) # Train XGBoost model global_model = xgboost.XGBClassifier() global_model.fit(X_train, y_train) print("XGBoost Model training completed!") def explain_example(): """ Explain a given sample without retraining the model. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: train_model() fig, ax = plt.subplots(figsize=(10, 5)) st.write("1D Main Effect ALE Plot") perm_imp = permutation_importance(global_model, X_test, y_test, n_repeats=30, random_state=0) sorted_idx = perm_imp.importances_mean.argsort() ax.barh(X_test.columns[sorted_idx], perm_imp.importances[sorted_idx].mean(axis=1).T) ax.set_title("Permutation Importances") fig.tight_layout() st.pyplot(fig) fig, ax = plt.subplots(figsize=(10, 5)) ax.boxplot(perm_imp.importances[sorted_idx].T, vert=False, labels=X_test.columns[sorted_idx]) ax.set_title("Permutation Importances") fig.tight_layout() st.pyplot(fig) def main(): global global_model # Ensure the model is trained only once if global_model is None: train_model() st.title("Permutation Feature Importance") st.write(prompt_params.PERMUTATION_INTRODUCTION) # Explain the selected sample if st.button("Explain Sample"): explain_example() if __name__ == '__main__': main()