Spaces:
Sleeping
Sleeping
# Import Libraries | |
import streamlit as st | |
import src.prompt_config as prompt_params | |
import streamlit.components.v1 as components | |
# Models | |
import xgboost | |
from sklearn.model_selection import train_test_split | |
# XAI (Explainability) | |
import shap | |
import lime | |
# Global Variables to Store Model & Data | |
global_model = None | |
X_train, X_test, y_train, y_test = None, None, None, None | |
def train_model(): | |
""" Train the XGBoost model only once and store it globally. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
# Load Data from SHAP library | |
X, y = shap.datasets.adult() | |
# Split data | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) | |
# Train XGBoost model | |
global_model = xgboost.XGBClassifier() | |
global_model.fit(X_train, y_train) | |
print("XGBoost Model training completed!") | |
def define_features(): | |
""" Define feature names and categorical mappings. """ | |
feature_names = ["Age", "Workclass", | |
"Education-Num", "Marital Status", "Occupation", | |
"Relationship", "Race", "Sex", "Capital Gain", | |
"Capital Loss", "Hours per week", "Country"] | |
categorical_features = ["Workclass", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Country"] | |
class_names = ['<=50K', '>50K'] | |
categorical_names = { | |
1: ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay', | |
'Never-worked'], | |
3: ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed', 'Married-spouse-absent', | |
'Married-AF-spouse'], | |
4: ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty', | |
'Handlers-cleaners', | |
'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Transport-moving', 'Priv-house-serv', | |
'Protective-serv', 'Armed-Forces'], | |
5: ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried'], | |
6: ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'], | |
7: ['Female', 'Male'], | |
11: ['United-States', 'Cambodia', 'England', 'Puerto-Rico', 'Canada', 'Germany', 'Outlying-US(Guam-USVI-etc)', | |
'India', | |
'Japan', 'Greece', 'South', 'China', 'Cuba', 'Iran', 'Honduras', 'Philippines', 'Italy', 'Poland', | |
'Jamaica', 'Vietnam', | |
'Mexico', 'Portugal', 'Ireland', 'France', 'Dominican-Republic', 'Laos', 'Ecuador', 'Taiwan', 'Haiti', | |
'Columbia', 'Hungary', | |
'Guatemala', 'Nicaragua', 'Scotland', 'Thailand', 'Yugoslavia', 'El-Salvador', 'Trinadad&Tobago', 'Peru', | |
'Hong', 'Holand-Netherlands'] | |
} | |
return feature_names, categorical_features, class_names, categorical_names | |
def explain_example(kernel_width, example_idx): | |
""" Explain a given sample without retraining the model. """ | |
global global_model, X_train, X_test, y_train, y_test | |
if global_model is None: | |
train_model() | |
feature_names, categorical_features, class_names, categorical_names = define_features() | |
# Initialize LIME Explainer | |
explainer = lime.lime_tabular.LimeTabularExplainer( | |
X_train.values, | |
class_names=class_names, | |
feature_names=feature_names, | |
categorical_features=categorical_features, | |
categorical_names=categorical_names, | |
kernel_width=kernel_width | |
) | |
# Explain the selected sample | |
exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict_proba, num_features=12) | |
# Generate HTML explanation | |
explanation_html = exp.as_html() | |
# Display explanation in Streamlit | |
components.html(explanation_html, height=700, scrolling=True) | |
def main(): | |
global global_model | |
# Ensure the model is trained only once | |
if global_model is None: | |
train_model() | |
# Streamlit UI Controls | |
lime_kernel_width = st.sidebar.slider( | |
label="Set the `kernel` value:", | |
min_value=0.0, | |
max_value=100.0, | |
value=3.0, # Default value | |
step=0.1, # Step size | |
help=prompt_params.LIME_KERNEL_WIDTH_HELP, | |
) | |
example_idx = st.sidebar.number_input( | |
label="Select the sample index to explain:", | |
min_value=0, | |
max_value=len(X_test) - 1, # Ensures the index is within range | |
value=1, # Default value | |
step=1, # Step size | |
help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX, | |
) | |
st.title("LIME: Local Interpretable Model-agnostic Explanations") | |
st.write(prompt_params.LIME_INTRODUCTION) | |
# Explain the selected sample | |
if st.button("Explain Sample"): | |
explain_example(lime_kernel_width, example_idx) | |
if __name__ == '__main__': | |
main() | |