# Import Libraries import matplotlib.pyplot as plt import streamlit as st import src.prompt_config as prompt_params # Models import xgboost from sklearn.model_selection import train_test_split from sklearn.inspection import PartialDependenceDisplay # XAI (Explainability) import shap # Global Variables to Store Model & Data global_model = None X_train, X_test, y_train, y_test = None, None, None, None def train_model(): """ Train the XGBoost model only once and store it globally. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: # Load Data from SHAP library X, y = shap.datasets.adult() # Split data X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) # Train XGBoost model global_model = xgboost.XGBClassifier() global_model.fit(X_train, y_train) print("XGBoost Model training completed!") def explain_example(features, kind): """ Explain a given sample without retraining the model. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: train_model() fig, ax = plt.subplots(figsize=(10, 5)) PartialDependenceDisplay.from_estimator(global_model, X_test, features, kind=kind) st.pyplot(fig) def main(): global global_model # Ensure the model is trained only once if global_model is None: train_model() # Define feature names selected_feature = st.sidebar.selectbox("Select a feature for PDP/ICE analysis:", ("Age", "Workclass", "Education-Num", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Capital Gain", "Capital Loss", "Hours per week", "Country"),) print(f"selected feature is {selected_feature}") kind = st.sidebar.selectbox("Select plot type:", ("individual", "average", "both"),) st.title("ICE and PDP") st.write(prompt_params.ICE_INTRODUCTION) # Explain the selected sample if st.button("Explain Sample"): explain_example("Age", kind) if __name__ == '__main__': main()