Spaces:
Runtime error
Runtime error
| import torch | |
| import cv2 | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| from pathlib import Path | |
| import json | |
| from scalelsd.ssl.models.detector import ScaleLSD | |
| from scalelsd.base import show, WireframeGraph | |
| from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model | |
| # Title for the Gradio interface | |
| _TITLE = 'Gradio Demo of ScaleLSD for Structured Representation of Images' | |
| MAX_SEED = 1000 | |
| def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
| """random seed""" | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| return seed | |
| def stop_run(): | |
| """stop run""" | |
| return ( | |
| gr.update(value="Run", variant="primary", visible=True), | |
| gr.update(visible=False), | |
| ) | |
| def process_image( | |
| input_image, | |
| model_name='scalelsd-vitbase-v2-train-sa1b.pt', | |
| save_name='temp_output', | |
| threshold=10, | |
| junction_threshold_hm=0.008, | |
| num_junctions_inference=512, | |
| width=512, | |
| height=512, | |
| line_width=2, | |
| juncs_size=4, | |
| whitebg=0.0, | |
| draw_junctions_only=False, | |
| use_lsd=False, | |
| use_nms=False, | |
| edge_color='orange', | |
| vertex_color='Cyan', | |
| output_format='png', | |
| seed=0, | |
| randomize_seed=False | |
| ): | |
| """core processing function for image inference""" | |
| # set random seed | |
| seed = int(randomize_seed_fn(seed, randomize_seed)) | |
| fix_seeds(seed) | |
| # initialize model | |
| ckpt = "models/" + model_name | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = load_scalelsd_model(ckpt, device) | |
| # set model parameters | |
| model.junction_threshold_hm = junction_threshold_hm | |
| model.num_junctions_inference = num_junctions_inference | |
| # transform input image | |
| if isinstance(input_image, np.ndarray): | |
| image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY) | |
| else: | |
| image = cv2.imread(input_image, 0) | |
| # resize | |
| ori_shape = image.shape[:2] | |
| image_resized = cv2.resize(image.copy(), (width, height)) | |
| image_tensor = torch.from_numpy(image_resized).float() / 255.0 | |
| image_tensor = image_tensor[None, None].to('cuda') | |
| # meta data | |
| meta = { | |
| 'width': ori_shape[1], | |
| 'height': ori_shape[0], | |
| 'filename': '', | |
| 'use_lsd': use_lsd, | |
| 'use_nms': use_nms, | |
| } | |
| # inference | |
| with torch.no_grad(): | |
| outputs, _ = model(image_tensor, meta) | |
| outputs = outputs[0] | |
| # visual results | |
| painter = show.painters.HAWPainter() | |
| painter.confidence_threshold = threshold | |
| painter.line_width = line_width | |
| painter.marker_size = juncs_size | |
| if whitebg > 0.0: | |
| show.Canvas.white_overlay = whitebg | |
| temp_folder = "temp_output" | |
| os.makedirs(temp_folder, exist_ok=True) | |
| fig_file = f"{temp_folder}/{save_name}.png" | |
| with show.image_canvas(input_image, fig_file=fig_file) as ax: | |
| if draw_junctions_only: | |
| painter.draw_junctions(ax, outputs) | |
| else: | |
| painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color) | |
| # read the result image | |
| result_image = cv2.imread(fig_file) | |
| if output_format != 'png': | |
| fig_file = f"{temp_folder}/{save_name}.{output_format}" | |
| with show.image_canvas(input_image, fig_file=fig_file) as ax: | |
| if draw_junctions_only: | |
| painter.draw_junctions(ax, outputs) | |
| else: | |
| painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color) | |
| json_file = f"{temp_folder}/{save_name}.json" | |
| indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred']) | |
| wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height']) | |
| with open(json_file, 'w') as f: | |
| json.dump(wireframe.jsonize(),f) | |
| return result_image[:, :, ::-1], json_file, fig_file | |
| def run_demo(): | |
| """create the Gradio demo interface""" | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 800px; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title=_TITLE) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown(f'# {_TITLE}') | |
| gr.Markdown("Detect wireframe structures in images using ScaleLSD model") | |
| pid = gr.State() | |
| figs_root = "assets/figs" | |
| example_images = [os.path.join(figs_root, iname) for iname in os.listdir(figs_root)] | |
| with gr.Row(): | |
| input_image = gr.Image(example_images[0], label="Input Image", type="numpy") | |
| output_image = gr.Image(label="Detection Result") | |
| with gr.Row(): | |
| run_btn = gr.Button(value="Run", variant="primary") | |
| stop_btn = gr.Button(value="Stop", variant="stop", visible=False) | |
| with gr.Row(): | |
| json_file = gr.File(label="Download JSON Output", type="filepath") | |
| image_file = gr.File(label="Download Image Output", type="filepath") | |
| with gr.Accordion("Advanced Settings", open=True): | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| [ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')], | |
| value='scalelsd-vitbase-v1-train-sa1b.pt', | |
| label="Model Selection" | |
| ) | |
| with gr.Row(): | |
| save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files") | |
| with gr.Row(): | |
| with gr.Column(): | |
| threshold = gr.Number(10, label="Line Threshold") | |
| junction_threshold_hm = gr.Number(0.008, label="Junction Threshold") | |
| num_junctions_inference = gr.Number(1024, label="Max Number of Junctions") | |
| width = gr.Number(512, label="Input Width") | |
| height = gr.Number(512, label="Input Height") | |
| with gr.Column(): | |
| draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only") | |
| use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier") | |
| use_nms = gr.Checkbox(True, label="Use NMS") | |
| output_format = gr.Dropdown( | |
| ['png', 'jpg', 'pdf'], | |
| value='png', | |
| label="Output Format" | |
| ) | |
| whitebg = gr.Slider(0.0, 1.0, value=0.7, label="White Background Opacity") | |
| line_width = gr.Number(2, label="Line Width") | |
| juncs_size = gr.Number(8, label="Junctions Size") | |
| with gr.Row(): | |
| edge_color = gr.Dropdown( | |
| ['orange', 'midnightblue', 'red', 'green'], | |
| value='orange', | |
| label="Edge Color" | |
| ) | |
| vertex_color = gr.Dropdown( | |
| ['Cyan', 'deeppink', 'yellow', 'purple'], | |
| value='Cyan', | |
| label="Vertex Color" | |
| ) | |
| with gr.Row(): | |
| randomize_seed = gr.Checkbox(False, label="Randomize Seed") | |
| seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed") | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=input_image, | |
| ) | |
| # star event handlers | |
| run_event = run_btn.click( | |
| fn=process_image, | |
| inputs=[ | |
| input_image, | |
| model_name, | |
| save_name, | |
| threshold, | |
| junction_threshold_hm, | |
| num_junctions_inference, | |
| width, | |
| height, | |
| line_width, | |
| juncs_size, | |
| whitebg, | |
| draw_junctions_only, | |
| use_lsd, | |
| use_nms, | |
| edge_color, | |
| vertex_color, | |
| output_format, | |
| seed, | |
| randomize_seed | |
| ], | |
| outputs=[output_image, json_file, image_file], | |
| ) | |
| # stop event handlers | |
| stop_btn.click( | |
| fn=stop_run, | |
| outputs=[run_btn, stop_btn], | |
| cancels=[run_event], | |
| queue=False, | |
| ) | |
| return demo | |
| run_demo().launch() | |