Files changed (1) hide show
  1. executor.py +50 -14
executor.py CHANGED
@@ -9,7 +9,7 @@ from risk_atlas_nexus.blocks.inference import WMLInferenceEngine
9
  from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams
10
  from risk_atlas_nexus.library import RiskAtlasNexus
11
 
12
- from functools import lru_cache
13
  from dotenv import load_dotenv
14
 
15
  load_dotenv(override=True)
@@ -65,7 +65,26 @@ def generate_subgraph(risk):
65
  return gr.Markdown(value = diagram_string)
66
 
67
 
68
- @lru_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def risk_identifier(usecase: str,
70
  model_name_or_path: str = "meta-llama/llama-3-3-70b-instruct",
71
  taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame:
@@ -79,7 +98,7 @@ def risk_identifier(usecase: str,
79
  "project_id": os.environ["WML_PROJECT_ID"],
80
  },
81
  parameters=WMLInferenceEngineParams(
82
- max_new_tokens=150, decoding_method="greedy", repetition_penalty=1
83
  ), # type: ignore
84
  )
85
 
@@ -121,6 +140,7 @@ def get_controls_and_actions(riskid, taxonomy):
121
  related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)]
122
  action_ids = []
123
  control_ids =[]
 
124
 
125
  if taxonomy == "ibm-risk-atlas":
126
  # look for actions associated with related risks
@@ -133,16 +153,22 @@ def get_controls_and_actions(riskid, taxonomy):
133
  rac = ran.get_related_risk_controls(id=i)
134
  if rac:
135
  control_ids += rac
 
 
 
 
136
 
137
  else:
138
  action_ids = []
139
  control_ids = []
 
140
  else:
141
  # Use only actions related to primary risks
142
  action_ids = ran.get_related_actions(id=riskid)
143
  control_ids = ran.get_related_risk_controls(id=riskid)
 
144
 
145
- return [ran.get_action_by_id(i).name for i in action_ids] + [ran.get_risk_control(i.id).name for i in control_ids] #type: ignore
146
 
147
 
148
  @lru_cache
@@ -169,26 +195,34 @@ def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr
169
 
170
  action_ids = []
171
  control_ids =[]
 
172
 
173
  if taxonomy == "ibm-risk-atlas":
174
  # look for actions associated with related risks
175
  if related_risk_ids:
176
  for i in related_risk_ids:
177
- rai = ran.get_related_actions(id=i)
178
- if rai:
179
- action_ids += rai
180
 
181
- rac = ran.get_related_risk_controls(id=i)
182
- if rac:
183
- control_ids += rac
 
 
 
 
184
 
185
  else:
186
  action_ids = []
187
  control_ids = []
 
188
  else:
189
  # Use only actions related to primary risks
190
  action_ids = ran.get_related_actions(id=riskid)
191
  control_ids = ran.get_related_risk_controls(id=riskid)
 
 
192
 
193
  # Sanitize outputs
194
  if not related_risk_ids:
@@ -200,7 +234,7 @@ def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr
200
  samples = related_risk_ids
201
  sample_labels = [i.name for i in ran.get_related_risks(id=riskid)] #type: ignore
202
 
203
- if not action_ids and not control_ids:
204
  alabel = "No mitigations found."
205
  asamples = None
206
  asample_labels = None
@@ -210,9 +244,11 @@ def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr
210
  alabel = f"Mitigation actions and controls related to risk {riskid}."
211
  asamples = action_ids
212
  asamples_ctl = control_ids
213
- asample_labels = [ran.get_action_by_id(i).description for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl]# type: ignore
214
- asample_name = [ran.get_action_by_id(i).name for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl] #type: ignore
215
- mitdf = pd.DataFrame({"Mitigation": asample_name, "Description": asample_labels})
 
 
216
 
217
  if not related_ai_eval_ids:
218
  blabel = "No related AI evaluations found."
 
9
  from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams
10
  from risk_atlas_nexus.library import RiskAtlasNexus
11
 
12
+ from functools import lru_cache, wraps
13
  from dotenv import load_dotenv
14
 
15
  load_dotenv(override=True)
 
65
  return gr.Markdown(value = diagram_string)
66
 
67
 
