Spaces:
Running
Running
| from dotenv import load_dotenv | |
| load_dotenv(override=True) | |
| import re | |
| import os | |
| import json | |
| from typing import List, Dict, Any | |
| import pandas as pd | |
| import gradio as gr | |
| from risk_atlas_nexus.blocks.inference import WMLInferenceEngine | |
| from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams | |
| from risk_atlas_nexus.library import RiskAtlasNexus | |
| from functools import lru_cache | |
| # Load the taxonomies | |
| ran = RiskAtlasNexus() | |
| def risk_identifier(usecase: str, | |
| model_name_or_path: str = "ibm/granite-20b-code-instruct", | |
| taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame: | |
| # inference_engine = WMLInferenceEngine( | |
| # model_name_or_path= model_name_or_path, | |
| # credentials={ | |
| # "api_key": os.environ["WML_API_KEY"], | |
| # "api_url": os.environ["WML_API_URL"], | |
| # "project_id": os.environ["WML_PROJECT_ID"], | |
| # }, | |
| # parameters=WMLInferenceEngineParams( | |
| # max_new_tokens=100, decoding_method="greedy", repetition_penalty=1 | |
| # ), # type: ignore | |
| # ) | |
| # risks = ran.identify_risks_from_usecase( | |
| # usecase=usecase, | |
| # inference_engine=inference_engine, | |
| # taxonomy=taxonomy, | |
| # ) | |
| risks = ' ["Harmful code generation", "Hallucination", "Harmful output", "Toxic output", "Spreading toxicity", "Spreading disinformation", "Nonconsensual use", "Non-disclosure", "Data contamination", "Data acquisition restrictions", "Data usage rights restrictions", "Confidential data in prompt", "Confidential information in data", "Personal information in prompt", "Personal information in data", "IP information in prompt",' | |
| if isinstance(risks, str): | |
| # Translate LLM output to Risk catalog | |
| out = [] | |
| try: | |
| risks = json.loads(risks) | |
| except json.JSONDecodeError: | |
| # Fallback to regex - will skip any partial categories | |
| risks = re.findall(r'"(.*?)"', risks) | |
| for risk in risks: | |
| k = list(filter(lambda r: r.name == risk, ran._ontology.risks)) # type: ignore | |
| out+= ([i.model_dump() for i in k]) | |
| elif isinstance(risks, List): | |
| # FIXME: assumes that the output is structured - not sure if that's correct. | |
| out = risks | |
| #out_df = pd.DataFrame(out) | |
| #return out_df | |
| return gr.State(out), gr.Dataset(samples=[i['id'] for i in out], | |
| sample_labels=[i['name'] for i in out], samples_per_page=50, visible=True, label="Estimated by an LLM.") | |
| def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Dataset, gr.Dataset]: | |
| related_risk_ids = ran.get_related_risk_ids_by_risk_id(riskid) | |
| action_ids = [] | |
| if taxonomy == "ibm-risk-atlas": | |
| # look for actions associated with related risks | |
| if related_risk_ids: | |
| for i in related_risk_ids: | |
| rai = ran.get_risk_actions_by_risk_id(i) | |
| if rai: | |
| action_ids += rai | |
| else: | |
| action_ids = [] | |
| else: | |
| # Use only actions related to primary risks | |
| action_ids = ran.get_risk_actions_by_risk_id(riskid) | |
| # Sanitize outputs | |
| if not related_risk_ids: | |
| label = "No related risks found." | |
| samples = None | |
| sample_labels = None | |
| else: | |
| label = "Related risks" | |
| samples = related_risk_ids | |
| sample_labels = [i.name for i in ran.get_related_risks_by_risk_id(riskid)] | |
| if not action_ids: | |
| alabel = "No mitigations found." | |
| asamples = None | |
| asample_labels = None | |
| else: | |
| alabel = "" | |
| asamples = action_ids | |
| asample_labels = [ran.get_action_by_id(i).description for i in action_ids] | |
| return (gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True), | |
| gr.Dataset(samples=asamples, label=alabel, sample_labels=asample_labels, visible=True)) | |