INP-Former / Zero_Shot_App.py
luoweibetter's picture
Update Zero_Shot_App.py
334ba02 verified
import argparse
from functools import partial
import gradio as gr
from torch.nn import functional as F
from torch import nn
from dataset import get_data_transforms
from PIL import Image
import os
from utils import get_gaussian_kernel
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import os
import torch
import cv2
import numpy as np
# # Model-Related Modules
from models import vit_encoder
from models.uad import INP_Former
from models.vision_transformer import Mlp, Aggregation_Block, Prototype_Block
# Configurations
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
parser = argparse.ArgumentParser(description='')
# model info
parser.add_argument('--encoder', type=str, default='dinov2reg_vit_base_14')
parser.add_argument('--input_size', type=int, default=448)
parser.add_argument('--crop_size', type=int, default=392)
parser.add_argument('--INP_num', type=int, default=6)
args = parser.parse_args()
############ Init Model
ckt_path1 = 'weights/Real-IAD/model.pth'
ckt_path2 = "weights/Real-IAD/model.pth"
#
data_transform, _ = get_data_transforms(args.input_size, args.crop_size)
# device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Adopting a grouping-based reconstruction strategy similar to Dinomaly
target_layers = [2, 3, 4, 5, 6, 7, 8, 9]
fuse_layer_encoder = [[0, 1, 2, 3], [4, 5, 6, 7]]
fuse_layer_decoder = [[0, 1, 2, 3], [4, 5, 6, 7]]
# Encoder info
encoder = vit_encoder.load(args.encoder)
if 'small' in args.encoder:
embed_dim, num_heads = 384, 6
elif 'base' in args.encoder:
embed_dim, num_heads = 768, 12
elif 'large' in args.encoder:
embed_dim, num_heads = 1024, 16
target_layers = [4, 6, 8, 10, 12, 14, 16, 18]
else:
raise "Architecture not in small, base, large."
# Model Preparation
Bottleneck = []
INP_Guided_Decoder = []
INP_Extractor = []
# bottleneck
Bottleneck.append(Mlp(embed_dim, embed_dim * 4, embed_dim, drop=0.))
Bottleneck = nn.ModuleList(Bottleneck)
# INP
INP = nn.ParameterList(
[nn.Parameter(torch.randn(args.INP_num, embed_dim))
for _ in range(1)])
# INP Extractor
for i in range(1):
blk = Aggregation_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4.,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-8))
INP_Extractor.append(blk)
INP_Extractor = nn.ModuleList(INP_Extractor)
# INP_Guided_Decoder
for i in range(8):
blk = Prototype_Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4.,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-8))
INP_Guided_Decoder.append(blk)
INP_Guided_Decoder = nn.ModuleList(INP_Guided_Decoder)
model = INP_Former(encoder=encoder, bottleneck=Bottleneck, aggregation=INP_Extractor, decoder=INP_Guided_Decoder,
target_layers=target_layers, remove_class_token=True, fuse_layer_encoder=fuse_layer_encoder,
fuse_layer_decoder=fuse_layer_decoder, prototype_token=INP)
model = model.to(device)
gaussian_kernel = get_gaussian_kernel(kernel_size=5, sigma=4).to(device)
def resize_and_center_crop(image, resize_size=448, crop_size=392):
# Resize to 448x448
image_resized = cv2.resize(image, (resize_size, resize_size), interpolation=cv2.INTER_LINEAR)
# Compute crop coordinates
start = (resize_size - crop_size) // 2
end = start + crop_size
# Center crop to 392x392
image_cropped = image_resized[start:end, start:end, :]
return image_cropped
def process_image(image, options):
# Load the model based on selected options
if 'Real-IAD' in options:
model.load_state_dict(torch.load(ckt_path1, map_location=torch.device('cpu')), strict=True)
elif 'VisA' in options:
model.load_state_dict(torch.load(ckt_path2, map_location=torch.device('cpu')), strict=True)
else:
# Default to 'All' if no valid option is provided
model.load_state_dict(torch.load(ckt_path1), strict=True)
print('Invalid option. Defaulting to All.')
# Ensure image is in RGB mode
image = image.convert('RGB')
# Convert PIL image to NumPy array
np_image = np.array(image)
image_shape = np_image.shape[0]
# Convert RGB to BGR for OpenCV
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
np_image = resize_and_center_crop(np_image, resize_size=args.input_size, crop_size=args.crop_size)
# Preprocess the image and run the model
input_image = data_transform(image)
input_image = input_image.to(device)
with torch.no_grad():
_ = model(input_image.unsqueeze(0))
anomaly_map = model.distance
side = int(model.distance.shape[1] ** 0.5)
anomaly_map = anomaly_map.reshape([anomaly_map.shape[0], side, side]).contiguous()
anomaly_map = torch.unsqueeze(anomaly_map, dim=1)
anomaly_map = F.interpolate(anomaly_map, size=input_image.shape[-1], mode='bilinear', align_corners=True)
anomaly_map = gaussian_kernel(anomaly_map)
# Process anomaly map
anomaly_map = anomaly_map.squeeze().cpu().numpy()
anomaly_map = (anomaly_map * 255).astype(np.uint8)
# Apply color map and blend with original image
heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)
# Convert OpenCV image back to PIL image for Gradio
vis_map_pil = Image.fromarray(cv2.resize(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB), (image_shape, image_shape)))
return vis_map_pil
# Define examples
examples = [
["assets/img2.png", "Real-IAD"],
["assets/img.png", "VisA"]
]
# Gradio interface layout
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Radio(["Real-IAD",
"VisA"],
label="Pre-trained Datasets")
],
outputs=[
gr.Image(type="pil", label="Output Image")
],
examples=examples,
title="INP-Former -- Zero-shot Anomaly Detection",
description="Upload an image and select pre-trained datasets to do zero-shot anomaly detection"
)
# Launch the demo
demo.launch()
# demo.launch(server_name="0.0.0.0", server_port=10002)