Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	add intrinsics and custom cache
Browse files- executor.py +50 -14
    	
        executor.py
    CHANGED
    
    | @@ -9,7 +9,7 @@ from risk_atlas_nexus.blocks.inference import WMLInferenceEngine | |
| 9 | 
             
            from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams
         | 
| 10 | 
             
            from risk_atlas_nexus.library import RiskAtlasNexus
         | 
| 11 |  | 
| 12 | 
            -
            from functools import lru_cache
         | 
| 13 | 
             
            from dotenv import load_dotenv
         | 
| 14 |  | 
| 15 | 
             
            load_dotenv(override=True)
         | 
| @@ -65,7 +65,26 @@ def generate_subgraph(risk): | |
| 65 | 
             
                return gr.Markdown(value = diagram_string)
         | 
| 66 |  | 
| 67 |  | 
| 68 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 69 | 
             
            def risk_identifier(usecase: str, 
         | 
| 70 | 
             
                                model_name_or_path: str = "meta-llama/llama-3-3-70b-instruct", 
         | 
| 71 | 
             
                                taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame:
         | 
| @@ -79,7 +98,7 @@ def risk_identifier(usecase: str, | |
| 79 | 
             
                        "project_id": os.environ["WML_PROJECT_ID"],
         | 
| 80 | 
             
                    },
         | 
| 81 | 
             
                    parameters=WMLInferenceEngineParams(
         | 
| 82 | 
            -
                        max_new_tokens= | 
| 83 | 
             
                    ),  # type: ignore
         | 
| 84 | 
             
                )
         | 
| 85 |  | 
| @@ -121,6 +140,7 @@ def get_controls_and_actions(riskid, taxonomy): | |
| 121 | 
             
                related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)]
         | 
| 122 | 
             
                action_ids = []
         | 
| 123 | 
             
                control_ids =[]
         | 
|  | |
| 124 |  | 
| 125 | 
             
                if taxonomy == "ibm-risk-atlas":
         | 
| 126 | 
             
                    # look for actions associated with related risks    
         | 
| @@ -133,16 +153,22 @@ def get_controls_and_actions(riskid, taxonomy): | |
| 133 | 
             
                            rac = ran.get_related_risk_controls(id=i)
         | 
| 134 | 
             
                            if rac:
         | 
| 135 | 
             
                                control_ids += rac
         | 
|  | |
|  | |
|  | |
|  | |
| 136 |  | 
| 137 | 
             
                    else:
         | 
| 138 | 
             
                        action_ids = []
         | 
| 139 | 
             
                        control_ids = []
         | 
|  | |
| 140 | 
             
                else:
         | 
| 141 | 
             
                    # Use only actions related to primary risks
         | 
| 142 | 
             
                    action_ids = ran.get_related_actions(id=riskid)
         | 
| 143 | 
             
                    control_ids = ran.get_related_risk_controls(id=riskid)
         | 
|  | |
| 144 |  | 
| 145 | 
            -
                return [ran.get_action_by_id(i).name for i in action_ids] + [ran.get_risk_control(i.id).name for i in control_ids] #type: ignore
         | 
| 146 |  | 
| 147 |  | 
| 148 | 
             
            @lru_cache
         | 
| @@ -169,26 +195,34 @@ def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr | |
| 169 |  | 
| 170 | 
             
                action_ids = []
         | 
| 171 | 
             
                control_ids =[]
         | 
|  | |
| 172 |  | 
| 173 | 
             
                if taxonomy == "ibm-risk-atlas":
         | 
| 174 | 
             
                    # look for actions associated with related risks    
         | 
| 175 | 
             
                    if related_risk_ids:
         | 
| 176 | 
             
                        for i in related_risk_ids:
         | 
| 177 | 
            -
                             | 
| 178 | 
            -
                            if  | 
| 179 | 
            -
                                action_ids +=  | 
| 180 |  | 
| 181 | 
            -
                             | 
| 182 | 
            -
                            if  | 
| 183 | 
            -
                                control_ids +=  | 
|  | |
|  | |
|  | |
|  | |
| 184 |  | 
| 185 | 
             
                    else:
         | 
| 186 | 
             
                        action_ids = []
         | 
| 187 | 
             
                        control_ids = []
         | 
|  | |
| 188 | 
             
                else:
         | 
| 189 | 
             
                    # Use only actions related to primary risks
         | 
| 190 | 
             
                    action_ids = ran.get_related_actions(id=riskid)
         | 
| 191 | 
             
                    control_ids = ran.get_related_risk_controls(id=riskid)
         | 
|  | |
|  | |
| 192 |  | 
| 193 | 
             
                # Sanitize outputs
         | 
