Spaces:
Sleeping
Sleeping
add permutation
Browse files- pages/ICE_and_PDP.py +5 -6
- pages/PermutationFeatureImportance.py +75 -0
- requirements.txt +2 -1
pages/ICE_and_PDP.py
CHANGED
@@ -50,13 +50,12 @@ def main():
|
|
50 |
if global_model is None:
|
51 |
train_model()
|
52 |
# Define feature names
|
53 |
-
feature_names = ["Age", "Workclass", "Education-Num", "Marital Status", "Occupation",
|
54 |
-
"Relationship", "Race", "Sex", "Capital Gain", "Capital Loss", "Hours per week", "Country"]
|
55 |
-
print(X_test.columns) # Check the actual feature names
|
56 |
-
|
57 |
-
selected_feature = st.sidebar.selectbox("Select a feature for PDP/ICE analysis:", feature_names)
|
58 |
|
59 |
-
|
|
|
|
|
|
|
|
|
60 |
|
61 |
st.title("ICE (Individual Conditional Expectation) and PDP (Partial Dependence Plot)")
|
62 |
st.write(prompt_params.ICE_INTRODUCTION)
|
|
|
50 |
if global_model is None:
|
51 |
train_model()
|
52 |
# Define feature names
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
|
55 |
+
selected_feature = st.sidebar.selectbox("Select a feature for PDP/ICE analysis:", ("Age", "Workclass", "Education-Num", "Marital Status", "Occupation",
|
56 |
+
"Relationship", "Race", "Sex", "Capital Gain", "Capital Loss", "Hours per week", "Country"))
|
57 |
+
st.write(f"selected feature is {selected_feature}")
|
58 |
+
kind = st.sidebar.selectbox("Select plot type:", ("average", "both", "individual"))
|
59 |
|
60 |
st.title("ICE (Individual Conditional Expectation) and PDP (Partial Dependence Plot)")
|
61 |
st.write(prompt_params.ICE_INTRODUCTION)
|
pages/PermutationFeatureImportance.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import Libraries
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import streamlit as st
|
4 |
+
import src.prompt_config as prompt_params
|
5 |
+
# Models
|
6 |
+
import xgboost
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from alepython import ale_plot
|
9 |
+
# XAI (Explainability)
|
10 |
+
import shap
|
11 |
+
from sklearn.inspection import permutation_importance
|
12 |
+
# Global Variables to Store Model & Data
|
13 |
+
global_model = None
|
14 |
+
X_train, X_test, y_train, y_test = None, None, None, None
|
15 |
+
|
16 |
+
|
17 |
+
def train_model():
|
18 |
+
""" Train the XGBoost model only once and store it globally. """
|
19 |
+
global global_model, X_train, X_test, y_train, y_test
|
20 |
+
|
21 |
+
if global_model is None:
|
22 |
+
# Load Data from SHAP library
|
23 |
+
X, y = shap.datasets.adult()
|
24 |
+
|
25 |
+
# Split data
|
26 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
|
27 |
+
|
28 |
+
# Train XGBoost model
|
29 |
+
global_model = xgboost.XGBClassifier()
|
30 |
+
global_model.fit(X_train, y_train)
|
31 |
+
|
32 |
+
print("XGBoost Model training completed!")
|
33 |
+
|
34 |
+
def explain_example():
|
35 |
+
""" Explain a given sample without retraining the model. """
|
36 |
+
global global_model, X_train, X_test, y_train, y_test
|
37 |
+
|
38 |
+
if global_model is None:
|
39 |
+
train_model()
|
40 |
+
|
41 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
42 |
+
st.write("1D Main Effect ALE Plot")
|
43 |
+
perm_imp = permutation_importance(global_model, X_test, y_test,
|
44 |
+
n_repeats=30,
|
45 |
+
random_state=0)
|
46 |
+
sorted_idx = perm_imp.importances_mean.argsort()
|
47 |
+
ax.barh(X_test.columns[sorted_idx], perm_imp.importances[sorted_idx].mean(axis=1).T)
|
48 |
+
ax.set_title("Permutation Importances")
|
49 |
+
fig.tight_layout()
|
50 |
+
st.pyplot(fig)
|
51 |
+
|
52 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
53 |
+
st.write("2D Second-Order ALE Plot")
|
54 |
+
ax.boxplot(perm_imp.importances[sorted_idx].T,
|
55 |
+
vert=False, labels=X_test.columns[sorted_idx])
|
56 |
+
ax.set_title("Permutation Importances")
|
57 |
+
fig.tight_layout()
|
58 |
+
st.pyplot(fig)
|
59 |
+
|
60 |
+
def main():
|
61 |
+
global global_model
|
62 |
+
|
63 |
+
# Ensure the model is trained only once
|
64 |
+
if global_model is None:
|
65 |
+
train_model()
|
66 |
+
|
67 |
+
st.title("ALE (Accumulated Local Effects)")
|
68 |
+
st.write(prompt_params.ALE_INTRODUCTION)
|
69 |
+
# Explain the selected sample
|
70 |
+
if st.button("Explain Sample"):
|
71 |
+
explain_example()
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
main()
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ lime
|
|
3 |
xgboost
|
4 |
shap
|
5 |
anchor-exp
|
6 |
-
scikit-learn
|
|
|
|
3 |
xgboost
|
4 |
shap
|
5 |
anchor-exp
|
6 |
+
scikit-learn
|
7 |
+
git+https://github.com/MaximeJumelle/ALEPython.git@dev#egg=alepython
|