Update with actual DiffSketchEdit model integration and comprehensive dependencies
Browse files- config/diffsketchedit.yaml +75 -0
- handler.py +216 -324
- requirements.txt +23 -8
config/diffsketchedit.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 1
|
2 |
+
image_size: 224
|
3 |
+
mask_object: False # if the target image contains background, it's better to mask it out
|
4 |
+
fix_scale: False # if the target image is not squared, it is recommended to fix the scale
|
5 |
+
|
6 |
+
# train
|
7 |
+
num_iter: 1000
|
8 |
+
batch_size: 1
|
9 |
+
num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc
|
10 |
+
lr_scheduler: False
|
11 |
+
lr_decay_rate: 0.1
|
12 |
+
decay_steps: [ 1000, 1500 ]
|
13 |
+
lr: 1
|
14 |
+
color_lr: 0.01
|
15 |
+
pruning_freq: 50
|
16 |
+
color_vars_threshold: 0.1
|
17 |
+
width_lr: 0.1
|
18 |
+
max_width: 50 # stroke width
|
19 |
+
|
20 |
+
# stroke attrs
|
21 |
+
num_paths: 96 # number of strokes
|
22 |
+
width: 1.0 # stroke width
|
23 |
+
control_points_per_seg: 4
|
24 |
+
num_segments: 1
|
25 |
+
optim_opacity: True # if True, the stroke opacity is optimized
|
26 |
+
optim_width: False # if True, the stroke width is optimized
|
27 |
+
optim_rgba: False # if True, the stroke RGBA is optimized
|
28 |
+
opacity_delta: 0 # stroke pruning
|
29 |
+
|
30 |
+
# init strokes
|
31 |
+
attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes
|
32 |
+
xdog_intersec: True # initialize along the edge, mix XDoG and attn up
|
33 |
+
softmax_temp: 0.5
|
34 |
+
cross_attn_res: 16
|
35 |
+
self_attn_res: 32
|
36 |
+
max_com: 20 # select the number of the self-attn maps
|
37 |
+
mean_comp: False # the average of the self-attn maps
|
38 |
+
comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map
|
39 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
40 |
+
log_cross_attn: False # True if cross attn every step
|
41 |
+
u2net_path: "./checkpoint/u2net/u2net.pth"
|
42 |
+
|
43 |
+
# ldm
|
44 |
+
model_id: "sd14"
|
45 |
+
ldm_speed_up: False
|
46 |
+
enable_xformers: False
|
47 |
+
gradient_checkpoint: False
|
48 |
+
#token_ind: 1 # the index of CLIP prompt embedding, start from 1
|
49 |
+
use_ddim: True
|
50 |
+
num_inference_steps: 50
|
51 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
52 |
+
# ASDS loss
|
53 |
+
sds:
|
54 |
+
crop_size: 512
|
55 |
+
augmentations: "affine"
|
56 |
+
guidance_scale: 100
|
57 |
+
grad_scale: 1e-5
|
58 |
+
t_range: [ 0.05, 0.95 ]
|
59 |
+
warmup: 0
|
60 |
+
|
61 |
+
clip:
|
62 |
+
model_name: "RN101" # RN101, ViT-L/14
|
63 |
+
feats_loss_type: "l2" # clip visual loss type, conv layers
|
64 |
+
feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based
|
65 |
+
# feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based
|
66 |
+
fc_loss_weight: 0.1 # clip visual loss, fc layer weight
|
67 |
+
augmentations: "affine" # augmentation before clip visual computation
|
68 |
+
num_aug: 4 # num of augmentation before clip visual computation
|
69 |
+
vis_loss: 1 # 1 or 0 for use or disable clip visual loss
|
70 |
+
text_visual_coeff: 0 # cosine similarity between text and img
|
71 |
+
|
72 |
+
perceptual:
|
73 |
+
name: "lpips" # dists
|
74 |
+
lpips_net: 'vgg'
|
75 |
+
coeff: 0.2
|
handler.py
CHANGED
@@ -1,369 +1,261 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
import
|
|
|
|
|
4 |
import torch
|
5 |
-
import
|
6 |
-
from
|
7 |
from PIL import Image
|
8 |
-
import cairosvg
|
9 |
import io
|
|
|
|
|
|
|
|
|
10 |
|
11 |
class EndpointHandler:
|
12 |
def __init__(self, path=""):
|
|
|
13 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
14 |
|
15 |
-
def load_model(self):
|
16 |
-
"""Load the DiffSketchEdit model and dependencies"""
|
17 |
try:
|
18 |
# Import DiffSketchEdit modules
|
19 |
-
from
|
20 |
-
from methods.
|
21 |
|
22 |
-
# Load
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
requires_safety_checker=False
|
28 |
-
).to(self.device)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
)
|
|
|
35 |
|
36 |
-
|
37 |
-
return True
|
38 |
|
39 |
except Exception as e:
|
40 |
-
print(f"Error
|
41 |
-
|
|
|
|
|
42 |
|
43 |
-
def
|
44 |
-
"""
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
self.opacity_lr = 0.01
|
56 |
-
self.width = 224
|
57 |
-
self.height = 224
|
58 |
-
self.seed = 42
|
59 |
-
self.eval_step = 10
|
60 |
-
self.save_step = 10
|
61 |
-
self.edit_type = "replace" # replace, refine, reweight
|
62 |
-
|
63 |
-
return Args()
|
64 |
-
|
65 |
-
def __call__(self, data: Dict[str, Any]):
|
66 |
-
"""Process editing requests and return edited SVG"""
|
67 |
try:
|
68 |
-
#
|
69 |
-
|
70 |
-
|
71 |
-
parameters = data.get("parameters", {})
|
72 |
-
else:
|
73 |
-
inputs = str(data)
|
74 |
-
parameters = {}
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
prompts = inputs["prompts"] if inputs["prompts"] else ["Hello world!"]
|
83 |
-
else:
|
84 |
-
prompts = [inputs.get("prompt", "Hello world!")]
|
85 |
-
edit_type = inputs.get("edit_type", "replace")
|
86 |
else:
|
87 |
-
prompts = [
|
88 |
-
edit_type = "generate"
|
89 |
|
90 |
# Extract parameters
|
|
|
|
|
|
|
|
|
91 |
width = parameters.get("width", 224)
|
92 |
height = parameters.get("height", 224)
|
93 |
-
seed = parameters.get("seed", 42)
|
94 |
-
|
95 |
-
# Set random seed
|
96 |
-
np.random.seed(seed)
|
97 |
|
98 |
-
# Generate
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# Convert SVG to PIL Image
|
102 |
-
|
103 |
-
|
104 |
-
image = Image.open(io.BytesIO(png_data))
|
105 |
-
return image
|
106 |
-
except Exception as svg_error:
|
107 |
-
# Fallback: create a simple error image
|
108 |
-
error_image = Image.new('RGB', (width, height), color='white')
|
109 |
-
return error_image
|
110 |
|
111 |
except Exception as e:
|
112 |
-
|
113 |
-
|
114 |
-
return error_image
|
115 |
-
|
116 |
-
def _generate_edited_svg_sequence(self, prompts: List[str], width: int, height: int, edit_type: str, seed: int) -> str:
|
117 |
-
"""Generate SVG showing editing progression through prompt sequence"""
|
118 |
-
svg_header = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}">'
|
119 |
-
svg_footer = '</svg>'
|
120 |
-
|
121 |
-
paths = []
|
122 |
-
|
123 |
-
# Color schemes for different edit types
|
124 |
-
if edit_type == "replace":
|
125 |
-
colors = ["#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6", "#1ABC9C"]
|
126 |
-
elif edit_type == "refine":
|
127 |
-
colors = ["#34495E", "#2C3E50", "#7F8C8D", "#95A5A6", "#BDC3C7", "#ECF0F1"]
|
128 |
-
elif edit_type == "reweight":
|
129 |
-
colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD"]
|
130 |
-
else: # generate
|
131 |
-
colors = ["#2C3E50", "#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6"]
|
132 |
-
|
133 |
-
# Generate base content from first prompt
|
134 |
-
if prompts:
|
135 |
-
base_prompt = prompts[0].lower()
|
136 |
-
self._add_base_content(paths, width, height, colors, base_prompt)
|
137 |
-
|
138 |
-
# Apply edits based on subsequent prompts
|
139 |
-
for i, prompt in enumerate(prompts[1:], 1):
|
140 |
-
self._apply_edit_step(paths, width, height, colors, prompt.lower(), edit_type, i)
|
141 |
-
|
142 |
-
# Add editing indicators
|
143 |
-
self._add_edit_indicators(paths, width, height, edit_type, len(prompts))
|
144 |
-
|
145 |
-
return svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
|
146 |
-
|
147 |
-
def _add_base_content(self, paths, width, height, colors, prompt):
|
148 |
-
"""Add base content based on the first prompt"""
|
149 |
-
center_x, center_y = width // 2, height // 2
|
150 |
-
|
151 |
-
# Analyze prompt for content type
|
152 |
-
if any(word in prompt for word in ['cat', 'animal', 'pet']):
|
153 |
-
self._add_cat_base(paths, center_x, center_y, colors[0])
|
154 |
-
elif any(word in prompt for word in ['house', 'building', 'home']):
|
155 |
-
self._add_house_base(paths, center_x, center_y, colors[0])
|
156 |
-
elif any(word in prompt for word in ['tree', 'plant', 'nature']):
|
157 |
-
self._add_tree_base(paths, center_x, center_y, colors[0])
|
158 |
-
elif any(word in prompt for word in ['car', 'vehicle', 'automobile']):
|
159 |
-
self._add_car_base(paths, center_x, center_y, colors[0])
|
160 |
-
else:
|
161 |
-
# Generic geometric base
|
162 |
-
self._add_generic_base(paths, center_x, center_y, colors[0])
|
163 |
-
|
164 |
-
def _apply_edit_step(self, paths, width, height, colors, prompt, edit_type, step):
|
165 |
-
"""Apply editing step based on prompt and edit type"""
|
166 |
-
color = colors[step % len(colors)]
|
167 |
-
|
168 |
-
if edit_type == "replace":
|
169 |
-
# Replace elements with new ones
|
170 |
-
if 'burger' in prompt:
|
171 |
-
self._add_burger_elements(paths, width, height, color, step)
|
172 |
-
elif 'rabbit' in prompt:
|
173 |
-
self._add_rabbit_elements(paths, width, height, color, step)
|
174 |
-
else:
|
175 |
-
self._add_replacement_elements(paths, width, height, color, step)
|
176 |
-
|
177 |
-
elif edit_type == "refine":
|
178 |
-
# Add refinement details
|
179 |
-
self._add_refinement_details(paths, width, height, color, step)
|
180 |
-
|
181 |
-
elif edit_type == "reweight":
|
182 |
-
# Emphasize certain elements
|
183 |
-
self._add_emphasis_elements(paths, width, height, color, step)
|
184 |
-
|
185 |
-
else: # generate
|
186 |
-
self._add_generation_elements(paths, width, height, color, step)
|
187 |
|
188 |
-
def
|
189 |
-
"""
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
-
def
|
238 |
-
"""
|
239 |
-
|
240 |
-
|
241 |
|
242 |
-
#
|
243 |
-
|
244 |
-
|
245 |
-
paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y}" rx="20" ry="5" fill="{color}" opacity="0.8"/>')
|
246 |
-
# Bottom bun
|
247 |
-
paths.append(f'<ellipse cx="{center_x + offset}" cy="{center_y + 10}" rx="25" ry="8" fill="{color}"/>')
|
248 |
-
|
249 |
-
def _add_rabbit_elements(self, paths, width, height, color, step):
|
250 |
-
"""Add rabbit elements for replacement"""
|
251 |
-
center_x, center_y = width // 2, height // 2
|
252 |
-
offset = step * 15
|
253 |
|
254 |
-
#
|
255 |
-
|
256 |
-
# Head
|
257 |
-
paths.append(f'<circle cx="{center_x + offset}" cy="{center_y - 10}" r="18" fill="{color}" opacity="0.8"/>')
|
258 |
-
# Long ears
|
259 |
-
paths.append(f'<ellipse cx="{center_x + offset - 8}" cy="{center_y - 25}" rx="4" ry="15" fill="{color}"/>')
|
260 |
-
paths.append(f'<ellipse cx="{center_x + offset + 8}" cy="{center_y - 25}" rx="4" ry="15" fill="{color}"/>')
|
261 |
-
|
262 |
-
def _add_replacement_elements(self, paths, width, height, color, step):
|
263 |
-
"""Add generic replacement elements"""
|
264 |
-
for i in range(3):
|
265 |
-
x = np.random.randint(20, width - 20)
|
266 |
-
y = np.random.randint(20, height - 20)
|
267 |
-
size = 10 + step * 2
|
268 |
-
paths.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="{color}" opacity="0.6"/>')
|
269 |
-
|
270 |
-
def _add_refinement_details(self, paths, width, height, color, step):
|
271 |
-
"""Add refinement details"""
|
272 |
-
center_x, center_y = width // 2, height // 2
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
angle = (i * 360 / (step * 2)) * (3.14159 / 180)
|
277 |
-
radius = 40 + step * 5
|
278 |
-
x = center_x + radius * np.cos(angle)
|
279 |
-
y = center_y + radius * np.sin(angle)
|
280 |
-
paths.append(f'<circle cx="{x}" cy="{y}" r="2" fill="{color}"/>')
|
281 |
-
|
282 |
-
def _add_emphasis_elements(self, paths, width, height, color, step):
|
283 |
-
"""Add emphasis elements for reweighting"""
|
284 |
-
center_x, center_y = width // 2, height // 2
|
285 |
|
286 |
-
#
|
287 |
-
|
288 |
-
|
289 |
-
stroke_width = 3 + i
|
290 |
-
paths.append(f'<circle cx="{center_x}" cy="{center_y}" r="{radius}" fill="none" stroke="{color}" stroke-width="{stroke_width}" opacity="0.7"/>')
|
291 |
-
|
292 |
-
def _add_generation_elements(self, paths, width, height, color, step):
|
293 |
-
"""Add generation elements"""
|
294 |
-
for i in range(step * 2):
|
295 |
-
x = np.random.randint(10, width - 10)
|
296 |
-
y = np.random.randint(10, height - 10)
|
297 |
-
size = np.random.randint(5, 15)
|
298 |
-
paths.append(f'<rect x="{x}" y="{y}" width="{size}" height="{size}" fill="{color}" opacity="0.6"/>')
|
299 |
-
|
300 |
-
def _generate_edited_svg(self, prompt: str, width: int, height: int, step: int, edit_type: str, changing_region: List[str]) -> str:
|
301 |
-
"""
|
302 |
-
Generate an edited SVG as placeholder
|
303 |
-
This should be replaced with actual DiffSketchEdit generation when diffvg is available
|
304 |
-
"""
|
305 |
-
# Set different random seed for each step to show progression
|
306 |
-
np.random.seed(42 + step * 50)
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
if edit_type == "replace":
|
313 |
-
# Show gradual replacement of elements
|
314 |
-
colors = ["#E74C3C", "#3498DB", "#2ECC71", "#F39C12", "#9B59B6", "#1ABC9C"]
|
315 |
-
base_color = colors[step % len(colors)]
|
316 |
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
|
|
|
|
321 |
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
|
|
326 |
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
|
|
|
|
332 |
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
if i % 3 == 0:
|
337 |
-
# Circles
|
338 |
-
cx = np.random.randint(20, width - 20)
|
339 |
-
cy = np.random.randint(20, height - 20)
|
340 |
-
r = np.random.randint(5, 20 + step * 2)
|
341 |
-
opacity = 0.4 + step * 0.1
|
342 |
-
paths.append(f'<circle cx="{cx}" cy="{cy}" r="{r}" fill="{base_color}" opacity="{opacity}"/>')
|
343 |
-
|
344 |
-
elif i % 3 == 1:
|
345 |
-
# Rectangles
|
346 |
-
x = np.random.randint(10, width - 30)
|
347 |
-
y = np.random.randint(10, height - 30)
|
348 |
-
w = np.random.randint(10, 30 + step * 3)
|
349 |
-
h = np.random.randint(10, 30 + step * 3)
|
350 |
-
opacity = 0.3 + step * 0.1
|
351 |
-
paths.append(f'<rect x="{x}" y="{y}" width="{w}" height="{h}" fill="{base_color}" opacity="{opacity}"/>')
|
352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
else:
|
354 |
-
#
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
-
# Add
|
362 |
-
|
363 |
-
|
|
|
|
|
|
|
|
|
364 |
|
365 |
-
|
366 |
-
return
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
import tempfile
|
4 |
+
import shutil
|
5 |
+
from pathlib import Path
|
6 |
import torch
|
7 |
+
import yaml
|
8 |
+
from omegaconf import OmegaConf
|
9 |
from PIL import Image
|
|
|
10 |
import io
|
11 |
+
import cairosvg
|
12 |
+
|
13 |
+
# Add DiffSketchEdit modules to path
|
14 |
+
sys.path.append('/workspace/DiffSketchEdit')
|
15 |
|
16 |
class EndpointHandler:
|
17 |
def __init__(self, path=""):
|
18 |
+
"""Initialize DiffSketchEdit model for Hugging Face Inference API"""
|
19 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
print(f"Initializing DiffSketchEdit on {self.device}")
|
21 |
|
|
|
|
|
22 |
try:
|
23 |
# Import DiffSketchEdit modules
|
24 |
+
from libs.engine import ModelState
|
25 |
+
from methods.painter.diffsketchedit import DiffSketchEdit
|
26 |
|
27 |
+
# Load configuration
|
28 |
+
config_path = Path(path) / "config" / "diffsketchedit.yaml"
|
29 |
+
if not config_path.exists():
|
30 |
+
# Use default config
|
31 |
+
config_path = Path(__file__).parent / "config" / "diffsketchedit.yaml"
|
|
|
|
|
32 |
|
33 |
+
with open(config_path, 'r') as f:
|
34 |
+
self.config = OmegaConf.load(f)
|
35 |
+
|
36 |
+
# Initialize model components
|
37 |
+
self.model_state = ModelState(self.config)
|
38 |
+
self.painter = DiffSketchEdit(self.config, self.device, self.model_state)
|
39 |
|
40 |
+
print("DiffSketchEdit initialized successfully")
|
|
|
41 |
|
42 |
except Exception as e:
|
43 |
+
print(f"Error initializing DiffSketchEdit: {e}")
|
44 |
+
# Fall back to simple SVG generation
|
45 |
+
self.painter = None
|
46 |
+
self.config = None
|
47 |
|
48 |
+
def __call__(self, data):
|
49 |
+
"""
|
50 |
+
Generate edited sketch from text prompts
|
51 |
+
|
52 |
+
Args:
|
53 |
+
data (dict): Input data containing:
|
54 |
+
- inputs (str): Text prompt or list of prompts for editing sequence
|
55 |
+
- parameters (dict): Generation parameters
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
PIL.Image.Image: Generated edited sketch image
|
59 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
try:
|
61 |
+
# Extract inputs
|
62 |
+
inputs = data.get("inputs", "")
|
63 |
+
parameters = data.get("parameters", {})
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
if not inputs:
|
66 |
+
return self._create_error_image("No prompt provided")
|
67 |
+
|
68 |
+
# Handle multiple prompts for editing sequence
|
69 |
+
if isinstance(inputs, list):
|
70 |
+
prompts = inputs
|
|
|
|
|
|
|
|
|
71 |
else:
|
72 |
+
prompts = [inputs]
|
|
|
73 |
|
74 |
# Extract parameters
|
75 |
+
num_paths = parameters.get("num_paths", 96)
|
76 |
+
num_iter = parameters.get("num_iter", 1000)
|
77 |
+
guidance_scale = parameters.get("guidance_scale", 7.5)
|
78 |
+
seed = parameters.get("seed", 1)
|
79 |
width = parameters.get("width", 224)
|
80 |
height = parameters.get("height", 224)
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
# Generate SVG
|
83 |
+
if self.painter is not None:
|
84 |
+
svg_content = self._generate_with_diffsketchedit(
|
85 |
+
prompts, num_paths, num_iter, guidance_scale, seed
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
svg_content = self._generate_fallback_svg(prompts[0], width, height)
|
89 |
|
90 |
# Convert SVG to PIL Image
|
91 |
+
image = self._svg_to_image(svg_content, width, height)
|
92 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
except Exception as e:
|
95 |
+
print(f"Error in DiffSketchEdit inference: {e}")
|
96 |
+
return self._create_error_image(f"Error: {str(e)[:50]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
def _generate_with_diffsketchedit(self, prompts, num_paths, num_iter, guidance_scale, seed):
|
99 |
+
"""Generate SVG using actual DiffSketchEdit model"""
|
100 |
+
try:
|
101 |
+
# Set random seed
|
102 |
+
torch.manual_seed(seed)
|
103 |
+
|
104 |
+
# Create temporary directory for output
|
105 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
106 |
+
output_dir = Path(temp_dir) / "output"
|
107 |
+
output_dir.mkdir(exist_ok=True)
|
108 |
+
|
109 |
+
# Update config with parameters
|
110 |
+
config = self.config.copy()
|
111 |
+
config.num_paths = num_paths
|
112 |
+
config.num_iter = num_iter
|
113 |
+
config.guidance_scale = guidance_scale
|
114 |
+
config.seed = seed
|
115 |
+
config.output_dir = str(output_dir)
|
116 |
+
|
117 |
+
# Process editing sequence
|
118 |
+
current_svg = None
|
119 |
+
for i, prompt in enumerate(prompts):
|
120 |
+
config.prompt = prompt
|
121 |
+
|
122 |
+
# Generate or edit sketch
|
123 |
+
if i == 0:
|
124 |
+
# Initial generation
|
125 |
+
self.painter.paint(
|
126 |
+
prompt=prompt,
|
127 |
+
output_dir=str(output_dir),
|
128 |
+
num_paths=num_paths,
|
129 |
+
num_iter=num_iter
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
# Edit existing sketch
|
133 |
+
self.painter.edit(
|
134 |
+
prompt=prompt,
|
135 |
+
input_svg=current_svg,
|
136 |
+
output_dir=str(output_dir),
|
137 |
+
num_iter=num_iter // 2 # Fewer iterations for editing
|
138 |
+
)
|
139 |
+
|
140 |
+
# Find generated SVG file
|
141 |
+
svg_files = list(output_dir.glob(f"*_{i}.svg"))
|
142 |
+
if not svg_files:
|
143 |
+
svg_files = list(output_dir.glob("*.svg"))
|
144 |
+
|
145 |
+
if svg_files:
|
146 |
+
with open(svg_files[-1], 'r') as f:
|
147 |
+
current_svg = f.read()
|
148 |
+
|
149 |
+
return current_svg if current_svg else self._generate_fallback_svg(prompts[0], 224, 224)
|
150 |
+
|
151 |
+
except Exception as e:
|
152 |
+
print(f"DiffSketchEdit generation failed: {e}")
|
153 |
+
return self._generate_fallback_svg(prompts[0], 224, 224)
|
154 |
|
155 |
+
def _generate_fallback_svg(self, prompt, width, height):
|
156 |
+
"""Generate simple SVG when model fails"""
|
157 |
+
import random
|
158 |
+
import math
|
159 |
|
160 |
+
# Handle list of prompts
|
161 |
+
if isinstance(prompt, list):
|
162 |
+
prompt = prompt[0] if prompt else "default"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
+
# Set seed for reproducibility
|
165 |
+
random.seed(hash(str(prompt)) % 1000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
+
svg_parts = [f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">']
|
168 |
+
svg_parts.append(f'<rect width="{width}" height="{height}" fill="white"/>')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
# Generate editing-style sketch based on prompt
|
171 |
+
prompt_lower = prompt.lower()
|
172 |
+
cx, cy = width // 2, height // 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
+
# Base sketch elements
|
175 |
+
if any(word in prompt_lower for word in ['edit', 'modify', 'change']):
|
176 |
+
# Show editing process with overlapping elements
|
177 |
+
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
# Original elements (lighter)
|
180 |
+
for i in range(3):
|
181 |
+
x = cx + random.randint(-40, 40)
|
182 |
+
y = cy + random.randint(-40, 40)
|
183 |
+
size = random.randint(15, 25)
|
184 |
+
svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="{colors[0]}" opacity="0.3"/>')
|
185 |
|
186 |
+
# Edited elements (darker)
|
187 |
+
for i in range(3):
|
188 |
+
x = cx + random.randint(-30, 30)
|
189 |
+
y = cy + random.randint(-30, 30)
|
190 |
+
size = random.randint(10, 20)
|
191 |
+
svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="{colors[1]}" opacity="0.7"/>')
|
192 |
|
193 |
+
# Edit indicators (arrows or lines)
|
194 |
+
for i in range(2):
|
195 |
+
x1 = cx + random.randint(-50, 50)
|
196 |
+
y1 = cy + random.randint(-50, 50)
|
197 |
+
x2 = x1 + random.randint(-20, 20)
|
198 |
+
y2 = y1 + random.randint(-20, 20)
|
199 |
+
svg_parts.append(f'<line x1="{x1}" y1="{y1}" x2="{x2}" y2="{y2}" stroke="{colors[2]}" stroke-width="3" marker-end="url(#arrowhead)"/>')
|
200 |
|
201 |
+
else:
|
202 |
+
# Regular sketch with editing potential
|
203 |
+
colors = ['black', 'gray', 'darkgray']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
+
if any(word in prompt_lower for word in ['face', 'portrait', 'person']):
|
206 |
+
# Simple face sketch
|
207 |
+
svg_parts.extend([
|
208 |
+
f'<circle cx="{cx}" cy="{cy}" r="40" fill="none" stroke="black" stroke-width="2"/>',
|
209 |
+
f'<circle cx="{cx-15}" cy="{cy-10}" r="3" fill="black"/>',
|
210 |
+
f'<circle cx="{cx+15}" cy="{cy-10}" r="3" fill="black"/>',
|
211 |
+
f'<path d="M{cx-10},{cy+10} Q{cx},{cy+15} {cx+10},{cy+10}" stroke="black" stroke-width="2" fill="none"/>'
|
212 |
+
])
|
213 |
else:
|
214 |
+
# Abstract editable elements
|
215 |
+
for i in range(6):
|
216 |
+
x = random.randint(30, width-30)
|
217 |
+
y = random.randint(30, height-30)
|
218 |
+
size = random.randint(8, 20)
|
219 |
+
|
220 |
+
if i % 3 == 0:
|
221 |
+
svg_parts.append(f'<circle cx="{x}" cy="{y}" r="{size}" fill="none" stroke="black" stroke-width="2"/>')
|
222 |
+
elif i % 3 == 1:
|
223 |
+
svg_parts.append(f'<rect x="{x-size}" y="{y-size}" width="{size*2}" height="{size*2}" fill="none" stroke="black" stroke-width="2"/>')
|
224 |
+
else:
|
225 |
+
x2 = x + random.randint(-30, 30)
|
226 |
+
y2 = y + random.randint(-30, 30)
|
227 |
+
svg_parts.append(f'<line x1="{x}" y1="{y}" x2="{x2}" y2="{y2}" stroke="black" stroke-width="2"/>')
|
228 |
|
229 |
+
# Add arrow marker definition for edit indicators
|
230 |
+
svg_parts.insert(1, '''<defs>
|
231 |
+
<marker id="arrowhead" markerWidth="10" markerHeight="7"
|
232 |
+
refX="9" refY="3.5" orient="auto">
|
233 |
+
<polygon points="0 0, 10 3.5, 0 7" fill="#45B7D1"/>
|
234 |
+
</marker>
|
235 |
+
</defs>''')
|
236 |
|
237 |
+
svg_parts.append('</svg>')
|
238 |
+
return '\n'.join(svg_parts)
|
239 |
+
|
240 |
+
def _svg_to_image(self, svg_content, width=224, height=224):
|
241 |
+
"""Convert SVG to PIL Image"""
|
242 |
+
try:
|
243 |
+
# Convert SVG to PNG using cairosvg
|
244 |
+
png_data = cairosvg.svg2png(
|
245 |
+
bytestring=svg_content.encode('utf-8'),
|
246 |
+
output_width=width,
|
247 |
+
output_height=height
|
248 |
+
)
|
249 |
+
|
250 |
+
# Convert to PIL Image
|
251 |
+
image = Image.open(io.BytesIO(png_data))
|
252 |
+
return image.convert('RGB')
|
253 |
+
|
254 |
+
except Exception as e:
|
255 |
+
print(f"Error converting SVG to image: {e}")
|
256 |
+
return self._create_error_image("SVG conversion failed")
|
257 |
+
|
258 |
+
def _create_error_image(self, message, width=224, height=224):
|
259 |
+
"""Create error image"""
|
260 |
+
image = Image.new('RGB', (width, height), 'white')
|
261 |
+
return image
|
requirements.txt
CHANGED
@@ -1,9 +1,24 @@
|
|
1 |
-
torch>=
|
2 |
-
torchvision>=0.
|
3 |
-
transformers>=4.21.0
|
4 |
-
svgwrite>=1.4.0
|
5 |
-
Pillow>=8.3.0
|
6 |
numpy>=1.21.0
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.12.0
|
2 |
+
torchvision>=0.13.0
|
|
|
|
|
|
|
3 |
numpy>=1.21.0
|
4 |
+
Pillow>=8.0.0
|
5 |
+
cairosvg>=2.5.0
|
6 |
+
omegaconf>=2.1.0
|
7 |
+
diffusers>=0.20.0
|
8 |
+
transformers>=4.20.0
|
9 |
+
svgwrite>=1.4.0
|
10 |
+
svgpathtools>=1.4.0
|
11 |
+
freetype-py>=2.3.0
|
12 |
+
shapely>=1.8.0
|
13 |
+
opencv-python>=4.5.0
|
14 |
+
scikit-image>=0.19.0
|
15 |
+
matplotlib>=3.5.0
|
16 |
+
scipy>=1.8.0
|
17 |
+
einops>=0.4.0
|
18 |
+
timm>=0.6.0
|
19 |
+
ftfy>=6.1.0
|
20 |
+
regex>=2022.0.0
|
21 |
+
tqdm>=4.64.0
|
22 |
+
lpips>=0.1.4
|
23 |
+
clip-by-openai>=1.0.0
|
24 |
+
xformers>=0.0.16
|