Spaces:
Running
Running
Upload app.py
Browse files
app.py
CHANGED
|
@@ -1,61 +1,62 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import requests
|
| 4 |
-
import io #
|
| 5 |
|
| 6 |
-
# Spaces
|
| 7 |
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
| 8 |
|
| 9 |
import time
|
| 10 |
import random
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
-
from PIL import Image
|
| 14 |
from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler
|
| 15 |
-
from huggingface_hub import hf_hub_url, login # hf_hub_url
|
| 16 |
import gradio as gr
|
| 17 |
|
| 18 |
-
#
|
| 19 |
try:
|
| 20 |
from ip_adapter.ip_adapter_anomagic import Anomagic
|
|
|
|
| 21 |
HAS_ANOMAGIC = True
|
| 22 |
except ImportError:
|
| 23 |
HAS_ANOMAGIC = False
|
| 24 |
-
print("Anomagic
|
| 25 |
|
| 26 |
-
#
|
| 27 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 28 |
|
| 29 |
|
| 30 |
class SingleAnomalyGenerator:
|
| 31 |
def __init__(self, device="cuda:0"):
|
| 32 |
-
#
|
| 33 |
if torch.cuda.is_available() and "cuda" in device:
|
| 34 |
self.device = torch.device(device)
|
| 35 |
self.dtype = torch.float16
|
| 36 |
-
print(f"
|
| 37 |
else:
|
| 38 |
self.device = torch.device("cpu")
|
| 39 |
self.dtype = torch.float32
|
| 40 |
-
print(f"
|
| 41 |
|
| 42 |
self.anomagic_model = None
|
| 43 |
-
self.pipe = None #
|
| 44 |
self.clip_vision_model = None
|
| 45 |
self.clip_image_processor = None
|
| 46 |
-
self.ip_ckpt_path = None #
|
| 47 |
-
self.att_ckpt_path = None #
|
| 48 |
|
| 49 |
def load_models(self):
|
| 50 |
"""Load models with official CLIP"""
|
| 51 |
-
print("
|
| 52 |
from diffusers import AutoencoderKL
|
| 53 |
vae = AutoencoderKL.from_pretrained(
|
| 54 |
"stabilityai/sd-vae-ft-mse",
|
| 55 |
torch_dtype=self.dtype
|
| 56 |
).to(self.device)
|
| 57 |
|
| 58 |
-
print("
|
| 59 |
from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, DPMSolverMultistepScheduler
|
| 60 |
|
| 61 |
noise_scheduler = DDIMScheduler(
|
|
@@ -80,7 +81,7 @@ class SingleAnomalyGenerator:
|
|
| 80 |
|
| 81 |
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
|
| 82 |
|
| 83 |
-
print("
|
| 84 |
from transformers import CLIPVisionModel, CLIPImageProcessor
|
| 85 |
self.clip_vision_model = CLIPVisionModel.from_pretrained(
|
| 86 |
"openai/clip-vit-large-patch14",
|
|
@@ -88,38 +89,39 @@ class SingleAnomalyGenerator:
|
|
| 88 |
).to(self.device)
|
| 89 |
self.clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 90 |
|
| 91 |
-
print("
|
| 92 |
|
| 93 |
-
#
|
| 94 |
-
print("
|
| 95 |
weight_files = [
|
| 96 |
("checkpoint/ip_adapter_0.bin", "ip_ckpt_path"),
|
| 97 |
("checkpoint/att.bin", "att_ckpt_path")
|
| 98 |
]
|
| 99 |
for filename, attr_name in weight_files:
|
| 100 |
try:
|
| 101 |
-
#
|
| 102 |
repo_id = "yuxinjiang11/Anomagic_model"
|
| 103 |
url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type="model")
|
| 104 |
|
| 105 |
-
#
|
| 106 |
if attr_name == "ip_ckpt_path":
|
| 107 |
self.ip_ckpt_path = url
|
| 108 |
elif attr_name == "att_ckpt_path":
|
| 109 |
self.att_ckpt_path = url
|
| 110 |
|
| 111 |
-
print(f"
|
| 112 |
except Exception as e:
|
| 113 |
-
raise FileNotFoundError(f"
|
| 114 |
|
| 115 |
-
#
|
| 116 |
if HAS_ANOMAGIC:
|
| 117 |
-
print("
|
| 118 |
-
self.anomagic_model = Anomagic(self.pipe, self.clip_vision_model, self.ip_ckpt_path, self.att_ckpt_path,
|
|
|
|
| 119 |
else:
|
| 120 |
-
print("
|
| 121 |
|
| 122 |
-
print("
|
| 123 |
|
| 124 |
def generate_single_image(self, normal_image, reference_image, mask, mask_0, prompt, num_inference_steps=50,
|
| 125 |
ip_scale=0.3, seed=42, strength=0.3):
|
|
@@ -148,10 +150,10 @@ class SingleAnomalyGenerator:
|
|
| 148 |
print(f"Generating with seed {seed}...")
|
| 149 |
torch.manual_seed(seed)
|
| 150 |
|
| 151 |
-
#
|
| 152 |
if HAS_ANOMAGIC and self.anomagic_model:
|
| 153 |
# generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 154 |
-
#
|
| 155 |
generated_image = self.anomagic_model.generate(
|
| 156 |
pil_image=reference_image,
|
| 157 |
num_samples=1,
|
|
@@ -165,10 +167,10 @@ class SingleAnomalyGenerator:
|
|
| 165 |
# generator=generator
|
| 166 |
)[0]
|
| 167 |
else:
|
| 168 |
-
#
|
| 169 |
# generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 170 |
if mask is None:
|
| 171 |
-
mask = Image.new('L', target_size, 255) #
|
| 172 |
generated_image = self.pipe(
|
| 173 |
prompt=prompt,
|
| 174 |
image=normal_image,
|
|
@@ -181,50 +183,66 @@ class SingleAnomalyGenerator:
|
|
| 181 |
return generated_image
|
| 182 |
|
| 183 |
|
| 184 |
-
#
|
| 185 |
generator = None
|
| 186 |
load_status = {"loaded": False, "error": None}
|
| 187 |
|
| 188 |
|
| 189 |
def load_generator():
|
| 190 |
-
"""
|
| 191 |
global generator, load_status
|
| 192 |
|
| 193 |
if load_status["loaded"]:
|
| 194 |
-
return "
|
| 195 |
|
| 196 |
if load_status["error"]:
|
| 197 |
-
return f"
|
| 198 |
|
| 199 |
try:
|
|
|
|
| 200 |
generator = SingleAnomalyGenerator()
|
| 201 |
generator.load_models()
|
| 202 |
load_status["loaded"] = True
|
| 203 |
-
|
|
|
|
| 204 |
except Exception as e:
|
| 205 |
load_status["error"] = str(e)
|
| 206 |
-
error_msg = f"
|
| 207 |
print(error_msg)
|
| 208 |
import traceback
|
| 209 |
print(traceback.format_exc())
|
| 210 |
return error_msg
|
| 211 |
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
def generate_anomaly(normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed):
|
| 214 |
-
"""
|
| 215 |
global generator
|
| 216 |
|
| 217 |
if not load_status["loaded"]:
|
| 218 |
-
return None, "
|
| 219 |
|
| 220 |
if normal_img is None or reference_img is None or not prompt.strip():
|
| 221 |
-
return None, "
|
| 222 |
|
| 223 |
if mask_img is None:
|
| 224 |
-
return None, "
|
| 225 |
|
| 226 |
try:
|
| 227 |
-
#
|
| 228 |
random.seed(seed)
|
| 229 |
np.random.seed(seed)
|
| 230 |
torch.manual_seed(seed)
|
|
@@ -241,61 +259,165 @@ def generate_anomaly(normal_img, reference_img, mask_img, mask_0_img, prompt, st
|
|
| 241 |
strength=strength
|
| 242 |
)
|
| 243 |
|
| 244 |
-
return generated_img, f"
|
| 245 |
|
| 246 |
except Exception as e:
|
| 247 |
-
error_msg = f"
|
| 248 |
print(error_msg)
|
| 249 |
import traceback
|
| 250 |
print(traceback.format_exc())
|
| 251 |
return None, error_msg
|
| 252 |
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
# Gradio UI
|
| 255 |
-
with gr.Blocks(title="Anomagic
|
| 256 |
-
|
|
|
|
| 257 |
gr.Markdown(
|
| 258 |
-
"
|
| 259 |
|
| 260 |
with gr.Row():
|
| 261 |
with gr.Column(scale=1):
|
| 262 |
-
normal_img = gr.Image(type="pil", label="
|
| 263 |
-
reference_img = gr.Image(type="pil", label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
|
|
|
| 267 |
|
| 268 |
-
prompt = gr.Textbox(label="
|
| 269 |
placeholder="e.g., a broken machine part with rust and cracks")
|
| 270 |
|
| 271 |
with gr.Column(scale=1):
|
| 272 |
-
strength = gr.Slider(0.1, 1.0, value=0.5, label="
|
| 273 |
-
ip_scale = gr.Slider(0, 2.0, value=0.3, step=0.1, label="IP
|
| 274 |
-
steps = gr.Slider(10, 100, value=20, step=5, label="
|
| 275 |
-
seed = gr.Slider(0, 2 ** 32 - 1, value=42, step=1, label="
|
| 276 |
|
| 277 |
with gr.Row():
|
| 278 |
-
|
| 279 |
-
generate_btn = gr.Button("生成图像 (Generate)", variant="primary")
|
| 280 |
|
| 281 |
-
output_img = gr.Image(type="pil", label="
|
| 282 |
-
status = gr.Textbox(label="
|
| 283 |
|
| 284 |
-
#
|
| 285 |
-
load_btn.click(load_generator, outputs=status)
|
| 286 |
generate_btn.click(
|
| 287 |
generate_anomaly,
|
| 288 |
inputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed],
|
| 289 |
-
outputs=[output_img, status]
|
| 290 |
)
|
| 291 |
|
| 292 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
def clear_cache():
|
| 294 |
global load_status
|
| 295 |
load_status = {"loaded": False, "error": None}
|
| 296 |
-
return "
|
|
|
|
| 297 |
|
| 298 |
-
clear_btn = gr.Button("
|
| 299 |
clear_btn.click(clear_cache, outputs=status)
|
| 300 |
|
| 301 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import requests
|
| 4 |
+
import io # Memory buffer
|
| 5 |
|
| 6 |
+
# Spaces environment configuration
|
| 7 |
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
| 8 |
|
| 9 |
import time
|
| 10 |
import random
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
+
from PIL import Image, ImageDraw
|
| 14 |
from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler
|
| 15 |
+
from huggingface_hub import hf_hub_url, login # hf_hub_url for generating cloud URL
|
| 16 |
import gradio as gr
|
| 17 |
|
| 18 |
+
# Attempt to import Anomagic (if ip_adapter module exists)
|
| 19 |
try:
|
| 20 |
from ip_adapter.ip_adapter_anomagic import Anomagic
|
| 21 |
+
|
| 22 |
HAS_ANOMAGIC = True
|
| 23 |
except ImportError:
|
| 24 |
HAS_ANOMAGIC = False
|
| 25 |
+
print("Anomagic not imported, will use basic Inpainting")
|
| 26 |
|
| 27 |
+
# Get the absolute path of the current script (to solve path issues)
|
| 28 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 29 |
|
| 30 |
|
| 31 |
class SingleAnomalyGenerator:
|
| 32 |
def __init__(self, device="cuda:0"):
|
| 33 |
+
# Auto-detect GPU and set dtype
|
| 34 |
if torch.cuda.is_available() and "cuda" in device:
|
| 35 |
self.device = torch.device(device)
|
| 36 |
self.dtype = torch.float16
|
| 37 |
+
print(f"Using GPU: {device}, dtype: {self.dtype}")
|
| 38 |
else:
|
| 39 |
self.device = torch.device("cpu")
|
| 40 |
self.dtype = torch.float32
|
| 41 |
+
print(f"Using CPU, dtype: {self.dtype}")
|
| 42 |
|
| 43 |
self.anomagic_model = None
|
| 44 |
+
self.pipe = None # Save pipe for reuse
|
| 45 |
self.clip_vision_model = None
|
| 46 |
self.clip_image_processor = None
|
| 47 |
+
self.ip_ckpt_path = None # IP weights state_dict in memory
|
| 48 |
+
self.att_ckpt_path = None # ATT weights state_dict in memory
|
| 49 |
|
| 50 |
def load_models(self):
|
| 51 |
"""Load models with official CLIP"""
|
| 52 |
+
print("Loading VAE...")
|
| 53 |
from diffusers import AutoencoderKL
|
| 54 |
vae = AutoencoderKL.from_pretrained(
|
| 55 |
"stabilityai/sd-vae-ft-mse",
|
| 56 |
torch_dtype=self.dtype
|
| 57 |
).to(self.device)
|
| 58 |
|
| 59 |
+
print("Loading base model...")
|
| 60 |
from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, DPMSolverMultistepScheduler
|
| 61 |
|
| 62 |
noise_scheduler = DDIMScheduler(
|
|
|
|
| 81 |
|
| 82 |
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
|
| 83 |
|
| 84 |
+
print("Loading CLIP image encoder...")
|
| 85 |
from transformers import CLIPVisionModel, CLIPImageProcessor
|
| 86 |
self.clip_vision_model = CLIPVisionModel.from_pretrained(
|
| 87 |
"openai/clip-vit-large-patch14",
|
|
|
|
| 89 |
).to(self.device)
|
| 90 |
self.clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 91 |
|
| 92 |
+
print("All models loaded!")
|
| 93 |
|
| 94 |
+
# Load weights (download from cloud repo to memory, avoid any disk usage)
|
| 95 |
+
print("Loading weights into memory...")
|
| 96 |
weight_files = [
|
| 97 |
("checkpoint/ip_adapter_0.bin", "ip_ckpt_path"),
|
| 98 |
("checkpoint/att.bin", "att_ckpt_path")
|
| 99 |
]
|
| 100 |
for filename, attr_name in weight_files:
|
| 101 |
try:
|
| 102 |
+
# Generate cloud URL (public repo, no token needed)
|
| 103 |
repo_id = "yuxinjiang11/Anomagic_model"
|
| 104 |
url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type="model")
|
| 105 |
|
| 106 |
+
# Dynamically set attribute (or use if to assign explicitly)
|
| 107 |
if attr_name == "ip_ckpt_path":
|
| 108 |
self.ip_ckpt_path = url
|
| 109 |
elif attr_name == "att_ckpt_path":
|
| 110 |
self.att_ckpt_path = url
|
| 111 |
|
| 112 |
+
print(f"Weight file path: {filename} -> {url}")
|
| 113 |
except Exception as e:
|
| 114 |
+
raise FileNotFoundError(f"Unable to get weight file path {filename}: {str(e)}")
|
| 115 |
|
| 116 |
+
# If Anomagic is available, load weights into the model
|
| 117 |
if HAS_ANOMAGIC:
|
| 118 |
+
print("Initializing Anomagic model...")
|
| 119 |
+
self.anomagic_model = Anomagic(self.pipe, self.clip_vision_model, self.ip_ckpt_path, self.att_ckpt_path,
|
| 120 |
+
self.device)
|
| 121 |
else:
|
| 122 |
+
print("No Anomagic, using basic Pipe.")
|
| 123 |
|
| 124 |
+
print("Model loading complete!")
|
| 125 |
|
| 126 |
def generate_single_image(self, normal_image, reference_image, mask, mask_0, prompt, num_inference_steps=50,
|
| 127 |
ip_scale=0.3, seed=42, strength=0.3):
|
|
|
|
| 150 |
print(f"Generating with seed {seed}...")
|
| 151 |
torch.manual_seed(seed)
|
| 152 |
|
| 153 |
+
# If Anomagic is available, use it to generate; otherwise basic Inpainting
|
| 154 |
if HAS_ANOMAGIC and self.anomagic_model:
|
| 155 |
# generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 156 |
+
# Assume Anomagic.generate supports parameters (adjust based on actual)
|
| 157 |
generated_image = self.anomagic_model.generate(
|
| 158 |
pil_image=reference_image,
|
| 159 |
num_samples=1,
|
|
|
|
| 167 |
# generator=generator
|
| 168 |
)[0]
|
| 169 |
else:
|
| 170 |
+
# Basic Inpainting
|
| 171 |
# generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 172 |
if mask is None:
|
| 173 |
+
mask = Image.new('L', target_size, 255) # Full white mask
|
| 174 |
generated_image = self.pipe(
|
| 175 |
prompt=prompt,
|
| 176 |
image=normal_image,
|
|
|
|
| 183 |
return generated_image
|
| 184 |
|
| 185 |
|
| 186 |
+
# Global generator and load status
|
| 187 |
generator = None
|
| 188 |
load_status = {"loaded": False, "error": None}
|
| 189 |
|
| 190 |
|
| 191 |
def load_generator():
|
| 192 |
+
"""Background load function: Automatically load model on startup"""
|
| 193 |
global generator, load_status
|
| 194 |
|
| 195 |
if load_status["loaded"]:
|
| 196 |
+
return "Models loaded!"
|
| 197 |
|
| 198 |
if load_status["error"]:
|
| 199 |
+
return f"Previous load failed: {load_status['error']}"
|
| 200 |
|
| 201 |
try:
|
| 202 |
+
print("Starting background model load...")
|
| 203 |
generator = SingleAnomalyGenerator()
|
| 204 |
generator.load_models()
|
| 205 |
load_status["loaded"] = True
|
| 206 |
+
print("Background model load complete!")
|
| 207 |
+
return "Model loading complete! You can now generate images."
|
| 208 |
except Exception as e:
|
| 209 |
load_status["error"] = str(e)
|
| 210 |
+
error_msg = f"Model loading failed: {str(e)}"
|
| 211 |
print(error_msg)
|
| 212 |
import traceback
|
| 213 |
print(traceback.format_exc())
|
| 214 |
return error_msg
|
| 215 |
|
| 216 |
|
| 217 |
+
def generate_random_mask(size=(512, 512), num_blobs=3, blob_size_range=(50, 150)):
|
| 218 |
+
"""Generate random mask: Create several random blobs as anomaly areas"""
|
| 219 |
+
mask = Image.new('L', size, 0) # Black background
|
| 220 |
+
draw = ImageDraw.Draw(mask)
|
| 221 |
+
for _ in range(num_blobs):
|
| 222 |
+
x = random.randint(0, size[0])
|
| 223 |
+
y = random.randint(0, size[1])
|
| 224 |
+
width = random.randint(*blob_size_range)
|
| 225 |
+
height = random.randint(*blob_size_range)
|
| 226 |
+
# Draw elliptical blobs
|
| 227 |
+
draw.ellipse([x - width // 2, y - height // 2, x + width // 2, y + height // 2], fill=255)
|
| 228 |
+
return mask
|
| 229 |
+
|
| 230 |
+
|
| 231 |
def generate_anomaly(normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed):
|
| 232 |
+
"""Core generation function: Called by Gradio (supports two masks)"""
|
| 233 |
global generator
|
| 234 |
|
| 235 |
if not load_status["loaded"]:
|
| 236 |
+
return None, "Please wait for model loading to complete."
|
| 237 |
|
| 238 |
if normal_img is None or reference_img is None or not prompt.strip():
|
| 239 |
+
return None, "Please upload normal image, reference image, and enter prompt text."
|
| 240 |
|
| 241 |
if mask_img is None:
|
| 242 |
+
return None, "Please upload or generate mask image for normal image."
|
| 243 |
|
| 244 |
try:
|
| 245 |
+
# Set seed
|
| 246 |
random.seed(seed)
|
| 247 |
np.random.seed(seed)
|
| 248 |
torch.manual_seed(seed)
|
|
|
|
| 259 |
strength=strength
|
| 260 |
)
|
| 261 |
|
| 262 |
+
return generated_img, f"Generation successful! Seed: {seed}, Steps: {steps}"
|
| 263 |
|
| 264 |
except Exception as e:
|
| 265 |
+
error_msg = f"Generation error: {str(e)}"
|
| 266 |
print(error_msg)
|
| 267 |
import traceback
|
| 268 |
print(traceback.format_exc())
|
| 269 |
return None, error_msg
|
| 270 |
|
| 271 |
|
| 272 |
+
# Predefined anomaly examples (using local image paths; assume images are in examples/ folder in the same directory as the script)
|
| 273 |
+
EXAMPLE_PAIRS = [
|
| 274 |
+
{
|
| 275 |
+
"normal": "examples/normal_leather.png", # Your local normal gear image
|
| 276 |
+
"reference": "examples/reference_leather.png", # Your local rusty gear reference image
|
| 277 |
+
"mask": "examples/normal_mask_leather.png", # Your local mask for normal gear
|
| 278 |
+
"mask_0": "examples/ref_mask_leather.png", # Your local mask for reference gear
|
| 279 |
+
"prompt": "Bagel has a crack running across its surface.",
|
| 280 |
+
"strength": 0.6,
|
| 281 |
+
"ip_scale": 0.1,
|
| 282 |
+
"steps": 20,
|
| 283 |
+
"seed": 42,
|
| 284 |
+
"description": "Bagel has a crack running across its surface."
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"normal": "examples/normal_candle.JPG", # Your local normal gear image
|
| 288 |
+
"reference": "examples/reference_candle.png", # Your local rusty gear reference image
|
| 289 |
+
"mask": "examples/normal_mask_candle.png", # Your local mask for normal gear
|
| 290 |
+
"mask_0": "examples/ref_mask_candle.png", # Your local mask for reference gear
|
| 291 |
+
"prompt": "Chocolate - chip cookie has a chunk - missing defect with exposed inner texture. ",
|
| 292 |
+
"strength": 0.6,
|
| 293 |
+
"ip_scale": 0.1,
|
| 294 |
+
"steps": 20,
|
| 295 |
+
"seed": 42,
|
| 296 |
+
"description": "Chocolate - chip cookie has a chunk - missing defect with exposed inner texture. "
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"normal": "examples/normal_apple.png", # Your local normal gear image
|
| 300 |
+
"reference": "examples/reference_apple.png", # Your local rusty gear reference image
|
| 301 |
+
"mask": "examples/normal_mask_apple.jpg", # Your local mask for normal gear
|
| 302 |
+
"mask_0": "examples/ref_mask_apple.png", # Your local mask for reference gear
|
| 303 |
+
"prompt": "Wood surface has holes with rough - edged circular openings.",
|
| 304 |
+
"strength": 0.6,
|
| 305 |
+
"ip_scale": 0.1,
|
| 306 |
+
"steps": 20,
|
| 307 |
+
"seed": 42,
|
| 308 |
+
"description": "Wood surface has holes with rough - edged circular openings."
|
| 309 |
+
}
|
| 310 |
+
]
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def load_example(idx):
|
| 314 |
+
"""Load example: Load images from local path, generate random mask if not provided, and set UI"""
|
| 315 |
+
if idx >= len(EXAMPLE_PAIRS):
|
| 316 |
+
return None, None, None, None, EXAMPLE_PAIRS[idx]["prompt"], EXAMPLE_PAIRS[idx]["strength"], EXAMPLE_PAIRS[idx][
|
| 317 |
+
"ip_scale"], EXAMPLE_PAIRS[idx]["steps"], EXAMPLE_PAIRS[idx][
|
| 318 |
+
"seed"], f"Example {idx + 1}: {EXAMPLE_PAIRS[idx]['description']}"
|
| 319 |
+
|
| 320 |
+
ex = EXAMPLE_PAIRS[idx]
|
| 321 |
+
try:
|
| 322 |
+
# Load normal image
|
| 323 |
+
normal_img = Image.open(ex["normal"]).convert('RGB')
|
| 324 |
+
|
| 325 |
+
# Load reference image
|
| 326 |
+
reference_img = Image.open(ex["reference"]).convert('RGB')
|
| 327 |
+
|
| 328 |
+
# Load or generate normal mask
|
| 329 |
+
if ex["mask"] is not None:
|
| 330 |
+
mask_img = Image.open(ex["mask"]).convert('L')
|
| 331 |
+
else:
|
| 332 |
+
mask_img = generate_random_mask()
|
| 333 |
+
|
| 334 |
+
# Load or generate reference mask (mask_0)
|
| 335 |
+
if ex["mask_0"] is not None:
|
| 336 |
+
mask_0_img = Image.open(ex["mask_0"]).convert('L')
|
| 337 |
+
else:
|
| 338 |
+
mask_0_img = generate_random_mask()
|
| 339 |
+
|
| 340 |
+
return normal_img, reference_img, mask_img, mask_0_img, ex["prompt"], ex["strength"], ex["ip_scale"], ex[
|
| 341 |
+
"steps"], ex["seed"], f"Example {idx + 1}: {ex['description']} loaded!"
|
| 342 |
+
except Exception as e:
|
| 343 |
+
error_msg = f"Example loading failed: {str(e)} (Check if local image paths are correct)"
|
| 344 |
+
print(error_msg)
|
| 345 |
+
# Fallback to placeholder images and random masks
|
| 346 |
+
normal_img = Image.new('RGB', (512, 512), color='gray')
|
| 347 |
+
reference_img = Image.new('RGB', (512, 512), color='blue')
|
| 348 |
+
mask_img = generate_random_mask()
|
| 349 |
+
mask_0_img = generate_random_mask()
|
| 350 |
+
return normal_img, reference_img, mask_img, mask_0_img, ex["prompt"], ex["strength"], ex["ip_scale"], ex[
|
| 351 |
+
"steps"], ex["seed"], error_msg
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# Automatically load model on startup
|
| 355 |
+
load_generator()
|
| 356 |
+
|
| 357 |
# Gradio UI
|
| 358 |
+
with gr.Blocks(title="Anomagic Anomaly Image Generator",
|
| 359 |
+
theme=gr.themes.Soft()) as demo: # Use Soft theme for beautification
|
| 360 |
+
gr.Markdown("# Anomagic: Single Anomaly Image Generation Demo")
|
| 361 |
gr.Markdown(
|
| 362 |
+
"Upload normal image, reference image, normal mask and reference mask (white areas are for inpainting/anomaly generation), enter prompt, adjust parameters, and generate synthetic anomaly images with one click. Model is loaded in the background.")
|
| 363 |
|
| 364 |
with gr.Row():
|
| 365 |
with gr.Column(scale=1):
|
| 366 |
+
normal_img = gr.Image(type="pil", label="Normal Image", height=300) # Limit height
|
| 367 |
+
reference_img = gr.Image(type="pil", label="Reference Image", height=300)
|
| 368 |
+
|
| 369 |
+
with gr.Row(): # Mask row: Add buttons
|
| 370 |
+
mask_img = gr.Image(type="pil", label="Normal Image Mask (white for anomaly generation area)",
|
| 371 |
+
height=300, tool="sketch") # Add sketch tool
|
| 372 |
+
gr.Button("Generate Random Normal Mask").click(lambda: generate_random_mask(), outputs=mask_img)
|
| 373 |
|
| 374 |
+
mask_0_img = gr.Image(type="pil", label="Reference Image Mask (mask_0)", height=300,
|
| 375 |
+
tool="sketch") # Add sketch tool
|
| 376 |
+
gr.Button("Generate Random Reference Mask").click(lambda: generate_random_mask(), outputs=mask_0_img)
|
| 377 |
|
| 378 |
+
prompt = gr.Textbox(label="Prompt Text",
|
| 379 |
placeholder="e.g., a broken machine part with rust and cracks")
|
| 380 |
|
| 381 |
with gr.Column(scale=1):
|
| 382 |
+
strength = gr.Slider(0.1, 1.0, value=0.5, label="Denoising Strength")
|
| 383 |
+
ip_scale = gr.Slider(0, 2.0, value=0.3, step=0.1, label="IP Adapter Scale")
|
| 384 |
+
steps = gr.Slider(10, 100, value=20, step=5, label="Inference Steps")
|
| 385 |
+
seed = gr.Slider(0, 2 ** 32 - 1, value=42, step=1, label="Random Seed")
|
| 386 |
|
| 387 |
with gr.Row():
|
| 388 |
+
generate_btn = gr.Button("Generate Image", variant="primary", size="lg") # Enlarge button
|
|
|
|
| 389 |
|
| 390 |
+
output_img = gr.Image(type="pil", label="Generated Anomaly Image", height=400)
|
| 391 |
+
status = gr.Textbox(label="Status", interactive=False)
|
| 392 |
|
| 393 |
+
# Event bindings
|
|
|
|
| 394 |
generate_btn.click(
|
| 395 |
generate_anomaly,
|
| 396 |
inputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed],
|
| 397 |
+
outputs=[output_img, status]
|
| 398 |
)
|
| 399 |
|
| 400 |
+
# Examples section
|
| 401 |
+
gr.Markdown("## Examples")
|
| 402 |
+
gr.Markdown(
|
| 403 |
+
"Click the buttons below to load predefined examples for quick testing. After loading, click 'Generate Image' to view the anomaly synthesis result.")
|
| 404 |
+
with gr.Row():
|
| 405 |
+
for i in range(len(EXAMPLE_PAIRS)):
|
| 406 |
+
with gr.Column():
|
| 407 |
+
ex_btn = gr.Button(f"Example {i + 1}: {EXAMPLE_PAIRS[i]['description']}", variant="secondary")
|
| 408 |
+
ex_btn.click(load_example, inputs=gr.State(i),
|
| 409 |
+
outputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale,
|
| 410 |
+
steps, seed, status])
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
# Clear cache button
|
| 414 |
def clear_cache():
|
| 415 |
global load_status
|
| 416 |
load_status = {"loaded": False, "error": None}
|
| 417 |
+
return "Cache cleared, please restart the app to reload the model."
|
| 418 |
+
|
| 419 |
|
| 420 |
+
clear_btn = gr.Button("Clear Cache", variant="stop")
|
| 421 |
clear_btn.click(clear_cache, outputs=status)
|
| 422 |
|
| 423 |
if __name__ == "__main__":
|