jree423 commited on
Commit
173b929
·
verified ·
1 Parent(s): f992885

Update with actual DiffSketchEdit model integration and comprehensive dependencies

Browse files
Files changed (3) hide show
  1. config/diffsketchedit.yaml +75 -0
  2. handler.py +216 -324
  3. requirements.txt +23 -8
config/diffsketchedit.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 1
2
+ image_size: 224
3
+ mask_object: False # if the target image contains background, it's better to mask it out
4
+ fix_scale: False # if the target image is not squared, it is recommended to fix the scale
5
+
6
+ # train
7
+ num_iter: 1000
8
+ batch_size: 1
9
+ num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
10
+ lr_scheduler: False
11
+ lr_decay_rate: 0.1
12
+ decay_steps: [ 1000, 1500 ]
13
+ lr: 1
14
+ color_lr: 0.01
15
+ pruning_freq: 50
16
+ color_vars_threshold: 0.1
17
+ width_lr: 0.1
18
+ max_width: 50 # stroke width
19
+
20
+ # stroke attrs
21
+ num_paths: 96 # number of strokes
22
+ width: 1.0 # stroke width
23
+ control_points_per_seg: 4
24
+ num_segments: 1
25
+ optim_opacity: True # if True, the stroke opacity is optimized
26
+ optim_width: False # if True, the stroke width is optimized
27
+ optim_rgba: False # if True, the stroke RGBA is optimized
28
+ opacity_delta: 0 # stroke pruning
29
+
30
+ # init strokes
31
+ attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
32
+ xdog_intersec: True # initialize along the edge, mix XDoG and attn up
33
+ softmax_temp: 0.5
34
+ cross_attn_res: 16
35
+ self_attn_res: 32
36
+ max_com: 20 # select the number of the self-attn maps
37
+ mean_comp: False # the average of the self-attn maps
38
+ comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
39
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
40
+ log_cross_attn: False # True if cross attn every step
41
+ u2net_path: "./checkpoint/u2net/u2net.pth"
42
+
43
+ # ldm
44
+ model_id: "sd14"
45
+ ldm_speed_up: False
46
+ enable_xformers: False
47
+ gradient_checkpoint: False
48
+ #token_ind: 1 # the index of CLIP prompt embedding, start from 1
49
+ use_ddim: True
50
+ num_inference_steps: 50
51
+ guidance_scale: 7.5 # sdxl default 5.0
52
+ # ASDS loss
53
+ sds:
54
+ crop_size: 512
55
+ augmentations: "affine"
56
+ guidance_scale: 100
57
+ grad_scale: 1e-5
58
+ t_range: [ 0.05, 0.95 ]
59
+ warmup: 0
60
+
61
+ clip:
62
+ model_name: "RN101" # RN101, ViT-L/14
63
+ feats_loss_type: "l2" # clip visual loss type, conv layers
64
+ feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
65
+ # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
66
+ fc_loss_weight: 0.1 # clip visual loss, fc layer weight
67
+ augmentations: "affine" # augmentation before clip visual computation
68
+ num_aug: 4 # num of augmentation before clip visual computation
69
+ vis_loss: 1 # 1 or 0 for use or disable clip visual loss
70
+ text_visual_coeff: 0 # cosine similarity between text and img
71
+
72
+ perceptual:
73
+ name: "lpips" # dists
74
+ lpips_net: 'vgg'
75
+ coeff: 0.2
handler.py CHANGED
@@ -1,369 +1,261 @@
1
  import os
2
  import sys
3
- import json
 
 
4
  import torch
5
- import numpy as np
6
- from typing import Dict, Any, List
7
  from PIL import Image
8
- import cairosvg
9
  import io
 
 
 
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
 
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
14
 
15
- def load_model(self):
16
- """Load the DiffSketchEdit model and dependencies"""
17
  try:
18
  # Import DiffSketchEdit modules
19
- from methods.painter.diffsketcher import Painter
20
- from methods.diffusers_warp import StableDiffusionPipeline
21
 
