abreza commited on
Commit
1c87170
·
1 Parent(s): ea52793

add shap value

Browse files
Files changed (2) hide show
  1. interface.py +9 -5
  2. model_utils.py +95 -3
interface.py CHANGED
@@ -70,16 +70,16 @@ def predict_with_explanation(age, weight, height, gravidity, parity, h_abortion,
70
  if any(field is None or field == "" for field in required_fields):
71
  return "⚠️ لطفاً تمام فیلدها را پر کنید", "برای پیش‌بینی دقیق، تمام اطلاعات مورد نیاز است.", None
72
 
73
- result, detailed_report = predict_outcome(
74
  age, weight, height, gravidity, parity, h_abortion,
75
  living_child, gestational_age, hemoglobin, hematocrit,
76
  platelet, mpv, pdw, neutrophil, lymphocyte
77
  )
78
 
79
- return result, detailed_report
80
 
81
  def clear_all_fields():
82
- return tuple([None] * 17)
83
 
84
  def load_example(example_name):
85
  example_data = EXAMPLE_CASES[example_name]
@@ -103,6 +103,7 @@ def create_interface():
103
  - پیش‌بینی دقیق با استفاده از هوش مصنوعی
104
  - تحلیل SHAP برای توضیح تأثیر هر ویژگی
105
  - گزارش تفصیلی و قابل فهم برای پزشکان
 
106
 
107
  📝 **راهنما:** تمام فیلدها را پر کنید یا از مثال‌های آماده استفاده کنید.
108
  """)
@@ -119,6 +120,9 @@ def create_interface():
119
  with gr.Column(scale=2):
120
  result_text = gr.Textbox(label="نتیجه پیش‌بینی", lines=2)
121
  detailed_report = gr.Markdown(label="گزارش تفصیلی")
 
 
 
122
 
123
  gr.Markdown("---")
124
  gr.Markdown("## 📚 مثال‌های آماده")
@@ -134,12 +138,12 @@ def create_interface():
134
  predict_btn.click(
135
  fn=predict_with_explanation,
136
  inputs=list(patient_inputs) + list(lab_inputs),
137
- outputs=[result_text, detailed_report]
138
  )
139
 
140
  clear_btn.click(
141
  fn=clear_all_fields,
142
- outputs=list(patient_inputs) + list(lab_inputs) + [result_text, detailed_report]
143
  )
144
 
145
  return demo
 
70
  if any(field is None or field == "" for field in required_fields):
71
  return "⚠️ لطفاً تمام فیلدها را پر کنید", "برای پیش‌بینی دقیق، تمام اطلاعات مورد نیاز است.", None
72
 
73
+ result, detailed_report, shap_plot = predict_outcome(
74
  age, weight, height, gravidity, parity, h_abortion,
75
  living_child, gestational_age, hemoglobin, hematocrit,
76
  platelet, mpv, pdw, neutrophil, lymphocyte
77
  )
78
 
79
+ return result, detailed_report, shap_plot
80
 
81
  def clear_all_fields():
82
+ return tuple([None] * 17) + ("", "", None)
83
 
84
  def load_example(example_name):
85
  example_data = EXAMPLE_CASES[example_name]
 
103
  - پیش‌بینی دقیق با استفاده از هوش مصنوعی
104
  - تحلیل SHAP برای توضیح تأثیر هر ویژگی
105
  - گزارش تفصیلی و قابل فهم برای پزشکان
106
+ - نمودار تصویری تأثیر پارامترها
107
 
108
  📝 **راهنما:** تمام فیلدها را پر کنید یا از مثال‌های آماده استفاده کنید.
109
  """)
 
120
  with gr.Column(scale=2):
121
  result_text = gr.Textbox(label="نتیجه پیش‌بینی", lines=2)
122
  detailed_report = gr.Markdown(label="گزارش تفصیلی")
123
+
124
+ with gr.Column(scale=1):
125
+ shap_plot = gr.Image(label="نمودار SHAP - تأثیر ویژگی‌ها", type="filepath")
126
 
127
  gr.Markdown("---")
128
  gr.Markdown("## 📚 مثال‌های آماده")
 