| 194 | 
             
                if not related_risk_ids:
         | 
| @@ -200,7 +234,7 @@ def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr | |
| 200 | 
             
                    samples = related_risk_ids
         | 
| 201 | 
             
                    sample_labels = [i.name for i in ran.get_related_risks(id=riskid)] #type: ignore
         | 
| 202 |  | 
| 203 | 
            -
                if not action_ids and not control_ids:
         | 
| 204 | 
             
                    alabel = "No mitigations found."
         | 
| 205 | 
             
                    asamples = None
         | 
| 206 | 
             
                    asample_labels = None
         | 
| @@ -210,9 +244,11 @@ def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr | |
| 210 | 
             
                    alabel = f"Mitigation actions and controls related to risk {riskid}."
         | 
| 211 | 
             
                    asamples = action_ids
         | 
| 212 | 
             
                    asamples_ctl = control_ids
         | 
| 213 | 
            -
                     | 
| 214 | 
            -
                     | 
| 215 | 
            -
                     | 
|  | |
|  | |
| 216 |  | 
| 217 | 
             
                if not related_ai_eval_ids:
         | 
| 218 | 
             
                    blabel = "No related AI evaluations found."
         | 
|  | |
| 9 | 
             
            from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams
         | 
| 10 | 
             
            from risk_atlas_nexus.library import RiskAtlasNexus
         | 
| 11 |  | 
| 12 | 
            +
            from functools import lru_cache, wraps
         | 
| 13 | 
             
            from dotenv import load_dotenv
         | 
| 14 |  | 
| 15 | 
             
            load_dotenv(override=True)
         | 
|  | |
| 65 | 
             
                return gr.Markdown(value = diagram_string)
         | 
| 66 |  | 
| 67 |  | 
| 68 | 
            +
            def custom_lru_cache(maxsize=128, exclude_values=(None,[],[[]])):
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                Make the LRU cache not cache result when empty result was returned
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                def decorator(func):
         | 
| 73 | 
            +
                    cached_func = lru_cache(maxsize=maxsize)(func)
         | 
| 74 | 
            +
                    
         | 
| 75 | 
            +
                    @wraps(func)
         | 
| 76 | 
            +
                    def wrapper(*args, **kwargs):
         | 
| 77 | 
            +
                        result = cached_func(*args, **kwargs)
         | 
| 78 | 
            +
                        # check for empty df of risks
         | 
| 79 | 
            +
                        if result[2].constructor_args["samples"] in exclude_values:
         | 
| 80 | 
            +
                            return func(*args, **kwargs)
         | 
| 81 | 
            +
                        return result
         | 
| 82 | 
            +
                    
         | 
| 83 | 
            +
                    return wrapper
         | 
| 84 | 
            +
                return decorator
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            @custom_lru_cache(exclude_values=(None, []))
         | 
| 88 | 
             
            def risk_identifier(usecase: str, 
         | 
| 89 | 
             
                                model_name_or_path: str = "meta-llama/llama-3-3-70b-instruct", 
         | 
| 90 | 
             
                                taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame:
         | 
|  | |
| 98 | 
             
                        "project_id": os.environ["WML_PROJECT_ID"],
         | 
| 99 | 
             
                    },
         | 
| 100 | 
             
                    parameters=WMLInferenceEngineParams(
         | 
| 101 | 
            +
                        max_new_tokens=1000, decoding_method="greedy", repetition_penalty=1
         | 
| 102 | 
             
                    ),  # type: ignore
         | 
| 103 | 
             
                )
         | 
| 104 |  | 
|  | |
| 140 | 
             
                related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)]
         | 
| 141 | 
             
                action_ids = []
         | 
| 142 | 
             
                control_ids =[]
         | 
| 143 | 
            +
                intrinsic_ids=[]
         | 
| 144 |  | 
| 145 | 
             
                if taxonomy == "ibm-risk-atlas":
         | 
| 146 | 
             
                    # look for actions associated with related risks    
         | 
|  | |
| 153 | 
             
                            rac = ran.get_related_risk_controls(id=i)
         | 
| 154 | 
             
                            if rac:
         | 
| 155 | 
             
                                control_ids += rac
         | 
| 156 | 
            +
                            
         | 
| 157 | 
            +
                            ran_intrinsics = ran.get_related_intrinsics(risk_id=i)
         | 
| 158 | 
            +
                            if ran_intrinsics: 
         | 
| 159 | 
            +
                                intrinsic_ids += ran_intrinsics
         | 
| 160 |  | 
| 161 | 
             
                    else:
         | 
| 162 | 
             
                        action_ids = []
         | 
| 163 | 
             
                        control_ids = []
         | 
| 164 | 
            +
                        intrinsic_ids=[]
         | 
