yuxinjiang11 commited on
Commit
74166a2
·
verified ·
1 Parent(s): 8e77a78

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -67
app.py CHANGED
@@ -1,61 +1,62 @@
1
  import os
2
  import sys
3
  import requests
4
- import io # 内存缓冲
5
 
6
- # Spaces环境配置
7
  os.environ['CUDA_VISIBLE_DEVICES'] = '0'
8
 
9
  import time
10
  import random
11
  import numpy as np
12
  import torch
13
- from PIL import Image
14
  from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler
15
- from huggingface_hub import hf_hub_url, login # hf_hub_url 用于生成云 URL
16
  import gradio as gr
17
 
18
- # 尝试导入 Anomagic(如果有 ip_adapter 模块)
19
  try:
20
  from ip_adapter.ip_adapter_anomagic import Anomagic
 
21
  HAS_ANOMAGIC = True
22
  except ImportError:
23
  HAS_ANOMAGIC = False
24
- print("Anomagic 未导入,将使用基础 Inpainting")
25
 
26
- # 获取当前脚本的绝对路径(解决路径问题)
27
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
28
 
29
 
30
  class SingleAnomalyGenerator:
31
  def __init__(self, device="cuda:0"):
32
- # 自动检测GPU并设置dtype
33
  if torch.cuda.is_available() and "cuda" in device:
34
  self.device = torch.device(device)
35
  self.dtype = torch.float16
36
- print(f"使用GPU: {device}, dtype: {self.dtype}")
37
  else:
38
  self.device = torch.device("cpu")
39
  self.dtype = torch.float32
40
- print(f"使用CPU, dtype: {self.dtype}")
41
 
42
  self.anomagic_model = None
43
- self.pipe = None # 保存 pipe 以复用
44
  self.clip_vision_model = None
45
  self.clip_image_processor = None
46
- self.ip_ckpt_path = None # 内存中 IP 权重 state_dict
47
- self.att_ckpt_path = None # 内存中 ATT 权重 state_dict
48
 
49
  def load_models(self):
50
  """Load models with official CLIP"""
51
- print("正在加载VAE...")
52
  from diffusers import AutoencoderKL
53
  vae = AutoencoderKL.from_pretrained(
54
  "stabilityai/sd-vae-ft-mse",
55
  torch_dtype=self.dtype
56
  ).to(self.device)
57
 
58
- print("正在加载基础模型...")
59
  from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, DPMSolverMultistepScheduler
60
 
61
  noise_scheduler = DDIMScheduler(
@@ -80,7 +81,7 @@ class SingleAnomalyGenerator:
80
 
81
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
82
 
83
- print("正在加载CLIP图像编码器...")
84
  from transformers import CLIPVisionModel, CLIPImageProcessor
85
  self.clip_vision_model = CLIPVisionModel.from_pretrained(
86
  "openai/clip-vit-large-patch14",
@@ -88,38 +89,39 @@ class SingleAnomalyGenerator:
88
  ).to(self.device)
89
  self.clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
90
 
91
- print("所有模型加载完成!")
92
 
93
- # 加载权重(从云仓库下载到内存,避免任何磁盘使用)
94
- print("正在加载权重到内存...")
95
  weight_files = [
96
  ("checkpoint/ip_adapter_0.bin", "ip_ckpt_path"),
97
  ("checkpoint/att.bin", "att_ckpt_path")
98
  ]
99
  for filename, attr_name in weight_files:
100
  try:
101
- # 生成云 URLpublic 仓库,无需 token
102
  repo_id = "yuxinjiang11/Anomagic_model"
103
  url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type="model")
104
 
105
- # 动态设置属性(或用if判断显式赋值)
106
  if attr_name == "ip_ckpt_path":
107
  self.ip_ckpt_path = url
108
  elif attr_name == "att_ckpt_path":
109
  self.att_ckpt_path = url
110
 
111
- print(f"权重文件路径: {filename} -> {url}")
112
  except Exception as e:
113
- raise FileNotFoundError(f"无法获取权重文件路径 {filename}: {str(e)}")
114
 
115
- # 如果有 Anomagic,加载权重到模型
116
  if HAS_ANOMAGIC:
117
- print("初始化 Anomagic 模型...")
118
- self.anomagic_model = Anomagic(self.pipe, self.clip_vision_model, self.ip_ckpt_path, self.att_ckpt_path, self.device)
 
119
  else:
120
- print(" Anomagic,使用基础 Pipe")
121
 
122
- print("模型加载完成!")
123
 
