Spaces:
Running
Running
import argparse | |
from functools import partial | |
import gradio as gr | |
from torch.nn import functional as F | |
from torch import nn | |
from dataset import get_data_transforms | |
from PIL import Image | |
import os | |
from utils import get_gaussian_kernel | |
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' | |
import os | |
import torch | |
import cv2 | |
import numpy as np | |
# # Model-Related Modules | |
from models import vit_encoder | |
from models.uad import INP_Former | |
from models.vision_transformer import Mlp, Aggregation_Block, Prototype_Block | |
# Configurations | |
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" | |
parser = argparse.ArgumentParser(description='') | |
# model info | |
parser.add_argument('--encoder', type=str, default='dinov2reg_vit_base_14') | |
parser.add_argument('--input_size', type=int, default=448) | |
parser.add_argument('--crop_size', type=int, default=392) | |
parser.add_argument('--INP_num', type=int, default=6) | |
args = parser.parse_args() | |
############ Init Model | |
ckt_path1 = 'weights/Real-IAD/model.pth' | |
ckt_path2 = "weights/Real-IAD/model.pth" | |
# | |
data_transform, _ = get_data_transforms(args.input_size, args.crop_size) | |
# device | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# Adopting a grouping-based reconstruction strategy similar to Dinomaly | |
target_layers = [2, 3, 4, 5, 6, 7, 8, 9] | |
fuse_layer_encoder = [[0, 1, 2, 3], [4, 5, 6, 7]] | |
fuse_layer_decoder = [[0, 1, 2, 3], [4, 5, 6, 7]] | |
# Encoder info | |
encoder = vit_encoder.load(args.encoder) | |
if 'small' in args.encoder: | |
embed_dim, num_heads = 384, 6 | |
elif 'base' in args.encoder: | |
embed_dim, num_heads = 768, 12 | |
elif 'large' in args.encoder: | |
embed_dim, num_heads = 1024, 16 | |
target_layers = [4, 6, 8, 10, 12, 14, 16, 18] | |
else: | |
raise "Architecture not in small, base, large." | |
# Model Preparation | |
Bottleneck = [] | |
INP_Guided_Decoder = [] | |
INP_Extractor = [] | |
# bottleneck | |
Bottleneck.append(Mlp(embed_dim, embed_dim * 4, embed_dim, drop=0.)) | |
Bottleneck = nn.ModuleList(Bottleneck) | |
# INP | |
INP = nn.ParameterList( | |
[nn.Parameter(torch.randn(args.INP_num, embed_dim)) | |
for _ in range(1)]) | |
# INP Extractor | |
for i in range(1): | |
blk = Aggregation_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4., | |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-8)) | |
INP_Extractor.append(blk) | |
INP_Extractor = nn.ModuleList(INP_Extractor) | |
# INP_Guided_Decoder | |
for i in range(8): | |
blk = Prototype_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4., | |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-8)) | |
INP_Guided_Decoder.append(blk) | |
INP_Guided_Decoder = nn.ModuleList(INP_Guided_Decoder) | |
model = INP_Former(encoder=encoder, bottleneck=Bottleneck, aggregation=INP_Extractor, decoder=INP_Guided_Decoder, | |
target_layers=target_layers, remove_class_token=True, fuse_layer_encoder=fuse_layer_encoder, | |
fuse_layer_decoder=fuse_layer_decoder, prototype_token=INP) | |
model = model.to(device) | |
gaussian_kernel = get_gaussian_kernel(kernel_size=5, sigma=4).to(device) | |
def resize_and_center_crop(image, resize_size=448, crop_size=392): | |
# Resize to 448x448 | |
image_resized = cv2.resize(image, (resize_size, resize_size), interpolation=cv2.INTER_LINEAR) | |
# Compute crop coordinates | |
start = (resize_size - crop_size) // 2 | |
end = start + crop_size | |
# Center crop to 392x392 | |
image_cropped = image_resized[start:end, start:end, :] | |
return image_cropped | |
def process_image(image, options): | |
# Load the model based on selected options | |
if 'Real-IAD' in options: | |
model.load_state_dict(torch.load(ckt_path1, map_location=torch.device('cpu')), strict=True) | |
elif 'VisA' in options: | |
model.load_state_dict(torch.load(ckt_path2, map_location=torch.device('cpu')), strict=True) | |
else: | |
# Default to 'All' if no valid option is provided | |
model.load_state_dict(torch.load(ckt_path1), strict=True) | |
print('Invalid option. Defaulting to All.') | |
# Ensure image is in RGB mode | |
image = image.convert('RGB') | |
# Convert PIL image to NumPy array | |
np_image = np.array(image) | |
image_shape = np_image.shape[0] | |
# Convert RGB to BGR for OpenCV | |
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR) | |
np_image = resize_and_center_crop(np_image, resize_size=args.input_size, crop_size=args.crop_size) | |
# Preprocess the image and run the model | |
input_image = data_transform(image) | |
input_image = input_image.to(device) | |
with torch.no_grad(): | |
_ = model(input_image.unsqueeze(0)) | |
anomaly_map = model.distance | |
side = int(model.distance.shape[1] ** 0.5) | |
anomaly_map = anomaly_map.reshape([anomaly_map.shape[0], side, side]).contiguous() | |
anomaly_map = torch.unsqueeze(anomaly_map, dim=1) | |
anomaly_map = F.interpolate(anomaly_map, size=input_image.shape[-1], mode='bilinear', align_corners=True) | |
anomaly_map = gaussian_kernel(anomaly_map) | |
# Process anomaly map | |
anomaly_map = anomaly_map.squeeze().cpu().numpy() | |
anomaly_map = (anomaly_map * 255).astype(np.uint8) | |
# Apply color map and blend with original image | |
heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET) | |
vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0) | |
# Convert OpenCV image back to PIL image for Gradio | |
vis_map_pil = Image.fromarray(cv2.resize(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB), (image_shape, image_shape))) | |
return vis_map_pil | |
# Define examples | |
examples = [ | |
["assets/img2.png", "Real-IAD"], | |
["assets/img.png", "VisA"] | |
] | |
# Gradio interface layout | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Radio(["Real-IAD", | |
"VisA"], | |
label="Pre-trained Datasets") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Output Image") | |
], | |
examples=examples, | |
title="INP-Former -- Zero-shot Anomaly Detection", | |
description="Upload an image and select pre-trained datasets to do zero-shot anomaly detection" | |
) | |
# Launch the demo | |
demo.launch() | |
# demo.launch(server_name="0.0.0.0", server_port=10002) | |