peggy30 commited on
Commit
12ee8ba
·
1 Parent(s): 79f6bd8
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +11 -0
  3. pages/LIME.py +136 -0
  4. src/prompt_config.py +27 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .DS_Store
2
+ venv/
app.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import src.prompt_config as prompt_params
3
+
4
+
5
+ def main():
6
+ st.title("Explainable AI")
7
+ st.markdown(prompt_params.APP_INTRODUCTION)
8
+
9
+
10
+ if __name__ == '__main__':
11
+ main()
pages/LIME.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import Libraries
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import streamlit as st
6
+ import src.prompt_config as prompt_params
7
+ import streamlit.components.v1 as components
8
+ # Models
9
+ import xgboost
10
+ from sklearn.model_selection import train_test_split
11
+
12
+ # XAI (Explainability)
13
+ import shap
14
+ import lime
15
+ # from anchor import anchor_tabular
16
+ from sklearn.inspection import PartialDependenceDisplay
17
+
18
+ # Global Variables to Store Model & Data
19
+ global_model = None
20
+ X_train, X_test, y_train, y_test = None, None, None, None
21
+
22
+
23
+ def train_model():
24
+ """ Train the XGBoost model only once and store it globally. """
25
+ global global_model, X_train, X_test, y_train, y_test
26
+
27
+ if global_model is None:
28
+ # Load Data from SHAP library
29
+ X, y = shap.datasets.adult()
30
+
31
+ # Split data
32
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
33
+
34
+ # Train XGBoost model
35
+ global_model = xgboost.XGBClassifier()
36
+ global_model.fit(X_train, y_train)
37
+
38
+ print("XGBoost Model training completed!")
39
+
40
+
41
+ def define_features():
42
+ """ Define feature names and categorical mappings. """
43
+ global feature_names, categorical_features, class_names, categorical_names
44
+
45
+ feature_names = ["Age", "Workclass",
46
+ "Education-Num", "Marital Status", "Occupation",
47
+ "Relationship", "Race", "Sex", "Capital Gain",
48
+ "Capital Loss", "Hours per week", "Country"]
49
+
50
+ categorical_features = ["Workclass", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Country"]
51
+
52
+ class_names = ['<=50K', '>50K']
53
+
54
+ categorical_names = {
55
+ 1: ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay',
56
+ 'Never-worked'],
57
+ 3: ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed', 'Married-spouse-absent',
58
+ 'Married-AF-spouse'],
59
+ 4: ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty',
60
+ 'Handlers-cleaners',
61
+ 'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Transport-moving', 'Priv-house-serv',
62
+ 'Protective-serv', 'Armed-Forces'],
63
+ 5: ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried'],
64
+ 6: ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'],
65
+ 7: ['Female', 'Male'],
66
+ 11: ['United-States', 'Cambodia', 'England', 'Puerto-Rico', 'Canada', 'Germany', 'Outlying-US(Guam-USVI-etc)',
67
+ 'India',
68
+ 'Japan', 'Greece', 'South', 'China', 'Cuba', 'Iran', 'Honduras', 'Philippines', 'Italy', 'Poland',
69
+ 'Jamaica', 'Vietnam',
70
+ 'Mexico', 'Portugal', 'Ireland', 'France', 'Dominican-Republic', 'Laos', 'Ecuador', 'Taiwan', 'Haiti',
71
+ 'Columbia', 'Hungary',
72
+ 'Guatemala', 'Nicaragua', 'Scotland', 'Thailand', 'Yugoslavia', 'El-Salvador', 'Trinadad&Tobago', 'Peru',
73
+ 'Hong', 'Holand-Netherlands']
74
+ }
75
+
76
+
77
+ def explain_example(kernel_width, example_idx):
78
+ """ Explain a given sample without retraining the model. """
79
+ global global_model, X_train, X_test, y_train, y_test
80
+
81
+ if global_model is None:
82
+ train_model()
83
+
84
+ # Initialize LIME Explainer
85
+ explainer = lime.lime_tabular.LimeTabularExplainer(
86
+ X_train.values,
87
+ class_names=class_names,
88
+ feature_names=feature_names,
89
+ categorical_features=categorical_features,
90
+ categorical_names=categorical_names,
91
+ kernel_width=kernel_width
92
+ )
93
+
94
+ # Explain the selected sample
95
+ exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict_proba, num_features=12)
96
+
97
+ # Generate HTML explanation
98
+ explanation_html = exp.as_html()
99
+
100
+ # Display explanation in Streamlit
101
+ components.html(explanation_html, height=600, scrolling=True)
102
+
103
+
104
+ def main():
105
+ global global_model
106
+
107
+ # Ensure the model is trained only once
108
+ if global_model is None:
109
+ train_model()
110
+
111
+ # Streamlit UI Controls
112
+ lime_kernel_width = st.sidebar.slider(
113
+ label="Set the `kernel` value:",
114
+ min_value=0.0,
115
+ max_value=100.0,
116
+ value=3.0, # Default value
117
+ step=0.1, # Step size
118
+ help=prompt_params.LIME_KERNEL_WIDTH_HELP,
119
+ )
120
+
121
+ example_idx = st.sidebar.slider(
122
+ label="Select the sample index to explain:",
123
+ min_value=0,
124
+ max_value=len(X_test) - 1, # Ensures the index is within range
125
+ value=1, # Default value
126
+ step=1, # Step size
127
+ help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX,
128
+ )
129
+
130
+ # Explain the selected sample
131
+ if st.button("Explain Sample"):
132
+ explain_example(lime_kernel_width, example_idx)
133
+
134
+
135
+ if __name__ == '__main__':
136
+ main()
src/prompt_config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ APP_INTRODUCTION = """
2
+ This application provides explainability for machine learning models using LIME and SHAP.
3
+ It allows users to explore how different features influence model predictions by selecting
4
+ specific samples and visualizing their explanations interactively.
5
+ """
6
+
7
+
8
+ LIME_KERNEL_WIDTH_HELP = """
9
+ The `kernel_width` parameter in LIME controls the size of the neighborhood used to generate perturbations around a sample
10
+ for explanation. It determines how far the generated synthetic data points will be from the original instance.
11
+
12
+ ### How It Works:
13
+ - **Smaller Values (e.g., 1.0 - 3.0):** Focus on very local explanations, meaning LIME will give more weight to points closer to the original sample.
14
+ - **Larger Values (e.g., 10.0 - 100.0):** Expand the neighborhood, leading to more global explanations that consider a broader range of feature values.
15
+
16
+ ### Recommended Settings:
17
+ - **For simple models or small datasets:** Start with `kernel_width = 3.0`.
18
+ - **For complex models or high-dimensional data:** A larger value (e.g., `10.0 - 25.0`) may provide better stability.
19
+ - **For debugging or fine-tuning:** Experiment with different values to see how it impacts feature importance rankings.
20
+
21
+ ⚠️ **Note:** A very large `kernel_width` can make explanations less interpretable, as it may include too many outliers.
22
+ """
23
+
24
+
25
+ EXAMPLE_BE_EXPLAINED_IDX="""
26
+ Select the index of the example you want to explain.
27
+ """