peggy30 commited on
Commit
f6d6f0e
Β·
1 Parent(s): ec3fab5
Files changed (3) hide show
  1. pages/Anchors.py +1 -1
  2. pages/LIME.py +0 -5
  3. pages/SHAP.py +106 -0
pages/Anchors.py CHANGED
@@ -86,7 +86,7 @@ def explain_example(anchors_threshold, example_idx):
86
  class_names,
87
  feature_names,
88
  X_train.values,
89
- categorical_names)
90
 
91
  # Explain the selected sample
92
  exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict, threshold=anchors_threshold)
 
86
  class_names,
87
  feature_names,
88
  X_train.values,
89
+ categorical_names, seed=42)
90
 
91
  # Explain the selected sample
92
  exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict, threshold=anchors_threshold)
pages/LIME.py CHANGED
@@ -1,7 +1,4 @@
1
  # Import Libraries
2
- import numpy as np
3
- import pandas as pd
4
- import matplotlib.pyplot as plt
5
  import streamlit as st
6
  import src.prompt_config as prompt_params
7
  import streamlit.components.v1 as components
@@ -12,8 +9,6 @@ from sklearn.model_selection import train_test_split
12
  # XAI (Explainability)
13
  import shap
14
  import lime
15
- # from anchor import anchor_tabular
16
- from sklearn.inspection import PartialDependenceDisplay
17
 
18
  # Global Variables to Store Model & Data
19
  global_model = None
 
1
  # Import Libraries
 
 
 
2
  import streamlit as st
3
  import src.prompt_config as prompt_params
4
  import streamlit.components.v1 as components
 
9
  # XAI (Explainability)
10
  import shap
11
  import lime
 
 
12
 
13
  # Global Variables to Store Model & Data
14
  global_model = None
pages/SHAP.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import Libraries
2
+ import matplotlib.pyplot as plt
3
+ import streamlit as st
4
+ import src.prompt_config as prompt_params
5
+ # Models
6
+ import xgboost
7
+ from sklearn.model_selection import train_test_split
8
+
9
+ # XAI (Explainability)
10
+ import shap
11
+
12
+ # Global Variables to Store Model & Data
13
+ global_model = None
14
+ X_train, X_test, y_train, y_test = None, None, None, None
15
+
16
+
17
+ def train_model():
18
+ """ Train the XGBoost model only once and store it globally. """
19
+ global global_model, X_train, X_test, y_train, y_test
20
+
21
+ if global_model is None:
22
+ # Load Data from SHAP library
23
+ X, y = shap.datasets.adult()
24
+
25
+ # Split data
26
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
27
+
28
+ # Train XGBoost model
29
+ global_model = xgboost.XGBClassifier()
30
+ global_model.fit(X_train, y_train)
31
+
32
+ print("XGBoost Model training completed!")
33
+
34
+
35
+
36
+ def explain_example(kernel_width, example_idx):
37
+ """ Explain a given sample without retraining the model. """
38
+ global global_model, X_train, X_test, y_train, y_test
39
+
40
+ if global_model is None:
41
+ train_model()
42
+
43
+ X, y = shap.datasets.adult()
44
+ X100 = shap.utils.sample(X, 100)
45
+ explainer = shap.TreeExplainer(global_model, X100) # Use the TreeExplainer algorithm with background distribution
46
+ shap_values = explainer.shap_values(X_test) # Get shap values
47
+ shap_values_exp = explainer(X_test) # Get explainer for X_test
48
+
49
+ # SHAP Summary Plot (BeeSwarm)
50
+ st.write("### πŸ“Š SHAP Summary Plot")
51
+ fig, ax = plt.subplots(figsize=(10, 5))
52
+ shap.summary_plot(shap_values, X_test, show=False)
53
+ st.pyplot(fig)
54
+
55
+ # SHAP Summary Bar Plot
56
+ st.write("### πŸ“Š SHAP Feature Importance (Bar Plot)")
57
+ fig, ax = plt.subplots(figsize=(10, 5))
58
+ shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)
59
+ st.pyplot(fig)
60
+
61
+ # SHAP Dependence Plot
62
+ st.write("### πŸ” SHAP Dependence Plot for 'Age'")
63
+ fig, ax = plt.subplots(figsize=(10, 5))
64
+ shap.dependence_plot('Age', shap_values, X_test, ax=ax, show=False)
65
+ st.pyplot(fig)
66
+
67
+ # SHAP Waterfall Plot
68
+ st.write(f"### 🌊 SHAP Waterfall Plot for Example {example_idx}")
69
+ fig, ax = plt.subplots(figsize=(10, 5))
70
+ shap.plots.waterfall(shap_values_exp[example_idx], show=False)
71
+ st.pyplot(fig)
72
+
73
+
74
+ def main():
75
+ global global_model
76
+
77
+ # Ensure the model is trained only once
78
+ if global_model is None:
79
+ train_model()
80
+
81
+ # Streamlit UI Controls
82
+ lime_kernel_width = st.sidebar.slider(
83
+ label="Set the `kernel` value:",
84
+ min_value=0.0,
85
+ max_value=100.0,
86
+ value=3.0, # Default value
87
+ step=0.1, # Step size
88
+ help=prompt_params.LIME_KERNEL_WIDTH_HELP,
89
+ )
90
+
91
+ example_idx = st.sidebar.number_input(
92
+ label="Select the sample index to explain:",
93
+ min_value=0,
94
+ max_value=len(X_test) - 1, # Ensures the index is within range
95
+ value=1, # Default value
96
+ step=1, # Step size
97
+ help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX,
98
+ )
99
+
100
+ # Explain the selected sample
101
+ if st.button("Explain Sample"):
102
+ explain_example(lime_kernel_width, example_idx)
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()