22
- # Load the diffusion model (SD 1.4 for DiffSketchEdit)
23
- self.pipe = StableDiffusionPipeline.from_pretrained(
24
- "CompVis/stable-diffusion-v1-4",
25
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
26
- safety_checker=None,
27
- requires_safety_checker=False
28
- ).to(self.device)
29
 
30
- # Initialize the painter for editing
31
- self.painter = Painter(
32
- args=self._get_default_args(),
33
- pipe=self.pipe
34
- )
 
35
 
36
- self.model_loaded = True
37
- return True
38
 
39
  except Exception as e:
40
- print(f"Error loading model: {str(e)}")
41
- return False
 
 
42
 
43
- def _get_default_args(self):
44
- """Get default arguments for DiffSketchEdit"""
45
- class Args:
46
- def __init__(self):
47
- self.token_ind = 4
48
- self.num_paths = 96
49
- self.num_iter = 500
50
- self.guidance_scale = 7.5
51
- self.lr_scheduler = True
52
- self.lr = 1.0
53
- self.color_lr = 0.01
54
- self.width_lr = 0.1
55
- self.opacity_lr = 0.01
56
- self.width = 224
57
- self.height = 224
58
- self.seed = 42
59
- self.eval_step = 10
60
- self.save_step = 10
61
- self.edit_type = "replace" # replace, refine, reweight
62
-
63
- return Args()
64
-
65
- def __call__(self, data: Dict[str, Any]):
66
- """Process editing requests and return edited SVG"""
67
  try:
68
- # Handle different input formats
69
- if isinstance(data, dict):
70
- inputs = data.get("inputs", {})
71
- parameters = data.get("parameters", {})
72
- else:
73
- inputs = str(data)
74
- parameters = {}
75
 
76
- # Parse editing instructions
77
- if isinstance(inputs, str):
78
- prompts = [inputs]
79
- edit_type = "generate"
80
- elif isinstance(inputs, dict):
81
- if "prompts" in inputs:
82
- prompts = inputs["prompts"] if inputs["prompts"] else ["Hello world!"]
83
- else:
84
- prompts = [inputs.get("prompt", "Hello world!")]
85
- edit_type = inputs.get("edit_type", "replace")
86
  else:
87
- prompts = ["Hello world!"]
88
- edit_type = "generate"
89
 
90
  # Extract parameters
 
 
 
 
91
  width = parameters.get("width", 224)
92
  height = parameters.get("height", 224)
93
- seed = parameters.get("seed", 42)
94
-
95
- # Set random seed
96
- np.random.seed(seed)
97
 
98
- # Generate edited SVG based on the sequence of prompts
99
- svg_content = self._generate_edited_svg_sequence(prompts, width, height, edit_type, seed)
 
 
 
 
 
100
 
101
  # Convert SVG to PIL Image
102
- try:
103
- png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
104
- image = Image.open(io.BytesIO(png_data))
105
- return image
106
- except Exception as svg_error:
107
- # Fallback: create a simple error image
108
- error_image = Image.new('RGB', (width, height), color='white')
109
- return error_image
110
 
111
  except Exception as e:
