# Import Libraries import matplotlib.pyplot as plt import streamlit as st import src.prompt_config as prompt_params # Models import xgboost from sklearn.model_selection import train_test_split # XAI (Explainability) import shap # Global Variables to Store Model & Data global_model = None X_train, X_test, y_train, y_test = None, None, None, None def train_model(): """ Train the XGBoost model only once and store it globally. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: # Load Data from SHAP library X, y = shap.datasets.adult() # Split data X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) # Train XGBoost model global_model = xgboost.XGBClassifier() global_model.fit(X_train, y_train) print("XGBoost Model training completed!") def explain_example(baseline_number, example_idx): """ Explain a given sample without retraining the model. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: train_model() X, y = shap.datasets.adult() X_base = shap.utils.sample(X, baseline_number) explainer = shap.TreeExplainer(global_model, X_base) # Use the TreeExplainer algorithm with background distribution shap_values = explainer.shap_values(X_test) # Get shap values shap_values_exp = explainer(X_test) # Get explainer for X_test # SHAP Summary Plot (BeeSwarm) st.write("### 📊 SHAP Summary Plot") st.write("This plot provides an intuitive way to see how different features contribute to individual predictions, making model interpretations easier!") fig, ax = plt.subplots(figsize=(10, 5)) shap.summary_plot(shap_values, X_test, show=False) st.pyplot(fig) # SHAP Summary Bar Plot st.write("### 📊 SHAP Feature Importance (Bar Plot)") st.write("It helps understand which features the model relies on most.") fig, ax = plt.subplots(figsize=(10, 5)) shap.summary_plot(shap_values, X_test, plot_type="bar", show=False) st.pyplot(fig) # SHAP Dependence Plot st.write("### 🔍 SHAP Dependence Plot for 'Age'") fig, ax = plt.subplots(figsize=(10, 5)) shap.dependence_plot('Age', shap_values, X_test, ax=ax, show=False) st.pyplot(fig) # SHAP Waterfall Plot st.write(f"### 🌊 SHAP Waterfall Plot for Example {example_idx}") st.write(f"Visualize the SHAP values for an instance of interest") fig, ax = plt.subplots(figsize=(10, 5)) shap.plots.waterfall(shap_values_exp[example_idx], show=False) st.pyplot(fig) def main(): global global_model # Ensure the model is trained only once if global_model is None: train_model() # Streamlit UI Controls baseline_number = st.sidebar.number_input( label="Select the number of baseline:", min_value=20, max_value=1000, value=100, # Default value step=1 ) example_idx = st.sidebar.number_input( label="Select the sample index to explain:", min_value=0, max_value=len(X_test) - 1, # Ensures the index is within range value=100, # Default value step=1, # Step size help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX, ) st.title("SHAP") st.write(prompt_params.SHAP_INTRODUCTION) # Explain the selected sample if st.button("Explain Sample"): explain_example(baseline_number, example_idx) if __name__ == '__main__': main()