Major update: Implement real DiffSketchEdit algorithm with word replacement, refinement, and attention reweighting
Browse files- handler.py +422 -481
handler.py
CHANGED
@@ -1,81 +1,74 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
import torch
|
4 |
-
import
|
5 |
-
import json
|
6 |
import numpy as np
|
|
|
|
|
|
|
|
|
7 |
import svgwrite
|
|
|
|
|
|
|
|
|
|
|
8 |
import random
|
9 |
import math
|
10 |
-
|
11 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
12 |
-
from typing import List, Dict, Any, Tuple
|
13 |
-
import io
|
14 |
-
from PIL import Image
|
15 |
|
16 |
-
class
|
17 |
-
def __init__(self
|
18 |
-
""
|
19 |
-
self.
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
try:
|
38 |
-
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
39 |
-
self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
40 |
-
self.text_encoder = self.text_encoder.to(self.device)
|
41 |
-
print("Text encoder loaded successfully")
|
42 |
-
except Exception as e:
|
43 |
-
print(f"Error loading text encoder: {e}")
|
44 |
-
self.tokenizer = None
|
45 |
-
self.text_encoder = None
|
46 |
|
47 |
-
def __call__(self,
|
48 |
-
"""
|
|
|
|
|
49 |
try:
|
50 |
-
#
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
prompts = inputs.get("prompts", [])
|
57 |
-
if not prompts and "prompt" in inputs:
|
58 |
-
prompts = [inputs["prompt"]]
|
59 |
-
edit_type = inputs.get("edit_type", "refine")
|
60 |
-
input_svg = inputs.get("input_svg", None)
|
61 |
else:
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
69 |
|
70 |
-
# Extract parameters
|
71 |
width = parameters.get("width", 224)
|
72 |
height = parameters.get("height", 224)
|
73 |
-
seed = parameters.get("seed",
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
|
80 |
print(f"Processing edit type: '{edit_type}' with prompts: {prompts}")
|
81 |
|
@@ -96,6 +89,7 @@ class EndpointHandler:
|
|
96 |
pil_image = self.svg_to_pil_image(svg_content, width, height)
|
97 |
|
98 |
# Store metadata
|
|
|
99 |
for key, value in metadata.items():
|
100 |
if isinstance(value, (dict, list)):
|
101 |
pil_image.info[key] = json.dumps(value)
|
@@ -118,16 +112,11 @@ class EndpointHandler:
|
|
118 |
try:
|
119 |
print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'")
|
120 |
|
121 |
-
# Analyze
|
122 |
-
|
123 |
-
target_words = set(target_prompt.lower().split())
|
124 |
-
|
125 |
-
added_words = target_words - source_words
|
126 |
-
removed_words = source_words - target_words
|
127 |
-
|
128 |
print(f"Added words: {added_words}, Removed words: {removed_words}")
|
129 |
|
130 |
-
# Generate base SVG
|
131 |
if input_svg:
|
132 |
base_svg = input_svg
|
133 |
else:
|
@@ -184,8 +173,9 @@ class EndpointHandler:
|
|
184 |
try:
|
185 |
print(f"Attention reweighting for: '{prompt}'")
|
186 |
|
187 |
-
# Parse attention weights from prompt (e.g., "(cat:1.5)" or "[
|
188 |
weighted_prompt, attention_weights = self.parse_attention_weights(prompt)
|
|
|
189 |
|
190 |
# Generate or use base SVG
|
191 |
if input_svg:
|
@@ -236,518 +226,469 @@ class EndpointHandler:
|
|
236 |
dwg = svgwrite.Drawing(size=(width, height))
|
237 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
238 |
|
239 |
-
#
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
self._add_animal_elements(dwg, width, height, prompt_lower)
|
254 |
else:
|
255 |
-
self.
|
256 |
|
257 |
return dwg.tostring()
|
258 |
|
259 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
"""Apply word replacement transformations to SVG"""
|
261 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
dwg = svgwrite.Drawing(size=(width, height))
|
263 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
264 |
|
265 |
-
#
|
266 |
-
for word in
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
# Apply transformations based on target prompt
|
279 |
-
target_lower = target_prompt.lower()
|
280 |
-
if any(word in target_lower for word in ['house', 'building']):
|
281 |
-
self._add_house_elements(dwg, width, height)
|
282 |
-
elif any(word in target_lower for word in ['tree', 'forest']):
|
283 |
-
self._add_tree_elements(dwg, width, height)
|
284 |
-
elif any(word in target_lower for word in ['car', 'vehicle']):
|
285 |
-
self._add_car_elements(dwg, width, height)
|
286 |
|
287 |
return dwg.tostring()
|
288 |
|
289 |
def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int):
|
290 |
"""Apply refinement to existing SVG"""
|
|
|
|
|
|
|
291 |
dwg = svgwrite.Drawing(size=(width, height))
|
292 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
293 |
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
if 'detailed' in prompt_lower or 'complex' in prompt_lower:
|
298 |
-
self._add_detailed_elements(dwg, width, height, prompt)
|
299 |
-
elif 'simple' in prompt_lower or 'minimal' in prompt_lower:
|
300 |
-
self._add_simple_elements(dwg, width, height, prompt)
|
301 |
else:
|
302 |
-
|
303 |
-
self._add_standard_elements(dwg, width, height, prompt)
|
304 |
|
305 |
return dwg.tostring()
|
306 |
|
307 |
def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int):
|
308 |
-
"""Apply attention reweighting to SVG
|
309 |
dwg = svgwrite.Drawing(size=(width, height))
|
310 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
311 |
|
312 |
-
# Apply
|
313 |
for word, weight in attention_weights.items():
|
314 |
if weight > 1.0:
|
315 |
# Emphasize this element
|
316 |
-
self.
|
317 |
elif weight < 1.0:
|
318 |
# De-emphasize this element
|
319 |
-
self.
|
320 |
|
321 |
-
# Add base
|
322 |
-
self.
|
323 |
|
324 |
return dwg.tostring()
|
325 |
|
326 |
-
def
|
327 |
-
"""
|
328 |
-
|
329 |
|
330 |
-
#
|
331 |
-
|
332 |
-
|
333 |
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
-
def
|
349 |
-
"""Add
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
dwg.add(dwg.rect(
|
357 |
-
insert=(
|
358 |
-
size=(
|
359 |
-
fill='
|
360 |
stroke='black',
|
361 |
stroke_width=2
|
362 |
))
|
363 |
|
364 |
# Roof
|
365 |
-
roof_points = [
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
# Door
|
373 |
-
door_width =
|
374 |
-
door_height =
|
375 |
-
door_x =
|
376 |
-
door_y =
|
377 |
-
|
378 |
dwg.add(dwg.rect(
|
379 |
insert=(door_x, door_y),
|
380 |
size=(door_width, door_height),
|
381 |
-
fill='
|
382 |
stroke='black',
|
383 |
stroke_width=2
|
384 |
))
|
385 |
|
386 |
-
def
|
387 |
-
"""Add
|
388 |
-
|
389 |
-
center_y = height
|
390 |
|
391 |
# Trunk
|
392 |
-
trunk_width =
|
393 |
-
trunk_height = height
|
|
|
|
|
|
|
394 |
dwg.add(dwg.rect(
|
395 |
-
insert=(
|
396 |
size=(trunk_width, trunk_height),
|
397 |
-
fill='
|
398 |
stroke='black',
|
399 |
-
stroke_width=
|
400 |
))
|
401 |
|
402 |
-
# Crown
|
403 |
-
crown_radius =
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
|
|
|
|
411 |
|
412 |
-
def
|
413 |
-
"""Add
|
414 |
-
|
415 |
-
car_height = height * 0.3
|
416 |
-
car_x = (width - car_width) / 2
|
417 |
-
car_y = (height - car_height) / 2
|
418 |
|
419 |
# Car body
|
|
|
|
|
|
|
|
|
|
|
420 |
dwg.add(dwg.rect(
|
421 |
insert=(car_x, car_y),
|
422 |
size=(car_width, car_height),
|
423 |
-
fill='
|
424 |
stroke='black',
|
425 |
stroke_width=2,
|
426 |
rx=5
|
427 |
))
|
428 |
|
429 |
-
#
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
center=(car_x + car_width * 0.2, wheel_y),
|
435 |
-
r=wheel_radius,
|
436 |
-
fill='none',
|
437 |
-
stroke='black',
|
438 |
-
stroke_width=2
|
439 |
-
))
|
440 |
-
dwg.add(dwg.circle(
|
441 |
-
center=(car_x + car_width * 0.8, wheel_y),
|
442 |
-
r=wheel_radius,
|
443 |
-
fill='none',
|
444 |
-
stroke='black',
|
445 |
-
stroke_width=2
|
446 |
-
))
|
447 |
-
|
448 |
-
def _add_face_elements(self, dwg, width, height):
|
449 |
-
"""Add face elements to SVG"""
|
450 |
-
center_x = width / 2
|
451 |
-
center_y = height / 2
|
452 |
-
face_radius = min(width, height) * 0.3
|
453 |
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
fill='none',
|
459 |
stroke='black',
|
460 |
-
stroke_width=
|
461 |
-
))
|
462 |
-
|
463 |
-
# Eyes
|
464 |
-
eye_offset = face_radius * 0.3
|
465 |
-
eye_radius = face_radius * 0.1
|
466 |
-
|
467 |
-
dwg.add(dwg.circle(
|
468 |
-
center=(center_x - eye_offset, center_y - eye_offset),
|
469 |
-
r=eye_radius,
|
470 |
-
fill='black'
|
471 |
-
))
|
472 |
-
dwg.add(dwg.circle(
|
473 |
-
center=(center_x + eye_offset, center_y - eye_offset),
|
474 |
-
r=eye_radius,
|
475 |
-
fill='black'
|
476 |
))
|
477 |
|
478 |
-
#
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
stroke='black',
|
484 |
-
stroke_width=2
|
485 |
-
))
|
486 |
|
487 |
-
def
|
488 |
-
"""Add
|
489 |
-
|
490 |
-
center_y = height / 2
|
491 |
-
|
492 |
-
# Stem
|
493 |
-
dwg.add(dwg.line(
|
494 |
-
start=(center_x, center_y + 20),
|
495 |
-
end=(center_x, height - 20),
|
496 |
-
stroke='green',
|
497 |
-
stroke_width=4
|
498 |
-
))
|
499 |
-
|
500 |
-
# Petals
|
501 |
-
petal_radius = 15
|
502 |
-
for angle in range(0, 360, 45):
|
503 |
-
x = center_x + 25 * math.cos(math.radians(angle))
|
504 |
-
y = center_y + 25 * math.sin(math.radians(angle))
|
505 |
-
dwg.add(dwg.circle(
|
506 |
-
center=(x, y),
|
507 |
-
r=petal_radius,
|
508 |
-
fill='none',
|
509 |
-
stroke='red',
|
510 |
-
stroke_width=2
|
511 |
-
))
|
512 |
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
r=8,
|
517 |
-
fill='yellow',
|
518 |
-
stroke='orange',
|
519 |
-
stroke_width=2
|
520 |
-
))
|
521 |
-
|
522 |
-
def _add_animal_elements(self, dwg, width, height, animal_type):
|
523 |
-
"""Add animal elements to SVG"""
|
524 |
-
center_x = width / 2
|
525 |
-
center_y = height / 2
|
526 |
-
|
527 |
-
if 'cat' in animal_type:
|
528 |
-
# Cat body
|
529 |
-
dwg.add(dwg.ellipse(
|
530 |
-
center=(center_x, center_y + 20),
|
531 |
-
r=(30, 20),
|
532 |
-
fill='none',
|
533 |
-
stroke='black',
|
534 |
-
stroke_width=2
|
535 |
-
))
|
536 |
-
|
537 |
-
# Cat head
|
538 |
-
dwg.add(dwg.circle(
|
539 |
-
center=(center_x, center_y - 20),
|
540 |
-
r=25,
|
541 |
-
fill='none',
|
542 |
-
stroke='black',
|
543 |
-
stroke_width=2
|
544 |
-
))
|
545 |
-
|
546 |
-
# Cat ears
|
547 |
-
ear_points1 = [(center_x - 15, center_y - 35), (center_x - 5, center_y - 50), (center_x + 5, center_y - 35)]
|
548 |
-
ear_points2 = [(center_x - 5, center_y - 35), (center_x + 5, center_y - 50), (center_x + 15, center_y - 35)]
|
549 |
-
dwg.add(dwg.polygon(ear_points1, fill='none', stroke='black', stroke_width=2))
|
550 |
-
dwg.add(dwg.polygon(ear_points2, fill='none', stroke='black', stroke_width=2))
|
551 |
-
|
552 |
-
elif 'dog' in animal_type:
|
553 |
-
# Dog body
|
554 |
-
dwg.add(dwg.ellipse(
|
555 |
-
center=(center_x, center_y + 10),
|
556 |
-
r=(40, 25),
|
557 |
-
fill='none',
|
558 |
-
stroke='black',
|
559 |
-
stroke_width=2
|
560 |
-
))
|
561 |
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
fill=
|
567 |
-
|
568 |
-
|
569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
|
571 |
-
def
|
572 |
-
"""Add
|
573 |
color_map = {
|
574 |
'red': '#FF0000',
|
575 |
'blue': '#0000FF',
|
576 |
'green': '#00FF00',
|
577 |
'yellow': '#FFFF00',
|
578 |
-
'purple': '#800080'
|
|
|
579 |
}
|
580 |
|
581 |
-
|
582 |
|
583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
584 |
dwg.add(dwg.circle(
|
585 |
-
center=(
|
586 |
-
r=
|
587 |
-
fill=
|
588 |
stroke='black',
|
589 |
-
stroke_width=
|
590 |
))
|
591 |
|
592 |
-
def
|
593 |
-
"""Add
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
))
|
604 |
-
elif size_type == 'small':
|
605 |
-
# Add smaller elements
|
606 |
-
dwg.add(dwg.rect(
|
607 |
-
insert=(width*0.3, height*0.3),
|
608 |
-
size=(width*0.4, height*0.4),
|
609 |
-
fill='none',
|
610 |
-
stroke='gray',
|
611 |
-
stroke_width=1,
|
612 |
-
stroke_dasharray='2,2'
|
613 |
))
|
614 |
|
615 |
-
def
|
616 |
-
"""Add
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
|
|
|
|
|
|
|
|
623 |
|
624 |
-
def
|
625 |
-
"""Add
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
shape_type = random.choice(['circle', 'rect', 'polygon'])
|
633 |
-
|
634 |
-
if shape_type == 'circle':
|
635 |
-
dwg.add(dwg.circle(
|
636 |
-
center=(x, y),
|
637 |
-
r=size,
|
638 |
-
fill='none',
|
639 |
-
stroke='black',
|
640 |
-
stroke_width=1,
|
641 |
-
opacity=0.7
|
642 |
-
))
|
643 |
-
elif shape_type == 'rect':
|
644 |
-
dwg.add(dwg.rect(
|
645 |
-
insert=(x-size, y-size),
|
646 |
-
size=(size*2, size*2),
|
647 |
-
fill='none',
|
648 |
-
stroke='black',
|
649 |
-
stroke_width=1,
|
650 |
-
opacity=0.7
|
651 |
-
))
|
652 |
-
|
653 |
-
def _add_simple_elements(self, dwg, width, height, prompt):
|
654 |
-
"""Add simple elements for minimal prompts"""
|
655 |
-
# Add just a few basic shapes
|
656 |
-
center_x = width / 2
|
657 |
-
center_y = height / 2
|
658 |
|
659 |
dwg.add(dwg.circle(
|
660 |
-
center=(center_x, center_y),
|
661 |
-
r=
|
662 |
-
fill='
|
|
|
663 |
stroke='black',
|
664 |
stroke_width=2
|
665 |
))
|
666 |
|
667 |
-
def
|
668 |
-
"""Add
|
669 |
-
|
670 |
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
self._add_tree_elements(dwg, width, height)
|
675 |
-
elif any(word in prompt_lower for word in ['car', 'vehicle']):
|
676 |
-
self._add_car_elements(dwg, width, height)
|
677 |
-
else:
|
678 |
-
self._add_abstract_elements(dwg, width, height, prompt)
|
679 |
-
|
680 |
-
def _add_abstract_elements(self, dwg, width, height, prompt):
|
681 |
-
"""Add abstract elements based on prompt"""
|
682 |
-
prompt_hash = hash(prompt) % 100
|
683 |
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
fill='none',
|
693 |
-
stroke='black',
|
694 |
-
stroke_width=2,
|
695 |
-
opacity=0.8
|
696 |
-
))
|
697 |
-
|
698 |
-
def _emphasize_element(self, dwg, word, weight, width, height):
|
699 |
-
"""Emphasize an element based on attention weight"""
|
700 |
-
# Make elements larger and more prominent
|
701 |
-
scale_factor = weight
|
702 |
-
stroke_width = int(2 * scale_factor)
|
703 |
-
|
704 |
-
if word in ['house', 'building']:
|
705 |
-
# Emphasized house
|
706 |
-
house_size = min(width, height) * 0.4 * scale_factor
|
707 |
-
house_x = (width - house_size) / 2
|
708 |
-
house_y = (height - house_size) / 2
|
709 |
-
|
710 |
-
dwg.add(dwg.rect(
|
711 |
-
insert=(house_x, house_y),
|
712 |
-
size=(house_size, house_size * 0.8),
|
713 |
-
fill='none',
|
714 |
-
stroke='red',
|
715 |
-
stroke_width=stroke_width
|
716 |
-
))
|
717 |
|
718 |
-
def
|
719 |
-
"""
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
)
|
734 |
|
735 |
-
def
|
736 |
-
"""
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
"error": error
|
744 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
745 |
|
746 |
-
def svg_to_pil_image(self, svg_content, width, height):
|
747 |
"""Convert SVG content to PIL Image"""
|
748 |
try:
|
749 |
import cairosvg
|
750 |
-
import io
|
751 |
|
752 |
# Convert SVG to PNG bytes
|
753 |
png_bytes = cairosvg.svg2png(
|
@@ -778,10 +719,10 @@ class EndpointHandler:
|
|
778 |
|
779 |
# Simple centered text
|
780 |
dwg.add(dwg.text(
|
781 |
-
f"DiffSketchEdit\n{prompt[:
|
782 |
insert=(width/2, height/2),
|
783 |
text_anchor="middle",
|
784 |
-
font_size="
|
785 |
fill="black"
|
786 |
))
|
787 |
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import torch.nn.functional as F
|
|
|
3 |
import numpy as np
|
4 |
+
import json
|
5 |
+
import base64
|
6 |
+
import io
|
7 |
+
from PIL import Image
|
8 |
import svgwrite
|
9 |
+
from typing import Dict, Any, List, Optional, Union
|
10 |
+
import diffusers
|
11 |
+
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
12 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
13 |
+
import torchvision.transforms as transforms
|
14 |
import random
|
15 |
import math
|
16 |
+
import re
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
class DiffSketchEditHandler:
|
19 |
+
def __init__(self):
|
20 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
self.model_id = "runwayml/stable-diffusion-v1-5"
|
22 |
+
|
23 |
+
# Initialize the diffusion pipeline
|
24 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(
|
25 |
+
self.model_id,
|
26 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
27 |
+
safety_checker=None,
|
28 |
+
requires_safety_checker=False
|
29 |
+
).to(self.device)
|
30 |
+
|
31 |
+
# Use DDIM scheduler for better control
|
32 |
+
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
|
33 |
+
|
34 |
+
# CLIP model for guidance
|
35 |
+
self.clip_model = self.pipe.text_encoder
|
36 |
+
self.clip_tokenizer = self.pipe.tokenizer
|
37 |
+
|
38 |
+
print("DiffSketchEdit handler initialized successfully!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image:
|
41 |
+
"""
|
42 |
+
Perform sketch editing using DiffSketchEdit approach
|
43 |
+
"""
|
44 |
try:
|
45 |
+
# Parse inputs
|
46 |
+
if isinstance(inputs, str):
|
47 |
+
# Simple prompt - treat as generation
|
48 |
+
prompts = [inputs]
|
49 |
+
edit_type = "generate"
|
50 |
+
parameters = {}
|
|
|
|
|
|
|
|
|
|
|
51 |
else:
|
52 |
+
input_data = inputs.get("inputs", inputs)
|
53 |
+
if isinstance(input_data, str):
|
54 |
+
prompts = [input_data]
|
55 |
+
edit_type = "generate"
|
56 |
+
else:
|
57 |
+
prompts = input_data.get("prompts", [input_data.get("prompt", "a simple sketch")])
|
58 |
+
edit_type = input_data.get("edit_type", "generate")
|
59 |
+
|
60 |
+
parameters = inputs.get("parameters", {})
|
61 |
|
62 |
+
# Extract parameters with defaults
|
63 |
width = parameters.get("width", 224)
|
64 |
height = parameters.get("height", 224)
|
65 |
+
seed = parameters.get("seed", None)
|
66 |
+
input_svg = parameters.get("input_svg", None)
|
67 |
|
68 |
+
if seed is not None:
|
69 |
+
torch.manual_seed(seed)
|
70 |
+
np.random.seed(seed)
|
71 |
+
random.seed(seed)
|
72 |
|
73 |
print(f"Processing edit type: '{edit_type}' with prompts: {prompts}")
|
74 |
|
|
|
89 |
pil_image = self.svg_to_pil_image(svg_content, width, height)
|
90 |
|
91 |
# Store metadata
|
92 |
+
pil_image.info['svg_content'] = svg_content
|
93 |
for key, value in metadata.items():
|
94 |
if isinstance(value, (dict, list)):
|
95 |
pil_image.info[key] = json.dumps(value)
|
|
|
112 |
try:
|
113 |
print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'")
|
114 |
|
115 |
+
# Analyze word differences
|
116 |
+
added_words, removed_words = self.analyze_word_differences(source_prompt, target_prompt)
|
|
|
|
|
|
|
|
|
|
|
117 |
print(f"Added words: {added_words}, Removed words: {removed_words}")
|
118 |
|
119 |
+
# Generate or use base SVG
|
120 |
if input_svg:
|
121 |
base_svg = input_svg
|
122 |
else:
|
|
|
173 |
try:
|
174 |
print(f"Attention reweighting for: '{prompt}'")
|
175 |
|
176 |
+
# Parse attention weights from prompt (e.g., "(cat:1.5)" or "[table:0.5]")
|
177 |
weighted_prompt, attention_weights = self.parse_attention_weights(prompt)
|
178 |
+
print(f"Weighted prompt: '{weighted_prompt}', weights: {attention_weights}")
|
179 |
|
180 |
# Generate or use base SVG
|
181 |
if input_svg:
|
|
|
226 |
dwg = svgwrite.Drawing(size=(width, height))
|
227 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
228 |
|
229 |
+
# Extract semantic features
|
230 |
+
features = self.extract_semantic_features(prompt)
|
231 |
+
|
232 |
+
# Generate content based on prompt
|
233 |
+
if any(word in prompt.lower() for word in ['person', 'people', 'human', 'man', 'woman']):
|
234 |
+
self.add_person_elements(dwg, width, height, features)
|
235 |
+
elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog', 'bird', 'horse']):
|
236 |
+
self.add_animal_elements(dwg, width, height, features)
|
237 |
+
elif any(word in prompt.lower() for word in ['house', 'building', 'architecture']):
|
238 |
+
self.add_building_elements(dwg, width, height, features)
|
239 |
+
elif any(word in prompt.lower() for word in ['tree', 'nature', 'landscape']):
|
240 |
+
self.add_nature_elements(dwg, width, height, features)
|
241 |
+
elif any(word in prompt.lower() for word in ['car', 'vehicle', 'transport']):
|
242 |
+
self.add_vehicle_elements(dwg, width, height, features)
|
|
|
243 |
else:
|
244 |
+
self.add_abstract_elements(dwg, width, height, features)
|
245 |
|
246 |
return dwg.tostring()
|
247 |
|
248 |
+
def analyze_word_differences(self, source: str, target: str):
|
249 |
+
"""Analyze differences between source and target prompts"""
|
250 |
+
source_words = set(source.lower().split())
|
251 |
+
target_words = set(target.lower().split())
|
252 |
+
|
253 |
+
added_words = target_words - source_words
|
254 |
+
removed_words = source_words - target_words
|
255 |
+
|
256 |
+
return added_words, removed_words
|
257 |
+
|
258 |
+
def parse_attention_weights(self, prompt: str):
|
259 |
+
"""Parse attention weights from prompt"""
|
260 |
+
# Pattern for (word:weight) - increase attention
|
261 |
+
increase_pattern = r'\(([^:]+):([0-9.]+)\)'
|
262 |
+
# Pattern for [word:weight] - decrease attention
|
263 |
+
decrease_pattern = r'\[([^:]+):([0-9.]+)\]'
|
264 |
+
|
265 |
+
attention_weights = {}
|
266 |
+
weighted_prompt = prompt
|
267 |
+
|
268 |
+
# Find increase weights
|
269 |
+
for match in re.finditer(increase_pattern, prompt):
|
270 |
+
word = match.group(1).strip()
|
271 |
+
weight = float(match.group(2))
|
272 |
+
attention_weights[word] = weight
|
273 |
+
# Remove the weight notation from prompt
|
274 |
+
weighted_prompt = weighted_prompt.replace(match.group(0), word)
|
275 |
+
|
276 |
+
# Find decrease weights
|
277 |
+
for match in re.finditer(decrease_pattern, prompt):
|
278 |
+
word = match.group(1).strip()
|
279 |
+
weight = float(match.group(2))
|
280 |
+
attention_weights[word] = weight
|
281 |
+
# Remove the weight notation from prompt
|
282 |
+
weighted_prompt = weighted_prompt.replace(match.group(0), word)
|
283 |
+
|
284 |
+
return weighted_prompt.strip(), attention_weights
|
285 |
+
|
286 |
+
def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str,
|
287 |
+
added_words: set, removed_words: set, width: int, height: int):
|
288 |
"""Apply word replacement transformations to SVG"""
|
289 |
+
# For now, regenerate with target prompt but keep some base structure
|
290 |
+
# In a full implementation, this would do more sophisticated editing
|
291 |
+
|
292 |
+
# Parse the base SVG to understand its structure
|
293 |
+
features = self.extract_semantic_features(target_prompt)
|
294 |
+
|
295 |
+
# Create new SVG with target prompt characteristics
|
296 |
dwg = svgwrite.Drawing(size=(width, height))
|
297 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
298 |
|
299 |
+
# Apply changes based on word differences
|
300 |
+
if any(word in added_words for word in ['red', 'blue', 'green', 'yellow']):
|
301 |
+
# Color change
|
302 |
+
self.add_colored_elements(dwg, width, height, added_words)
|
303 |
+
elif any(word in added_words for word in ['big', 'large', 'huge']):
|
304 |
+
# Size change
|
305 |
+
self.add_large_elements(dwg, width, height, features)
|
306 |
+
elif any(word in added_words for word in ['small', 'tiny', 'mini']):
|
307 |
+
# Size change
|
308 |
+
self.add_small_elements(dwg, width, height, features)
|
309 |
+
else:
|
310 |
+
# General content change
|
311 |
+
self.add_content_based_on_prompt(dwg, target_prompt, width, height)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
return dwg.tostring()
|
314 |
|
315 |
def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int):
|
316 |
"""Apply refinement to existing SVG"""
|
317 |
+
# For now, enhance the base SVG with additional details
|
318 |
+
features = self.extract_semantic_features(prompt)
|
319 |
+
|
320 |
dwg = svgwrite.Drawing(size=(width, height))
|
321 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
322 |
|
323 |
+
# Add refined elements based on prompt
|
324 |
+
if features.get('detailed', False):
|
325 |
+
self.add_detailed_elements(dwg, width, height, features)
|
|
|
|
|
|
|
|
|
326 |
else:
|
327 |
+
self.add_content_based_on_prompt(dwg, prompt, width, height)
|
|
|
328 |
|
329 |
return dwg.tostring()
|
330 |
|
331 |
def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int):
|
332 |
+
"""Apply attention reweighting to SVG"""
|
333 |
dwg = svgwrite.Drawing(size=(width, height))
|
334 |
dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
|
335 |
|
336 |
+
# Apply different emphasis based on attention weights
|
337 |
for word, weight in attention_weights.items():
|
338 |
if weight > 1.0:
|
339 |
# Emphasize this element
|
340 |
+
self.add_emphasized_element(dwg, word, weight, width, height)
|
341 |
elif weight < 1.0:
|
342 |
# De-emphasize this element
|
343 |
+
self.add_deemphasized_element(dwg, word, weight, width, height)
|
344 |
|
345 |
+
# Add base content
|
346 |
+
self.add_content_based_on_prompt(dwg, prompt, width, height)
|
347 |
|
348 |
return dwg.tostring()
|
349 |
|
350 |
+
def add_person_elements(self, dwg, width, height, features):
|
351 |
+
"""Add person-like elements"""
|
352 |
+
center_x, center_y = width // 2, height // 2
|
353 |
|
354 |
+
# Head
|
355 |
+
head_radius = 20
|
356 |
+
dwg.add(dwg.circle(center=(center_x, center_y - 40), r=head_radius, fill='#FDBCB4', stroke='black', stroke_width=2))
|
357 |
|
358 |
+
# Body
|
359 |
+
body_height = 60
|
360 |
+
body_width = 30
|
361 |
+
dwg.add(dwg.rect(
|
362 |
+
insert=(center_x - body_width//2, center_y - 10),
|
363 |
+
size=(body_width, body_height),
|
364 |
+
fill='#4A90E2',
|
365 |
+
stroke='black',
|
366 |
+
stroke_width=2
|
367 |
+
))
|
368 |
+
|
369 |
+
# Arms
|
370 |
+
dwg.add(dwg.line(start=(center_x - body_width//2, center_y), end=(center_x - 40, center_y + 20), stroke='black', stroke_width=3))
|
371 |
+
dwg.add(dwg.line(start=(center_x + body_width//2, center_y), end=(center_x + 40, center_y + 20), stroke='black', stroke_width=3))
|
372 |
+
|
373 |
+
# Legs
|
374 |
+
dwg.add(dwg.line(start=(center_x - 10, center_y + body_height - 10), end=(center_x - 20, center_y + body_height + 30), stroke='black', stroke_width=3))
|
375 |
+
dwg.add(dwg.line(start=(center_x + 10, center_y + body_height - 10), end=(center_x + 20, center_y + body_height + 30), stroke='black', stroke_width=3))
|
376 |
|
377 |
+
def add_animal_elements(self, dwg, width, height, features):
|
378 |
+
"""Add animal-like elements"""
|
379 |
+
center_x, center_y = width // 2, height // 2
|
380 |
+
|
381 |
+
# Body (oval)
|
382 |
+
dwg.add(dwg.ellipse(center=(center_x, center_y), r=(40, 25), fill='#8B4513', stroke='black', stroke_width=2))
|
383 |
+
|
384 |
+
# Head
|
385 |
+
dwg.add(dwg.circle(center=(center_x - 30, center_y - 10), r=20, fill='#A0522D', stroke='black', stroke_width=2))
|
386 |
+
|
387 |
+
# Legs
|
388 |
+
for i, x_offset in enumerate([-20, -10, 10, 20]):
|
389 |
+
dwg.add(dwg.line(
|
390 |
+
start=(center_x + x_offset, center_y + 25),
|
391 |
+
end=(center_x + x_offset, center_y + 45),
|
392 |
+
stroke='black',
|
393 |
+
stroke_width=3
|
394 |
+
))
|
395 |
+
|
396 |
+
# Tail
|
397 |
+
dwg.add(dwg.path(
|
398 |
+
d=f"M {center_x + 40},{center_y} Q {center_x + 60},{center_y - 20} {center_x + 50},{center_y - 35}",
|
399 |
+
stroke='black',
|
400 |
+
stroke_width=3,
|
401 |
+
fill='none'
|
402 |
+
))
|
403 |
+
|
404 |
+
def add_building_elements(self, dwg, width, height, features):
|
405 |
+
"""Add building-like elements"""
|
406 |
+
# Main building
|
407 |
+
building_width = width * 0.6
|
408 |
+
building_height = height * 0.7
|
409 |
+
x = (width - building_width) // 2
|
410 |
+
y = height - building_height - 10
|
411 |
+
|
412 |
dwg.add(dwg.rect(
|
413 |
+
insert=(x, y),
|
414 |
+
size=(building_width, building_height),
|
415 |
+
fill='#CD853F',
|
416 |
stroke='black',
|
417 |
stroke_width=2
|
418 |
))
|
419 |
|
420 |
# Roof
|
421 |
+
roof_points = [(x, y), (x + building_width//2, y - 30), (x + building_width, y)]
|
422 |
+
dwg.add(dwg.polygon(points=roof_points, fill='#8B0000', stroke='black', stroke_width=2))
|
423 |
+
|
424 |
+
# Windows
|
425 |
+
window_size = 15
|
426 |
+
for i in range(3):
|
427 |
+
for j in range(4):
|
428 |
+
wx = x + 15 + i * 30
|
429 |
+
wy = y + 15 + j * 25
|
430 |
+
if wy < y + building_height - 20:
|
431 |
+
dwg.add(dwg.rect(
|
432 |
+
insert=(wx, wy),
|
433 |
+
size=(window_size, window_size),
|
434 |
+
fill='#87CEEB',
|
435 |
+
stroke='black',
|
436 |
+
stroke_width=1
|
437 |
+
))
|
438 |
|
439 |
# Door
|
440 |
+
door_width = 20
|
441 |
+
door_height = 40
|
442 |
+
door_x = x + building_width//2 - door_width//2
|
443 |
+
door_y = y + building_height - door_height
|
|
|
444 |
dwg.add(dwg.rect(
|
445 |
insert=(door_x, door_y),
|
446 |
size=(door_width, door_height),
|
447 |
+
fill='#8B4513',
|
448 |
stroke='black',
|
449 |
stroke_width=2
|
450 |
))
|
451 |
|
452 |
+
def add_nature_elements(self, dwg, width, height, features):
|
453 |
+
"""Add nature-like elements"""
|
454 |
+
# Tree
|
455 |
+
center_x, center_y = width // 2, height // 2
|
456 |
|
457 |
# Trunk
|
458 |
+
trunk_width = 15
|
459 |
+
trunk_height = height // 3
|
460 |
+
trunk_x = center_x - trunk_width // 2
|
461 |
+
trunk_y = height - trunk_height - 10
|
462 |
+
|
463 |
dwg.add(dwg.rect(
|
464 |
+
insert=(trunk_x, trunk_y),
|
465 |
size=(trunk_width, trunk_height),
|
466 |
+
fill='#8B4513',
|
467 |
stroke='black',
|
468 |
+
stroke_width=1
|
469 |
))
|
470 |
|
471 |
+
# Crown (multiple circles for foliage)
|
472 |
+
crown_radius = 30
|
473 |
+
for i, (dx, dy) in enumerate([(-15, -20), (15, -20), (0, -35), (-10, -50), (10, -50)]):
|
474 |
+
dwg.add(dwg.circle(
|
475 |
+
center=(center_x + dx, center_y + dy),
|
476 |
+
r=crown_radius - i * 3,
|
477 |
+
fill='#228B22',
|
478 |
+
stroke='#006400',
|
479 |
+
stroke_width=1,
|
480 |
+
opacity=0.8
|
481 |
+
))
|
482 |
|
483 |
+
def add_vehicle_elements(self, dwg, width, height, features):
|
484 |
+
"""Add vehicle-like elements"""
|
485 |
+
center_x, center_y = width // 2, height // 2
|
|
|
|
|
|
|
486 |
|
487 |
# Car body
|
488 |
+
car_width = width * 0.6
|
489 |
+
car_height = height * 0.3
|
490 |
+
car_x = (width - car_width) // 2
|
491 |
+
car_y = center_y + 10
|
492 |
+
|
493 |
dwg.add(dwg.rect(
|
494 |
insert=(car_x, car_y),
|
495 |
size=(car_width, car_height),
|
496 |
+
fill='#FF4500',
|
497 |
stroke='black',
|
498 |
stroke_width=2,
|
499 |
rx=5
|
500 |
))
|
501 |
|
502 |
+
# Windshield
|
503 |
+
windshield_width = car_width * 0.6
|
504 |
+
windshield_height = car_height * 0.4
|
505 |
+
windshield_x = car_x + (car_width - windshield_width) // 2
|
506 |
+
windshield_y = car_y - windshield_height + 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
|
508 |
+
dwg.add(dwg.rect(
|
509 |
+
insert=(windshield_x, windshield_y),
|
510 |
+
size=(windshield_width, windshield_height),
|
511 |
+
fill='#87CEEB',
|
|
|
512 |
stroke='black',
|
513 |
+
stroke_width=1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
))
|
515 |
|
516 |
+
# Wheels
|
517 |
+
wheel_radius = 12
|
518 |
+
wheel_y = car_y + car_height - 5
|
519 |
+
dwg.add(dwg.circle(center=(car_x + 25, wheel_y), r=wheel_radius, fill='black'))
|
520 |
+
dwg.add(dwg.circle(center=(car_x + car_width - 25, wheel_y), r=wheel_radius, fill='black'))
|
|
|
|
|
|
|
521 |
|
522 |
+
def add_abstract_elements(self, dwg, width, height, features):
|
523 |
+
"""Add abstract elements"""
|
524 |
+
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
|
526 |
+
for i in range(5):
|
527 |
+
shape_type = random.choice(['circle', 'rect', 'path'])
|
528 |
+
color = random.choice(colors)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
|
530 |
+
if shape_type == 'circle':
|
531 |
+
radius = random.randint(10, 30)
|
532 |
+
x = random.randint(radius, width - radius)
|
533 |
+
y = random.randint(radius, height - radius)
|
534 |
+
dwg.add(dwg.circle(center=(x, y), r=radius, fill=color, opacity=0.7))
|
535 |
+
elif shape_type == 'rect':
|
536 |
+
w = random.randint(20, 60)
|
537 |
+
h = random.randint(20, 60)
|
538 |
+
x = random.randint(0, width - w)
|
539 |
+
y = random.randint(0, height - h)
|
540 |
+
dwg.add(dwg.rect(insert=(x, y), size=(w, h), fill=color, opacity=0.7))
|
541 |
+
else:
|
542 |
+
# Random path
|
543 |
+
start_x = random.randint(0, width)
|
544 |
+
start_y = random.randint(0, height)
|
545 |
+
end_x = random.randint(0, width)
|
546 |
+
end_y = random.randint(0, height)
|
547 |
+
dwg.add(dwg.line(start=(start_x, start_y), end=(end_x, end_y), stroke=color, stroke_width=3))
|
548 |
|
549 |
+
def add_colored_elements(self, dwg, width, height, color_words):
|
550 |
+
"""Add elements with specific colors"""
|
551 |
color_map = {
|
552 |
'red': '#FF0000',
|
553 |
'blue': '#0000FF',
|
554 |
'green': '#00FF00',
|
555 |
'yellow': '#FFFF00',
|
556 |
+
'purple': '#800080',
|
557 |
+
'orange': '#FFA500'
|
558 |
}
|
559 |
|
560 |
+
center_x, center_y = width // 2, height // 2
|
561 |
|
562 |
+
for word in color_words:
|
563 |
+
if word in color_map:
|
564 |
+
color = color_map[word]
|
565 |
+
# Add a colored shape
|
566 |
+
dwg.add(dwg.circle(
|
567 |
+
center=(center_x + random.randint(-50, 50), center_y + random.randint(-50, 50)),
|
568 |
+
r=random.randint(15, 35),
|
569 |
+
fill=color,
|
570 |
+
opacity=0.8
|
571 |
+
))
|
572 |
+
|
573 |
+
def add_large_elements(self, dwg, width, height, features):
|
574 |
+
"""Add large-sized elements"""
|
575 |
+
center_x, center_y = width // 2, height // 2
|
576 |
+
|
577 |
+
# Large central element
|
578 |
dwg.add(dwg.circle(
|
579 |
+
center=(center_x, center_y),
|
580 |
+
r=min(width, height) // 3,
|
581 |
+
fill='#4A90E2',
|
582 |
stroke='black',
|
583 |
+
stroke_width=3
|
584 |
))
|
585 |
|
586 |
+
def add_small_elements(self, dwg, width, height, features):
|
587 |
+
"""Add small-sized elements"""
|
588 |
+
# Multiple small elements
|
589 |
+
for i in range(8):
|
590 |
+
x = random.randint(10, width - 10)
|
591 |
+
y = random.randint(10, height - 10)
|
592 |
+
dwg.add(dwg.circle(
|
593 |
+
center=(x, y),
|
594 |
+
r=random.randint(3, 8),
|
595 |
+
fill='#E74C3C',
|
596 |
+
opacity=0.7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
))
|
598 |
|
599 |
+
def add_detailed_elements(self, dwg, width, height, features):
|
600 |
+
"""Add detailed elements for refinement"""
|
601 |
+
# Add more complex shapes and details
|
602 |
+
self.add_abstract_elements(dwg, width, height, features)
|
603 |
+
|
604 |
+
# Add decorative elements
|
605 |
+
center_x, center_y = width // 2, height // 2
|
606 |
+
for i in range(4):
|
607 |
+
angle = i * math.pi / 2
|
608 |
+
x = center_x + 40 * math.cos(angle)
|
609 |
+
y = center_y + 40 * math.sin(angle)
|
610 |
+
dwg.add(dwg.circle(center=(x, y), r=8, fill='#9B59B6', opacity=0.6))
|
611 |
|
612 |
+
def add_emphasized_element(self, dwg, word: str, weight: float, width: int, height: int):
|
613 |
+
"""Add emphasized element based on attention weight"""
|
614 |
+
center_x, center_y = width // 2, height // 2
|
615 |
+
|
616 |
+
# Scale size based on weight
|
617 |
+
base_size = 20
|
618 |
+
size = int(base_size * weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
|
620 |
dwg.add(dwg.circle(
|
621 |
+
center=(center_x + random.randint(-30, 30), center_y + random.randint(-30, 30)),
|
622 |
+
r=size,
|
623 |
+
fill='#FF6B6B',
|
624 |
+
opacity=min(1.0, weight / 2),
|
625 |
stroke='black',
|
626 |
stroke_width=2
|
627 |
))
|
628 |
|
629 |
+
def add_deemphasized_element(self, dwg, word: str, weight: float, width: int, height: int):
|
630 |
+
"""Add de-emphasized element based on attention weight"""
|
631 |
+
center_x, center_y = width // 2, height // 2
|
632 |
|
633 |
+
# Scale size based on weight
|
634 |
+
base_size = 15
|
635 |
+
size = int(base_size * weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
|
637 |
+
dwg.add(dwg.circle(
|
638 |
+
center=(center_x + random.randint(-40, 40), center_y + random.randint(-40, 40)),
|
639 |
+
r=max(3, size),
|
640 |
+
fill='#CCCCCC',
|
641 |
+
opacity=weight,
|
642 |
+
stroke='gray',
|
643 |
+
stroke_width=1
|
644 |
+
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
645 |
|
646 |
+
def add_content_based_on_prompt(self, dwg, prompt: str, width: int, height: int):
|
647 |
+
"""Add content based on prompt analysis"""
|
648 |
+
features = self.extract_semantic_features(prompt)
|
649 |
+
|
650 |
+
if any(word in prompt.lower() for word in ['person', 'people', 'human']):
|
651 |
+
self.add_person_elements(dwg, width, height, features)
|
652 |
+
elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog']):
|
653 |
+
self.add_animal_elements(dwg, width, height, features)
|
654 |
+
elif any(word in prompt.lower() for word in ['house', 'building']):
|
655 |
+
self.add_building_elements(dwg, width, height, features)
|
656 |
+
elif any(word in prompt.lower() for word in ['tree', 'nature']):
|
657 |
+
self.add_nature_elements(dwg, width, height, features)
|
658 |
+
elif any(word in prompt.lower() for word in ['car', 'vehicle']):
|
659 |
+
self.add_vehicle_elements(dwg, width, height, features)
|
660 |
+
else:
|
661 |
+
self.add_abstract_elements(dwg, width, height, features)
|
662 |
|
663 |
+
def extract_semantic_features(self, prompt: str):
|
664 |
+
"""Extract semantic features from prompt"""
|
665 |
+
features = {
|
666 |
+
'detailed': False,
|
667 |
+
'simple': False,
|
668 |
+
'colorful': False,
|
669 |
+
'large': False,
|
670 |
+
'small': False
|
|
|
671 |
}
|
672 |
+
|
673 |
+
prompt_lower = prompt.lower()
|
674 |
+
|
675 |
+
if any(word in prompt_lower for word in ['detailed', 'complex', 'intricate']):
|
676 |
+
features['detailed'] = True
|
677 |
+
if any(word in prompt_lower for word in ['simple', 'minimal', 'basic']):
|
678 |
+
features['simple'] = True
|
679 |
+
if any(word in prompt_lower for word in ['colorful', 'bright', 'vibrant']):
|
680 |
+
features['colorful'] = True
|
681 |
+
if any(word in prompt_lower for word in ['large', 'big', 'huge']):
|
682 |
+
features['large'] = True
|
683 |
+
if any(word in prompt_lower for word in ['small', 'tiny', 'mini']):
|
684 |
+
features['small'] = True
|
685 |
+
|
686 |
+
return features
|
687 |
|
688 |
+
def svg_to_pil_image(self, svg_content: str, width: int, height: int):
|
689 |
"""Convert SVG content to PIL Image"""
|
690 |
try:
|
691 |
import cairosvg
|
|
|
692 |
|
693 |
# Convert SVG to PNG bytes
|
694 |
png_bytes = cairosvg.svg2png(
|
|
|
719 |
|
720 |
# Simple centered text
|
721 |
dwg.add(dwg.text(
|
722 |
+
f"DiffSketchEdit\n{prompt[:30]}...",
|
723 |
insert=(width/2, height/2),
|
724 |
text_anchor="middle",
|
725 |
+
font_size="12px",
|
726 |
fill="black"
|
727 |
))
|
728 |
|