peggy30 commited on
Commit
e1da86b
·
1 Parent(s): d4ef106

add anchors

Browse files
Files changed (4) hide show
  1. .idea/workspace.xml +9 -1
  2. pages/Anchors.py +135 -0
  3. pages/LIME.py +2 -3
  4. src/prompt_config.py +14 -1
.idea/workspace.xml CHANGED
@@ -2,7 +2,10 @@
2
  <project version="4">
3
  <component name="ChangeListManager">
4
  <list default="true" id="d4d4c856-4e4e-4d5f-b4ca-4c1c8515b14a" name="Default Changelist" comment="">
 
 
5
  <change beforePath="$PROJECT_DIR$/pages/LIME.py" beforeDir="false" afterPath="$PROJECT_DIR$/pages/LIME.py" afterDir="false" />
 
6
  </list>
7
  <option name="SHOW_DIALOG" value="false" />
8
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
@@ -33,9 +36,14 @@
33
  <property name="ASKED_ADD_EXTERNAL_FILES" value="true" />
34
  <property name="ASKED_SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
35
  <property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
36
- <property name="last_opened_file_path" value="$PROJECT_DIR$" />
37
  <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
38
  </component>
 
 
 
 
 
39
  <component name="SvnConfiguration">
40
  <configuration />
41
  </component>
 
2
  <project version="4">
3
  <component name="ChangeListManager">
4
  <list default="true" id="d4d4c856-4e4e-4d5f-b4ca-4c1c8515b14a" name="Default Changelist" comment="">
5
+ <change afterPath="$PROJECT_DIR$/pages/Anchors.py" afterDir="false" />
6
+ <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
7
  <change beforePath="$PROJECT_DIR$/pages/LIME.py" beforeDir="false" afterPath="$PROJECT_DIR$/pages/LIME.py" afterDir="false" />
8
+ <change beforePath="$PROJECT_DIR$/src/prompt_config.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/prompt_config.py" afterDir="false" />
9
  </list>
10
  <option name="SHOW_DIALOG" value="false" />
11
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
 
36
  <property name="ASKED_ADD_EXTERNAL_FILES" value="true" />
37
  <property name="ASKED_SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
38
  <property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
39
+ <property name="last_opened_file_path" value="$PROJECT_DIR$/pages" />
40
  <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
41
  </component>
42
+ <component name="RecentsManager">
43
+ <key name="CopyFile.RECENT_KEYS">
44
+ <recent name="$PROJECT_DIR$/pages" />
45
+ </key>
46
+ </component>
47
  <component name="SvnConfiguration">
48
  <configuration />
49
  </component>
pages/Anchors.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import Libraries
2
+
3
+ import pandas as pd
4
+ import streamlit as st
5
+ import src.prompt_config as prompt_params
6
+
7
+ # Models
8
+ import xgboost
9
+ from sklearn.model_selection import train_test_split
10
+
11
+ # XAI (Explainability)
12
+ import shap
13
+ import lime
14
+ from anchor import anchor_tabular
15
+
16
+ # Global Variables to Store Model & Data
17
+ global_model = None
18
+ X_train, X_test, y_train, y_test = None, None, None, None
19
+
20
+
21
+ def train_model():
22
+ """ Train the XGBoost model only once and store it globally. """
23
+ global global_model, X_train, X_test, y_train, y_test
24
+
25
+ if global_model is None:
26
+ # Load Data from SHAP library
27
+ X, y = shap.datasets.adult()
28
+
29
+ # Split data
30
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
31
+
32
+ # Train XGBoost model
33
+ global_model = xgboost.XGBClassifier()
34
+ global_model.fit(X_train, y_train)
35
+
36
+ print("XGBoost Model training completed!")
37
+
38
+
39
+ def define_features():
40
+ """ Define feature names and categorical mappings. """
41
+
42
+ feature_names = ["Age", "Workclass",
43
+ "Education-Num", "Marital Status", "Occupation",
44
+ "Relationship", "Race", "Sex", "Capital Gain",
45
+ "Capital Loss", "Hours per week", "Country"]
46
+
47
+ categorical_features = ["Workclass", "Marital Status", "Occupation", "Relationship", "Race", "Sex", "Country"]
48
+
49
+ class_names = ['<=50K', '>50K']
50
+
51
+ categorical_names = {
52
+ 1: ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay',
53
+ 'Never-worked'],
54
+ 3: ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed', 'Married-spouse-absent',
55
+ 'Married-AF-spouse'],
56
+ 4: ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty',
57
+ 'Handlers-cleaners',
58
+ 'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Transport-moving', 'Priv-house-serv',
59
+ 'Protective-serv', 'Armed-Forces'],
60
+ 5: ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried'],
61
+ 6: ['White', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other', 'Black'],
62
+ 7: ['Female', 'Male'],
63
+ 11: ['United-States', 'Cambodia', 'England', 'Puerto-Rico', 'Canada', 'Germany', 'Outlying-US(Guam-USVI-etc)',
64
+ 'India',
65
+ 'Japan', 'Greece', 'South', 'China', 'Cuba', 'Iran', 'Honduras', 'Philippines', 'Italy', 'Poland',
66
+ 'Jamaica', 'Vietnam',
67
+ 'Mexico', 'Portugal', 'Ireland', 'France', 'Dominican-Republic', 'Laos', 'Ecuador', 'Taiwan', 'Haiti',
68
+ 'Columbia', 'Hungary',
69
+ 'Guatemala', 'Nicaragua', 'Scotland', 'Thailand', 'Yugoslavia', 'El-Salvador', 'Trinadad&Tobago', 'Peru',
70
+ 'Hong', 'Holand-Netherlands']
71
+ }
72
+
73
+ return feature_names, categorical_features, class_names, categorical_names
74
+
75
+
76
+ def explain_example(anchors_threshold, example_idx):
77
+ """ Explain a given sample without retraining the model. """
78
+ global global_model, X_train, X_test, y_train, y_test
79
+
80
+ if global_model is None:
81
+ train_model()
82
+ feature_names, categorical_features, class_names, categorical_names = define_features()
83
+
84
+ # Initialize Anchors explainer
85
+ explainer = anchor_tabular.AnchorTabularExplainer(
86
+ class_names,
87
+ feature_names,
88
+ X_train.values,
89
+ categorical_names)
90
+
91
+ # Explain the selected sample
92
+ exp = explainer.explain_instance(X_test.values[example_idx], global_model.predict, threshold=anchors_threshold)
93
+
94
+ explanation_data = {
95
+ "Feature Rule": exp.names(),
96
+ "Precision": [f"{exp.precision():.2f}"] * len(exp.names()),
97
+ "Coverage": [f"{exp.coverage():.2f}"] * len(exp.names())
98
+ }
99
+ df_explanation = pd.DataFrame(explanation_data)
100
+
101
+ st.table(df_explanation)
102
+
103
+
104
+ def main():
105
+ global global_model
106
+
107
+ # Ensure the model is trained only once
108
+ if global_model is None:
109
+ train_model()
110
+
111
+ anchors_threshold = st.sidebar.slider(
112
+ label="Set the `Threshold` value:",
113
+ min_value=0.00,
114
+ max_value=1.00,
115
+ value=0.8, # Default value
116
+ step=0.01, # Step size
117
+ help=prompt_params.ANCHORS_THRESHOLD_HELP,
118
+ )
119
+
120
+ example_idx = st.sidebar.number_input(
121
+ label="Select the sample index to explain:",
122
+ min_value=0,
123
+ max_value=len(X_test) - 1, # Ensures the index is within range
124
+ value=1, # Default value
125
+ step=1, # Step size
126
+ help=prompt_params.EXAMPLE_BE_EXPLAINED_IDX,
127
+ )
128
+
129
+ # Explain the selected sample
130
+ if st.button("Explain Sample"):
131
+ explain_example(anchors_threshold, example_idx)
132
+
133
+
134
+ if __name__ == '__main__':
135
+ main()
pages/LIME.py CHANGED
@@ -41,7 +41,6 @@ def train_model():
41
  def define_features():
