peggy30's picture
add introduction
5963f5d
# Import Libraries
import streamlit as st
import src.prompt_config as prompt_params
import streamlit.components.v1 as components
# Models
import xgboost
from sklearn.model_selection import train_test_split
# XAI (Explainability)
import shap
import lime
# Global Variables to Store Model & Data
global_model = None
X_train, X_test, y_train, y_test = None, None, None, None
def train_model():
""" Train the XGBoost model only once and store it globally. """
global global_model, X_train, X_test, y_train, y_test
if global_model is None:
# Load Data from SHAP library
X, y = shap.datasets.adult()
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
# Train XGBoost model
global_model = xgboost.XGBClassifier()
global_model.fit(X_train, y_train)
print("XGBoost Model training completed!")
def define_features():
""" Define feature names and categorical mappings. """
feature_names = ["Age", "Workclass",
"Education-Num", "Marital Status", "Occupation",
"Relationship", "Race", "Sex", "Capital Gain",
"Capital Loss", "Hours per week", "Country"]
categorical_features = ["Workclass", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Country"]
class_names = ['<=50K', '>50K']
categorical_names = {
1: ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay',
'Never-worked'],
3: ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed', 'Married-spouse-absent',
'Married-AF-spouse'],
4: ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty',
'Handlers-cleaners',
'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Transport-moving', 'Priv-house-serv',
'Protective-serv', 'Armed-Forces'],
5: ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried'],
6: ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'],
7: ['Female', 'Male'],
11: ['United-States', 'Cambodia', 'England', 'Puerto-Rico', 'Canada', 'Germany', 'Outlying-US(Guam-USVI-etc)',
'India',
'Japan', 'Greece', 'South', 'China', 'Cuba', 'Iran', 'Honduras', 'Philippines', 'Italy', 'Poland',
'Jamaica', 'Vietnam',
'Mexico', 'Portugal', 'Ireland', 'France', 'Dominican-Republic', 'Laos', 'Ecuador', 'Taiwan', 'Haiti',
'Columbia', 'Hungary',
'Guatemala', 'Nicaragua', 'Scotland', 'Thailand', 'Yugoslavia', 'El-Salvador', 'Trinadad&Tobago', 'Peru',
'Hong', 'Holand-Netherlands']
}
return feature_names, categorical_features, class_names, categorical_names
def explain_example(kernel_width, example_idx):
""" Explain a given sample without retraining the model. """
global global_model, X_train, X_test, y_train, y_test
if global_model is None:
train_model()
feature_names, categorical_features, class_names, categorical_names = define_features()
# Initialize LIME Explainer
explainer = lime.lime_tabular.LimeTabularExplainer(
X_train.values,
class_names=class_names,
feature_names=feature_names,
categorical_features=categorical_features,
categorical_names=categorical_names,
kernel_width=kernel_width
)
# Explain the selected sample
exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict_proba, num_features=12)
# Generate HTML explanation
explanation_html = exp.as_html()
# Display explanation in Streamlit
components.html(explanation_html, height=700, scrolling=True)
def main():
global global_model
# Ensure the model is trained only once
if global_model is None:
train_model()
# Streamlit UI Controls
lime_kernel_width = st.sidebar.slider(
label="Set the `kernel` value:",
min_value=0.0,
max_value=100.0,
value=3.0, # Default value
step=0.1, # Step size
help=prompt_params.LIME_KERNEL_WIDTH_HELP,
)
example_idx = st.sidebar.number_input(
label="Select the sample index to explain:",
min_value=0,
max_value=len(X_test) - 1, # Ensures the index is within range
value=1, # Default value
step=1, # Step size
help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX,
)
st.title("LIME: Local Interpretable Model-agnostic Explanations")
st.write(prompt_params.LIME_INTRODUCTION)
# Explain the selected sample
if st.button("Explain Sample"):
explain_example(lime_kernel_width, example_idx)
if __name__ == '__main__':
main()