68
+ def custom_lru_cache(maxsize=128, exclude_values=(None,[],[[]])):
69
+ """
70
+ Make the LRU cache not cache result when empty result was returned
71
+ """
72
+ def decorator(func):
73
+ cached_func = lru_cache(maxsize=maxsize)(func)
74
+
75
+ @wraps(func)
76
+ def wrapper(*args, **kwargs):
77
+ result = cached_func(*args, **kwargs)
78
+ # check for empty df of risks
79
+ if result[2].constructor_args["samples"] in exclude_values:
80
+ return func(*args, **kwargs)
81
+ return result
82
+
83
+ return wrapper
84
+ return decorator
85
+
86
+
87
+ @custom_lru_cache(exclude_values=(None, []))
88
  def risk_identifier(usecase: str,
89
  model_name_or_path: str = "meta-llama/llama-3-3-70b-instruct",
90
  taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame:
 
98
  "project_id": os.environ["WML_PROJECT_ID"],
99
  },
100
  parameters=WMLInferenceEngineParams(
101
+ max_new_tokens=1000, decoding_method="greedy", repetition_penalty=1
102
  ), # type: ignore
103
  )
104
 
 
140
  related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)]
141
  action_ids = []
142
  control_ids =[]
143
+ intrinsic_ids=[]
144
 
145
  if taxonomy == "ibm-risk-atlas":
146
  # look for actions associated with related risks
 
153
  rac = ran.get_related_risk_controls(id=i)
154
  if rac:
155
  control_ids += rac
156
+
157
+ ran_intrinsics = ran.get_related_intrinsics(risk_id=i)
158
+ if ran_intrinsics:
159
+ intrinsic_ids += ran_intrinsics
160
 
161
  else:
162
  action_ids = []
163
  control_ids = []
164
+ intrinsic_ids=[]
165
  else:
166
  # Use only actions related to primary risks
167
  action_ids = ran.get_related_actions(id=riskid)
168
  control_ids = ran.get_related_risk_controls(id=riskid)
169
+ intrinsic_ids = ran.get_related_intrinsics(risk_id=riskid)
170
 
171
+ return [ran.get_action_by_id(i).name for i in action_ids] + [ran.get_risk_control(i.id).name for i in control_ids] + [ran.get_intrinsic(i.id).name for i in intrinsic_ids]#type: ignore
172
 
173
 
174
  @lru_cache
 
195
 
196
  action_ids = []
197
  control_ids =[]
198
+ intrinsic_ids=[]
199
 
200
  if taxonomy == "ibm-risk-atlas":
201
  # look for actions associated with related risks
202
  if related_risk_ids:
203
  for i in related_risk_ids:
204
+ ran_actions = ran.get_related_actions(id=i)
205
+ if ran_actions:
206
+ action_ids += ran_actions
207
 
208
+ ran_controls = ran.get_related_risk_controls(id=i)
209
+ if ran_controls:
210
+ control_ids += ran_controls
211
+
212
+ ran_intrinsics = ran.get_related_intrinsics(risk_id=i)
213
+ if ran_intrinsics:
214
+ intrinsic_ids += ran_intrinsics
215
 
216
  else:
217
  action_ids = []
218
  control_ids = []
219
+ intrinsic_ids=[]
220
  else:
221
  # Use only actions related to primary risks
222
  action_ids = ran.get_related_actions(id=riskid)
223
  control_ids = ran.get_related_risk_controls(id=riskid)
224
+ intrinsic_ids = ran.get_related_intrinsics(risk_id=riskid)
225
+
226
 
227
  # Sanitize outputs
228
  if not related_risk_ids:
 
234
  samples = related_risk_ids
235
  sample_labels = [i.name for i in ran.get_related_risks(id=riskid)] #type: ignore
236
 
237
+ if not action_ids and not control_ids and not intrinsic_ids:
238
  alabel = "No mitigations found."
239
  asamples = None
240
  asample_labels = None
 
244
  alabel = f"Mitigation actions and controls related to risk {riskid}."
245
  asamples = action_ids
246
  asamples_ctl = control_ids
247
+ asamples_int = intrinsic_ids
248
+ asample_labels = [ran.get_action_by_id(i).description for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl] + [ran.get_intrinsic(i.id).description for i in asamples_int]# type: ignore
249
+ asample_name = [ran.get_action_by_id(i).name for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl] + [ran.get_intrinsic(i.id).name for i in asamples_int] #type: ignore
250
+ asample_types = ["Action" for i in asamples] + ["Control" for i in asamples_ctl] + ["Intrinsic" for i in asamples_int]
251
+ mitdf = pd.DataFrame({"Type": asample_types, "Mitigation": asample_name, "Description": asample_labels})
252
 
253
  if not related_ai_eval_ids:
254
  blabel = "No related AI evaluations found."