124
  def generate_single_image(self, normal_image, reference_image, mask, mask_0, prompt, num_inference_steps=50,
125
  ip_scale=0.3, seed=42, strength=0.3):
@@ -148,10 +150,10 @@ class SingleAnomalyGenerator:
148
  print(f"Generating with seed {seed}...")
149
  torch.manual_seed(seed)
150
 
151
- # 如果有 Anomagic,用它生成;否则基础 Inpainting
152
  if HAS_ANOMAGIC and self.anomagic_model:
153
  # generator = torch.Generator(device=self.device).manual_seed(seed)
154
- # 假设 Anomagic.generate 支持参数(调整根据实际)
155
  generated_image = self.anomagic_model.generate(
156
  pil_image=reference_image,
157
  num_samples=1,
@@ -165,10 +167,10 @@ class SingleAnomalyGenerator:
165
  # generator=generator
166
  )[0]
167
  else:
168
- # 基础 Inpainting
169
  # generator = torch.Generator(device=self.device).manual_seed(seed)
170
  if mask is None:
171
- mask = Image.new('L', target_size, 255) # 全白 mask
172
  generated_image = self.pipe(
173
  prompt=prompt,
174
  image=normal_image,
@@ -181,50 +183,66 @@ class SingleAnomalyGenerator:
181
  return generated_image
182
 
183
 
184
- # 全局 generator 和加载状态
185
  generator = None
186
  load_status = {"loaded": False, "error": None}
187
 
188
 
189
  def load_generator():
190
- """Gradio 加载函数:首次运行时加载模型"""
191
  global generator, load_status
192
 
193
  if load_status["loaded"]:
194
- return "模型已加载完成!"
195
 
196
  if load_status["error"]:
197
- return f"之前加载失败: {load_status['error']}"
198
 
199
  try:
 
200
  generator = SingleAnomalyGenerator()
201
  generator.load_models()
202
  load_status["loaded"] = True
203
- return "模型加载完成!现在可以生成图像。"
 
204
  except Exception as e:
205
  load_status["error"] = str(e)
206
- error_msg = f"模型加载失败: {str(e)}"
207
  print(error_msg)
208
  import traceback
209
  print(traceback.format_exc())
210
  return error_msg
211
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  def generate_anomaly(normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed):
214
- """核心生成函数:Gradio 调用(支持两个mask)"""
215
  global generator
216
 
217
  if not load_status["loaded"]:
218
- return None, "请先点击 '加载模型' 按钮初始化。"
219
 
220
  if normal_img is None or reference_img is None or not prompt.strip():
221
- return None, "请上传正常图片、参考图片,并输入提示文本。"
222
 
223
  if mask_img is None:
224
- return None, "请上传normal image的mask图片。"
225
 
226
  try:
227
- # 设置种子
228
  random.seed(seed)
229
  np.random.seed(seed)
230
  torch.manual_seed(seed)
@@ -241,61 +259,165 @@ def generate_anomaly(normal_img, reference_img, mask_img, mask_0_img, prompt, st
241
  strength=strength
242
  )
243
 
244
- return generated_img, f"生成成功!种子: {seed}, 步数: {steps}"
245
 
246
  except Exception as e:
247
- error_msg = f"生成错误: {str(e)}"
248
  print(error_msg)
249
  import traceback
250
  print(traceback.format_exc())
251
  return None, error_msg
252
 
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  # Gradio UI
255
- with gr.Blocks(title="Anomagic 异常图像生成器") as demo:
256
- gr.Markdown("# Anomagic: 单异常图像生成 Demo")
 
257
  gr.Markdown(
258
- "上传正常图片、参考图片、normal maskreference mask(白色区域为修复/异常生成区),输入提示,调整参数,一键生成合成异常图像。首次使用需加载模型(几分钟)。")
259
 
260
  with gr.Row():
261
  with gr.Column(scale=1):
262
- normal_img = gr.Image(type="pil", label="正常图片 (Normal Image)")
263
- reference_img = gr.Image(type="pil", label="参考图片 (Reference Image)")
 
 
 
 
 
264
 
265
- mask_img = gr.Image(type="pil", label="Normal Image Mask (白色为异常生成区域)")
266
- mask_0_img = gr.Image(type="pil", label="Reference Image Mask (mask_0)")
 
267
 
268
- prompt = gr.Textbox(label="提示文本 (Prompt)",
269
  placeholder="e.g., a broken machine part with rust and cracks")
270
 
271
  with gr.Column(scale=1):
272
- strength = gr.Slider(0.1, 1.0, value=0.5, label="去噪强度 (Strength)")
273
- ip_scale = gr.Slider(0, 2.0, value=0.3, step=0.1, label="IP 适配器缩放 (IP Scale)")
274
- steps = gr.Slider(10, 100, value=20, step=5, label="推理步数 (Steps)")
275
- seed = gr.Slider(0, 2 ** 32 - 1, value=42, step=1, label="随机种子 (Seed)")
276
 
277
  with gr.Row():
278
- load_btn = gr.Button("加载模型 (Load Models)", variant="secondary")
279
- generate_btn = gr.Button("生成图像 (Generate)", variant="primary")
280
 
281
- output_img = gr.Image(type="pil", label="生成的异常图像 (Generated Anomaly)")
282
- status = gr.Textbox(label="状态 (Status)", interactive=False)
283
 
284
- # 事件绑定(修复重复输出问题)
285
- load_btn.click(load_generator, outputs=status)
286
  generate_btn.click(
287
  generate_anomaly,
288
  inputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed],
289
- outputs=[output_img, status] # 修复重复绑定问题
290
  )
