Spaces:
Sleeping
Sleeping
# Import Libraries | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
import src.prompt_config as prompt_params | |
# Models | |
import xgboost | |
from sklearn.model_selection import train_test_split | |
from sklearn.inspection import PartialDependenceDisplay | |
# XAI (Explainability) | |
import shap | |
# Global Variables to Store Model & Data | |
global_model = None | |
X_train, X_test, y_train, y_test = None, None, None, None | |
def train_model(): | |
""" Train the XGBoost model only once and store it globally. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
# Load Data from SHAP library | |
X, y = shap.datasets.adult() | |
# Split data | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) | |
# Train XGBoost model | |
global_model = xgboost.XGBClassifier() | |
global_model.fit(X_train, y_train) | |
print("XGBoost Model training completed!") | |
def explain_example(features, kind): | |
""" Explain a given sample without retraining the model. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
train_model() | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
PartialDependenceDisplay.from_estimator(global_model, X_test, features, kind=kind) | |
st.pyplot(fig) | |
def main(): | |
global global_model | |
# Ensure the model is trained only once | |
if global_model is None: | |
train_model() | |
# Define feature names | |
selected_feature = st.sidebar.selectbox("Select a feature for PDP/ICE analysis:", ("Age", "Workclass", "Education-Num", "Marital Status", "Occupation", | |
"Relationship", "Race", "Sex", "Capital Gain", "Capital Loss", "Hours per week", "Country"),) | |
print(f"selected feature is {selected_feature}") | |
kind = st.sidebar.selectbox("Select plot type:", ("individual", "average", "both"),) | |
st.title("ICE and PDP") | |
st.write(prompt_params.ICE_INTRODUCTION) | |
# Explain the selected sample | |
if st.button("Explain Sample"): | |
explain_example("Age", kind) | |
if __name__ == '__main__': | |
main() | |