File size: 20,485 Bytes
b6b4cfc
 
cba29f4
74166a2
b6b4cfc
74166a2
b6b4cfc
 
 
 
 
 
74166a2
b6b4cfc
74166a2
b6b4cfc
 
74166a2
cba29f4
 
1f63c25
cba29f4
 
 
74166a2
cba29f4
74166a2
d0be031
 
b6b4cfc
1f63c25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6b4cfc
 
74166a2
b6b4cfc
 
 
74166a2
b6b4cfc
 
 
74166a2
b6b4cfc
 
74166a2
d0be031
 
74166a2
 
b6b4cfc
 
97dbe85
74166a2
97dbe85
 
 
 
 
cba29f4
74166a2
97dbe85
cba29f4
b6b4cfc
 
 
 
 
 
 
 
 
 
97dbe85
 
cba29f4
97dbe85
 
 
 
 
 
b6b4cfc
 
 
74166a2
d0be031
 
97dbe85
 
d0be031
97dbe85
 
74166a2
d0be031
74166a2
 
cba29f4
e97455b
 
cba29f4
 
 
74166a2
f7f8b26
cba29f4
 
74166a2
d0f6c7a
 
 
 
cba29f4
74166a2
cba29f4
74166a2
cba29f4
74166a2
cba29f4
74166a2
 
 
cba29f4
74166a2
d0be031
74166a2
b6b4cfc
 
 
 
 
 
 
c95caea
b6b4cfc
 
 
 
1f63c25
 
 
 
 
 
 
b6b4cfc
 
1f63c25
 
 
 
 
 
 
b6b4cfc
 
 
 
74166a2
cba29f4
be92808
74166a2
cba29f4
 
 
 
 
 
 
be92808
 
cba29f4
be92808
cba29f4
 
74166a2
be92808
cba29f4
74166a2
cba29f4
 
 
 
 
 
be92808
cba29f4
b6b4cfc
d0be031
b6b4cfc
 
74166a2
b6b4cfc
 
 
 
 
74166a2
b6b4cfc
 
 
74166a2
b6b4cfc
 
74166a2
b6b4cfc
 
74166a2
b6b4cfc
 
 
74166a2
 
b6b4cfc
 
74166a2
b6b4cfc
 
 
 
 
 
74166a2
 
 
 
 
 
 
 
 
 
 
 
 
 
b6b4cfc
74166a2
b6b4cfc
 
 
74166a2
b6b4cfc
 
74166a2
b6b4cfc
 
74166a2
b6b4cfc
 
74166a2
b6b4cfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74166a2
b6b4cfc
 
74166a2
b6b4cfc
 
 
d0be031
b6b4cfc
 
74166a2
 
65b5df8
 
 
 
 
 
 
 
 
 
 
 
74166a2
 
 
 
 
 
 
56ca380
74166a2
 
65b5df8
74166a2
3e1a471
65b5df8
 
 
 
 
74166a2
 
 
 
65b5df8
3e1a471
74166a2
 
 
 
 
 
41e483a
74166a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6b4cfc
777f38f
74166a2
b6b4cfc
74166a2
b6b4cfc
 
 
ff6f0ad
 
74166a2
 
4f07be3
f13fca8
eafa172
afb9bd5
ff6f0ad
4f07be3
7a76e18
f13fca8
afb9bd5
 
41e483a
 
afb9bd5
4f07be3
f13fca8
eafa172
afb9bd5
ff6f0ad
4f07be3
7a76e18
f13fca8
afb9bd5
 
41e483a
 
b6b4cfc
74166a2
b6b4cfc
 
 
74166a2
 
 
 
b6b4cfc
ff6f0ad
 
 
 
41e483a
 
 
 
 
1f63c25
41e483a
ff6f0ad
 
74166a2
41e483a
 
 
ff6f0ad
 
 
 
 
1f63c25
41e483a
ff6f0ad
41e483a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6b4cfc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
import os
import sys
import requests
import io  # Memory buffer

# Spaces environment configuration
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import time
import random
import numpy as np
import torch
from PIL import Image, ImageDraw
from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler
from huggingface_hub import hf_hub_url, login  # hf_hub_url for generating cloud URL
import gradio as gr

# Attempt to import Anomagic (if ip_adapter module exists)
try:
    from ip_adapter.ip_adapter_anomagic import Anomagic

    HAS_ANOMAGIC = True