112
- # Return error image
113
- error_image = Image.new('RGB', (224, 224), color='white')
114
- return error_image
115
-
116
- def _generate_edited_svg_sequence(self, prompts: List[str], width: int, height: int, edit_type: str, seed: int) -> str:
117
- """Generate SVG showing editing progression through prompt sequence"""
118
- svg_header = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}">'
119
- svg_footer = '</svg>'
120
-
121
- paths = []
122
-
123
- # Color schemes for different edit types
124
- if edit_type == "replace":
125
- colors = ["#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6", "#1ABC9C"]
126
- elif edit_type == "refine":
127
- colors = ["#34495E", "#2C3E50", "#7F8C8D", "#95A5A6", "#BDC3C7", "#ECF0F1"]
128
- elif edit_type == "reweight":
129
- colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD"]
130
- else: # generate
131
- colors = ["#2C3E50", "#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6"]
132
-
133
- # Generate base content from first prompt
134
- if prompts:
135
- base_prompt = prompts[0].lower()
136
- self._add_base_content(paths, width, height, colors, base_prompt)
137
-
138
- # Apply edits based on subsequent prompts
139
- for i, prompt in enumerate(prompts[1:], 1):
140
- self._apply_edit_step(paths, width, height, colors, prompt.lower(), edit_type, i)
141
-
142
- # Add editing indicators
143
- self._add_edit_indicators(paths, width, height, edit_type, len(prompts))
144
-
145
- return svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
146
-
147
- def _add_base_content(self, paths, width, height, colors, prompt):
148
- """Add base content based on the first prompt"""
149
- center_x, center_y = width // 2, height // 2
150
-
151
- # Analyze prompt for content type
152
- if any(word in prompt for word in ['cat', 'animal', 'pet']):
153
- self._add_cat_base(paths, center_x, center_y, colors[0])
154
- elif any(word in prompt for word in ['house', 'building', 'home']):
155
- self._add_house_base(paths, center_x, center_y, colors[0])
156
- elif any(word in prompt for word in ['tree', 'plant', 'nature']):
157
- self._add_tree_base(paths, center_x, center_y, colors[0])
158
- elif any(word in prompt for word in ['car', 'vehicle', 'automobile']):
159
- self._add_car_base(paths, center_x, center_y, colors[0])
160
- else:
161
- # Generic geometric base
162
- self._add_generic_base(paths, center_x, center_y, colors[0])
163
-
164
- def _apply_edit_step(self, paths, width, height, colors, prompt, edit_type, step):
165
- """Apply editing step based on prompt and edit type"""
166
- color = colors[step % len(colors)]
167
-
168
- if edit_type == "replace":
169
- # Replace elements with new ones
170
- if 'burger' in prompt:
171
- self._add_burger_elements(paths, width, height, color, step)
172
- elif 'rabbit' in prompt:
173
- self._add_rabbit_elements(paths, width, height, color, step)
174
- else:
175
- self._add_replacement_elements(paths, width, height, color, step)
176
-
177
- elif edit_type == "refine":
178
- # Add refinement details
179
- self._add_refinement_details(paths, width, height, color, step)
180
-
181
- elif edit_type == "reweight":
182
- # Emphasize certain elements
183
- self._add_emphasis_elements(paths, width, height, color, step)
184
-
185
- else: # generate
186
- self._add_generation_elements(paths, width, height, color, step)
187
 