138
  predict_btn.click(
139
  fn=predict_with_explanation,
140
  inputs=list(patient_inputs) + list(lab_inputs),
141
+ outputs=[result_text, detailed_report, shap_plot]
142
  )
143
 
144
  clear_btn.click(
145
  fn=clear_all_fields,
146
+ outputs=list(patient_inputs) + list(lab_inputs) + [result_text, detailed_report, shap_plot]
147
  )
148
 
149
  return demo
model_utils.py CHANGED
@@ -1,10 +1,20 @@
1
  import numpy as np
2
  import joblib
3
  import warnings
4
- from config import MODEL_PATH
 
 
 
 
 
5
 
6
  warnings.filterwarnings('ignore')
7
 
 
 
 
 
 
8
  def calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet):
9
  height_m = height / 100
10
  bmi = weight / (height_m ** 2)
@@ -12,11 +22,67 @@ def calculate_derived_features(age, weight, height, neutrophil, lymphocyte, plat
12
  plr = platelet / lymphocyte if lymphocyte > 0 else 0
13
  return bmi, nlr, plr
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def predict_outcome(age, weight, height, gravidity, parity, h_abortion,
16
  living_child, gestational_age, hemoglobin, hematocrit,
17
  platelet, mpv, pdw, neutrophil, lymphocyte):
18
  model = get_model()
19
 
 
 
 
20
  try:
21
  bmi, nlr, plr = calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet)
22
 
