from ast import Attribute
from dotenv import load_dotenv
load_dotenv(override=True)
import re
import os
import pandas as pd
import json
from typing import List, Dict, Any
import pandas as pd
import gradio as gr
import datetime
from pathlib import Path
import json
from risk_atlas_nexus.blocks.inference import WMLInferenceEngine
from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams
from risk_atlas_nexus.library import RiskAtlasNexus
from functools import lru_cache
# Load the taxonomies
ran = RiskAtlasNexus() # type: ignore
def clear_previous_risks():
return gr.Markdown("""
Potential Risks
"""), [], gr.Dataset(samples=[],
sample_labels=[],
samples_per_page=50, visible=False), gr.DownloadButton("Download JSON", visible=False, ), "", gr.Dataset(samples=[], sample_labels=[], visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.Markdown(" ")
def clear_previous_mitigations():
return "", gr.Dataset(samples=[], sample_labels=[], visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.Markdown(" ")
@lru_cache
def risk_identifier(usecase: str,
model_name_or_path: str = "ibm/granite-3-3-8b-instruct",
taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame:
downloadable = False
inference_engine = WMLInferenceEngine(
model_name_or_path= model_name_or_path,
credentials={
"api_key": os.environ["WML_API_KEY"],
"api_url": os.environ["WML_API_URL"],
"project_id": os.environ["WML_PROJECT_ID"],
},
parameters=WMLInferenceEngineParams(
max_new_tokens=150, decoding_method="greedy", repetition_penalty=1
), # type: ignore
)
risks = ran.identify_risks_from_usecases( # type: ignore
usecases=[usecase],
inference_engine=inference_engine,
taxonomy=taxonomy,
max_risk=5
)[0]
sample_labels = [r.name if r else r.id for r in risks]
out_sec = gr.Markdown(""" Potential Risks
""")
# write out a JSON
data = {'time': str(datetime.datetime.now(datetime.timezone.utc)),
'intent': usecase,
'model': model_name_or_path,
'taxonomy': taxonomy,
'risks': [json.loads(r.json()) for r in risks]
}
file_path = Path("static/download.json")
with open(file_path, mode='w') as f:
f.write(json.dumps(data, indent=4))
downloadable = True
#return out_df
return out_sec, gr.State(risks), gr.Dataset(samples=[r.id for r in risks],
sample_labels=sample_labels,
samples_per_page=50, visible=True, label="Estimated by an LLM."), gr.DownloadButton("Download JSON", "static/download.json", visible=(downloadable and len(risks) > 0))
@lru_cache
def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr.DataFrame, gr.DataFrame, gr.Markdown]:
"""
For a specific risk (riskid), returns
(a) a risk description
(b) related risks - as a dataset
(c) mitigations
(d) related ai evaluations
"""
try:
risk_desc = ran.get_risk(id=riskid).description # type: ignore
risk_sec = f"Description:
{risk_desc}"
except AttributeError:
risk_sec = ""
related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)]
related_ai_eval_ids = [ai_eval.id for ai_eval in ran.get_related_evaluations(risk_id=riskid)]
action_ids = []
control_ids =[]
if taxonomy == "ibm-risk-atlas":
# look for actions associated with related risks
if related_risk_ids:
for i in related_risk_ids:
rai = ran.get_related_actions(id=i)
if rai:
action_ids += rai
rac = ran.get_related_risk_controls(id=i)
if rac:
control_ids += rac
else:
action_ids = []
control_ids = []
else:
# Use only actions related to primary risks
action_ids = ran.get_related_actions(id=riskid)
control_ids = ran.get_related_risk_controls(id=riskid)
# Sanitize outputs
if not related_risk_ids:
label = "No related risks found."
samples = None
sample_labels = None
else:
label = f"Risks from other taxonomies related to {riskid}"
samples = related_risk_ids
sample_labels = [i.name for i in ran.get_related_risks(id=riskid)] #type: ignore
if not action_ids and not control_ids:
alabel = "No mitigations found."
asamples = None
asample_labels = None
mitdf = pd.DataFrame()
else:
alabel = f"Mitigation actions and controls related to risk {riskid}."
asamples = action_ids
asamples_ctl = control_ids
asample_labels = [ran.get_action_by_id(i).description for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl]# type: ignore
asample_name = [ran.get_action_by_id(i).name for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl] #type: ignore
mitdf = pd.DataFrame({"Mitigation": asample_name, "Description": asample_labels})
if not related_ai_eval_ids:
blabel = "No related AI evaluations found."
bsamples = None
bsample_labels = None
aievalsdf = pd.DataFrame()
else:
blabel = f"AI Evaluations related to {riskid}"
bsamples = related_ai_eval_ids
bsample_labels = [ran.get_evaluation(i).description for i in bsamples] # type: ignore
bsample_name = [ran.get_evaluation(i).name for i in bsamples] #type: ignore
aievalsdf = pd.DataFrame({"AI Evaluation": bsample_name, "Description": bsample_labels})
status = gr.Markdown(" ") if len(mitdf) > 0 else gr.Markdown("No mitigations found.")
return (gr.Markdown(risk_sec),
gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True),
gr.DataFrame(mitdf, wrap=True, show_copy_button=True, show_search="search", label=alabel, visible=True),
gr.DataFrame(aievalsdf, wrap=True, show_copy_button=True, show_search="search", label=blabel, visible=True),
status)