| 165 | 
             
                else:
         | 
| 166 | 
             
                    # Use only actions related to primary risks
         | 
| 167 | 
             
                    action_ids = ran.get_related_actions(id=riskid)
         | 
| 168 | 
             
                    control_ids = ran.get_related_risk_controls(id=riskid)
         | 
| 169 | 
            +
                    intrinsic_ids = ran.get_related_intrinsics(risk_id=riskid)
         | 
| 170 |  | 
| 171 | 
            +
                return [ran.get_action_by_id(i).name for i in action_ids] + [ran.get_risk_control(i.id).name for i in control_ids] + [ran.get_intrinsic(i.id).name for i in intrinsic_ids]#type: ignore
         | 
| 172 |  | 
| 173 |  | 
| 174 | 
             
            @lru_cache
         | 
|  | |
| 195 |  | 
| 196 | 
             
                action_ids = []
         | 
| 197 | 
             
                control_ids =[]
         | 
| 198 | 
            +
                intrinsic_ids=[]
         | 
| 199 |  | 
| 200 | 
             
                if taxonomy == "ibm-risk-atlas":
         | 
| 201 | 
             
                    # look for actions associated with related risks    
         | 
| 202 | 
             
                    if related_risk_ids:
         | 
| 203 | 
             
                        for i in related_risk_ids:
         | 
| 204 | 
            +
                            ran_actions = ran.get_related_actions(id=i)
         | 
| 205 | 
            +
                            if ran_actions:
         | 
| 206 | 
            +
                                action_ids += ran_actions
         | 
| 207 |  | 
| 208 | 
            +
                            ran_controls = ran.get_related_risk_controls(id=i)
         | 
| 209 | 
            +
                            if ran_controls:
         | 
| 210 | 
            +
                                control_ids += ran_controls
         | 
| 211 | 
            +
                            
         | 
| 212 | 
            +
                            ran_intrinsics = ran.get_related_intrinsics(risk_id=i)
         | 
| 213 | 
            +
                            if ran_intrinsics: 
         | 
| 214 | 
            +
                                intrinsic_ids += ran_intrinsics
         | 
| 215 |  | 
| 216 | 
             
                    else:
         | 
| 217 | 
             
                        action_ids = []
         | 
| 218 | 
             
                        control_ids = []
         | 
| 219 | 
            +
                        intrinsic_ids=[]
         | 
| 220 | 
             
                else:
         | 
| 221 | 
             
                    # Use only actions related to primary risks
         | 
| 222 | 
             
                    action_ids = ran.get_related_actions(id=riskid)
         | 
| 223 | 
             
                    control_ids = ran.get_related_risk_controls(id=riskid)
         | 
| 224 | 
            +
                    intrinsic_ids = ran.get_related_intrinsics(risk_id=riskid)
         | 
| 225 | 
            +
                    
         | 
| 226 |  | 
| 227 | 
             
                # Sanitize outputs
         | 
| 228 | 
             
                if not related_risk_ids:
         | 
|  | |
| 234 | 
             
                    samples = related_risk_ids
         | 
| 235 | 
             
                    sample_labels = [i.name for i in ran.get_related_risks(id=riskid)] #type: ignore
         | 
| 236 |  | 
| 237 | 
            +
                if not action_ids and not control_ids and not intrinsic_ids:
         | 
| 238 | 
             
                    alabel = "No mitigations found."
         | 
| 239 | 
             
                    asamples = None
         | 
| 240 | 
             
                    asample_labels = None
         | 
|  | |
| 244 | 
             
                    alabel = f"Mitigation actions and controls related to risk {riskid}."
         | 
| 245 | 
             
                    asamples = action_ids
         | 
| 246 | 
             
                    asamples_ctl = control_ids
         | 
| 247 | 
            +
                    asamples_int = intrinsic_ids
         | 
| 248 | 
            +
                    asample_labels = [ran.get_action_by_id(i).description for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl] + [ran.get_intrinsic(i.id).description for i in asamples_int]# type: ignore
         | 
| 249 | 
            +
                    asample_name = [ran.get_action_by_id(i).name for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl] + [ran.get_intrinsic(i.id).name for i in asamples_int] #type: ignore
         | 
| 250 | 
            +
                    asample_types = ["Action" for i in asamples] + ["Control" for i in asamples_ctl] + ["Intrinsic" for i in asamples_int]
         | 
| 251 | 
            +
                    mitdf = pd.DataFrame({"Type": asample_types, "Mitigation": asample_name, "Description": asample_labels})
         | 
| 252 |  | 
| 253 | 
             
                if not related_ai_eval_ids:
         | 
| 254 | 
             
                    blabel = "No related AI evaluations found."
         | 
