Spaces:
Sleeping
Sleeping
File size: 4,948 Bytes
12ee8ba d4ef106 12ee8ba d4ef106 12ee8ba d4ef106 12ee8ba e1da86b 12ee8ba e1da86b 12ee8ba 5963f5d 12ee8ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Import Libraries
import streamlit as st
import src.prompt_config as prompt_params
import streamlit.components.v1 as components
# Models
import xgboost
from sklearn.model_selection import train_test_split
# XAI (Explainability)
import shap
import lime
# Global Variables to Store Model & Data
global_model = None
X_train, X_test, y_train, y_test = None, None, None, None
def train_model():
""" Train the XGBoost model only once and store it globally. """
global global_model, X_train, X_test, y_train, y_test
if global_model is None:
# Load Data from SHAP library
X, y = shap.datasets.adult()
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
# Train XGBoost model
global_model = xgboost.XGBClassifier()
global_model.fit(X_train, y_train)
print("XGBoost Model training completed!")
def define_features():
""" Define feature names and categorical mappings. """
feature_names = ["Age", "Workclass",
"Education-Num", "Marital Status", "Occupation",
"Relationship", "Race", "Sex", "Capital Gain",
"Capital Loss", "Hours per week", "Country"]
categorical_features = ["Workclass", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Country"]
class_names = ['<=50K', '>50K']
categorical_names = {
1: ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay',
'Never-worked'],
3: ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed', 'Married-spouse-absent',
'Married-AF-spouse'],
4: ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty',
'Handlers-cleaners',
'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Transport-moving', 'Priv-house-serv',
'Protective-serv', 'Armed-Forces'],
5: ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried'],
6: ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'],
7: ['Female', 'Male'],
11: ['United-States', 'Cambodia', 'England', 'Puerto-Rico', 'Canada', 'Germany', 'Outlying-US(Guam-USVI-etc)',
'India',
'Japan', 'Greece', 'South', 'China', 'Cuba', 'Iran', 'Honduras', 'Philippines', 'Italy', 'Poland',
'Jamaica', 'Vietnam',
'Mexico', 'Portugal', 'Ireland', 'France', 'Dominican-Republic', 'Laos', 'Ecuador', 'Taiwan', 'Haiti',
'Columbia', 'Hungary',
'Guatemala', 'Nicaragua', 'Scotland', 'Thailand', 'Yugoslavia', 'El-Salvador', 'Trinadad&Tobago', 'Peru',
'Hong', 'Holand-Netherlands']
}
return feature_names, categorical_features, class_names, categorical_names
def explain_example(kernel_width, example_idx):
""" Explain a given sample without retraining the model. """
global global_model, X_train, X_test, y_train, y_test
if global_model is None:
train_model()
feature_names, categorical_features, class_names, categorical_names = define_features()
# Initialize LIME Explainer
explainer = lime.lime_tabular.LimeTabularExplainer(
X_train.values,
class_names=class_names,
feature_names=feature_names,
categorical_features=categorical_features,
categorical_names=categorical_names,
kernel_width=kernel_width
)
# Explain the selected sample
exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict_proba, num_features=12)
# Generate HTML explanation
explanation_html = exp.as_html()
# Display explanation in Streamlit
components.html(explanation_html, height=700, scrolling=True)
def main():
global global_model
# Ensure the model is trained only once
if global_model is None:
train_model()
# Streamlit UI Controls
lime_kernel_width = st.sidebar.slider(
label="Set the `kernel` value:",
min_value=0.0,
max_value=100.0,
value=3.0, # Default value
step=0.1, # Step size
help=prompt_params.LIME_KERNEL_WIDTH_HELP,
)
example_idx = st.sidebar.number_input(
label="Select the sample index to explain:",
min_value=0,
max_value=len(X_test) - 1, # Ensures the index is within range
value=1, # Default value
step=1, # Step size
help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX,
)
st.title("LIME: Local Interpretable Model-agnostic Explanations")
st.write(prompt_params.LIME_INTRODUCTION)
# Explain the selected sample
if st.button("Explain Sample"):
explain_example(lime_kernel_width, example_idx)
if __name__ == '__main__':
main()
|