evb-br / app.py
mmrech's picture
Update app.py
0b42b27 verified
import gradio as gr
import pandas as pd
import joblib
import numpy as np
PREPROCESSOR_PATH = "preprocessor_v1.0.joblib"
MODEL_PATH = "calibrated_random_forest_model_updated_v1.1.joblib"
preprocessor = joblib.load(PREPROCESSOR_PATH)
model = joblib.load(MODEL_PATH)
def calculate_meld(bilirubin, inr, creatinine):
"""Calculate MELD score"""
bilirubin = max(bilirubin, 1.0)
inr = max(inr, 1.0)
creatinine = max(creatinine, 1.0)
meld = 3.78 * np.log(bilirubin) + 11.2 * np.log(inr) + 9.57 * np.log(creatinine) + 6.43
return int(np.clip(np.round(meld), 6, 40))
def calculate_meld_na(bilirubin, inr, creatinine, sodium):
"""Calculate MELD-Na score"""
meld = calculate_meld(bilirubin, inr, creatinine)
sodium = np.clip(sodium, 125, 137)
meld_na = meld + 1.32 * (137 - sodium) - (0.033 * meld * (137 - sodium))
return int(np.clip(np.round(meld_na), 6, 40))
def calculate_child_pugh(bilirubin, albumin, inr, ascites):
"""Calculate Child-Pugh score"""
score = 0
# Bilirubin
if bilirubin < 2:
score += 1
elif bilirubin <= 3:
score += 2
else:
score += 3
# Albumin
if albumin > 3.5:
score += 1
elif albumin >= 2.8:
score += 2
else:
score += 3
# INR
if inr < 1.7:
score += 1
elif inr <= 2.3:
score += 2
else:
score += 3
# Ascites
if ascites == 'no':
score += 1
else:
score += 2 # Assuming mild-moderate
# Encephalopathy (not available, assume none)
score += 1
# Determine class
if score <= 6:
cp_class = "A"
elif score <= 9:
cp_class = "B"
else:
cp_class = "C"
return score, cp_class
def predict_patient_outcome(
age: int,
sex: str,
race: str,
etiology_cirrosis: str,
hepatorenal_syndrome: str,
omeprazole: str,
spironolactone: str,
furosemide: str,
propanolol: str,
dialisis: str,
portal_vein_thrombosis: str,
ascitis: str,
hepatocellular_carcinoma: str,
albumin: float,
total_bilirrubin: float,
direct_bilirrubina: float,
inr: float,
creatinine: float,
platelets: float,
ast: float,
alt: float,
hemoglobin: float,
hematocrit: float,
leucocytes: float,
sodium: float,
potassium: float,
varices: str,
red_wale_marks: str,
rupture_point: str,
active_bleeding: str,
therapy: str,
terlipressin_dose: float,
time_to_endoscophy_hours: float,
rebleeding: str
):
input_data = {
"age": age,
"sex": sex,
"race": race,
"etiology_cirrosis": etiology_cirrosis,
"hepatorenal_syndrome": hepatorenal_syndrome,
"omeprazole": omeprazole,
"spironolactone": spironolactone,
"furosemide": furosemide,
"propanolol": propanolol,
"dialisis": dialisis,
"portal_vein_thrombosis": portal_vein_thrombosis,
"ascitis": ascitis,
"hepatocellular_carcinoma": hepatocellular_carcinoma,
"albumin": albumin,
"total_bilirrubin": total_bilirrubin,
"direct_bilirrubina": direct_bilirrubina,
"inr": inr,
"creatinine": creatinine,
"platelets": platelets,
"ast": ast,
"alt": alt,
"hemoglobin": hemoglobin,
"hematocrit": hematocrit,
"leucocytes": leucocytes,
"sodium": sodium,
"potassium": potassium,
"varices": varices,
"red_wale_marks": red_wale_marks,
"rupture_point": rupture_point,
"active_bleeding": active_bleeding,
"therapy": therapy,
"terlipressin_dose": terlipressin_dose,
"time-to-endoscophy_hours": time_to_endoscophy_hours,
"rebleeding": rebleeding
}
df = pd.DataFrame([input_data])
processed_data = preprocessor.transform(df)
# ML Model predictions
prediction = model.predict(processed_data)[0]
probability = model.predict_proba(processed_data)[:, 1][0]
# Calculate confidence interval using calibration uncertainty
# This is a simplified approach - in production, use proper bootstrap
confidence_margin = 0.15 # Simplified confidence margin
ci_lower = max(0, probability - confidence_margin)
ci_upper = min(1, probability + confidence_margin)
# Calculate traditional scores
meld = calculate_meld(total_bilirrubin, inr, creatinine)
meld_na = calculate_meld_na(total_bilirrubin, inr, creatinine, sodium)
child_pugh, cp_class = calculate_child_pugh(total_bilirrubin, albumin, inr, ascitis)
# Risk interpretation based on ML probability
if probability < 0.3:
risk_category = "Low Risk"
risk_color = "🟢"
elif probability < 0.6:
risk_category = "Moderate Risk"
risk_color = "🟡"
else:
risk_category = "High Risk"
risk_color = "🔴"
# Create formatted output
ml_output = f"""
### Machine Learning Model Results
**Predicted Outcome:** {"Death within 1 year" if prediction == 1 else "Survival beyond 1 year"}
**Mortality Probability:** {probability:.1%} (95% CI: {ci_lower:.1%} - {ci_upper:.1%})
**Risk Category:** {risk_color} {risk_category}
"""
traditional_scores = f"""
### Traditional Clinical Scores
**MELD Score:** {meld} (range: 6-40)
- Expected 3-month mortality: {['<10%' if meld < 10 else '10-19%' if meld < 20 else '20-50%' if meld < 30 else '>50%'][0]}
**MELD-Na Score:** {meld_na} (range: 6-40)
**Child-Pugh Score:** {child_pugh} (Class {cp_class})
- Class A (5-6): Well-compensated disease
- Class B (7-9): Significant functional compromise
- Class C (10-15): Decompensated disease
"""
comparison = f"""
### Model Performance Comparison
Based on our validation study:
- **Random Forest Model**: AUC 0.915 (Sensitivity: 80%, Specificity: 86%)
- **MELD-Na**: AUC 0.742 (Sensitivity: 69%, Specificity: 72%)
- **MELD**: AUC 0.726 (Sensitivity: 67%, Specificity: 70%)
- **Child-Pugh**: AUC 0.685 (Sensitivity: 63%, Specificity: 67%)
The ML model shows superior predictive performance compared to traditional scores.
"""
return ml_output, traditional_scores, comparison
###############################
# GRADIO BLOCKS INTERFACE
###############################
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# EVB PROGNOSIS: 1-Year Mortality Risk Calculator
**Advanced Machine Learning Model for Predicting Post-Bleeding Survival in Cirrhotic Patients**
This tool combines a state-of-the-art Random Forest model with traditional clinical scores to provide
comprehensive mortality risk assessment for patients with acute esophageal variceal bleeding.
⚠️ **Clinical Decision Support Tool** - Results should be interpreted by qualified healthcare professionals
in conjunction with clinical judgment.
"""
)
# TAB 1: General Info
with gr.Tab("1️⃣ General Info"):
with gr.Row():
age = gr.Slider(minimum=18, maximum=100, step=1, label="Age", value=50)
sex = gr.Dropdown(choices=["male", "female"], label="Sex", value="male")
race = gr.Dropdown(choices=["white", "black", "asian", "other"], label="Race*", value="white")
etiology_cirrosis = gr.Dropdown(
choices=["alcohol", "hcv", "alcohol+hcv", "other"],
label="Etiology Cirrhosis",
value="alcohol"
)
with gr.Row():
hepatorenal_syndrome = gr.Dropdown(choices=["yes", "no"], label="Hepatorenal Syndrome", value="no")
omeprazole = gr.Dropdown(choices=["yes", "no"], label="Omeprazole", value="no")
spironolactone = gr.Dropdown(choices=["yes", "no"], label="Spironolactone", value="yes")
furosemide = gr.Dropdown(choices=["yes", "no"], label="Furosemide", value="yes")
propanolol = gr.Dropdown(choices=["yes", "no"], label="Propanolol", value="no")
dialisis = gr.Dropdown(choices=["yes", "no"], label="Dialysis", value="no")
gr.Markdown(
"""
*Note: Race is included as it was identified as a significant predictor in our model.
However, we acknowledge the ethical considerations and recommend interpreting results
within the broader clinical context.
"""
)
# TAB 2: Clinical Status
with gr.Tab("2️⃣ Clinical Status"):
with gr.Row():
portal_vein_thrombosis = gr.Dropdown(choices=["yes", "no"], label="Portal Vein Thrombosis", value="no")
ascitis = gr.Dropdown(choices=["yes", "no"], label="Ascites", value="yes")
hepatocellular_carcinoma = gr.Dropdown(choices=["yes", "no"], label="Hepatocellular Carcinoma", value="no")
varices = gr.Dropdown(choices=["yes", "no"], label="Varices", value="yes")
red_wale_marks = gr.Dropdown(choices=["yes", "no"], label="Red Wale Marks", value="no")
rupture_point = gr.Dropdown(choices=["yes", "no"], label="Rupture Point", value="no")
active_bleeding = gr.Dropdown(choices=["yes", "no"], label="Active Bleeding", value="no")
rebleeding = gr.Dropdown(choices=["yes", "no"], label="Rebleeding", value="no")
therapy = gr.Dropdown(
choices=["Banding", "Sclerotherapy", "No therapy"],
label="Therapy",
value="Banding"
)
terlipressin_dose = gr.Slider(minimum=0, maximum=20, step=1, label="Terlipressin Dose (mg)", value=2)
time_to_endoscophy_hours = gr.Slider(minimum=0, maximum=48, step=1, label="Time to Endoscopy (Hours)", value=12)
# TAB 3: Laboratory Values
with gr.Tab("3️⃣ Laboratory Values"):
gr.Markdown("### Liver Function Tests")
with gr.Row():
albumin = gr.Slider(minimum=1, maximum=5, step=0.1, label="Albumin (g/dL)", value=3.5)
total_bilirrubin = gr.Slider(minimum=0.1, maximum=30, step=0.1, label="Total Bilirubin (mg/dL)", value=2.0)
direct_bilirrubina = gr.Slider(minimum=0.1, maximum=10, step=0.1, label="Direct Bilirubin (mg/dL)", value=0.5)
inr = gr.Slider(minimum=0.5, maximum=5, step=0.1, label="INR", value=1.2)
creatinine = gr.Slider(minimum=0.1, maximum=10, step=0.1, label="Creatinine (mg/dL)", value=1.0)
gr.Markdown("### Complete Blood Count")
with gr.Row():
platelets = gr.Slider(minimum=10, maximum=500, step=1, label="Platelets (×10³/μL)", value=150)
hemoglobin = gr.Slider(minimum=5, maximum=20, step=0.1, label="Hemoglobin (g/dL)", value=13)
hematocrit = gr.Slider(minimum=15, maximum=60, step=0.1, label="Hematocrit (%)", value=40)
leucocytes = gr.Slider(minimum=1, maximum=50, step=0.1, label="Leukocytes (×10³/μL)", value=6)
gr.Markdown("### Transaminases & Electrolytes")
with gr.Row():
ast = gr.Slider(minimum=10, maximum=500, step=1, label="AST (U/L)", value=35)
alt = gr.Slider(minimum=10, maximum=500, step=1, label="ALT (U/L)", value=25)
sodium = gr.Slider(minimum=120, maximum=160, step=1, label="Sodium (mEq/L)", value=140)
potassium = gr.Slider(minimum=2, maximum=6, step=0.1, label="Potassium (mEq/L)", value=4)
# TAB 4: Model Info
with gr.Tab("ℹ️ Model Information"):
gr.Markdown(
"""
### Model Architecture & Validation
**Random Forest Classifier with Isotonic Calibration**
- 100 decision trees with bootstrapped sampling
- Isotonic regression calibration with 5-fold cross-validation
- Feature importance analysis using SHAP values
**Validation Results:**
- Internal validation (n=94): AUC 0.715 (95% CI: 0.610-0.820)
- Prospective validation (n=24): AUC 0.927 (95% CI: 0.874-0.980)
- Superior performance vs. traditional scores (p < 0.001)
### Clinical Integration Guidelines
1. **Risk Stratification**: Use in conjunction with clinical assessment
2. **Treatment Planning**: High-risk patients may benefit from intensive monitoring
3. **Resource Allocation**: Prioritize ICU beds and interventions
4. **Family Counseling**: Evidence-based prognostic information
### Limitations & Disclaimers
- Single-center development (external validation ongoing)
- Small prospective validation cohort (n=24)
- Not validated in patients < 18 years
- Should not replace clinical judgment
- Regular model updates recommended as new data becomes available
### Ethical Considerations
The model includes race as a predictor, which may reflect underlying social determinants
of health and healthcare disparities. Future iterations will explore race-neutral alternatives
while maintaining predictive accuracy.
**Research Use Only** - Not FDA approved for clinical decision-making
"""
)
# Ensure the order of these inputs matches the function signature
all_inputs = [
age, sex, race, etiology_cirrosis, hepatorenal_syndrome, omeprazole,
spironolactone, furosemide, propanolol, dialisis, portal_vein_thrombosis,
ascitis, hepatocellular_carcinoma, albumin, total_bilirrubin,
direct_bilirrubina, inr, creatinine, platelets, ast, alt, hemoglobin,
hematocrit, leucocytes, sodium, potassium, varices, red_wale_marks,
rupture_point, active_bleeding, therapy, terlipressin_dose,
time_to_endoscophy_hours, rebleeding
]
# Prediction Button & Outputs
with gr.Row():
predict_btn = gr.Button("🔮 Calculate Risk Assessment", variant="primary", scale=2)
with gr.Row():
with gr.Column():
ml_output = gr.Markdown(label="ML Model Results")
with gr.Column():
traditional_output = gr.Markdown(label="Traditional Scores")
with gr.Column():
comparison_output = gr.Markdown(label="Model Comparison")
predict_btn.click(
fn=predict_patient_outcome,
inputs=all_inputs,
outputs=[ml_output, traditional_output, comparison_output]
)
gr.Markdown(
"""
---
**Citation:** Rech MM, et al. Development, Internal and Prospective validation of a machine learning model
for the prediction of mortality in cirrhotic patients with acute esophageal variceal bleeding. [Journal] 2025.
**Contact:** [email protected] | **Version:** 2.0 (Updated based on peer review)
"""
)
if __name__ == "__main__":
demo.launch()