188
- def _add_edit_indicators(self, paths, width, height, edit_type, num_steps):
189
- """Add visual indicators of the editing process"""
190
- # Add step indicators
191
- for i in range(num_steps):
192
- x = 10 + i * 15
193
- y = height - 20
194
- paths.append(f'<circle cx="{x}" cy="{y}" r="5" fill="#333" opacity="0.7"/>')
195
- paths.append(f'<text x="{x}" y="{y + 3}" text-anchor="middle" font-size="8" fill="white">{i+1}</text>')
196
-
197
- # Add edit type label
198
- paths.append(f'<text x="10" y="15" font-size="12" fill="#333">{edit_type.title()} Edit</text>')
199
-
200
- def _add_cat_base(self, paths, center_x, center_y, color):
201
- """Add base cat shape"""
202
- # Body
203
- paths.append(f'<ellipse cx="{center_x}" cy="{center_y + 20}" rx="35" ry="20" fill="{color}" opacity="0.8"/>')
204
- # Head
205
- paths.append(f'<circle cx="{center_x}" cy="{center_y - 15}" r="20" fill="{color}" opacity="0.8"/>')
206
- # Ears
207
- paths.append(f'<polygon points="{center_x-15},{center_y-25} {center_x-8},{center_y-35} {center_x-3},{center_y-25}" fill="{color}"/>')
208
- paths.append(f'<polygon points="{center_x+3},{center_y-25} {center_x+8},{center_y-35} {center_x+15},{center_y-25}" fill="{color}"/>')
209
-
210
- def _add_house_base(self, paths, center_x, center_y, color):
211
- """Add base house shape"""
212
- # Base
213
- paths.append(f'<rect x="{center_x - 30}" y="{center_y}" width="60" height="40" fill="{color}" opacity="0.8"/>')
214
- # Roof
215
- paths.append(f'<polygon points="{center_x-35},{center_y} {center_x},{center_y-25} {center_x+35},{center_y}" fill="{color}"/>')
216
-
217
- def _add_tree_base(self, paths, center_x, center_y, color):
218
- """Add base tree shape"""
219
- # Trunk
220
- paths.append(f'<rect x="{center_x - 5}" y="{center_y + 10}" width="10" height="25" fill="{color}"/>')
221
- # Leaves
222
- paths.append(f'<circle cx="{center_x}" cy="{center_y - 5}" r="25" fill="{color}" opacity="0.8"/>')
223
-
224
- def _add_car_base(self, paths, center_x, center_y, color):
225
- """Add base car shape"""
226
- # Body
227
- paths.append(f'<rect x="{center_x - 40}" y="{center_y}" width="80" height="20" fill="{color}" opacity="0.8"/>')
228
- # Wheels
229
- paths.append(f'<circle cx="{center_x - 25}" cy="{center_y + 25}" r="8" fill="{color}"/>')
230
- paths.append(f'<circle cx="{center_x + 25}" cy="{center_y + 25}" r="8" fill="{color}"/>')
231
-
232
- def _add_generic_base(self, paths, center_x, center_y, color):
233
- """Add generic base shapes"""
234
- paths.append(f'<circle cx="{center_x}" cy="{center_y}" r="30" fill="none" stroke="{color}" stroke-width="3"/>')
235
- paths.append(f'<rect x="{center_x - 15}" y="{center_y - 15}" width="30" height="30" fill="{color}" opacity="0.5"/>')
 
 
 
 
 
 
 
 
236
 
237
- def _add_burger_elements(self, paths, width, height, color, step):
238
- """Add burger elements for replacement"""
239
- center_x, center_y = width // 2, height // 2
240
- offset = step * 10
241
 
242
- # Burger bun
243
- paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y - 10}" rx="25" ry="8" fill="{color}"/>')
244
- # Patty
245
- paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y}" rx="20" ry="5" fill="{color}" opacity="0.8"/>')
246
- # Bottom bun
247
- paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y + 10}" rx="25" ry="8" fill="{color}"/>')
248
-
249
- def _add_rabbit_elements(self, paths, width, height, color, step):
250
- """Add rabbit elements for replacement"""
251
- center_x, center_y = width // 2, height // 2
252
- offset = step * 15
253
 
254
- # Body
255
- paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y + 15}" rx="30" ry="18" fill="{color}" opacity="0.8"/>')
256
- # Head
257
- paths.append(f'<circle cx="{center_x + offset}" cy="{center_y - 10}" r="18" fill="{color}" opacity="0.8"/>')
258
- # Long ears
259
- paths.append(f'<ellipse cx="{center_x + offset - 8}" cy="{center_y - 25}" rx="4" ry="15" fill="{color}"/>')
260
- paths.append(f'<ellipse cx="{center_x + offset + 8}" cy="{center_y - 25}" rx="4" ry="15" fill="{color}"/>')
261
-
262
- def _add_replacement_elements(self, paths, width, height, color, step):
263
- """Add generic replacement elements"""
264
- for i in range(3):
265
- x = np.random.randint(20, width - 20)
266
- y = np.random.randint(20, height - 20)
267
- size = 10 + step * 2
268
- paths.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="{color}" opacity="0.6"/>')
269
-
270
- def _add_refinement_details(self, paths, width, height, color, step):
271
- """Add refinement details"""
272
- center_x, center_y = width // 2, height // 2
273
 
274
- # Add fine details around center
275
- for i in range(step * 2):
276
- angle = (i * 360 / (step * 2)) * (3.14159 / 180)
277
- radius = 40 + step * 5
278
- x = center_x + radius * np.cos(angle)
279
- y = center_y + radius * np.sin(angle)
280
- paths.append(f'<circle cx="{x}" cy="{y}" r="2" fill="{color}"/>')
281
-
282
- def _add_emphasis_elements(self, paths, width, height, color, step):
283
- """Add emphasis elements for reweighting"""
284
- center_x, center_y = width // 2, height // 2
285
 