291
 
292
- # 清理缓存按钮(简化,因为现在无持久下载)
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  def clear_cache():
294
  global load_status
295
  load_status = {"loaded": False, "error": None}
296
- return "缓存已清理,请重新加载模型。"
 
297
 
298
- clear_btn = gr.Button("清理缓存 (Clear Cache)", variant="stop")
299
  clear_btn.click(clear_cache, outputs=status)
300
 
301
  if __name__ == "__main__":
 
1
  import os
2
  import sys
3
  import requests
4
+ import io # Memory buffer
5
 
6
+ # Spaces environment configuration
7
  os.environ['CUDA_VISIBLE_DEVICES'] = '0'
8
 
9
  import time
10
  import random
11
  import numpy as np
12
  import torch
13
+ from PIL import Image, ImageDraw
14
  from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler
15
+ from huggingface_hub import hf_hub_url, login # hf_hub_url for generating cloud URL
16
  import gradio as gr
17
 
18
+ # Attempt to import Anomagic (if ip_adapter module exists)
19
  try:
20
  from ip_adapter.ip_adapter_anomagic import Anomagic
21
+
22
  HAS_ANOMAGIC = True
23
  except ImportError:
24
  HAS_ANOMAGIC = False
25
+ print("Anomagic not imported, will use basic Inpainting")
26
 
27
+ # Get the absolute path of the current script (to solve path issues)
28
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
29
 
30
 
31
  class SingleAnomalyGenerator:
32
  def __init__(self, device="cuda:0"):
33
+ # Auto-detect GPU and set dtype
34
  if torch.cuda.is_available() and "cuda" in device:
35
  self.device = torch.device(device)
36
  self.dtype = torch.float16
37
+ print(f"Using GPU: {device}, dtype: {self.dtype}")
38
  else:
39
  self.device = torch.device("cpu")
40
  self.dtype = torch.float32
41
+ print(f"Using CPU, dtype: {self.dtype}")
42
 
43
  self.anomagic_model = None
44
+ self.pipe = None # Save pipe for reuse
45
  self.clip_vision_model = None
46
  self.clip_image_processor = None
47
+ self.ip_ckpt_path = None # IP weights state_dict in memory
48
+ self.att_ckpt_path = None # ATT weights state_dict in memory
49
 
50
  def load_models(self):
51
  """Load models with official CLIP"""
52
+ print("Loading VAE...")
53
  from diffusers import AutoencoderKL
54
  vae = AutoencoderKL.from_pretrained(
55
  "stabilityai/sd-vae-ft-mse",
56
  torch_dtype=self.dtype
57
  ).to(self.device)
58
 
59
+ print("Loading base model...")
60
  from diffusers import StableDiffusionInpaintPipelineLegacy, DDIMScheduler, DPMSolverMultistepScheduler
61
 
62
  noise_scheduler = DDIMScheduler(
 
81
 
82
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
83
 
84
+ print("Loading CLIP image encoder...")
85
  from transformers import CLIPVisionModel, CLIPImageProcessor
86
  self.clip_vision_model = CLIPVisionModel.from_pretrained(
87
  "openai/clip-vit-large-patch14",
 
89
  ).to(self.device)
90
  self.clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
91
 
92
+ print("All models loaded!")
93
 
94
+ # Load weights (download from cloud repo to memory, avoid any disk usage)
95
+ print("Loading weights into memory...")
96
  weight_files = [
97
  ("checkpoint/ip_adapter_0.bin", "ip_ckpt_path"),
98
  ("checkpoint/att.bin", "att_ckpt_path")
99
  ]
100
  for filename, attr_name in weight_files:
101
  try:
102
+ # Generate cloud URL (public repo, no token needed)
103
  repo_id = "yuxinjiang11/Anomagic_model"
104
  url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type="model")
