Spaces:
Sleeping
Sleeping
# Import Libraries | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
import src.prompt_config as prompt_params | |
# Models | |
import xgboost | |
from sklearn.model_selection import train_test_split | |
from alepython import ale_plot | |
# XAI (Explainability) | |
import shap | |
# Global Variables to Store Model & Data | |
global_model = None | |
X_train, X_test, y_train, y_test = None, None, None, None | |
def train_model(): | |
""" Train the XGBoost model only once and store it globally. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
# Load Data from SHAP library | |
X, y = shap.datasets.adult() | |
# Split data | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) | |
# Train XGBoost model | |
global_model = xgboost.XGBClassifier() | |
global_model.fit(X_train, y_train) | |
print("XGBoost Model training completed!") | |
def explain_example(): | |
""" Explain a given sample without retraining the model. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
train_model() | |
# fig, ax = plt.subplots(figsize=(10, 5)) | |
# st.write("1D Main Effect ALE Plot") | |
# ale_plot( | |
# global_model, | |
# X_test, | |
# "Age", | |
# bins=5, | |
# monte_carlo=True, | |
# monte_carlo_rep=30, | |
# monte_carlo_ratio=0.5, | |
# ) | |
# | |
# st.pyplot(fig) | |
fig1, ax1 = plt.subplots(figsize=(10, 5)) | |
st.write("2D Second-Order ALE Plot") | |
ale_plot(global_model, X_test, X_train.columns[:2], bins=10) | |
st.pyplot(fig1) | |
def main(): | |
global global_model | |
# Ensure the model is trained only once | |
if global_model is None: | |
train_model() | |
st.title("ALE (Accumulated Local Effects)") | |
st.write(prompt_params.ALE_INTRODUCTION) | |
st.write("now has bug, waiting for fix") | |
# Explain the selected sample | |
if st.button("Explain Sample"): | |
explain_example() | |
if __name__ == '__main__': | |
main() | |