Anomagic / app.py
yuxinjiang11's picture
Upload app.py
e97455b verified
import os
import sys
import requests
import io # Memory buffer
# Spaces environment configuration
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import time
import random
import numpy as np
import torch
from PIL import Image, ImageDraw
from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler
from huggingface_hub import hf_hub_url, login # hf_hub_url for generating cloud URL
import gradio as gr
# Attempt to import Anomagic (if ip_adapter module exists)
try:
from ip_adapter.ip_adapter_anomagic import Anomagic
HAS_ANOMAGIC = True
except ImportError:
HAS_ANOMAGIC = False
print("Anomagic not imported, will use basic Inpainting")
# Get the absolute path of the current script (to solve path issues)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
def extract_image_from_editor_output(editor_output):
"""Extract PIL Image from gr.ImageEditor output (can be dict or PIL Image)"""
if editor_output is None:
return None
# 如果已经是 PIL Image,直接返回
if isinstance(editor_output, Image.Image):
return editor_output
# 如果是字典(gr.ImageEditor 的输出格式)
if isinstance(editor_output, dict):
# gr.ImageEditor 返回格式:{"background": image, "layers": [], "composite": image}
# 优先使用 composite(合成后的图像)
if "composite" in editor_output and editor_output["composite"] is not None:
return editor_output["composite"]
elif "background" in editor_output and editor_output["background"] is not None:
return editor_output["background"]
# 如果是其他格式但可转换为图像
try:
return Image.fromarray(editor_output)
except:
pass
return None
class SingleAnomalyGenerator:
def __init__(self, device="cuda:0"):
# Auto-detect GPU and set dtype
if torch.cuda.is_available() and "cuda" in device:
self.device = torch.device(device)
self.dtype = torch.float16
print(f"Using GPU: {device}, dtype: {self.dtype}")
else:
self.device = torch.device("cpu")
self.dtype = torch.float32
print(f"Using CPU, dtype: {self.dtype}")
self.anomagic_model = None
self.pipe = None # Save pipe for reuse
self.clip_vision_model = None
self.clip_image_processor = None
self.ip_ckpt_path = None # IP weights state_dict in memory
self.att_ckpt_path = None # ATT weights state_dict in memory
def load_models(self):
"""Load models with official CLIP"""
print("Loading VAE...")
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
torch_dtype=self.dtype
).to(self.device)
print("Loading base model...")
from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, DPMSolverMultistepScheduler
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
self.pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
"SG161222/Realistic_Vision_V4.0_noVAE",
torch_dtype=self.dtype,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None,
low_cpu_mem_usage=True
).to(self.device, dtype=self.dtype)
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
print("Loading CLIP image encoder...")
from transformers import CLIPVisionModel, CLIPImageProcessor
self.clip_vision_model = CLIPVisionModel.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=self.dtype
).to(self.device)
self.clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
print("All models loaded!")
# Load weights (download from cloud repo to memory, avoid any disk usage)
print("Loading weights into memory...")
weight_files = [
("checkpoint/anomagic.bin", "ip_ckpt_path"),
("checkpoint/attention_module.bin", "att_ckpt_path")
]
for filename, attr_name in weight_files:
try:
# Generate cloud URL (public repo, no token needed)
repo_id = "yuxinjiang11/Anomagic_model"
url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type="model")
# Dynamically set attribute (or use if to assign explicitly)
if attr_name == "ip_ckpt_path":
self.ip_ckpt_path = url
elif attr_name == "att_ckpt_path":
self.att_ckpt_path = url
print(f"Weight file path: {filename} -> {url}")
except Exception as e:
raise FileNotFoundError(f"Unable to get weight file path {filename}: {str(e)}")
# If Anomagic is available, load weights into the model
if HAS_ANOMAGIC:
print("Initializing Anomagic model...")
self.anomagic_model = Anomagic(self.pipe, self.clip_vision_model, self.ip_ckpt_path, self.att_ckpt_path,
self.device)
else:
print("No Anomagic, using basic Pipe.")
print("Model loading complete!")
def generate_single_image(self, normal_image, reference_image, mask, mask_0, prompt, num_inference_steps=50,
ip_scale=0.3, seed=42, strength=0.3):
"""Generate anomaly image with mask_0 support for reference image mask."""
if normal_image is None or reference_image is None:
raise ValueError("Normal or reference image is None. Please upload valid images.")
target_size = (512, 512)
normal_image = normal_image.resize(target_size)
reference_image = reference_image.resize(target_size)
# Process normal image mask
if mask is not None:
mask = extract_image_from_editor_output(mask)
if mask is not None and isinstance(mask, Image.Image):
mask = mask.resize(target_size)
mask = mask.convert('L')
mask = np.array(mask) > 0
mask = Image.fromarray(mask.astype(np.uint8) * 255).convert('L')
# Process reference image mask (mask_0)
if mask_0 is not None:
mask_0 = extract_image_from_editor_output(mask_0)
if mask_0 is not None and isinstance(mask_0, Image.Image):
mask_0 = mask_0.resize(target_size)
mask_0 = mask_0.convert('L')
mask_0 = np.array(mask_0) > 0
mask_0 = Image.fromarray(mask_0.astype(np.uint8) * 255).convert('L')
print(f"Generating with seed {seed}...")
torch.manual_seed(seed)
# If Anomagic is available, use it to generate; otherwise basic Inpainting
if HAS_ANOMAGIC and self.anomagic_model:
# generator = torch.Generator(device=self.device).manual_seed(seed)
# Assume Anomagic.generate supports parameters (adjust based on actual)
generated_image = self.anomagic_model.generate(
pil_image=reference_image,
num_samples=1,
num_inference_steps=num_inference_steps,
prompt=prompt,
scale=ip_scale,
image=normal_image,
mask_image=mask,
mask_image_0=mask_0, # Reference image mask
strength=strength,
# generator=generator
)[0]
else:
# Basic Inpainting
# generator = torch.Generator(device=self.device).manual_seed(seed)
if mask is None:
mask = Image.new('L', target_size, 255) # Full white mask
generated_image = self.pipe(
prompt=prompt,
image=normal_image,
mask_image=mask,
strength=strength,
num_inference_steps=num_inference_steps,
# generator=generator,
).images[0]
return generated_image
# Global generator and load status
generator = None
load_status = {"loaded": False, "error": None}
def load_generator():
"""Background load function: Automatically load model on startup"""
global generator, load_status
if load_status["loaded"]:
return "Models loaded!"
if load_status["error"]:
return f"Previous load failed: {load_status['error']}"
try:
print("Starting background model load...")
generator = SingleAnomalyGenerator()
generator.load_models()
load_status["loaded"] = True
print("Background model load complete!")
return "Model loading complete! You can now generate images."
except Exception as e:
load_status["error"] = str(e)
error_msg = f"Model loading failed: {str(e)}"
print(error_msg)
import traceback
print(traceback.format_exc())
return error_msg
def generate_random_mask(size=(512, 512), num_blobs=3, blob_size_range=(50, 150)):
"""Generate random mask: Create several random blobs as anomaly areas"""
mask = Image.new('L', size, 0) # Black background
draw = ImageDraw.Draw(mask)
for _ in range(num_blobs):
x = random.randint(0, size[0])
y = random.randint(0, size[1])
width = random.randint(*blob_size_range)
height = random.randint(*blob_size_range)
# Draw elliptical blobs
draw.ellipse([x - width // 2, y - height // 2, x + width // 2, y + height // 2], fill=255)
return mask
def generate_anomaly(normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed):
"""Core generation function: Called by Gradio (supports two masks)"""
global generator
if not load_status["loaded"]:
return None, "Please wait for model loading to complete."
if normal_img is None or reference_img is None or not prompt.strip():
return None, "Please upload normal image, reference image, and enter prompt text."
if mask_img is None:
return None, "Please upload or generate mask image for normal image."
try:
# Set seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
generated_img = generator.generate_single_image(
normal_image=normal_img,
reference_image=reference_img,
mask=mask_img,
mask_0=mask_0_img,
prompt=prompt,
num_inference_steps=steps,
ip_scale=ip_scale,
seed=seed,
strength=strength
)
return generated_img, f"Generation successful! Seed: {seed}, Steps: {steps}"
except Exception as e:
error_msg = f"Generation error: {str(e)}"
print(error_msg)
import traceback
print(traceback.format_exc())
return None, error_msg
# Predefined anomaly examples (using local image paths; assume images are in examples/ folder in the same directory as the script)
EXAMPLE_PAIRS = [
{
"normal": "examples/normal_apple.png", # Your local normal gear image
"reference": "examples/reference_apple.png", # Your local rusty gear reference image
"mask": "examples/normal_mask_apple.jpg", # Your local mask for normal gear
"mask_0": "examples/ref_mask_apple.png", # Your local mask for reference gear
"prompt": "Wood surface has holes with rough - edged circular openings.",
"strength": 0.6,
"ip_scale": 0.1,
"steps": 20,
"seed": 42,
"description": "Apple with wormholes and rough edges"
},
{
"normal": "examples/normal_candle.JPG", # Your local normal gear image
"reference": "examples/reference_candle.png", # Your local rusty gear reference image
"mask": "examples/normal_mask_candle.png", # Your local mask for normal gear
"mask_0": "examples/ref_mask_candle.png", # Your local mask for reference gear
"prompt": "Chocolate - chip cookie has a chunk - missing defect with exposed inner texture. ",
"strength": 0.6,
"ip_scale": 1,
"steps": 20,
"seed": 42,
"description": "Candle with deformed surface"
},
{
"normal": "examples/normal_wood.png", # Your local normal gear image
"reference": "examples/reference_wood.png", # Your local rusty gear reference image
"mask": "examples/normal_mask_wood.png", # Your local mask for normal gear
"mask_0": "examples/ref_mask_wood.png", # Your local mask for reference gear
"prompt": "Wood surface has a crack with a long, dark - hued split.",
"strength": 0.6,
"ip_scale": 0.1,
"steps": 20,
"seed": 42,
"description": "Wood with long dark crack and split"
},
]
def load_example(idx):
"""Load example: Load images from local path, generate random mask if not provided, and set UI"""
if idx >= len(EXAMPLE_PAIRS):
return None, None, None, None, "", 0.5, 0.3, 20, 42, f"Example {idx + 1} not found"
ex = EXAMPLE_PAIRS[idx]
try:
# Load normal image
normal_img = Image.open(ex["normal"]).convert('RGB')
# Load reference image
reference_img = Image.open(ex["reference"]).convert('RGB')
# Load or generate normal mask
if ex["mask"] is not None:
mask_img = Image.open(ex["mask"]).convert('L')
else:
mask_img = generate_random_mask()
# Load or generate reference mask (mask_0)
if ex["mask_0"] is not None:
mask_0_img = Image.open(ex["mask_0"]).convert('L')
else:
mask_0_img = generate_random_mask()
return normal_img, reference_img, mask_img, mask_0_img, ex["prompt"], ex["strength"], ex["ip_scale"], ex[
"steps"], ex["seed"], f"Example {idx + 1}: {ex['description']} loaded!"
except Exception as e:
error_msg = f"Example loading failed: {str(e)} (Check if local image paths are correct)"
print(error_msg)
# Fallback to placeholder images and random masks
normal_img = Image.new('RGB', (512, 512), color='gray')
reference_img = Image.new('RGB', (512, 512), color='blue')
mask_img = generate_random_mask()
mask_0_img = generate_random_mask()
return normal_img, reference_img, mask_img, mask_0_img, ex["prompt"], ex["strength"], ex["ip_scale"], ex[
"steps"], ex["seed"], error_msg
# Automatically load model on startup
load_generator()
# Gradio UI
with gr.Blocks(title="Anomagic Anomaly Image Generator") as demo: # Removed theme to fix compatibility
gr.Markdown("# Anomagic: Single Anomaly Image Generation Demo")
gr.Markdown(
"Upload normal image, reference image, normal mask and reference mask (white areas are for inpainting/anomaly generation), enter prompt, adjust parameters, and generate synthetic anomaly images with one click. Model is loaded in the background.")
with gr.Row():
with gr.Column(scale=1):
normal_img = gr.Image(type="pil", label="Normal Image", height=256) # Limit height
reference_img = gr.Image(type="pil", label="Reference Image", height=256)
with gr.Row(): # Mask row: Add buttons
mask_img = gr.ImageEditor(
type="pil",
sources=['upload', 'webcam', 'clipboard'],
label="Normal Image Mask (draw white anomaly areas on black background)",
height=256,
interactive=True,
brush=gr.Brush(default_color="white", default_size=15, color_mode="fixed"),
value=Image.new('L', (512, 512), 0) # Initial black canvas
)
with gr.Row():
generate_mask_btn = gr.Button("Generate Random Normal Mask", variant="secondary")
clear_mask_btn = gr.Button("Clear Normal Mask", variant="secondary")
mask_0_img = gr.ImageEditor(
type="pil",
sources=['upload', 'webcam', 'clipboard'],
label="Reference Image Mask (draw white areas on black background)",
height=256,
interactive=True,
brush=gr.Brush(default_color="white", default_size=15, color_mode="fixed"),
value=Image.new('L', (512, 512), 0) # Initial black canvas
)
with gr.Row():
generate_mask_0_btn = gr.Button("Generate Random Reference Mask", variant="secondary")
clear_mask_0_btn = gr.Button("Clear Reference Mask", variant="secondary")
prompt = gr.Textbox(label="Prompt Text",
placeholder="e.g., a broken machine part with rust and cracks")
with gr.Column(scale=1):
strength = gr.Slider(0.1, 1.0, value=0.5, label="Denoising Strength")
ip_scale = gr.Slider(0, 2.0, value=0.3, step=0.1, label="IP Adapter Scale")
steps = gr.Slider(10, 100, value=20, step=5, label="Inference Steps")
seed = gr.Slider(0, 2 ** 32 - 1, value=42, step=1, label="Random Seed")
gr.Markdown("## Examples")
gr.Markdown(
"Click the buttons below to load predefined examples for quick testing. After loading, click 'Generate Image' to view the anomaly synthesis result.")
# Create example buttons
example_buttons = []
for i in range(len(EXAMPLE_PAIRS)):
example_btn = gr.Button(f"Example {i + 1}: {EXAMPLE_PAIRS[i]['description']}", variant="secondary")
example_buttons.append(example_btn)
# 定义输出组件
output_img = gr.Image(type="pil", label="Generated Anomaly Image", height=256)
status = gr.Textbox(label="Status", interactive=False)
generate_btn = gr.Button("Generate Image", variant="primary", size="lg") # Enlarge button
# Clear cache button
def clear_cache():
global load_status
load_status = {"loaded": False, "error": None}
return "Cache cleared, please restart the app to reload the model."
clear_btn = gr.Button("Clear Cache", variant="stop")
# 连接所有事件处理函数(在组件定义之后)
# 生成按钮点击事件
generate_btn.click(
generate_anomaly,
inputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed],
outputs=[output_img, status]
)
# 清除缓存按钮点击事件
clear_btn.click(clear_cache, outputs=status)
# 示例按钮点击事件(为每个按钮单独连接)
for i, btn in enumerate(example_buttons):
btn.click(lambda idx=i: load_example(idx),
outputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed,
status])
# 掩码按钮点击事件
generate_mask_btn.click(lambda: generate_random_mask(), outputs=mask_img)
clear_mask_btn.click(lambda: Image.new('L', (512, 512), 0), outputs=mask_img)
generate_mask_0_btn.click(lambda: generate_random_mask(), outputs=mask_0_img)
clear_mask_0_btn.click(lambda: Image.new('L', (512, 512), 0), outputs=mask_0_img)
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(server_name="0.0.0.0", server_port=7860)
# demo.launch(share=True)