Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Dict, List | |
| import cv2 | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| import wget | |
| from PIL import Image | |
| from streamlit_drawable_canvas import st_canvas | |
| from isegm.inference import clicker as ck | |
| from isegm.inference import utils | |
| from isegm.inference.predictors import BasePredictor, get_predictor | |
| ################################### | |
| # Global scope objects. | |
| ################################### | |
| URL_PREFIX = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main" | |
| MODELS = {"RITM": "ritm_coco_lvis_h18_itermask.pth"} | |
| POS_COLOR, NEG_COLOR = "#3498DB", "#C70039" | |
| CANVAS_HEIGHT, CANVAS_WIDTH = 600, 600 | |
| ERR_X, ERR_Y = 5.5, 1.0 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| clicker = ck.Clicker() | |
| predictor = None | |
| image = None | |
| ################################### | |
| # Functions. | |
| ################################### | |
| def load_model(model_path: str, device: torch.device) -> BasePredictor: | |
| model = utils.load_is_model(model_path, device, cpu_dist_maps=True) | |
| predictor_params = {"brs_mode": "NoBRS"} | |
| predictor = get_predictor(model, device=device, **predictor_params) | |
| return predictor | |
| def feed_clicks( | |
| clicker: ck.Clicker, | |
| clicks: List[Dict[str, float]], | |
| image_width: int, | |
| image_height: int, | |
| ) -> None: | |
| ratio_h, ratio_w = image_height / CANVAS_HEIGHT, image_width / CANVAS_WIDTH | |
| for click in clicks: | |
| x, y = (click["left"] + ERR_X) * ratio_w, (click["top"] + ERR_Y) * ratio_h | |
| x, y = min(image_width, max(0, x)), min(image_height, max(0, y)) | |
| is_positive = click["stroke"] == POS_COLOR | |
| click = ck.Click(is_positive=is_positive, coords=(y, x)) | |
| clicker.add_click(click) | |
| def predict(image: Image, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: | |
| predictor.set_input_image(np.array(image)) | |
| with st.spinner("Wait for prediction..."): | |
| pred = predictor.get_prediction(clicker, prev_mask=mask) | |
| pred = cv2.resize( | |
| pred, | |
| dsize=(CANVAS_HEIGHT, CANVAS_WIDTH), | |
| interpolation=cv2.INTER_CUBIC, | |
| ) | |
| pred = np.where(pred > threshold, 1.0, 0) | |
| return pred | |
| ################################### | |
| # Sidebar GUI | |
| ################################### | |
| # Items in the sidebar. | |
| model = st.sidebar.selectbox("Select a Method:", tuple(MODELS.keys())) | |
| threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5) | |
| marking_type = st.sidebar.radio("Click Type:", ("Positive", "Negative")) | |
| image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"]) | |
| if image_path: | |
| image = Image.open(image_path).convert("RGB") | |
| ################################### | |
| # Preparation | |
| ################################### | |
| # Model. | |
| with st.spinner("Wait for downloading a model..."): | |
| if not os.path.exists(MODELS[model]): | |
| _ = wget.download(f"{URL_PREFIX}/{MODELS[model]}") | |
| # Predictor. | |
| with st.spinner("Wait for loading a model..."): | |
| predictor = load_model(MODELS[model], device) | |
| ################################### | |
| # GUI | |
| ################################### | |
| # Create a canvas component. | |
| st.title("Canvas:") | |
| canvas_result = st_canvas( | |
| fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity | |
| stroke_width=3, | |
| stroke_color=POS_COLOR if marking_type == "Positive" else NEG_COLOR, | |
| background_color="#eee", | |
| background_image=image, | |
| update_streamlit=True, | |
| drawing_mode="point", | |
| point_display_radius=3, | |
| key="canvas", | |
| width=CANVAS_WIDTH, | |
| height=CANVAS_HEIGHT, | |
| ) | |
| ################################### | |
| # Prediction | |
| ################################### | |
| # Check the user inputs ans execute predictions. | |
| st.title("Prediction:") | |
| if canvas_result.json_data and canvas_result.json_data["objects"] and image: | |
| image_width, image_height = image.size | |
| feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height) | |
| # Run prediction. | |
| mask = torch.zeros((1, 1, image_height, image_width), device=device) | |
| pred = predict(image, mask, threshold) | |
| # Show the prediction result. | |
| st.image(pred, caption="") | |