Spaces:
Sleeping
Sleeping
# Import Libraries | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
import src.prompt_config as prompt_params | |
# Models | |
import xgboost | |
from sklearn.model_selection import train_test_split | |
from alepython import ale_plot | |
# XAI (Explainability) | |
import shap | |
from sklearn.inspection import permutation_importance | |
# Global Variables to Store Model & Data | |
global_model = None | |
X_train, X_test, y_train, y_test = None, None, None, None | |
def train_model(): | |
""" Train the XGBoost model only once and store it globally. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
# Load Data from SHAP library | |
X, y = shap.datasets.adult() | |
# Split data | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) | |
# Train XGBoost model | |
global_model = xgboost.XGBClassifier() | |
global_model.fit(X_train, y_train) | |
print("XGBoost Model training completed!") | |
def explain_example(): | |
""" Explain a given sample without retraining the model. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
train_model() | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
st.write("1D Main Effect ALE Plot") | |
perm_imp = permutation_importance(global_model, X_test, y_test, | |
n_repeats=30, | |
random_state=0) | |
sorted_idx = perm_imp.importances_mean.argsort() | |
ax.barh(X_test.columns[sorted_idx], perm_imp.importances[sorted_idx].mean(axis=1).T) | |
ax.set_title("Permutation Importances") | |
fig.tight_layout() | |
st.pyplot(fig) | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
ax.boxplot(perm_imp.importances[sorted_idx].T, | |
vert=False, labels=X_test.columns[sorted_idx]) | |
ax.set_title("Permutation Importances") | |
fig.tight_layout() | |
st.pyplot(fig) | |
def main(): | |
global global_model | |
# Ensure the model is trained only once | |
if global_model is None: | |
train_model() | |
st.title("Permutation Feature Importance") | |
st.write(prompt_params.PERMUTATION_INTRODUCTION) | |
# Explain the selected sample | |
if st.button("Explain Sample"): | |
explain_example() | |
if __name__ == '__main__': | |
main() | |