Spaces:
Running
Running
File size: 20,485 Bytes
b6b4cfc cba29f4 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 cba29f4 1f63c25 cba29f4 74166a2 cba29f4 74166a2 d0be031 b6b4cfc 1f63c25 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 d0be031 74166a2 b6b4cfc 97dbe85 74166a2 97dbe85 cba29f4 74166a2 97dbe85 cba29f4 b6b4cfc 97dbe85 cba29f4 97dbe85 b6b4cfc 74166a2 d0be031 97dbe85 d0be031 97dbe85 74166a2 d0be031 74166a2 cba29f4 e97455b cba29f4 74166a2 f7f8b26 cba29f4 74166a2 d0f6c7a cba29f4 74166a2 cba29f4 74166a2 cba29f4 74166a2 cba29f4 74166a2 cba29f4 74166a2 d0be031 74166a2 b6b4cfc c95caea b6b4cfc 1f63c25 b6b4cfc 1f63c25 b6b4cfc 74166a2 cba29f4 be92808 74166a2 cba29f4 be92808 cba29f4 be92808 cba29f4 74166a2 be92808 cba29f4 74166a2 cba29f4 be92808 cba29f4 b6b4cfc d0be031 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc d0be031 b6b4cfc 74166a2 65b5df8 74166a2 56ca380 74166a2 65b5df8 74166a2 3e1a471 65b5df8 74166a2 65b5df8 3e1a471 74166a2 41e483a 74166a2 b6b4cfc 777f38f 74166a2 b6b4cfc 74166a2 b6b4cfc ff6f0ad 74166a2 4f07be3 f13fca8 eafa172 afb9bd5 ff6f0ad 4f07be3 7a76e18 f13fca8 afb9bd5 41e483a afb9bd5 4f07be3 f13fca8 eafa172 afb9bd5 ff6f0ad 4f07be3 7a76e18 f13fca8 afb9bd5 41e483a b6b4cfc 74166a2 b6b4cfc 74166a2 b6b4cfc ff6f0ad 41e483a 1f63c25 41e483a ff6f0ad 74166a2 41e483a ff6f0ad 1f63c25 41e483a ff6f0ad 41e483a b6b4cfc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 |
import os
import sys
import requests
import io # Memory buffer
# Spaces environment configuration
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import time
import random
import numpy as np
import torch
from PIL import Image, ImageDraw
from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler
from huggingface_hub import hf_hub_url, login # hf_hub_url for generating cloud URL
import gradio as gr
# Attempt to import Anomagic (if ip_adapter module exists)
try:
from ip_adapter.ip_adapter_anomagic import Anomagic
HAS_ANOMAGIC = True
except ImportError:
HAS_ANOMAGIC = False
print("Anomagic not imported, will use basic Inpainting")
# Get the absolute path of the current script (to solve path issues)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
def extract_image_from_editor_output(editor_output):
"""Extract PIL Image from gr.ImageEditor output (can be dict or PIL Image)"""
if editor_output is None:
return None
# 如果已经是 PIL Image,直接返回
if isinstance(editor_output, Image.Image):
return editor_output
# 如果是字典(gr.ImageEditor 的输出格式)
if isinstance(editor_output, dict):
# gr.ImageEditor 返回格式:{"background": image, "layers": [], "composite": image}
# 优先使用 composite(合成后的图像)
if "composite" in editor_output and editor_output["composite"] is not None:
return editor_output["composite"]
elif "background" in editor_output and editor_output["background"] is not None:
return editor_output["background"]
# 如果是其他格式但可转换为图像
try:
return Image.fromarray(editor_output)
except:
pass
return None
class SingleAnomalyGenerator:
def __init__(self, device="cuda:0"):
# Auto-detect GPU and set dtype
if torch.cuda.is_available() and "cuda" in device:
self.device = torch.device(device)
self.dtype = torch.float16
print(f"Using GPU: {device}, dtype: {self.dtype}")
else:
self.device = torch.device("cpu")
self.dtype = torch.float32
print(f"Using CPU, dtype: {self.dtype}")
self.anomagic_model = None
self.pipe = None # Save pipe for reuse
self.clip_vision_model = None
self.clip_image_processor = None
self.ip_ckpt_path = None # IP weights state_dict in memory
self.att_ckpt_path = None # ATT weights state_dict in memory
def load_models(self):
"""Load models with official CLIP"""
print("Loading VAE...")
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
torch_dtype=self.dtype
).to(self.device)
print("Loading base model...")
from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, DPMSolverMultistepScheduler
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
self.pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
"SG161222/Realistic_Vision_V4.0_noVAE",
torch_dtype=self.dtype,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None,
low_cpu_mem_usage=True
).to(self.device, dtype=self.dtype)
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
print("Loading CLIP image encoder...")
from transformers import CLIPVisionModel, CLIPImageProcessor
self.clip_vision_model = CLIPVisionModel.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=self.dtype
).to(self.device)
self.clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
print("All models loaded!")
# Load weights (download from cloud repo to memory, avoid any disk usage)
print("Loading weights into memory...")
weight_files = [
("checkpoint/anomagic.bin", "ip_ckpt_path"),
("checkpoint/attention_module.bin", "att_ckpt_path")
]
for filename, attr_name in weight_files:
try:
# Generate cloud URL (public repo, no token needed)
repo_id = "yuxinjiang11/Anomagic_model"
url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type="model")
# Dynamically set attribute (or use if to assign explicitly)
if attr_name == "ip_ckpt_path":
self.ip_ckpt_path = url
elif attr_name == "att_ckpt_path":
self.att_ckpt_path = url
print(f"Weight file path: {filename} -> {url}")
except Exception as e:
raise FileNotFoundError(f"Unable to get weight file path {filename}: {str(e)}")
# If Anomagic is available, load weights into the model
if HAS_ANOMAGIC:
print("Initializing Anomagic model...")
self.anomagic_model = Anomagic(self.pipe, self.clip_vision_model, self.ip_ckpt_path, self.att_ckpt_path,
self.device)
else:
print("No Anomagic, using basic Pipe.")
print("Model loading complete!")
def generate_single_image(self, normal_image, reference_image, mask, mask_0, prompt, num_inference_steps=50,
ip_scale=0.3, seed=42, strength=0.3):
"""Generate anomaly image with mask_0 support for reference image mask."""
if normal_image is None or reference_image is None:
raise ValueError("Normal or reference image is None. Please upload valid images.")
target_size = (512, 512)
normal_image = normal_image.resize(target_size)
reference_image = reference_image.resize(target_size)
# Process normal image mask
if mask is not None:
mask = extract_image_from_editor_output(mask)
if mask is not None and isinstance(mask, Image.Image):
mask = mask.resize(target_size)
mask = mask.convert('L')
mask = np.array(mask) > 0
mask = Image.fromarray(mask.astype(np.uint8) * 255).convert('L')
# Process reference image mask (mask_0)
if mask_0 is not None:
mask_0 = extract_image_from_editor_output(mask_0)
if mask_0 is not None and isinstance(mask_0, Image.Image):
mask_0 = mask_0.resize(target_size)
mask_0 = mask_0.convert('L')
mask_0 = np.array(mask_0) > 0
mask_0 = Image.fromarray(mask_0.astype(np.uint8) * 255).convert('L')
print(f"Generating with seed {seed}...")
torch.manual_seed(seed)
# If Anomagic is available, use it to generate; otherwise basic Inpainting
if HAS_ANOMAGIC and self.anomagic_model:
# generator = torch.Generator(device=self.device).manual_seed(seed)
# Assume Anomagic.generate supports parameters (adjust based on actual)
generated_image = self.anomagic_model.generate(
pil_image=reference_image,
num_samples=1,
num_inference_steps=num_inference_steps,
prompt=prompt,
scale=ip_scale,
image=normal_image,
mask_image=mask,
mask_image_0=mask_0, # Reference image mask
strength=strength,
# generator=generator
)[0]
else:
# Basic Inpainting
# generator = torch.Generator(device=self.device).manual_seed(seed)
if mask is None:
mask = Image.new('L', target_size, 255) # Full white mask
generated_image = self.pipe(
prompt=prompt,
image=normal_image,
mask_image=mask,
strength=strength,
num_inference_steps=num_inference_steps,
# generator=generator,
).images[0]
return generated_image
# Global generator and load status
generator = None
load_status = {"loaded": False, "error": None}
def load_generator():
"""Background load function: Automatically load model on startup"""
global generator, load_status
if load_status["loaded"]:
return "Models loaded!"
if load_status["error"]:
return f"Previous load failed: {load_status['error']}"
try:
print("Starting background model load...")
generator = SingleAnomalyGenerator()
generator.load_models()
load_status["loaded"] = True
print("Background model load complete!")
return "Model loading complete! You can now generate images."
except Exception as e:
load_status["error"] = str(e)
error_msg = f"Model loading failed: {str(e)}"
print(error_msg)
import traceback
print(traceback.format_exc())
return error_msg
def generate_random_mask(size=(512, 512), num_blobs=3, blob_size_range=(50, 150)):
"""Generate random mask: Create several random blobs as anomaly areas"""
mask = Image.new('L', size, 0) # Black background
draw = ImageDraw.Draw(mask)
for _ in range(num_blobs):
x = random.randint(0, size[0])
y = random.randint(0, size[1])
width = random.randint(*blob_size_range)
height = random.randint(*blob_size_range)
# Draw elliptical blobs
draw.ellipse([x - width // 2, y - height // 2, x + width // 2, y + height // 2], fill=255)
return mask
def generate_anomaly(normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed):
"""Core generation function: Called by Gradio (supports two masks)"""
global generator
if not load_status["loaded"]:
return None, "Please wait for model loading to complete."
if normal_img is None or reference_img is None or not prompt.strip():
return None, "Please upload normal image, reference image, and enter prompt text."
if mask_img is None:
return None, "Please upload or generate mask image for normal image."
try:
# Set seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
generated_img = generator.generate_single_image(
normal_image=normal_img,
reference_image=reference_img,
mask=mask_img,
mask_0=mask_0_img,
prompt=prompt,
num_inference_steps=steps,
ip_scale=ip_scale,
seed=seed,
strength=strength
)
return generated_img, f"Generation successful! Seed: {seed}, Steps: {steps}"
except Exception as e:
error_msg = f"Generation error: {str(e)}"
print(error_msg)
import traceback
print(traceback.format_exc())
return None, error_msg
# Predefined anomaly examples (using local image paths; assume images are in examples/ folder in the same directory as the script)
EXAMPLE_PAIRS = [
{
"normal": "examples/normal_apple.png", # Your local normal gear image
"reference": "examples/reference_apple.png", # Your local rusty gear reference image
"mask": "examples/normal_mask_apple.jpg", # Your local mask for normal gear
"mask_0": "examples/ref_mask_apple.png", # Your local mask for reference gear
"prompt": "Wood surface has holes with rough - edged circular openings.",
"strength": 0.6,
"ip_scale": 0.1,
"steps": 20,
"seed": 42,
"description": "Apple with wormholes and rough edges"
},
{
"normal": "examples/normal_candle.JPG", # Your local normal gear image
"reference": "examples/reference_candle.png", # Your local rusty gear reference image
"mask": "examples/normal_mask_candle.png", # Your local mask for normal gear
"mask_0": "examples/ref_mask_candle.png", # Your local mask for reference gear
"prompt": "Chocolate - chip cookie has a chunk - missing defect with exposed inner texture. ",
"strength": 0.6,
"ip_scale": 1,
"steps": 20,
"seed": 42,
"description": "Candle with deformed surface"
},
{
"normal": "examples/normal_wood.png", # Your local normal gear image
"reference": "examples/reference_wood.png", # Your local rusty gear reference image
"mask": "examples/normal_mask_wood.png", # Your local mask for normal gear
"mask_0": "examples/ref_mask_wood.png", # Your local mask for reference gear
"prompt": "Wood surface has a crack with a long, dark - hued split.",
"strength": 0.6,
"ip_scale": 0.1,
"steps": 20,
"seed": 42,
"description": "Wood with long dark crack and split"
},
]
def load_example(idx):
"""Load example: Load images from local path, generate random mask if not provided, and set UI"""
if idx >= len(EXAMPLE_PAIRS):
return None, None, None, None, "", 0.5, 0.3, 20, 42, f"Example {idx + 1} not found"
ex = EXAMPLE_PAIRS[idx]
try:
# Load normal image
normal_img = Image.open(ex["normal"]).convert('RGB')
# Load reference image
reference_img = Image.open(ex["reference"]).convert('RGB')
# Load or generate normal mask
if ex["mask"] is not None:
mask_img = Image.open(ex["mask"]).convert('L')
else:
mask_img = generate_random_mask()
# Load or generate reference mask (mask_0)
if ex["mask_0"] is not None:
mask_0_img = Image.open(ex["mask_0"]).convert('L')
else:
mask_0_img = generate_random_mask()
return normal_img, reference_img, mask_img, mask_0_img, ex["prompt"], ex["strength"], ex["ip_scale"], ex[
"steps"], ex["seed"], f"Example {idx + 1}: {ex['description']} loaded!"
except Exception as e:
error_msg = f"Example loading failed: {str(e)} (Check if local image paths are correct)"
print(error_msg)
# Fallback to placeholder images and random masks
normal_img = Image.new('RGB', (512, 512), color='gray')
reference_img = Image.new('RGB', (512, 512), color='blue')
mask_img = generate_random_mask()
mask_0_img = generate_random_mask()
return normal_img, reference_img, mask_img, mask_0_img, ex["prompt"], ex["strength"], ex["ip_scale"], ex[
"steps"], ex["seed"], error_msg
# Automatically load model on startup
load_generator()
# Gradio UI
with gr.Blocks(title="Anomagic Anomaly Image Generator") as demo: # Removed theme to fix compatibility
gr.Markdown("# Anomagic: Single Anomaly Image Generation Demo")
gr.Markdown(
"Upload normal image, reference image, normal mask and reference mask (white areas are for inpainting/anomaly generation), enter prompt, adjust parameters, and generate synthetic anomaly images with one click. Model is loaded in the background.")
with gr.Row():
with gr.Column(scale=1):
normal_img = gr.Image(type="pil", label="Normal Image", height=256) # Limit height
reference_img = gr.Image(type="pil", label="Reference Image", height=256)
with gr.Row(): # Mask row: Add buttons
mask_img = gr.ImageEditor(
type="pil",
sources=['upload', 'webcam', 'clipboard'],
label="Normal Image Mask (draw white anomaly areas on black background)",
height=256,
interactive=True,
brush=gr.Brush(default_color="white", default_size=15, color_mode="fixed"),
value=Image.new('L', (512, 512), 0) # Initial black canvas
)
with gr.Row():
generate_mask_btn = gr.Button("Generate Random Normal Mask", variant="secondary")
clear_mask_btn = gr.Button("Clear Normal Mask", variant="secondary")
mask_0_img = gr.ImageEditor(
type="pil",
sources=['upload', 'webcam', 'clipboard'],
label="Reference Image Mask (draw white areas on black background)",
height=256,
interactive=True,
brush=gr.Brush(default_color="white", default_size=15, color_mode="fixed"),
value=Image.new('L', (512, 512), 0) # Initial black canvas
)
with gr.Row():
generate_mask_0_btn = gr.Button("Generate Random Reference Mask", variant="secondary")
clear_mask_0_btn = gr.Button("Clear Reference Mask", variant="secondary")
prompt = gr.Textbox(label="Prompt Text",
placeholder="e.g., a broken machine part with rust and cracks")
with gr.Column(scale=1):
strength = gr.Slider(0.1, 1.0, value=0.5, label="Denoising Strength")
ip_scale = gr.Slider(0, 2.0, value=0.3, step=0.1, label="IP Adapter Scale")
steps = gr.Slider(10, 100, value=20, step=5, label="Inference Steps")
seed = gr.Slider(0, 2 ** 32 - 1, value=42, step=1, label="Random Seed")
gr.Markdown("## Examples")
gr.Markdown(
"Click the buttons below to load predefined examples for quick testing. After loading, click 'Generate Image' to view the anomaly synthesis result.")
# Create example buttons
example_buttons = []
for i in range(len(EXAMPLE_PAIRS)):
example_btn = gr.Button(f"Example {i + 1}: {EXAMPLE_PAIRS[i]['description']}", variant="secondary")
example_buttons.append(example_btn)
# 定义输出组件
output_img = gr.Image(type="pil", label="Generated Anomaly Image", height=256)
status = gr.Textbox(label="Status", interactive=False)
generate_btn = gr.Button("Generate Image", variant="primary", size="lg") # Enlarge button
# Clear cache button
def clear_cache():
global load_status
load_status = {"loaded": False, "error": None}
return "Cache cleared, please restart the app to reload the model."
clear_btn = gr.Button("Clear Cache", variant="stop")
# 连接所有事件处理函数(在组件定义之后)
# 生成按钮点击事件
generate_btn.click(
generate_anomaly,
inputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed],
outputs=[output_img, status]
)
# 清除缓存按钮点击事件
clear_btn.click(clear_cache, outputs=status)
# 示例按钮点击事件(为每个按钮单独连接)
for i, btn in enumerate(example_buttons):
btn.click(lambda idx=i: load_example(idx),
outputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed,
status])
# 掩码按钮点击事件
generate_mask_btn.click(lambda: generate_random_mask(), outputs=mask_img)
clear_mask_btn.click(lambda: Image.new('L', (512, 512), 0), outputs=mask_img)
generate_mask_0_btn.click(lambda: generate_random_mask(), outputs=mask_0_img)
clear_mask_0_btn.click(lambda: Image.new('L', (512, 512), 0), outputs=mask_0_img)
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(server_name="0.0.0.0", server_port=7860)
# demo.launch(share=True) |