@@ -48,13 +114,26 @@ def predict_outcome(age, weight, height, gravidity, parity, h_abortion,
48
  - NLR (نسبت نوتروفیل به لنفوسیت): {nlr:.2f}
49
  - PLR (نسبت پلاکت به لنفوسیت): {plr:.2f}
50
 
 
 
 
 
51
  ⚠️ **توجه:** این پیش‌بینی صرفاً جهت کمک به تشخیص است و نباید جایگزین نظر پزشک شود.
52
  """
53
 
54
- return result, detailed_report
 
 
 
 
 
 
 
 
 
55
 
56
  except Exception as e:
57
- return f"خطا در پردازش: {str(e)}", ""
58
 
59
 
60
  model = None
@@ -64,8 +143,21 @@ def get_model():
64
  if model is None:
65
  try:
66
  model = joblib.load(MODEL_PATH)
 
67
  return model
68
  except Exception as e:
69
  print(f"Error loading model: {e}")
70
  return None
71
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import joblib
3
  import warnings
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib
6
+ import shap
7
+ import os
8
+ import tempfile
9
+ from config import MODEL_PATH, FEATURE_NAMES
10
 
11
  warnings.filterwarnings('ignore')
12
 
13
+ matplotlib.use('Agg')
14
+
15
+ plt.rcParams['font.family'] = ['DejaVu Sans']
16
+ plt.rcParams['axes.unicode_minus'] = False
17
+
18
  def calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet):
19
  height_m = height / 100
20
  bmi = weight / (height_m ** 2)
 
22
  plr = platelet / lymphocyte if lymphocyte > 0 else 0
23
  return bmi, nlr, plr
24
 
25
+ def create_shap_plot(shap_values, feature_values, feature_names, prediction_proba):
26
+ shap_vals = shap_values[0][:, 1] # Shape: (18,) - SHAP values for class 1
27
+
28
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
29
+ temp_filename = temp_file.name
30
+ temp_file.close()
31
+
32
+ fig, ax = plt.subplots(figsize=(10, 12))
33
+
34
+ sorted_indices = np.argsort(np.abs(shap_vals))
35
+ sorted_shap_vals = shap_vals[sorted_indices]
36
+ sorted_feature_names = [feature_names[i] for i in sorted_indices]
37
+ sorted_feature_values = feature_values[sorted_indices]
38
+
39
+ colors = ['red' if val > 0 else 'blue' for val in sorted_shap_vals]
40
+ bars = ax.barh(range(len(sorted_shap_vals)), sorted_shap_vals, color=colors, alpha=0.7)
41
+
42
+ ax.set_yticks(range(len(sorted_feature_names)))
43
+ ax.set_yticklabels([f"{name} = {val:.2f}" for name, val in zip(sorted_feature_names, sorted_feature_values)])
44
+ ax.set_xlabel('SHAP Value (Impact on Prediction)', fontsize=12)
45
+ ax.set_title(f'Feature Impact Analysis\nComplication Risk: {prediction_proba[1]*100:.1f}%',
46
+ fontsize=14, pad=20)
47
+
48
+ ax.axvline(x=0, color='black', linestyle='-', alpha=0.3)
49
+
50
+ for i, (bar, val) in enumerate(zip(bars, sorted_shap_vals)):
51
+ if val != 0:
52
+ ax.text(val + (0.001 if val > 0 else -0.001), i, f'{val:.3f}',
53
+ va='center', ha='left' if val > 0 else 'right', fontsize=9)
54
+
55
+ ax.text(0.02, 0.98, 'Red: Increases risk\nBlue: Decreases risk',
56
+ transform=ax.transAxes, va='top', ha='left',
57
+ bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
58
+
59
+ plt.tight_layout()
60
+ plt.savefig(temp_filename, dpi=300, bbox_inches='tight',
61
+ facecolor='white', edgecolor='none')
62
+ plt.close()
63
+
64
+ return temp_filename
65
+
66
+ def get_shap_explainer_and_values(model, input_data):
67
+ background_data = np.array([[
68
+ 28, 65, 162, 24.7, 2, 1, 0, 1, 28, 11.5, 34.0,
69
+ 250, 8.5, 12.0, 6.0, 1.8, 3.33, 139
70
+ ]])
71
+
72
+ explainer = shap.KernelExplainer(model.predict_proba, background_data)
73
+ shap_values = explainer.shap_values(input_data, nsamples=100)
74
+
75
+ return shap_values
76
+
77
+
78
  def predict_outcome(age, weight, height, gravidity, parity, h_abortion,
79
  living_child, gestational_age, hemoglobin, hematocrit,
80
  platelet, mpv, pdw, neutrophil, lymphocyte):
81
  model = get_model()
82
 
83
+ if model is None:
84
+ return "خطا: مدل بارگذاری نشده است", "", None
85
+
86
  try:
87
  bmi, nlr, plr = calculate_derived_features(age, weight, height, neutrophil, lymphocyte, platelet)
88
 
 
114
  - NLR (نسبت نوتروفیل به لنفوسیت): {nlr:.2f}
115
  - PLR (نسبت پلاکت به لنفوسیت): {plr:.2f}
116
 
117
+ **احتمالات تفصیلی:**
118
+ - احتمال سالم بودن: {prediction_proba[0]*100:.1f}%
119
+ - احتمال بروز عوارض: {prediction_proba[1]*100:.1f}%
120
+
121
  ⚠️ **توجه:** این پیش‌بینی صرفاً جهت کمک به تشخیص است و نباید جایگزین نظر پزشک شود.
122
  """
123
 
124
+ shap_values = get_shap_explainer_and_values(model, input_data)
125
+
126
+ shap_plot_path = create_shap_plot(
127
+ shap_values,
128
+ input_data[0],
129
+ FEATURE_NAMES,
130
+ prediction_proba
131
+ )
132
+
133
+ return result, detailed_report, shap_plot_path
134
 
135
  except Exception as e:
136
+ return f"خطا در پردازش: {str(e)}", "", None
137
 
138
 
139
  model = None
 
143
  if model is None:
144
  try:
145
  model = joblib.load(MODEL_PATH)
146
+ print("Model loaded successfully!")
147
  return model
148
  except Exception as e:
149
  print(f"Error loading model: {e}")
150
  return None
151
  return model
152
+
153
+ def cleanup_temp_files():
154
+ try:
155
+ temp_dir = tempfile.gettempdir()
156
+ for filename in os.listdir(temp_dir):
157
+ if filename.endswith('.png') and 'tmp' in filename:
158
+ try:
159
+ os.remove(os.path.join(temp_dir, filename))
160
+ except:
161
+ pass
162
+ except:
163
+ pass