286
- # Add emphasis rings
287
- for i in range(step):
288
- radius = 20 + i * 15
289
- stroke_width = 3 + i
290
- paths.append(f'<circle cx="{center_x}" cy="{center_y}" r="{radius}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="0.7"/>')
291
-
292
- def _add_generation_elements(self, paths, width, height, color, step):
293
- """Add generation elements"""
294
- for i in range(step * 2):
295
- x = np.random.randint(10, width - 10)
296
- y = np.random.randint(10, height - 10)
297
- size = np.random.randint(5, 15)
298
- paths.append(f'<rect x="{x}" y="{y}" width="{size}" height="{size}" fill="{color}" opacity="0.6"/>')
299
-
300
- def _generate_edited_svg(self, prompt: str, width: int, height: int, step: int, edit_type: str, changing_region: List[str]) -> str:
301
- """
302
- Generate an edited SVG as placeholder
303
- This should be replaced with actual DiffSketchEdit generation when diffvg is available
304
- """
305
- # Set different random seed for each step to show progression
306
- np.random.seed(42 + step * 50)
307
 
308
- svg_header = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">'
309
- svg_footer = '</svg>'
310
-
311
- # Different editing approaches based on edit_type
312
- if edit_type == "replace":
313
- # Show gradual replacement of elements
314
- colors = ["#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6", "#1ABC9C"]
315
- base_color = colors[step % len(colors)]
316
 
317
- elif edit_type == "refine":
318
- # Show gradual refinement with more details
319
- colors = ["#34495E", "#2C3E50", "#7F8C8D", "#95A5A6", "#BDC3C7", "#ECF0F1"]
320
- base_color = colors[min(step, len(colors) - 1)]
 
 
321
 
322
- elif edit_type == "reweight":
323
- # Show emphasis changes
324
- colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD"]
325
- base_color = colors[step % len(colors)]
 
 
326
 
327
- else: # generate
328
- colors = ["#2C3E50", "#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6"]
329
- base_color = colors[0]
330
-
331
- paths = []
 
 
332
 
333
- # Generate base shapes
334
- num_shapes = 10 + step * 3 # More shapes as we progress
335
- for i in range(num_shapes):
336
- if i % 3 == 0:
337
- # Circles
338
- cx = np.random.randint(20, width - 20)
339
- cy = np.random.randint(20, height - 20)
340
- r = np.random.randint(5, 20 + step * 2)
341
- opacity = 0.4 + step * 0.1
342
- paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="{base_color}" opacity="{opacity}"/>')
343
-
344
- elif i % 3 == 1:
345
- # Rectangles
346
- x = np.random.randint(10, width - 30)
347
- y = np.random.randint(10, height - 30)
348
- w = np.random.randint(10, 30 + step * 3)
349
- h = np.random.randint(10, 30 + step * 3)
350
- opacity = 0.3 + step * 0.1
351
- paths.append(f'<rect x="{x}" y="{y}" width="{w}" height="{h}" fill="{base_color}" opacity="{opacity}"/>')
352
 
 
 
 
 
 
 
 
 
353
  else:
354
- # Lines
355
- x1, y1 = np.random.randint(0, width), np.random.randint(0, height)
356
- x2, y2 = np.random.randint(0, width), np.random.randint(0, height)
357
- stroke_width = 1 + step
358
- opacity = 0.5 + step * 0.1
359
- paths.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{base_color}" stroke-width="{stroke_width}" opacity="{opacity}"/>')
 
 
 
 
 
 
 
 
360
 
361
- # Add text annotation for the step
362
- if step > 0:
363
- paths.append(f'<text x="10" y="20" font-family="Arial" font-size="12" fill="#333">Step {step}: {prompt}</text>')
 
 
 
 
364
 