42
  """ Define feature names and categorical mappings. """
43
 
44
-
45
  feature_names = ["Age", "Workclass",
46
  "Education-Num", "Marital Status", "Occupation",
47
  "Relationship", "Race", "Sex", "Capital Gain",
@@ -100,7 +99,7 @@ def explain_example(kernel_width, example_idx):
100
  explanation_html = exp.as_html()
101
 
102
  # Display explanation in Streamlit
103
- components.html(explanation_html, height=600, scrolling=True)
104
 
105
 
106
  def main():
@@ -120,7 +119,7 @@ def main():
120
  help=prompt_params.LIME_KERNEL_WIDTH_HELP,
121
  )
122
 
123
- example_idx = st.sidebar.slider(
124
  label="Select the sample index to explain:",
125
  min_value=0,
126
  max_value=len(X_test) - 1, # Ensures the index is within range
 
41
  def define_features():
42
  """ Define feature names and categorical mappings. """
43
 
 
44
  feature_names = ["Age", "Workclass",
45
  "Education-Num", "Marital Status", "Occupation",
46
  "Relationship", "Race", "Sex", "Capital Gain",
 
99
  explanation_html = exp.as_html()
100
 
101
  # Display explanation in Streamlit
102
+ components.html(explanation_html, height=700, scrolling=True)
103
 
104
 
105
  def main():
 
119
  help=prompt_params.LIME_KERNEL_WIDTH_HELP,
120
  )
121
 
122
+ example_idx = st.sidebar.number_input(
123
  label="Select the sample index to explain:",
124
  min_value=0,
125
  max_value=len(X_test) - 1, # Ensures the index is within range
src/prompt_config.py CHANGED
@@ -24,4 +24,17 @@ for explanation. It determines how far the generated synthetic data points will
24
 
25
  EXAMPLE_BE_EXPLAINED_IDX="""
26
  Select the index of the example you want to explain.
27
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  EXAMPLE_BE_EXPLAINED_IDX="""
26
  Select the index of the example you want to explain.
27
+ e.g., Example 100 is higher than 50K
28
+ """
29
+
30
+ ANCHORS_THRESHOLD_HELP = """
31
+ The `threshold` parameter controls the precision (confidence) level of the Anchor rule.
32
+
33
+ - It defines the **minimum confidence** required for an Anchor rule to be considered valid.
34
+ - The typical range is **0.8 to 0.95**:
35
+ - Lower values (e.g., 0.7) allow more flexible rules but may include some noise.
36
+ - Higher values (e.g., 0.95) ensure highly reliable rules but make them harder to find.
37
+ - If set to **1.0**, only rules with 100% confidence will be accepted, which may result in no valid rules being found.
38
+
39
+ Choosing an appropriate threshold balances **rule reliability** and **availability**.
40
+ """