Spaces:
Running
on
Zero
Running
on
Zero
| import subprocess | |
| subprocess.run(["pip", "install", "fastapi==0.108.0"]) | |
| import gradio as gr | |
| from UniVAD.tools import process_image | |
| subprocess.run(["wget", "-q","https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"], check=True) | |
| subprocess.run(["wget", "-q","https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth"], check=True) | |
| from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor | |
| import torch | |
| import torchvision.transforms as transforms | |
| from UniVAD.univad import UniVAD | |
| from ram.models import ram_plus | |
| from UniVAD.models.segment_anything import ( | |
| sam_hq_model_registry, | |
| SamPredictor, | |
| ) | |
| import spaces | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| image_size = 448 | |
| univad_model = UniVAD(image_size=image_size).to(device) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| ram_model = ram_plus( | |
| pretrained="./ram_plus_swin_large_14m.pth", | |
| image_size=384, | |
| vit="swin_l", | |
| ) | |
| ram_model.eval() | |
| ram_model = ram_model.to(device) | |
| grounding_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") | |
| grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to("cuda") | |
| sam = sam_hq_model_registry["vit_h"]("./sam_hq_vit_h.pth").to(device) | |
| sam_predictor = SamPredictor(sam) | |
| def preprocess_image(img): | |
| return img.resize((448, 448)) | |
| def update_image(image): | |
| if image is not None: | |
| return preprocess_image(image) | |
| def ad(image_pil, normal_image, box_threshold, text_threshold, text_prompt, background_prompt, cluster_num): | |
| return process_image(image_pil, normal_image, box_threshold, text_threshold, sam_predictor, grounding_model, univad_model, ram_model, text_prompt, background_prompt, cluster_num, image_size, grounding_processor) | |
| with gr.Blocks() as demo: | |
| gr.HTML("""<h1 align="center" style='margin-top: 30px;'>Demo of UniVAD</h1>""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.Markdown("### Upload normal image here for reference.") | |
| with gr.Row(): | |
| normal_img = gr.Image(type="pil", label="Normal Image", value=None, height=475, width=440) | |
| normal_img.change(fn=update_image, inputs=normal_img, outputs=normal_img) | |
| with gr.Row(): | |
| normal_mask = gr.Image(type="filepath", label="Normal Component Mask", value=None, height=450, interactive=False) | |
| with gr.Row(): | |
| clearBtn = gr.Button("Clear", variant="secondary") | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.Markdown("### Upload query image here for anomaly detection.") | |
| with gr.Row(): | |
| query_img = gr.Image(type="pil", label="Query Image", value=None, height=475, width=440) | |
| query_img.change(fn=update_image, inputs=query_img, outputs=query_img) | |
| with gr.Row(): | |
| query_mask = gr.Image(type="filepath", label="Query Component Mask", value=None, height=450) | |
| with gr.Row(): | |
| submitBtn = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.Markdown("### Settings:") | |
| with gr.Row(): | |
| box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Box Threshold") | |
| with gr.Row(): | |
| text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text Threshold") | |
| with gr.Row(): | |
| text_prompt = gr.Textbox(label="Specify what should be in the image. Separate them with periods (.)", placeholder="(optional)") | |
| with gr.Row(): | |
| background_prompt = gr.Textbox(label="Specify what should be IGNORED in the image. Separate them with periods (.)", placeholder="(optional)") | |
| with gr.Row(): | |
| cluster_num = gr.Textbox(label="Number of Clusters", placeholder="(optional)") | |
| with gr.Row(): | |
| anomaly_map_raw = gr.Image(type="filepath", label="Localizaiton Result", value=None, height=450) | |
| with gr.Row(): | |
| anomaly_score = gr.HTML(value="<span style='font-size: 30px;'>Detection Result:</span>") | |
| gr.State() | |
| submitBtn.click( | |
| ad, [ | |
| query_img, | |
| normal_img, | |
| box_threshold, | |
| text_threshold, | |
| text_prompt, | |
| background_prompt, | |
| cluster_num, | |
| ], [ | |
| query_mask, | |
| normal_mask, | |
| anomaly_map_raw, | |
| anomaly_score | |
| ], | |
| show_progress=True | |
| ) | |
| clearBtn.click( | |
| lambda: (None, None, None, None, None, "<span style='font-size: 30px;'>Detection Result:</span>"), | |
| outputs=[query_img, normal_img, query_mask, normal_mask, anomaly_map_raw, anomaly_score] | |
| ) | |
| demo.queue().launch() | |