except ImportError:
    HAS_ANOMAGIC = False
    print("Anomagic not imported, will use basic Inpainting")

# Get the absolute path of the current script (to solve path issues)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))


def extract_image_from_editor_output(editor_output):
    """Extract PIL Image from gr.ImageEditor output (can be dict or PIL Image)"""
    if editor_output is None:
        return None

    # 如果已经是 PIL Image,直接返回
    if isinstance(editor_output, Image.Image):
        return editor_output

    # 如果是字典(gr.ImageEditor 的输出格式)
    if isinstance(editor_output, dict):
        # gr.ImageEditor 返回格式:{"background": image, "layers": [], "composite": image}
        # 优先使用 composite(合成后的图像)
        if "composite" in editor_output and editor_output["composite"] is not None:
            return editor_output["composite"]
        elif "background" in editor_output and editor_output["background"] is not None:
            return editor_output["background"]

    # 如果是其他格式但可转换为图像
    try:
        return Image.fromarray(editor_output)
    except:
        pass

    return None


class SingleAnomalyGenerator:
    def __init__(self, device="cuda:0"):
        # Auto-detect GPU and set dtype
        if torch.cuda.is_available() and "cuda" in device:
            self.device = torch.device(device)
            self.dtype = torch.float16
            print(f"Using GPU: {device}, dtype: {self.dtype}")
        else:
            self.device = torch.device("cpu")
            self.dtype = torch.float32
            print(f"Using CPU, dtype: {self.dtype}")

        self.anomagic_model = None
        self.pipe = None  # Save pipe for reuse
        self.clip_vision_model = None
        self.clip_image_processor = None
        self.ip_ckpt_path = None  # IP weights state_dict in memory
        self.att_ckpt_path = None  # ATT weights state_dict in memory

    def load_models(self):
        """Load models with official CLIP"""
        print("Loading VAE...")
        from diffusers import AutoencoderKL
        vae = AutoencoderKL.from_pretrained(
            "stabilityai/sd-vae-ft-mse",
            torch_dtype=self.dtype
        ).to(self.device)

        print("Loading base model...")
        from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, DPMSolverMultistepScheduler

        noise_scheduler = DDIMScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False,
            set_alpha_to_one=False,
            steps_offset=1,
        )

        self.pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
            "SG161222/Realistic_Vision_V4.0_noVAE",
            torch_dtype=self.dtype,
            scheduler=noise_scheduler,
            vae=vae,
            feature_extractor=None,
            safety_checker=None,
            low_cpu_mem_usage=True
        ).to(self.device, dtype=self.dtype)

        self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)

        print("Loading CLIP image encoder...")
        from transformers import CLIPVisionModel, CLIPImageProcessor
        self.clip_vision_model = CLIPVisionModel.from_pretrained(
            "openai/clip-vit-large-patch14",
            torch_dtype=self.dtype
        ).to(self.device)
        self.clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")

        print("All models loaded!")

        # Load weights (download from cloud repo to memory, avoid any disk usage)
        print("Loading weights into memory...")
        weight_files = [
            ("checkpoint/anomagic.bin", "ip_ckpt_path"),
            ("checkpoint/attention_module.bin", "att_ckpt_path")
        ]
        for filename, attr_name in weight_files:
            try:
                # Generate cloud URL (public repo, no token needed)
                repo_id = "yuxinjiang11/Anomagic_model"
                url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type="model")

                # Dynamically set attribute (or use if to assign explicitly)
                if attr_name == "ip_ckpt_path":
                    self.ip_ckpt_path = url
                elif attr_name == "att_ckpt_path":
                    self.att_ckpt_path = url

                print(f"Weight file path: {filename} -> {url}")
            except Exception as e:
                raise FileNotFoundError(f"Unable to get weight file path {filename}: {str(e)}")

        # If Anomagic is available, load weights into the model
        if HAS_ANOMAGIC:
            print("Initializing Anomagic model...")
            self.anomagic_model = Anomagic(self.pipe, self.clip_vision_model, self.ip_ckpt_path, self.att_ckpt_path,
                                           self.device)
        else:
            print("No Anomagic, using basic Pipe.")

        print("Model loading complete!")

    def generate_single_image(self, normal_image, reference_image, mask, mask_0, prompt, num_inference_steps=50,

                              ip_scale=0.3, seed=42, strength=0.3):
        """Generate anomaly image with mask_0 support for reference image mask."""
        if normal_image is None or reference_image is None:
            raise ValueError("Normal or reference image is None. Please upload valid images.")

        target_size = (512, 512)
        normal_image = normal_image.resize(target_size)
        reference_image = reference_image.resize(target_size)

        # Process normal image mask
        if mask is not None:
            mask = extract_image_from_editor_output(mask)
            if mask is not None and isinstance(mask, Image.Image):
                mask = mask.resize(target_size)
                mask = mask.convert('L')
                mask = np.array(mask) > 0
                mask = Image.fromarray(mask.astype(np.uint8) * 255).convert('L')

        # Process reference image mask (mask_0)
        if mask_0 is not None:
            mask_0 = extract_image_from_editor_output(mask_0)
            if mask_0 is not None and isinstance(mask_0, Image.Image):
                mask_0 = mask_0.resize(target_size)
                mask_0 = mask_0.convert('L')
                mask_0 = np.array(mask_0) > 0
                mask_0 = Image.fromarray(mask_0.astype(np.uint8) * 255).convert('L')

        print(f"Generating with seed {seed}...")
        torch.manual_seed(seed)

        # If Anomagic is available, use it to generate; otherwise basic Inpainting
        if HAS_ANOMAGIC and self.anomagic_model:
            # generator = torch.Generator(device=self.device).manual_seed(seed)
            # Assume Anomagic.generate supports parameters (adjust based on actual)
            generated_image = self.anomagic_model.generate(
                pil_image=reference_image,
                num_samples=1,
                num_inference_steps=num_inference_steps,
                prompt=prompt,
                scale=ip_scale,
                image=normal_image,
                mask_image=mask,
                mask_image_0=mask_0,  # Reference image mask
                strength=strength,
                # generator=generator
            )[0]
        else:
            # Basic Inpainting
            # generator = torch.Generator(device=self.device).manual_seed(seed)
            if mask is None:
                mask = Image.new('L', target_size, 255)  # Full white mask
            generated_image = self.pipe(
                prompt=prompt,
                image=normal_image,
                mask_image=mask,
                strength=strength,
                num_inference_steps=num_inference_steps,
                # generator=generator,
            ).images[0]

        return generated_image


