Spaces:
Sleeping
Sleeping
File size: 5,135 Bytes
e1da86b fe54272 e1da86b ec3fab5 e1da86b ec3fab5 e1da86b 5963f5d e1da86b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Import Libraries
import pandas as pd
import streamlit as st
import src.prompt_config as prompt_params
# Models
import xgboost
from sklearn.model_selection import train_test_split
# XAI (Explainability)
import shap
import lime
from anchor import anchor_tabular
# Global Variables to Store Model & Data
global_model = None
X_train, X_test, y_train, y_test = None, None, None, None
def train_model():
""" Train the XGBoost model only once and store it globally. """
global global_model, X_train, X_test, y_train, y_test
if global_model is None:
# Load Data from SHAP library
X, y = shap.datasets.adult()
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
# Train XGBoost model
global_model = xgboost.XGBClassifier()
global_model.fit(X_train, y_train)
print("XGBoost Model training completed!")
def define_features():
""" Define feature names and categorical mappings. """
feature_names = ["Age", "Workclass",
"Education-Num", "Marital Status", "Occupation",
"Relationship", "Race", "Sex", "Capital Gain",
"Capital Loss", "Hours per week", "Country"]
categorical_features = ["Workclass", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Country"]
class_names = ['<=50K', '>50K']
categorical_names = {
1: ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay',
'Never-worked'],
3: ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed', 'Married-spouse-absent',
'Married-AF-spouse'],
4: ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty',
'Handlers-cleaners',
'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Transport-moving', 'Priv-house-serv',
'Protective-serv', 'Armed-Forces'],
5: ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried'],
6: ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'],
7: ['Female', 'Male'],
11: ['United-States', 'Cambodia', 'England', 'Puerto-Rico', 'Canada', 'Germany', 'Outlying-US(Guam-USVI-etc)',
'India',
'Japan', 'Greece', 'South', 'China', 'Cuba', 'Iran', 'Honduras', 'Philippines', 'Italy', 'Poland',
'Jamaica', 'Vietnam',
'Mexico', 'Portugal', 'Ireland', 'France', 'Dominican-Republic', 'Laos', 'Ecuador', 'Taiwan', 'Haiti',
'Columbia', 'Hungary',
'Guatemala', 'Nicaragua', 'Scotland', 'Thailand', 'Yugoslavia', 'El-Salvador', 'Trinadad&Tobago', 'Peru',
'Hong', 'Holand-Netherlands']
}
return feature_names, categorical_features, class_names, categorical_names
def explain_example(anchors_threshold, example_idx):
""" Explain a given sample without retraining the model. """
global global_model, X_train, X_test, y_train, y_test
if global_model is None:
train_model()
feature_names, categorical_features, class_names, categorical_names = define_features()
# Initialize Anchors explainer
explainer = anchor_tabular.AnchorTabularExplainer(
class_names,
feature_names,
X_train.values,
categorical_names)
# Explain the selected sample
exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict, threshold=anchors_threshold)
explanation_data = {
"Feature Rule": exp.names(),
"Precision": [f"{exp.precision():.2f}"] * len(exp.names()),
"Coverage": [f"{exp.coverage():.2f}"] * len(exp.names())
}
df_explanation = pd.DataFrame(explanation_data)
Prediction = explainer.class_names[global_model.predict(X_test.values[example_idx].reshape(1, -1))[0]]
st.write(f"### π The prediction of No. {example_idx} example is: **{Prediction}**")
st.write("### π Explanation Rules for this example:")
st.table(df_explanation)
def main():
global global_model
# Ensure the model is trained only once
if global_model is None:
train_model()
anchors_threshold = st.sidebar.slider(
label="Set the `Threshold` value:",
min_value=0.00,
max_value=1.00,
value=0.8, # Default value
step=0.01, # Step size
help=prompt_params.ANCHORS_THRESHOLD_HELP,
)
example_idx = st.sidebar.number_input(
label="Select the sample index to explain:",
min_value=0,
max_value=len(X_test) - 1, # Ensures the index is within range
value=100, # Default value
step=1, # Step size
help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX,
)
st.title("Anchors")
st.write(prompt_params.ANCHORS_INTRODUCTION)
# Explain the selected sample
if st.button("Explain Sample"):
explain_example(anchors_threshold, example_idx)
if __name__ == '__main__':
main()
|