# Import Libraries import pandas as pd import streamlit as st import src.prompt_config as prompt_params # Models import xgboost from sklearn.model_selection import train_test_split # XAI (Explainability) import shap import lime from anchor import anchor_tabular # Global Variables to Store Model & Data global_model = None X_train, X_test, y_train, y_test = None, None, None, None def train_model(): """ Train the XGBoost model only once and store it globally. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: # Load Data from SHAP library X, y = shap.datasets.adult() # Split data X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) # Train XGBoost model global_model = xgboost.XGBClassifier() global_model.fit(X_train, y_train) print("XGBoost Model training completed!") def define_features(): """ Define feature names and categorical mappings. """ feature_names = ["Age", "Workclass", "Education-Num", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Capital Gain", "Capital Loss", "Hours per week", "Country"] categorical_features = ["Workclass", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Country"] class_names = ['<=50K', '>50K'] categorical_names = { 1: ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay', 'Never-worked'], 3: ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed', 'Married-spouse-absent', 'Married-AF-spouse'], 4: ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty', 'Handlers-cleaners', 'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Transport-moving', 'Priv-house-serv', 'Protective-serv', 'Armed-Forces'], 5: ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried'], 6: ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'], 7: ['Female', 'Male'], 11: ['United-States', 'Cambodia', 'England', 'Puerto-Rico', 'Canada', 'Germany', 'Outlying-US(Guam-USVI-etc)', 'India', 'Japan', 'Greece', 'South', 'China', 'Cuba', 'Iran', 'Honduras', 'Philippines', 'Italy', 'Poland', 'Jamaica', 'Vietnam', 'Mexico', 'Portugal', 'Ireland', 'France', 'Dominican-Republic', 'Laos', 'Ecuador', 'Taiwan', 'Haiti', 'Columbia', 'Hungary', 'Guatemala', 'Nicaragua', 'Scotland', 'Thailand', 'Yugoslavia', 'El-Salvador', 'Trinadad&Tobago', 'Peru', 'Hong', 'Holand-Netherlands'] } return feature_names, categorical_features, class_names, categorical_names def explain_example(anchors_threshold, example_idx): """ Explain a given sample without retraining the model. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: train_model() feature_names, categorical_features, class_names, categorical_names = define_features() # Initialize Anchors explainer explainer = anchor_tabular.AnchorTabularExplainer( class_names, feature_names, X_train.values, categorical_names) # Explain the selected sample exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict, threshold=anchors_threshold) explanation_data = { "Feature Rule": exp.names(), "Precision": [f"{exp.precision():.2f}"] * len(exp.names()), "Coverage": [f"{exp.coverage():.2f}"] * len(exp.names()) } df_explanation = pd.DataFrame(explanation_data) Prediction = explainer.class_names[global_model.predict(X_test.values[example_idx].reshape(1, -1))[0]] st.write(f"### 🔍 The prediction of No. {example_idx} example is: **{Prediction}**") st.write("### 📌 Explanation Rules for this example:") st.table(df_explanation) def main(): global global_model # Ensure the model is trained only once if global_model is None: train_model() anchors_threshold = st.sidebar.slider( label="Set the `Threshold` value:", min_value=0.00, max_value=1.00, value=0.8, # Default value step=0.01, # Step size help=prompt_params.ANCHORS_THRESHOLD_HELP, ) example_idx = st.sidebar.number_input( label="Select the sample index to explain:", min_value=0, max_value=len(X_test) - 1, # Ensures the index is within range value=100, # Default value step=1, # Step size help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX, ) st.title("Anchors") st.write(prompt_params.ANCHORS_INTRODUCTION) # Explain the selected sample if st.button("Explain Sample"): explain_example(anchors_threshold, example_idx) if __name__ == '__main__': main()