105
 
106
+ # Dynamically set attribute (or use if to assign explicitly)
107
  if attr_name == "ip_ckpt_path":
108
  self.ip_ckpt_path = url
109
  elif attr_name == "att_ckpt_path":
110
  self.att_ckpt_path = url
111
 
112
+ print(f"Weight file path: {filename} -> {url}")
113
  except Exception as e:
114
+ raise FileNotFoundError(f"Unable to get weight file path {filename}: {str(e)}")
115
 
116
+ # If Anomagic is available, load weights into the model
117
  if HAS_ANOMAGIC:
118
+ print("Initializing Anomagic model...")
119
+ self.anomagic_model = Anomagic(self.pipe, self.clip_vision_model, self.ip_ckpt_path, self.att_ckpt_path,
120
+ self.device)
121
  else:
122
+ print("No Anomagic, using basic Pipe.")
123
 
124
+ print("Model loading complete!")
125
 
126
  def generate_single_image(self, normal_image, reference_image, mask, mask_0, prompt, num_inference_steps=50,
127
  ip_scale=0.3, seed=42, strength=0.3):
 
150
  print(f"Generating with seed {seed}...")
151
  torch.manual_seed(seed)
152
 
153
+ # If Anomagic is available, use it to generate; otherwise basic Inpainting
154
  if HAS_ANOMAGIC and self.anomagic_model:
155
  # generator = torch.Generator(device=self.device).manual_seed(seed)
156
+ # Assume Anomagic.generate supports parameters (adjust based on actual)
157
  generated_image = self.anomagic_model.generate(
158
  pil_image=reference_image,
159
  num_samples=1,
 
167
  # generator=generator
168
  )[0]
169
  else:
170
+ # Basic Inpainting
171
  # generator = torch.Generator(device=self.device).manual_seed(seed)
172
  if mask is None:
