# Import Libraries import streamlit as st import src.prompt_config as prompt_params import streamlit.components.v1 as components # Models import xgboost from sklearn.model_selection import train_test_split # XAI (Explainability) import shap import lime # Global Variables to Store Model & Data global_model = None X_train, X_test, y_train, y_test = None, None, None, None def train_model(): """ Train the XGBoost model only once and store it globally. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: # Load Data from SHAP library X, y = shap.datasets.adult() # Split data X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) # Train XGBoost model global_model = xgboost.XGBClassifier() global_model.fit(X_train, y_train) print("XGBoost Model training completed!") def define_features(): """ Define feature names and categorical mappings. """ feature_names = ["Age", "Workclass", "Education-Num", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Capital Gain", "Capital Loss", "Hours per week", "Country"] categorical_features = ["Workclass", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Country"] class_names = ['<=50K', '>50K'] categorical_names = { 1: ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay', 'Never-worked'], 3: ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed', 'Married-spouse-absent', 'Married-AF-spouse'], 4: ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty', 'Handlers-cleaners', 'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Transport-moving', 'Priv-house-serv', 'Protective-serv', 'Armed-Forces'], 5: ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried'], 6: ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'], 7: ['Female', 'Male'], 11: ['United-States', 'Cambodia', 'England', 'Puerto-Rico', 'Canada', 'Germany', 'Outlying-US(Guam-USVI-etc)', 'India', 'Japan', 'Greece', 'South', 'China', 'Cuba', 'Iran', 'Honduras', 'Philippines', 'Italy', 'Poland', 'Jamaica', 'Vietnam', 'Mexico', 'Portugal', 'Ireland', 'France', 'Dominican-Republic', 'Laos', 'Ecuador', 'Taiwan', 'Haiti', 'Columbia', 'Hungary', 'Guatemala', 'Nicaragua', 'Scotland', 'Thailand', 'Yugoslavia', 'El-Salvador', 'Trinadad&Tobago', 'Peru', 'Hong', 'Holand-Netherlands'] } return feature_names, categorical_features, class_names, categorical_names def explain_example(kernel_width, example_idx): """ Explain a given sample without retraining the model. """ global global_model, X_train, X_test, y_train, y_test if global_model is None: train_model() feature_names, categorical_features, class_names, categorical_names = define_features() # Initialize LIME Explainer explainer = lime.lime_tabular.LimeTabularExplainer( X_train.values, class_names=class_names, feature_names=feature_names, categorical_features=categorical_features, categorical_names=categorical_names, kernel_width=kernel_width ) # Explain the selected sample exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict_proba, num_features=12) # Generate HTML explanation explanation_html = exp.as_html() # Display explanation in Streamlit components.html(explanation_html, height=700, scrolling=True) def main(): global global_model # Ensure the model is trained only once if global_model is None: train_model() # Streamlit UI Controls lime_kernel_width = st.sidebar.slider( label="Set the `kernel` value:", min_value=0.0, max_value=100.0, value=3.0, # Default value step=0.1, # Step size help=prompt_params.LIME_KERNEL_WIDTH_HELP, ) example_idx = st.sidebar.number_input( label="Select the sample index to explain:", min_value=0, max_value=len(X_test) - 1, # Ensures the index is within range value=1, # Default value step=1, # Step size help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX, ) st.title("LIME: Local Interpretable Model-agnostic Explanations") st.write(prompt_params.LIME_INTRODUCTION) # Explain the selected sample if st.button("Explain Sample"): explain_example(lime_kernel_width, example_idx) if __name__ == '__main__': main()