Spaces:
Sleeping
Sleeping
# Import Libraries | |
import pandas as pd | |
import streamlit as st | |
import src.prompt_config as prompt_params | |
# Models | |
import xgboost | |
from sklearn.model_selection import train_test_split | |
# XAI (Explainability) | |
import shap | |
import lime | |
from anchor import anchor_tabular | |
# Global Variables to Store Model & Data | |
global_model = None | |
X_train, X_test, y_train, y_test = None, None, None, None | |
def train_model(): | |
""" Train the XGBoost model only once and store it globally. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
# Load Data from SHAP library | |
X, y = shap.datasets.adult() | |
# Split data | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) | |
# Train XGBoost model | |
global_model = xgboost.XGBClassifier() | |
global_model.fit(X_train, y_train) | |
print("XGBoost Model training completed!") | |
def define_features(): | |
""" Define feature names and categorical mappings. """ | |
feature_names = ["Age", "Workclass", | |
"Education-Num", "Marital Status", "Occupation", | |
"Relationship", "Race", "Sex", "Capital Gain", | |
"Capital Loss", "Hours per week", "Country"] | |
categorical_features = ["Workclass", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Country"] | |
class_names = ['<=50K', '>50K'] | |
categorical_names = { | |
1: ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay', | |
'Never-worked'], | |
3: ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed', 'Married-spouse-absent', | |
'Married-AF-spouse'], | |
4: ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty', | |
'Handlers-cleaners', | |
'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Transport-moving', 'Priv-house-serv', | |
'Protective-serv', 'Armed-Forces'], | |
5: ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried'], | |
6: ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'], | |
7: ['Female', 'Male'], | |
11: ['United-States', 'Cambodia', 'England', 'Puerto-Rico', 'Canada', 'Germany', 'Outlying-US(Guam-USVI-etc)', | |
'India', | |
'Japan', 'Greece', 'South', 'China', 'Cuba', 'Iran', 'Honduras', 'Philippines', 'Italy', 'Poland', | |
'Jamaica', 'Vietnam', | |
'Mexico', 'Portugal', 'Ireland', 'France', 'Dominican-Republic', 'Laos', 'Ecuador', 'Taiwan', 'Haiti', | |
'Columbia', 'Hungary', | |
'Guatemala', 'Nicaragua', 'Scotland', 'Thailand', 'Yugoslavia', 'El-Salvador', 'Trinadad&Tobago', 'Peru', | |
'Hong', 'Holand-Netherlands'] | |
} | |
return feature_names, categorical_features, class_names, categorical_names | |
def explain_example(anchors_threshold, example_idx): | |
""" Explain a given sample without retraining the model. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
train_model() | |
feature_names, categorical_features, class_names, categorical_names = define_features() | |
# Initialize Anchors explainer | |
explainer = anchor_tabular.AnchorTabularExplainer( | |
class_names, | |
feature_names, | |
X_train.values, | |
categorical_names) | |
# Explain the selected sample | |
exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict, threshold=anchors_threshold) | |
explanation_data = { | |
"Feature Rule": exp.names(), | |
"Precision": [f"{exp.precision():.2f}"] * len(exp.names()), | |
"Coverage": [f"{exp.coverage():.2f}"] * len(exp.names()) | |
} | |
df_explanation = pd.DataFrame(explanation_data) | |
Prediction = explainer.class_names[global_model.predict(X_test.values[example_idx].reshape(1, -1))[0]] | |
st.write(f"### π The prediction of No. {example_idx} example is: **{Prediction}**") | |
st.write("### π Explanation Rules for this example:") | |
st.table(df_explanation) | |
def main(): | |
global global_model | |
# Ensure the model is trained only once | |
if global_model is None: | |
train_model() | |
anchors_threshold = st.sidebar.slider( | |
label="Set the `Threshold` value:", | |
min_value=0.00, | |
max_value=1.00, | |
value=0.8, # Default value | |
step=0.01, # Step size | |
help=prompt_params.ANCHORS_THRESHOLD_HELP, | |
) | |
example_idx = st.sidebar.number_input( | |
label="Select the sample index to explain:", | |
min_value=0, | |
max_value=len(X_test) - 1, # Ensures the index is within range | |
value=100, # Default value | |
step=1, # Step size | |
help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX, | |
) | |
st.title("Anchors") | |
st.write(prompt_params.ANCHORS_INTRODUCTION) | |
# Explain the selected sample | |
if st.button("Explain Sample"): | |
explain_example(anchors_threshold, example_idx) | |
if __name__ == '__main__': | |
main() | |