173
+ mask = Image.new('L', target_size, 255) # Full white mask
174
  generated_image = self.pipe(
175
  prompt=prompt,
176
  image=normal_image,
 
183
  return generated_image
184
 
185
 
186
+ # Global generator and load status
187
  generator = None
188
  load_status = {"loaded": False, "error": None}
189
 
190
 
191
  def load_generator():
192
+ """Background load function: Automatically load model on startup"""
193
  global generator, load_status
194
 
195
  if load_status["loaded"]:
196
+ return "Models loaded!"
197
 
198
  if load_status["error"]:
199
+ return f"Previous load failed: {load_status['error']}"
200
 
201
  try:
202
+ print("Starting background model load...")
203
  generator = SingleAnomalyGenerator()
204
  generator.load_models()
205
  load_status["loaded"] = True
206
+ print("Background model load complete!")
207
+ return "Model loading complete! You can now generate images."
208
  except Exception as e:
209
  load_status["error"] = str(e)
210
+ error_msg = f"Model loading failed: {str(e)}"
211
  print(error_msg)
212
  import traceback
213
  print(traceback.format_exc())
214
  return error_msg
215
 
216
 
217
+ def generate_random_mask(size=(512, 512), num_blobs=3, blob_size_range=(50, 150)):
218
+ """Generate random mask: Create several random blobs as anomaly areas"""
219
+ mask = Image.new('L', size, 0) # Black background
220
+ draw = ImageDraw.Draw(mask)
221
+ for _ in range(num_blobs):
222
+ x = random.randint(0, size[0])
223
+ y = random.randint(0, size[1])
224
+ width = random.randint(*blob_size_range)
225
+ height = random.randint(*blob_size_range)
226
+ # Draw elliptical blobs
227
+ draw.ellipse([x - width // 2, y - height // 2, x + width // 2, y + height // 2], fill=255)
228
+ return mask
229
+
230
+
231
  def generate_anomaly(normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed):
232
+ """Core generation function: Called by Gradio (supports two masks)"""
233
  global generator
234
 
235
  if not load_status["loaded"]:
236
+ return None, "Please wait for model loading to complete."
237
 
238
  if normal_img is None or reference_img is None or not prompt.strip():
239
+ return None, "Please upload normal image, reference image, and enter prompt text."
240
 
241
  if mask_img is None:
242
+ return None, "Please upload or generate mask image for normal image."
243
 
244
  try:
245
+ # Set seed
246
  random.seed(seed)
247
  np.random.seed(seed)
248
  torch.manual_seed(seed)
 
259
  strength=strength
260
  )
261
 
262
+ return generated_img, f"Generation successful! Seed: {seed}, Steps: {steps}"
263
 
264
  except Exception as e:
265
+ error_msg = f"Generation error: {str(e)}"
266
  print(error_msg)
267
  import traceback
268
  print(traceback.format_exc())
269
  return None, error_msg
270
 
271
 
272
+ # Predefined anomaly examples (using local image paths; assume images are in examples/ folder in the same directory as the script)
273
+ EXAMPLE_PAIRS = [
274
+ {
275
+ "normal": "examples/normal_leather.png", # Your local normal gear image
276
+ "reference": "examples/reference_leather.png", # Your local rusty gear reference image
277
+ "mask": "examples/normal_mask_leather.png", # Your local mask for normal gear
278
+ "mask_0": "examples/ref_mask_leather.png", # Your local mask for reference gear
279
+ "prompt": "Bagel has a crack running across its surface.",
280
+ "strength": 0.6,
281
+ "ip_scale": 0.1,
282
+ "steps": 20,
283
+ "seed": 42,
284
+ "description": "Bagel has a crack running across its surface."
285
+ },
286
+ {
287
+ "normal": "examples/normal_candle.JPG", # Your local normal gear image
288
+ "reference": "examples/reference_candle.png", # Your local rusty gear reference image
289
+ "mask": "examples/normal_mask_candle.png", # Your local mask for normal gear
290
+ "mask_0": "examples/ref_mask_candle.png", # Your local mask for reference gear
291
+ "prompt": "Chocolate - chip cookie has a chunk - missing defect with exposed inner texture. ",
292
+ "strength": 0.6,
293
+ "ip_scale": 0.1,
294
+ "steps": 20,
295
+ "seed": 42,
296
+ "description": "Chocolate - chip cookie has a chunk - missing defect with exposed inner texture. "
297
+ },
298
+ {
299
+ "normal": "examples/normal_apple.png", # Your local normal gear image
300
+ "reference": "examples/reference_apple.png", # Your local rusty gear reference image
301
+ "mask": "examples/normal_mask_apple.jpg", # Your local mask for normal gear
302
+ "mask_0": "examples/ref_mask_apple.png", # Your local mask for reference gear
303
+ "prompt": "Wood surface has holes with rough - edged circular openings.",
304
+ "strength": 0.6,
305
+ "ip_scale": 0.1,
306
+ "steps": 20,
307
+ "seed": 42,
308
+ "description": "Wood surface has holes with rough - edged circular openings."
309
+ }
310
+ ]
311
+
312
+
313
+ def load_example(idx):
314
+ """Load example: Load images from local path, generate random mask if not provided, and set UI"""
315
+ if idx >= len(EXAMPLE_PAIRS):
316
+ return None, None, None, None, EXAMPLE_PAIRS[idx]["prompt"], EXAMPLE_PAIRS[idx]["strength"], EXAMPLE_PAIRS[idx][
317
+ "ip_scale"], EXAMPLE_PAIRS[idx]["steps"], EXAMPLE_PAIRS[idx][
318
+ "seed"], f"Example {idx + 1}: {EXAMPLE_PAIRS[idx]['description']}"
319
+
320
+ ex = EXAMPLE_PAIRS[idx]
321
+ try:
322
+ # Load normal image
323
+ normal_img = Image.open(ex["normal"]).convert('RGB')
324
+
325
+ # Load reference image
326
+ reference_img = Image.open(ex["reference"]).convert('RGB')
327
+
328
+ # Load or generate normal mask
329
+ if ex["mask"] is not None:
330
+ mask_img = Image.open(ex["mask"]).convert('L')
331
+ else:
332
+ mask_img = generate_random_mask()
333
+
334
+ # Load or generate reference mask (mask_0)
335
+ if ex["mask_0"] is not None:
336
+ mask_0_img = Image.open(ex["mask_0"]).convert('L')
337
+ else:
338
+ mask_0_img = generate_random_mask()
339
+
340
+ return normal_img, reference_img, mask_img, mask_0_img, ex["prompt"], ex["strength"], ex["ip_scale"], ex[
341
+ "steps"], ex["seed"], f"Example {idx + 1}: {ex['description']} loaded!"
342
+ except Exception as e:
343
+ error_msg = f"Example loading failed: {str(e)} (Check if local image paths are correct)"
344
+ print(error_msg)
345
+ # Fallback to placeholder images and random masks
346
+ normal_img = Image.new('RGB', (512, 512), color='gray')
347
+ reference_img = Image.new('RGB', (512, 512), color='blue')
348
+ mask_img = generate_random_mask()
349
+ mask_0_img = generate_random_mask()
350
+ return normal_img, reference_img, mask_img, mask_0_img, ex["prompt"], ex["strength"], ex["ip_scale"], ex[
351
+ "steps"], ex["seed"], error_msg
352
+
353
+
354
+ # Automatically load model on startup
355
+ load_generator()
356
+
357
  # Gradio UI
358
+ with gr.Blocks(title="Anomagic Anomaly Image Generator",
359
+ theme=gr.themes.Soft()) as demo: # Use Soft theme for beautification
360
+ gr.Markdown("# Anomagic: Single Anomaly Image Generation Demo")
361
  gr.Markdown(
362
+ "Upload normal image, reference image, normal mask and reference mask (white areas are for inpainting/anomaly generation), enter prompt, adjust parameters, and generate synthetic anomaly images with one click. Model is loaded in the background.")
363
 
364
  with gr.Row():
365
  with gr.Column(scale=1):
366
+ normal_img = gr.Image(type="pil", label="Normal Image", height=300) # Limit height
367
+ reference_img = gr.Image(type="pil", label="Reference Image", height=300)
368
+
369
+ with gr.Row(): # Mask row: Add buttons
370
+ mask_img = gr.Image(type="pil", label="Normal Image Mask (white for anomaly generation area)",
371
+ height=300, tool="sketch") # Add sketch tool
372
+ gr.Button("Generate Random Normal Mask").click(lambda: generate_random_mask(), outputs=mask_img)
373
 
374
+ mask_0_img = gr.Image(type="pil", label="Reference Image Mask (mask_0)", height=300,
375
+ tool="sketch") # Add sketch tool
376
+ gr.Button("Generate Random Reference Mask").click(lambda: generate_random_mask(), outputs=mask_0_img)
377
 
378
+ prompt = gr.Textbox(label="Prompt Text",
379
  placeholder="e.g., a broken machine part with rust and cracks")
380
 
381
  with gr.Column(scale=1):
382
+ strength = gr.Slider(0.1, 1.0, value=0.5, label="Denoising Strength")
383
+ ip_scale = gr.Slider(0, 2.0, value=0.3, step=0.1, label="IP Adapter Scale")
384
+ steps = gr.Slider(10, 100, value=20, step=5, label="Inference Steps")
385
+ seed = gr.Slider(0, 2 ** 32 - 1, value=42, step=1, label="Random Seed")
386
 
387
  with gr.Row():
388
+ generate_btn = gr.Button("Generate Image", variant="primary", size="lg") # Enlarge button
 
389
 
390
+ output_img = gr.Image(type="pil", label="Generated Anomaly Image", height=400)
391
+ status = gr.Textbox(label="Status", interactive=False)
392
 
393
+ # Event bindings
 
394
  generate_btn.click(
395
  generate_anomaly,
396
  inputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale, steps, seed],
397
+ outputs=[output_img, status]
398
  )
399
 
400
+ # Examples section
401
+ gr.Markdown("## Examples")
402
+ gr.Markdown(
403
+ "Click the buttons below to load predefined examples for quick testing. After loading, click 'Generate Image' to view the anomaly synthesis result.")
404
+ with gr.Row():
405
+ for i in range(len(EXAMPLE_PAIRS)):
406
+ with gr.Column():
407
+ ex_btn = gr.Button(f"Example {i + 1}: {EXAMPLE_PAIRS[i]['description']}", variant="secondary")
408
+ ex_btn.click(load_example, inputs=gr.State(i),
409
+ outputs=[normal_img, reference_img, mask_img, mask_0_img, prompt, strength, ip_scale,
410
+ steps, seed, status])
411
+
412
+
413
+ # Clear cache button
414
  def clear_cache():
415
  global load_status
416
  load_status = {"loaded": False, "error": None}
417
+ return "Cache cleared, please restart the app to reload the model."
418
+
419
 
420
+ clear_btn = gr.Button("Clear Cache", variant="stop")
421
  clear_btn.click(clear_cache, outputs=status)
422
 
423
  if __name__ == "__main__":