# Global generator and load status
generator = None
load_status = {"loaded": False, "error": None}


def load_generator():
    """Background load function: Automatically load model on startup"""
    global generator, load_status

    if load_status["loaded"]:
        return "Models loaded!"

    if load_status["error"]:
        return f"Previous load failed: {load_status['error']}"

    try:
        print("Starting background model load...")
        generator = SingleAnomalyGenerator()
        generator.load_models()
        load_status["loaded"] = True
        print("Background model load complete!")
        return "Model loading complete! You can now generate images."
    except Exception as e:
        load_status["error"] = str(e)
        error_msg = f"Model loading failed: {str(e)}"
        print(error_msg)
        import traceback
        print(traceback.format_exc())
        return error_msg


def generate_random_mask(size=(512, 512), num_blobs=3, blob_size_range=(50, 150)):
    """Generate random mask: Create several random blobs as anomaly areas"""
    mask = Image.new('L', size, 0)  # Black background
    draw = ImageDraw.Draw(mask)
    for _ in range(num_blobs):
        x = random.randint(0, size[0])
        y = random.randint(0, size[1])
        width = random.randint(*blob_size_range)
        height = random.randint(*blob_size_range)
        # Draw elliptical blobs
        draw.ellipse([x - width // 2, y - height // 2, x + width // 2, y + height // 2], fill=255)
    return mask


def generate_anomaly(normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed):
    """Core generation function: Called by Gradio (supports two masks)"""
    global generator

    if not load_status["loaded"]:
        return None, "Please wait for model loading to complete."

    if normal_img is None or reference_img is None or not prompt.strip():
        return None, "Please upload normal image, reference image, and enter prompt text."

    if mask_img is None:
        return None, "Please upload or generate mask image for normal image."

    try:
        # Set seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        generated_img = generator.generate_single_image(
            normal_image=normal_img,
            reference_image=reference_img,
            mask=mask_img,
            mask_0=mask_0_img,
            prompt=prompt,
            num_inference_steps=steps,
            ip_scale=ip_scale,
            seed=seed,
            strength=strength
        )

        return generated_img, f"Generation successful! Seed: {seed}, Steps: {steps}"

    except Exception as e:
        error_msg = f"Generation error: {str(e)}"
        print(error_msg)
        import traceback
        print(traceback.format_exc())
        return None, error_msg


# Predefined anomaly examples (using local image paths; assume images are in examples/ folder in the same directory as the script)
EXAMPLE_PAIRS = [
    {
        "normal": "examples/normal_apple.png",  # Your local normal gear image
        "reference": "examples/reference_apple.png",  # Your local rusty gear reference image
        "mask": "examples/normal_mask_apple.jpg",  # Your local mask for normal gear
        "mask_0": "examples/ref_mask_apple.png",  # Your local mask for reference gear
        "prompt": "Wood surface has holes with rough - edged circular openings.",
        "strength": 0.6,
        "ip_scale": 0.1,
        "steps": 20,
        "seed": 42,
        "description": "Apple with wormholes and rough edges"
    },
    {
        "normal": "examples/normal_candle.JPG",  # Your local normal gear image
        "reference": "examples/reference_candle.png",  # Your local rusty gear reference image
        "mask": "examples/normal_mask_candle.png",  # Your local mask for normal gear
        "mask_0": "examples/ref_mask_candle.png",  # Your local mask for reference gear
        "prompt": "Chocolate - chip cookie has a chunk - missing defect with exposed inner texture. ",
        "strength": 0.6,
        "ip_scale": 1,
        "steps": 20,
        "seed": 42,
        "description": "Candle with deformed surface"
    },
    {
        "normal": "examples/normal_wood.png",  # Your local normal gear image
        "reference": "examples/reference_wood.png",  # Your local rusty gear reference image
        "mask": "examples/normal_mask_wood.png",  # Your local mask for normal gear
        "mask_0": "examples/ref_mask_wood.png",  # Your local mask for reference gear
        "prompt": "Wood surface has a crack with a long, dark - hued split.",
        "strength": 0.6,
        "ip_scale": 0.1,
        "steps": 20,
        "seed": 42,
        "description": "Wood with long dark crack and split"
    },
]


def load_example(idx):
    """Load example: Load images from local path, generate random mask if not provided, and set UI"""
    if idx >= len(EXAMPLE_PAIRS):
        return None, None, None, None, "", 0.5, 0.3, 20, 42, f"Example {idx + 1} not found"

    ex = EXAMPLE_PAIRS[idx]
    try:
        # Load normal image
        normal_img = Image.open(ex["normal"]).convert('RGB')

        # Load reference image
        reference_img = Image.open(ex["reference"]).convert('RGB')

        # Load or generate normal mask
        if ex["mask"] is not None:
            mask_img = Image.open(ex["mask"]).convert('L')
        else:
            mask_img = generate_random_mask()

        # Load or generate reference mask (mask_0)
        if ex["mask_0"] is not None:
            mask_0_img = Image.open(ex["mask_0"]).convert('L')
        else:
            mask_0_img = generate_random_mask()

        return normal_img, reference_img, mask_img, mask_0_img, ex["prompt"], ex["strength"], ex["ip_scale"], ex[
            "steps"], ex["seed"], f"Example {idx + 1}: {ex['description']} loaded!"
    except Exception as e:
        error_msg = f"Example loading failed: {str(e)} (Check if local image paths are correct)"
        print(error_msg)
        # Fallback to placeholder images and random masks
        normal_img = Image.new('RGB', (512, 512), color='gray')
        reference_img = Image.new('RGB', (512, 512), color='blue')
        mask_img = generate_random_mask()
        mask_0_img = generate_random_mask()
        return normal_img, reference_img, mask_img, mask_0_img, ex["prompt"], ex["strength"], ex["ip_scale"], ex[
            "steps"], ex["seed"], error_msg


# Automatically load model on startup
load_generator()

# Gradio UI
with gr.Blocks(title="Anomagic Anomaly Image Generator") as demo:  # Removed theme to fix compatibility
    gr.Markdown("# Anomagic: Single Anomaly Image Generation Demo")
    gr.Markdown(
        "Upload normal image, reference image, normal mask and reference mask (white areas are for inpainting/anomaly generation), enter prompt, adjust parameters, and generate synthetic anomaly images with one click. Model is loaded in the background.")

    with gr.Row():
        with gr.Column(scale=1):
            normal_img = gr.Image(type="pil", label="Normal Image", height=256)  # Limit height
            reference_img = gr.Image(type="pil", label="Reference Image", height=256)

            with gr.Row():  # Mask row: Add buttons
                mask_img = gr.ImageEditor(
                    type="pil",
                    sources=['upload', 'webcam', 'clipboard'],
                    label="Normal Image Mask (draw white anomaly areas on black background)",
                    height=256,
                    interactive=True,
                    brush=gr.Brush(default_color="white", default_size=15, color_mode="fixed"),
                    value=Image.new('L', (512, 512), 0)  # Initial black canvas
                )
                with gr.Row():
                    generate_mask_btn = gr.Button("Generate Random Normal Mask", variant="secondary")
                    clear_mask_btn = gr.Button("Clear Normal Mask", variant="secondary")

            mask_0_img = gr.ImageEditor(
                type="pil",
                sources=['upload', 'webcam', 'clipboard'],
                label="Reference Image Mask (draw white areas on black background)",
                height=256,
                interactive=True,
                brush=gr.Brush(default_color="white", default_size=15, color_mode="fixed"),
                value=Image.new('L', (512, 512), 0)  # Initial black canvas
            )
            with gr.Row():
                generate_mask_0_btn = gr.Button("Generate Random Reference Mask", variant="secondary")
                clear_mask_0_btn = gr.Button("Clear Reference Mask", variant="secondary")

            prompt = gr.Textbox(label="Prompt Text",
                                placeholder="e.g., a broken machine part with rust and cracks")

        with gr.Column(scale=1):
            strength = gr.Slider(0.1, 1.0, value=0.5, label="Denoising Strength")
            ip_scale = gr.Slider(0, 2.0, value=0.3, step=0.1, label="IP Adapter Scale")
            steps = gr.Slider(10, 100, value=20, step=5, label="Inference Steps")
            seed = gr.Slider(0, 2 ** 32 - 1, value=42, step=1, label="Random Seed")

            gr.Markdown("## Examples")
            gr.Markdown(
                "Click the buttons below to load predefined examples for quick testing. After loading, click 'Generate Image' to view the anomaly synthesis result.")

            # Create example buttons
            example_buttons = []
            for i in range(len(EXAMPLE_PAIRS)):
                example_btn = gr.Button(f"Example {i + 1}: {EXAMPLE_PAIRS[i]['description']}", variant="secondary")
                example_buttons.append(example_btn)

            # 定义输出组件
            output_img = gr.Image(type="pil", label="Generated Anomaly Image", height=256)
            status = gr.Textbox(label="Status", interactive=False)

            generate_btn = gr.Button("Generate Image", variant="primary", size="lg")  # Enlarge button


            # Clear cache button
            def clear_cache():
                global load_status
                load_status = {"loaded": False, "error": None}
                return "Cache cleared, please restart the app to reload the model."


            clear_btn = gr.Button("Clear Cache", variant="stop")

    # 连接所有事件处理函数(在组件定义之后)

    # 生成按钮点击事件
    generate_btn.click(
        generate_anomaly,
        inputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed],
        outputs=[output_img, status]
    )

    # 清除缓存按钮点击事件
    clear_btn.click(clear_cache, outputs=status)

    # 示例按钮点击事件(为每个按钮单独连接)
    for i, btn in enumerate(example_buttons):
        btn.click(lambda idx=i: load_example(idx),
                  outputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed,
                           status])

    # 掩码按钮点击事件
    generate_mask_btn.click(lambda: generate_random_mask(), outputs=mask_img)
    clear_mask_btn.click(lambda: Image.new('L', (512, 512), 0), outputs=mask_img)
    generate_mask_0_btn.click(lambda: generate_random_mask(), outputs=mask_0_img)
    clear_mask_0_btn.click(lambda: Image.new('L', (512, 512), 0), outputs=mask_0_img)

if __name__ == "__main__":
    demo.queue(max_size=10)
    demo.launch(server_name="0.0.0.0", server_port=7860)
    # demo.launch(share=True)