peggy30's picture
add ice pdp
fe54272
# Import Libraries
import matplotlib.pyplot as plt
import streamlit as st
import src.prompt_config as prompt_params
# Models
import xgboost
from sklearn.model_selection import train_test_split
# XAI (Explainability)
import shap
# Global Variables to Store Model & Data
global_model = None
X_train, X_test, y_train, y_test = None, None, None, None
def train_model():
""" Train the XGBoost model only once and store it globally. """
global global_model, X_train, X_test, y_train, y_test
if global_model is None:
# Load Data from SHAP library
X, y = shap.datasets.adult()
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
# Train XGBoost model
global_model = xgboost.XGBClassifier()
global_model.fit(X_train, y_train)
print("XGBoost Model training completed!")
def explain_example(baseline_number, example_idx):
""" Explain a given sample without retraining the model. """
global global_model, X_train, X_test, y_train, y_test
if global_model is None:
train_model()
X, y = shap.datasets.adult()
X_base = shap.utils.sample(X, baseline_number)
explainer = shap.TreeExplainer(global_model, X_base) # Use the TreeExplainer algorithm with background distribution
shap_values = explainer.shap_values(X_test) # Get shap values
shap_values_exp = explainer(X_test) # Get explainer for X_test
# SHAP Summary Plot (BeeSwarm)
st.write("### πŸ“Š SHAP Summary Plot")
st.write("This plot provides an intuitive way to see how different features contribute to individual predictions, making model interpretations easier!")
fig, ax = plt.subplots(figsize=(10, 5))
shap.summary_plot(shap_values, X_test, show=False)
st.pyplot(fig)
# SHAP Summary Bar Plot
st.write("### πŸ“Š SHAP Feature Importance (Bar Plot)")
st.write("It helps understand which features the model relies on most.")
fig, ax = plt.subplots(figsize=(10, 5))
shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)
st.pyplot(fig)
# SHAP Dependence Plot
st.write("### πŸ” SHAP Dependence Plot for 'Age'")
fig, ax = plt.subplots(figsize=(10, 5))
shap.dependence_plot('Age', shap_values, X_test, ax=ax, show=False)
st.pyplot(fig)
# SHAP Waterfall Plot
st.write(f"### 🌊 SHAP Waterfall Plot for Example {example_idx}")
st.write(f"Visualize the SHAP values for an instance of interest")
fig, ax = plt.subplots(figsize=(10, 5))
shap.plots.waterfall(shap_values_exp[example_idx], show=False)
st.pyplot(fig)
def main():
global global_model
# Ensure the model is trained only once
if global_model is None:
train_model()
# Streamlit UI Controls
baseline_number = st.sidebar.number_input(
label="Select the number of baseline:",
min_value=20,
max_value=1000,
value=100, # Default value
step=1
)
example_idx = st.sidebar.number_input(
label="Select the sample index to explain:",
min_value=0,
max_value=len(X_test) - 1, # Ensures the index is within range
value=100, # Default value
step=1, # Step size
help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX,
)
st.title("SHAP")
st.write(prompt_params.SHAP_INTRODUCTION)
# Explain the selected sample
if st.button("Explain Sample"):
explain_example(baseline_number, example_idx)
if __name__ == '__main__':
main()