peggy30 commited on
Commit
3303745
·
1 Parent(s): e73bf99

add permutation

Browse files
pages/ICE_and_PDP.py CHANGED
@@ -50,13 +50,12 @@ def main():
50
  if global_model is None:
51
  train_model()
52
  # Define feature names
53
- feature_names = ["Age", "Workclass", "Education-Num", "Marital Status", "Occupation",
54
- "Relationship", "Race", "Sex", "Capital Gain", "Capital Loss", "Hours per week", "Country"]
55
- print(X_test.columns) # Check the actual feature names
56
-
57
- selected_feature = st.sidebar.selectbox("Select a feature for PDP/ICE analysis:", feature_names)
58
 
59
- kind = st.sidebar.selectbox("Select plot type:", ["average", "both", "individual"])
 
 
 
 
60
 
61
  st.title("ICE (Individual Conditional Expectation) and PDP (Partial Dependence Plot)")
62
  st.write(prompt_params.ICE_INTRODUCTION)
 
50
  if global_model is None:
51
  train_model()
52
  # Define feature names
 
 
 
 
 
53
 
54
+
55
+ selected_feature = st.sidebar.selectbox("Select a feature for PDP/ICE analysis:", ("Age", "Workclass", "Education-Num", "Marital Status", "Occupation",
56
+ "Relationship", "Race", "Sex", "Capital Gain", "Capital Loss", "Hours per week", "Country"))
57
+ st.write(f"selected feature is {selected_feature}")
58
+ kind = st.sidebar.selectbox("Select plot type:", ("average", "both", "individual"))
59
 
60
  st.title("ICE (Individual Conditional Expectation) and PDP (Partial Dependence Plot)")
61
  st.write(prompt_params.ICE_INTRODUCTION)
pages/PermutationFeatureImportance.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import Libraries
2
+ import matplotlib.pyplot as plt
3
+ import streamlit as st
4
+ import src.prompt_config as prompt_params
5
+ # Models
6
+ import xgboost
7
+ from sklearn.model_selection import train_test_split
8
+ from alepython import ale_plot
9
+ # XAI (Explainability)
10
+ import shap
11
+ from sklearn.inspection import permutation_importance
12
+ # Global Variables to Store Model & Data
13
+ global_model = None
14
+ X_train, X_test, y_train, y_test = None, None, None, None
15
+
16
+
17
+ def train_model():
18
+ """ Train the XGBoost model only once and store it globally. """
19
+ global global_model, X_train, X_test, y_train, y_test
20
+
21
+ if global_model is None:
22
+ # Load Data from SHAP library
23
+ X, y = shap.datasets.adult()
24
+
25
+ # Split data
26
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
27
+
28
+ # Train XGBoost model
29
+ global_model = xgboost.XGBClassifier()
30
+ global_model.fit(X_train, y_train)
31
+
32
+ print("XGBoost Model training completed!")
33
+
34
+ def explain_example():
35
+ """ Explain a given sample without retraining the model. """
36
+ global global_model, X_train, X_test, y_train, y_test
37
+
38
+ if global_model is None:
39
+ train_model()
40
+
41
+ fig, ax = plt.subplots(figsize=(10, 5))
42
+ st.write("1D Main Effect ALE Plot")
43
+ perm_imp = permutation_importance(global_model, X_test, y_test,
44
+ n_repeats=30,
45
+ random_state=0)
46
+ sorted_idx = perm_imp.importances_mean.argsort()
47
+ ax.barh(X_test.columns[sorted_idx], perm_imp.importances[sorted_idx].mean(axis=1).T)
48
+ ax.set_title("Permutation Importances")
49
+ fig.tight_layout()
50
+ st.pyplot(fig)
51
+
52
+ fig, ax = plt.subplots(figsize=(10, 5))
53
+ st.write("2D Second-Order ALE Plot")
54
+ ax.boxplot(perm_imp.importances[sorted_idx].T,
55
+ vert=False, labels=X_test.columns[sorted_idx])
56
+ ax.set_title("Permutation Importances")
57
+ fig.tight_layout()
58
+ st.pyplot(fig)
59
+
60
+ def main():
61
+ global global_model
62
+
63
+ # Ensure the model is trained only once
64
+ if global_model is None:
65
+ train_model()
66
+
67
+ st.title("ALE (Accumulated Local Effects)")
68
+ st.write(prompt_params.ALE_INTRODUCTION)
69
+ # Explain the selected sample
70
+ if st.button("Explain Sample"):
71
+ explain_example()
72
+
73
+
74
+ if __name__ == '__main__':
75
+ main()
requirements.txt CHANGED
@@ -3,4 +3,5 @@ lime
3
  xgboost
4
  shap
5
  anchor-exp
6
- scikit-learn
 
 
3
  xgboost
4
  shap
5
  anchor-exp
6
+ scikit-learn
7
+ git+https://github.com/MaximeJumelle/ALEPython.git@dev#egg=alepython