Spaces:
Sleeping
Sleeping
# Import Libraries | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
import src.prompt_config as prompt_params | |
# Models | |
import xgboost | |
from sklearn.model_selection import train_test_split | |
# XAI (Explainability) | |
import shap | |
# Global Variables to Store Model & Data | |
global_model = None | |
X_train, X_test, y_train, y_test = None, None, None, None | |
def train_model(): | |
""" Train the XGBoost model only once and store it globally. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
# Load Data from SHAP library | |
X, y = shap.datasets.adult() | |
# Split data | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) | |
# Train XGBoost model | |
global_model = xgboost.XGBClassifier() | |
global_model.fit(X_train, y_train) | |
print("XGBoost Model training completed!") | |
def explain_example(baseline_number, example_idx): | |
""" Explain a given sample without retraining the model. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
train_model() | |
X, y = shap.datasets.adult() | |
X_base = shap.utils.sample(X, baseline_number) | |
explainer = shap.TreeExplainer(global_model, X_base) # Use the TreeExplainer algorithm with background distribution | |
shap_values = explainer.shap_values(X_test) # Get shap values | |
shap_values_exp = explainer(X_test) # Get explainer for X_test | |
# SHAP Summary Plot (BeeSwarm) | |
st.write("### π SHAP Summary Plot") | |
st.write("This plot provides an intuitive way to see how different features contribute to individual predictions, making model interpretations easier!") | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
shap.summary_plot(shap_values, X_test, show=False) | |
st.pyplot(fig) | |
# SHAP Summary Bar Plot | |
st.write("### π SHAP Feature Importance (Bar Plot)") | |
st.write("It helps understand which features the model relies on most.") | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
shap.summary_plot(shap_values, X_test, plot_type="bar", show=False) | |
st.pyplot(fig) | |
# SHAP Dependence Plot | |
st.write("### π SHAP Dependence Plot for 'Age'") | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
shap.dependence_plot('Age', shap_values, X_test, ax=ax, show=False) | |
st.pyplot(fig) | |
# SHAP Waterfall Plot | |
st.write(f"### π SHAP Waterfall Plot for Example {example_idx}") | |
st.write(f"Visualize the SHAP values for an instance of interest") | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
shap.plots.waterfall(shap_values_exp[example_idx], show=False) | |
st.pyplot(fig) | |
def main(): | |
global global_model | |
# Ensure the model is trained only once | |
if global_model is None: | |
train_model() | |
# Streamlit UI Controls | |
baseline_number = st.sidebar.number_input( | |
label="Select the number of baseline:", | |
min_value=20, | |
max_value=1000, | |
value=100, # Default value | |
step=1 | |
) | |
example_idx = st.sidebar.number_input( | |
label="Select the sample index to explain:", | |
min_value=0, | |
max_value=len(X_test) - 1, # Ensures the index is within range | |
value=100, # Default value | |
step=1, # Step size | |
help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX, | |
) | |
st.title("SHAP") | |
st.write(prompt_params.SHAP_INTRODUCTION) | |
# Explain the selected sample | |
if st.button("Explain Sample"): | |
explain_example(baseline_number, example_idx) | |
if __name__ == '__main__': | |
main() | |