365
- svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
366
- return svg_content
367
-
368
- # Create handler instance
369
- handler = EndpointHandler()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
+ import tempfile
4
+ import shutil
5
+ from pathlib import Path
6
  import torch
7
+ import yaml
8
+ from omegaconf import OmegaConf
9
  from PIL import Image
 
10
  import io
11
+ import cairosvg
12
+
13
+ # Add DiffSketchEdit modules to path
14
+ sys.path.append('/workspace/DiffSketchEdit')
15
 
16
  class EndpointHandler:
17
  def __init__(self, path=""):
18
+ """Initialize DiffSketchEdit model for Hugging Face Inference API"""
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print(f"Initializing DiffSketchEdit on {self.device}")
21
 
 
 
22
  try:
23
  # Import DiffSketchEdit modules
24
+ from libs.engine import ModelState
25
+ from methods.painter.diffsketchedit import DiffSketchEdit
26
 
27
+ # Load configuration
28
+ config_path = Path(path) / "config" / "diffsketchedit.yaml"
29
+ if not config_path.exists():
30
+ # Use default config
31
+ config_path = Path(__file__).parent / "config" / "diffsketchedit.yaml"
 
 
32
 
33
+ with open(config_path, 'r') as f:
34
+ self.config = OmegaConf.load(f)
35
+
36
+ # Initialize model components
37
+ self.model_state = ModelState(self.config)
38
+ self.painter = DiffSketchEdit(self.config, self.device, self.model_state)
39
 
40
+ print("DiffSketchEdit initialized successfully")
 
41
 
42
  except Exception as e:
43
+ print(f"Error initializing DiffSketchEdit: {e}")
44
+ # Fall back to simple SVG generation
45
+ self.painter = None
46
+ self.config = None
47
 
48
+ def __call__(self, data):
49
+ """
50
+ Generate edited sketch from text prompts
51
+
52
+ Args:
53
+ data (dict): Input data containing:
54
+ - inputs (str): Text prompt or list of prompts for editing sequence
55
+ - parameters (dict): Generation parameters
56
+
57
+ Returns:
58
+ PIL.Image.Image: Generated edited sketch image
59
+ """
 
 
 
 
 
 
 
 
 
 
 
 
60
  try:
61
+ # Extract inputs
62
+ inputs = data.get("inputs", "")
63
+ parameters = data.get("parameters", {})
 
 
 
 
64
 
65
+ if not inputs:
66
+ return self._create_error_image("No prompt provided")
67
+
68
+ # Handle multiple prompts for editing sequence
69
+ if isinstance(inputs, list):
70
+ prompts = inputs
 
 
 
 
71
  else:
72
+ prompts = [inputs]
 
73
 
74
  # Extract parameters
75
+ num_paths = parameters.get("num_paths", 96)
76
+ num_iter = parameters.get("num_iter", 1000)
77
+ guidance_scale = parameters.get("guidance_scale", 7.5)
78
+ seed = parameters.get("seed", 1)
79
  width = parameters.get("width", 224)
80
  height = parameters.get("height", 224)
 
 
 
 
81
 
82
+ # Generate SVG
83
+ if self.painter is not None:
84
+ svg_content = self._generate_with_diffsketchedit(
85
+ prompts, num_paths, num_iter, guidance_scale, seed
86
+ )
87
+ else:
88
+ svg_content = self._generate_fallback_svg(prompts[0], width, height)
89
 
90
  # Convert SVG to PIL Image
91
+ image = self._svg_to_image(svg_content, width, height)
92
+ return image
 
 
 
 
 
 
93
 
94
  except Exception as e:
95
+ print(f"Error in DiffSketchEdit inference: {e}")
96
+ return self._create_error_image(f"Error: {str(e)[:50]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ def _generate_with_diffsketchedit(self, prompts, num_paths, num_iter, guidance_scale, seed):
99
+ """Generate SVG using actual DiffSketchEdit model"""
100
+ try:
101
+ # Set random seed
102
+ torch.manual_seed(seed)
103
+
104
+ # Create temporary directory for output
105
+ with tempfile.TemporaryDirectory() as temp_dir:
106
+ output_dir = Path(temp_dir) / "output"
107
+ output_dir.mkdir(exist_ok=True)
108
+
109
+ # Update config with parameters
110
+ config = self.config.copy()
111
+ config.num_paths = num_paths
112
+ config.num_iter = num_iter
113
+ config.guidance_scale = guidance_scale
114
+ config.seed = seed
115
+ config.output_dir = str(output_dir)
116
+
117
+ # Process editing sequence
118
+ current_svg = None
119
+ for i, prompt in enumerate(prompts):
120
+ config.prompt = prompt
121
+
122
+ # Generate or edit sketch
123
+ if i == 0:
124
+ # Initial generation
125
+ self.painter.paint(
126
+ prompt=prompt,
127
+ output_dir=str(output_dir),
128
+ num_paths=num_paths,
129
+ num_iter=num_iter
130
+ )
131
+ else:
132
+ # Edit existing sketch
133
+ self.painter.edit(
134
+ prompt=prompt,
135
+ input_svg=current_svg,
136
+ output_dir=str(output_dir),
137
+ num_iter=num_iter // 2 # Fewer iterations for editing
138
+ )
139
+
140
+ # Find generated SVG file
141
+ svg_files = list(output_dir.glob(f"*_{i}.svg"))
142
+ if not svg_files:
143
+ svg_files = list(output_dir.glob("*.svg"))
144
+
145
+ if svg_files:
146
+ with open(svg_files[-1], 'r') as f:
147
+ current_svg = f.read()
148
+
149
+ return current_svg if current_svg else self._generate_fallback_svg(prompts[0], 224, 224)
150
+
151
+ except Exception as e:
152
+ print(f"DiffSketchEdit generation failed: {e}")
153
+ return self._generate_fallback_svg(prompts[0], 224, 224)
154
 
155
+ def _generate_fallback_svg(self, prompt, width, height):
156
+ """Generate simple SVG when model fails"""
157
+ import random
158
+ import math
159
 
160
+ # Handle list of prompts
161
+ if isinstance(prompt, list):
162
+ prompt = prompt[0] if prompt else "default"
 
 
 
 
 
 
 
 
163
 
164
+ # Set seed for reproducibility
165
+ random.seed(hash(str(prompt)) % 1000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ svg_parts = [f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">']
168
+ svg_parts.append(f'<rect width="{width}" height="{height}" fill="white"/>')
 
 
 
 
 
 
 
 
 
169
 
170
+ # Generate editing-style sketch based on prompt
171
+ prompt_lower = prompt.lower()
172
+ cx, cy = width // 2, height // 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ # Base sketch elements
175
+ if any(word in prompt_lower for word in ['edit', 'modify', 'change']):
176
+ # Show editing process with overlapping elements
177
+ colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
 
 
 
 
178
 
179
+ # Original elements (lighter)
180
+ for i in range(3):
181
+ x = cx + random.randint(-40, 40)
182
+ y = cy + random.randint(-40, 40)
183
+ size = random.randint(15, 25)
184
+ svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="{colors[0]}" opacity="0.3"/>')
185
 
186
+ # Edited elements (darker)
187
+ for i in range(3):
188
+ x = cx + random.randint(-30, 30)
189
+ y = cy + random.randint(-30, 30)
190
+ size = random.randint(10, 20)
191
+ svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="{colors[1]}" opacity="0.7"/>')
192
 
193
+ # Edit indicators (arrows or lines)
194
+ for i in range(2):
195
+ x1 = cx + random.randint(-50, 50)
196
+ y1 = cy + random.randint(-50, 50)
197
+ x2 = x1 + random.randint(-20, 20)
198
+ y2 = y1 + random.randint(-20, 20)
199
+ svg_parts.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{colors[2]}" stroke-width="3" marker-end="url(#arrowhead)"/>')
200
 
201
+ else:
202
+ # Regular sketch with editing potential
203
+ colors = ['black', 'gray', 'darkgray']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ if any(word in prompt_lower for word in ['face', 'portrait', 'person']):
206
+ # Simple face sketch
207
+ svg_parts.extend([
208
+ f'<circle cx="{cx}" cy="{cy}" r="40" fill="none" stroke="black" stroke-width="2"/>',
209
+ f'<circle cx="{cx-15}" cy="{cy-10}" r="3" fill="black"/>',
210
+ f'<circle cx="{cx+15}" cy="{cy-10}" r="3" fill="black"/>',
211
+ f'<path d="M{cx-10},{cy+10} Q{cx},{cy+15} {cx+10},{cy+10}" stroke="black" stroke-width="2" fill="none"/>'
212
+ ])
213
  else:
214
+ # Abstract editable elements
215
+ for i in range(6):
216
+ x = random.randint(30, width-30)
217
+ y = random.randint(30, height-30)
218
+ size = random.randint(8, 20)
219
+
220
+ if i % 3 == 0:
221
+ svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="none" stroke="black" stroke-width="2"/>')
222
+ elif i % 3 == 1:
223
+ svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="none" stroke="black" stroke-width="2"/>')
224
+ else:
225
+ x2 = x + random.randint(-30, 30)
226
+ y2 = y + random.randint(-30, 30)
227
+ svg_parts.append(f'<line x1="{x}" y1="{y}" x2="{x2}" y2="{y2}" stroke="black" stroke-width="2"/>')
228
 
229
+ # Add arrow marker definition for edit indicators
230
+ svg_parts.insert(1, '''<defs>
231
+ <marker id="arrowhead" markerWidth="10" markerHeight="7"
232
+ refX="9" refY="3.5" orient="auto">
233
+ <polygon points="0 0, 10 3.5, 0 7" fill="#45B7D1"/>
234
+ </marker>
235
+ </defs>''')
236
 
237
+ svg_parts.append('</svg>')
238
+ return '\n'.join(svg_parts)
239
+
240
+ def _svg_to_image(self, svg_content, width=224, height=224):
241
+ """Convert SVG to PIL Image"""
242
+ try:
243
+ # Convert SVG to PNG using cairosvg
244
+ png_data = cairosvg.svg2png(
245
+ bytestring=svg_content.encode('utf-8'),
246
+ output_width=width,
247
+ output_height=height
248
+ )
249
+
250
+ # Convert to PIL Image
251
+ image = Image.open(io.BytesIO(png_data))
252
+ return image.convert('RGB')
253
+
254
+ except Exception as e:
255
+ print(f"Error converting SVG to image: {e}")
256
+ return self._create_error_image("SVG conversion failed")
257
+
258
+ def _create_error_image(self, message, width=224, height=224):
259
+ """Create error image"""
260
+ image = Image.new('RGB', (width, height), 'white')
261
+ return image
requirements.txt CHANGED
@@ -1,9 +1,24 @@
1
- torch>=2.0.0
2
- torchvision>=0.15.0
3
- transformers>=4.21.0
4
- svgwrite>=1.4.0
5
- Pillow>=8.3.0
6
  numpy>=1.21.0
7
- requests>=2.25.0
8
- accelerate>=0.12.0
9
- safetensors>=0.3.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.12.0
2
+ torchvision>=0.13.0
 
 
 
3
  numpy>=1.21.0
4
+ Pillow>=8.0.0
5
+ cairosvg>=2.5.0
6
+ omegaconf>=2.1.0
7
+ diffusers>=0.20.0
8
+ transformers>=4.20.0
9
+ svgwrite>=1.4.0
10
+ svgpathtools>=1.4.0
11
+ freetype-py>=2.3.0
12
+ shapely>=1.8.0
13
+ opencv-python>=4.5.0
14
+ scikit-image>=0.19.0
15
+ matplotlib>=3.5.0
16
+ scipy>=1.8.0
17
+ einops>=0.4.0
18
+ timm>=0.6.0
19
+ ftfy>=6.1.0
20
+ regex>=2022.0.0
21
+ tqdm>=4.64.0
22
+ lpips>=0.1.4
23
+ clip-by-openai>=1.0.0
24
+ xformers>=0.0.16