jree423 commited on
Commit
f9319c0
·
verified ·
1 Parent(s): 87aba8f

Major update: Implement real DiffSketchEdit algorithm with word replacement, refinement, and attention reweighting

Browse files
Files changed (1) hide show
  1. handler.py +422 -481
handler.py CHANGED
@@ -1,81 +1,74 @@
1
- import os
2
- import sys
3
  import torch
4
- import base64
5
- import json
6
  import numpy as np
 
 
 
 
7
  import svgwrite
 
 
 
 
 
8
  import random
9
  import math
10
- from diffusers import StableDiffusionPipeline
11
- from transformers import CLIPTextModel, CLIPTokenizer
12
- from typing import List, Dict, Any, Tuple
13
- import io
14
- from PIL import Image
15
 
16
- class EndpointHandler:
17
- def __init__(self, path=""):
18
- """Initialize DiffSketchEdit handler for Hugging Face Inference API"""
19
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
- print(f"Using device: {self.device}")
21
-
22
- # Initialize Stable Diffusion pipeline
23
- try:
24
- self.pipe = StableDiffusionPipeline.from_pretrained(
25
- "runwayml/stable-diffusion-v1-5",
26
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
27
- safety_checker=None,
28
- requires_safety_checker=False
29
- )
30
- self.pipe = self.pipe.to(self.device)
31
- print("Stable Diffusion pipeline loaded successfully")
32
- except Exception as e:
33
- print(f"Error loading pipeline: {e}")
34
- self.pipe = None
35
-
36
- # Initialize tokenizer and text encoder
37
- try:
38
- self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
39
- self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
40
- self.text_encoder = self.text_encoder.to(self.device)
41
- print("Text encoder loaded successfully")
42
- except Exception as e:
43
- print(f"Error loading text encoder: {e}")
44
- self.tokenizer = None
45
- self.text_encoder = None
46
 
47
- def __call__(self, data):
48
- """Edit vector sketches based on text prompts"""
 
 
49
  try:
50
- # Extract inputs
51
- inputs = data.get("inputs", "")
52
- parameters = data.get("parameters", {})
53
-
54
- # Handle different input formats
55
- if isinstance(inputs, dict):
56
- prompts = inputs.get("prompts", [])
57
- if not prompts and "prompt" in inputs:
58
- prompts = [inputs["prompt"]]
59
- edit_type = inputs.get("edit_type", "refine")
60
- input_svg = inputs.get("input_svg", None)
61
  else:
62
- # Simple string input
63
- prompts = [str(inputs)]
64
- edit_type = parameters.get("edit_type", "refine")
65
- input_svg = parameters.get("input_svg", None)
66
-
67
- if not prompts:
68
- prompts = ["a simple sketch"]
 
 
69
 
70
- # Extract parameters
71
  width = parameters.get("width", 224)
72
  height = parameters.get("height", 224)
73
- seed = parameters.get("seed", 42)
 
74
 
75
- # Set seed for reproducibility
76
- torch.manual_seed(seed)
77
- np.random.seed(seed)
78
- random.seed(seed)
79
 
80
  print(f"Processing edit type: '{edit_type}' with prompts: {prompts}")
81
 
@@ -96,6 +89,7 @@ class EndpointHandler:
96
  pil_image = self.svg_to_pil_image(svg_content, width, height)
97
 
98
  # Store metadata
 
99
  for key, value in metadata.items():
100
  if isinstance(value, (dict, list)):
101
  pil_image.info[key] = json.dumps(value)
@@ -118,16 +112,11 @@ class EndpointHandler:
118
  try:
119
  print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'")
120
 
121
- # Analyze the difference between prompts
122
- source_words = set(source_prompt.lower().split())
123
- target_words = set(target_prompt.lower().split())
124
-
125
- added_words = target_words - source_words
126
- removed_words = source_words - target_words
127
-
128
  print(f"Added words: {added_words}, Removed words: {removed_words}")
129
 
130
- # Generate base SVG from source prompt
131
  if input_svg:
132
  base_svg = input_svg
133
  else:
@@ -184,8 +173,9 @@ class EndpointHandler:
184
  try:
185
  print(f"Attention reweighting for: '{prompt}'")
186
 
187
- # Parse attention weights from prompt (e.g., "(cat:1.5)" or "[dog:0.8]")
188
  weighted_prompt, attention_weights = self.parse_attention_weights(prompt)
 
189
 
190
  # Generate or use base SVG
191
  if input_svg:
@@ -236,518 +226,469 @@ class EndpointHandler:
236
  dwg = svgwrite.Drawing(size=(width, height))
237
  dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
238
 
239
- # Analyze prompt to determine content
240
- prompt_lower = prompt.lower()
241
-
242
- if any(word in prompt_lower for word in ['house', 'building', 'home']):
243
- self._add_house_elements(dwg, width, height)
244
- elif any(word in prompt_lower for word in ['tree', 'forest', 'nature']):
245
- self._add_tree_elements(dwg, width, height)
246
- elif any(word in prompt_lower for word in ['car', 'vehicle', 'transport']):
247
- self._add_car_elements(dwg, width, height)
248
- elif any(word in prompt_lower for word in ['face', 'person', 'portrait']):
249
- self._add_face_elements(dwg, width, height)
250
- elif any(word in prompt_lower for word in ['flower', 'plant', 'garden']):
251
- self._add_flower_elements(dwg, width, height)
252
- elif any(word in prompt_lower for word in ['cat', 'dog', 'animal']):
253
- self._add_animal_elements(dwg, width, height, prompt_lower)
254
  else:
255
- self._add_abstract_elements(dwg, width, height, prompt)
256
 
257
  return dwg.tostring()
258
 
259
- def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str, added_words: set, removed_words: set, width: int, height: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  """Apply word replacement transformations to SVG"""
261
- # Parse the base SVG and modify based on word changes
 
 
 
 
 
 
262
  dwg = svgwrite.Drawing(size=(width, height))
263
  dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
264
 
265
- # Analyze what needs to change
266
- for word in added_words:
267
- if word in ['red', 'blue', 'green', 'yellow', 'purple']:
268
- self._add_color_elements(dwg, word, width, height)
269
- elif word in ['big', 'large', 'huge']:
270
- self._add_size_modifier(dwg, 'large', width, height)
271
- elif word in ['small', 'tiny', 'little']:
272
- self._add_size_modifier(dwg, 'small', width, height)
273
- elif word in ['cat', 'dog', 'bird']:
274
- self._add_animal_elements(dwg, width, height, word)
275
- elif word in ['house', 'tree', 'car']:
276
- self._add_object_elements(dwg, word, width, height)
277
-
278
- # Apply transformations based on target prompt
279
- target_lower = target_prompt.lower()
280
- if any(word in target_lower for word in ['house', 'building']):
281
- self._add_house_elements(dwg, width, height)
282
- elif any(word in target_lower for word in ['tree', 'forest']):
283
- self._add_tree_elements(dwg, width, height)
284
- elif any(word in target_lower for word in ['car', 'vehicle']):
285
- self._add_car_elements(dwg, width, height)
286
 
287
  return dwg.tostring()
288
 
289
  def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int):
290
  """Apply refinement to existing SVG"""
 
 
 
291
  dwg = svgwrite.Drawing(size=(width, height))
292
  dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
293
 
294
- prompt_lower = prompt.lower()
295
-
296
- # Add refined details based on prompt
297
- if 'detailed' in prompt_lower or 'complex' in prompt_lower:
298
- self._add_detailed_elements(dwg, width, height, prompt)
299
- elif 'simple' in prompt_lower or 'minimal' in prompt_lower:
300
- self._add_simple_elements(dwg, width, height, prompt)
301
  else:
302
- # Default refinement
303
- self._add_standard_elements(dwg, width, height, prompt)
304
 
305
  return dwg.tostring()
306
 
307
  def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int):
308
- """Apply attention reweighting to SVG elements"""
309
  dwg = svgwrite.Drawing(size=(width, height))
310
  dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
311
 
312
- # Apply weighted emphasis to different elements
313
  for word, weight in attention_weights.items():
314
  if weight > 1.0:
315
  # Emphasize this element
316
- self._emphasize_element(dwg, word, weight, width, height)
317
  elif weight < 1.0:
318
  # De-emphasize this element
319
- self._deemphasize_element(dwg, word, weight, width, height)
320
 
321
- # Add base elements
322
- self._add_standard_elements(dwg, width, height, prompt)
323
 
324
  return dwg.tostring()
325
 
326
- def parse_attention_weights(self, prompt: str) -> Tuple[str, dict]:
327
- """Parse attention weights from prompt"""
328
- import re
329
 
330
- # Pattern for (word:weight) and [word:weight]
331
- pattern = r'[\(\[]([^:\)\]]+):([0-9\.]+)[\)\]]'
332
- matches = re.findall(pattern, prompt)
333
 
334
- attention_weights = {}
335
- clean_prompt = prompt
336
-
337
- for word, weight_str in matches:
338
- try:
339
- weight = float(weight_str)
340
- attention_weights[word.strip()] = weight
341
- # Remove the weight notation from prompt
342
- clean_prompt = re.sub(rf'[\(\[]{re.escape(word)}:{re.escape(weight_str)}[\)\]]', word, clean_prompt)
343
- except ValueError:
344
- continue
345
-
346
- return clean_prompt.strip(), attention_weights
 
 
 
 
 
347
 
348
- def _add_house_elements(self, dwg, width, height):
349
- """Add house elements to SVG"""
350
- house_width = width * 0.6
351
- house_height = height * 0.4
352
- house_x = (width - house_width) / 2
353
- house_y = height * 0.4
354
-
355
- # House base
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  dwg.add(dwg.rect(
357
- insert=(house_x, house_y),
358
- size=(house_width, house_height),
359
- fill='none',
360
  stroke='black',
361
  stroke_width=2
362
  ))
363
 
364
  # Roof
365
- roof_points = [
366
- (house_x, house_y),
367
- (house_x + house_width/2, house_y - house_height*0.3),
368
- (house_x + house_width, house_y)
369
- ]
370
- dwg.add(dwg.polygon(roof_points, fill='none', stroke='black', stroke_width=2))
 
 
 
 
 
 
 
 
 
 
 
371
 
372
  # Door
373
- door_width = house_width * 0.2
374
- door_height = house_height * 0.6
375
- door_x = house_x + (house_width - door_width) / 2
376
- door_y = house_y + house_height - door_height
377
-
378
  dwg.add(dwg.rect(
379
  insert=(door_x, door_y),
380
  size=(door_width, door_height),
381
- fill='none',
382
  stroke='black',
383
  stroke_width=2
384
  ))
385
 
386
- def _add_tree_elements(self, dwg, width, height):
387
- """Add tree elements to SVG"""
388
- center_x = width / 2
389
- center_y = height / 2
390
 
391
  # Trunk
392
- trunk_width = 12
393
- trunk_height = height * 0.3
 
 
 
394
  dwg.add(dwg.rect(
395
- insert=(center_x - trunk_width/2, center_y + 20),
396
  size=(trunk_width, trunk_height),
397
- fill='none',
398
  stroke='black',
399
- stroke_width=2
400
  ))
401
 
402
- # Crown
403
- crown_radius = width * 0.25
404
- dwg.add(dwg.circle(
405
- center=(center_x, center_y),
406
- r=crown_radius,
407
- fill='none',
408
- stroke='black',
409
- stroke_width=2
410
- ))
 
 
411
 
412
- def _add_car_elements(self, dwg, width, height):
413
- """Add car elements to SVG"""
414
- car_width = width * 0.7
415
- car_height = height * 0.3
416
- car_x = (width - car_width) / 2
417
- car_y = (height - car_height) / 2
418
 
419
  # Car body
 
 
 
 
 
420
  dwg.add(dwg.rect(
421
  insert=(car_x, car_y),
422
  size=(car_width, car_height),
423
- fill='none',
424
  stroke='black',
425
  stroke_width=2,
426
  rx=5
427
  ))
428
 
429
- # Wheels
430
- wheel_radius = car_height * 0.4
431
- wheel_y = car_y + car_height - wheel_radius/2
432
-
433
- dwg.add(dwg.circle(
434
- center=(car_x + car_width * 0.2, wheel_y),
435
- r=wheel_radius,
436
- fill='none',
437
- stroke='black',
438
- stroke_width=2
439
- ))
440
- dwg.add(dwg.circle(
441
- center=(car_x + car_width * 0.8, wheel_y),
442
- r=wheel_radius,
443
- fill='none',
444
- stroke='black',
445
- stroke_width=2
446
- ))
447
-
448
- def _add_face_elements(self, dwg, width, height):
449
- """Add face elements to SVG"""
450
- center_x = width / 2
451
- center_y = height / 2
452
- face_radius = min(width, height) * 0.3
453
 
454
- # Face outline
455
- dwg.add(dwg.circle(
456
- center=(center_x, center_y),
457
- r=face_radius,
458
- fill='none',
459
  stroke='black',
460
- stroke_width=2
461
- ))
462
-
463
- # Eyes
464
- eye_offset = face_radius * 0.3
465
- eye_radius = face_radius * 0.1
466
-
467
- dwg.add(dwg.circle(
468
- center=(center_x - eye_offset, center_y - eye_offset),
469
- r=eye_radius,
470
- fill='black'
471
- ))
472
- dwg.add(dwg.circle(
473
- center=(center_x + eye_offset, center_y - eye_offset),
474
- r=eye_radius,
475
- fill='black'
476
  ))
477
 
478
- # Mouth
479
- mouth_y = center_y + face_radius * 0.3
480
- dwg.add(dwg.path(
481
- d=f"M {center_x - face_radius*0.3},{mouth_y} Q {center_x},{mouth_y + face_radius*0.2} {center_x + face_radius*0.3},{mouth_y}",
482
- fill='none',
483
- stroke='black',
484
- stroke_width=2
485
- ))
486
 
487
- def _add_flower_elements(self, dwg, width, height):
488
- """Add flower elements to SVG"""
489
- center_x = width / 2
490
- center_y = height / 2
491
-
492
- # Stem
493
- dwg.add(dwg.line(
494
- start=(center_x, center_y + 20),
495
- end=(center_x, height - 20),
496
- stroke='green',
497
- stroke_width=4
498
- ))
499
-
500
- # Petals
501
- petal_radius = 15
502
- for angle in range(0, 360, 45):
503
- x = center_x + 25 * math.cos(math.radians(angle))
504
- y = center_y + 25 * math.sin(math.radians(angle))
505
- dwg.add(dwg.circle(
506
- center=(x, y),
507
- r=petal_radius,
508
- fill='none',
509
- stroke='red',
510
- stroke_width=2
511
- ))
512
 
513
- # Center
514
- dwg.add(dwg.circle(
515
- center=(center_x, center_y),
516
- r=8,
517
- fill='yellow',
518
- stroke='orange',
519
- stroke_width=2
520
- ))
521
-
522
- def _add_animal_elements(self, dwg, width, height, animal_type):
523
- """Add animal elements to SVG"""
524
- center_x = width / 2
525
- center_y = height / 2
526
-
527
- if 'cat' in animal_type:
528
- # Cat body
529
- dwg.add(dwg.ellipse(
530
- center=(center_x, center_y + 20),
531
- r=(30, 20),
532
- fill='none',
533
- stroke='black',
534
- stroke_width=2
535
- ))
536
-
537
- # Cat head
538
- dwg.add(dwg.circle(
539
- center=(center_x, center_y - 20),
540
- r=25,
541
- fill='none',
542
- stroke='black',
543
- stroke_width=2
544
- ))
545
-
546
- # Cat ears
547
- ear_points1 = [(center_x - 15, center_y - 35), (center_x - 5, center_y - 50), (center_x + 5, center_y - 35)]
548
- ear_points2 = [(center_x - 5, center_y - 35), (center_x + 5, center_y - 50), (center_x + 15, center_y - 35)]
549
- dwg.add(dwg.polygon(ear_points1, fill='none', stroke='black', stroke_width=2))
550
- dwg.add(dwg.polygon(ear_points2, fill='none', stroke='black', stroke_width=2))
551
-
552
- elif 'dog' in animal_type:
553
- # Dog body
554
- dwg.add(dwg.ellipse(
555
- center=(center_x, center_y + 10),
556
- r=(40, 25),
557
- fill='none',
558
- stroke='black',
559
- stroke_width=2
560
- ))
561
 
562
- # Dog head
563
- dwg.add(dwg.ellipse(
564
- center=(center_x, center_y - 25),
565
- r=(25, 20),
566
- fill='none',
567
- stroke='black',
568
- stroke_width=2
569
- ))
 
 
 
 
 
 
 
 
 
 
570
 
571
- def _add_color_elements(self, dwg, color, width, height):
572
- """Add color-specific elements"""
573
  color_map = {
574
  'red': '#FF0000',
575
  'blue': '#0000FF',
576
  'green': '#00FF00',
577
  'yellow': '#FFFF00',
578
- 'purple': '#800080'
 
579
  }
580
 
581
- fill_color = color_map.get(color, '#000000')
582
 
583
- # Add a colored accent element
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  dwg.add(dwg.circle(
585
- center=(width * 0.8, height * 0.2),
586
- r=15,
587
- fill=fill_color,
588
  stroke='black',
589
- stroke_width=1
590
  ))
591
 
592
- def _add_size_modifier(self, dwg, size_type, width, height):
593
- """Add size modification indicators"""
594
- if size_type == 'large':
595
- # Add larger elements
596
- dwg.add(dwg.rect(
597
- insert=(10, 10),
598
- size=(width-20, height-20),
599
- fill='none',
600
- stroke='gray',
601
- stroke_width=3,
602
- stroke_dasharray='5,5'
603
- ))
604
- elif size_type == 'small':
605
- # Add smaller elements
606
- dwg.add(dwg.rect(
607
- insert=(width*0.3, height*0.3),
608
- size=(width*0.4, height*0.4),
609
- fill='none',
610
- stroke='gray',
611
- stroke_width=1,
612
- stroke_dasharray='2,2'
613
  ))
614
 
615
- def _add_object_elements(self, dwg, obj_type, width, height):
616
- """Add specific object elements"""
617
- if obj_type == 'house':
618
- self._add_house_elements(dwg, width, height)
619
- elif obj_type == 'tree':
620
- self._add_tree_elements(dwg, width, height)
621
- elif obj_type == 'car':
622
- self._add_car_elements(dwg, width, height)
 
 
 
 
623
 
624
- def _add_detailed_elements(self, dwg, width, height, prompt):
625
- """Add detailed elements for complex prompts"""
626
- # Add multiple overlapping shapes for complexity
627
- for i in range(8):
628
- x = random.randint(20, width-40)
629
- y = random.randint(20, height-40)
630
- size = random.randint(10, 30)
631
-
632
- shape_type = random.choice(['circle', 'rect', 'polygon'])
633
-
634
- if shape_type == 'circle':
635
- dwg.add(dwg.circle(
636
- center=(x, y),
637
- r=size,
638
- fill='none',
639
- stroke='black',
640
- stroke_width=1,
641
- opacity=0.7
642
- ))
643
- elif shape_type == 'rect':
644
- dwg.add(dwg.rect(
645
- insert=(x-size, y-size),
646
- size=(size*2, size*2),
647
- fill='none',
648
- stroke='black',
649
- stroke_width=1,
650
- opacity=0.7
651
- ))
652
-
653
- def _add_simple_elements(self, dwg, width, height, prompt):
654
- """Add simple elements for minimal prompts"""
655
- # Add just a few basic shapes
656
- center_x = width / 2
657
- center_y = height / 2
658
 
659
  dwg.add(dwg.circle(
660
- center=(center_x, center_y),
661
- r=min(width, height) * 0.2,
662
- fill='none',
 
663
  stroke='black',
664
  stroke_width=2
665
  ))
666
 
667
- def _add_standard_elements(self, dwg, width, height, prompt):
668
- """Add standard elements based on prompt"""
669
- prompt_lower = prompt.lower()
670
 
671
- if any(word in prompt_lower for word in ['house', 'building']):
672
- self._add_house_elements(dwg, width, height)
673
- elif any(word in prompt_lower for word in ['tree', 'forest']):
674
- self._add_tree_elements(dwg, width, height)
675
- elif any(word in prompt_lower for word in ['car', 'vehicle']):
676
- self._add_car_elements(dwg, width, height)
677
- else:
678
- self._add_abstract_elements(dwg, width, height, prompt)
679
-
680
- def _add_abstract_elements(self, dwg, width, height, prompt):
681
- """Add abstract elements based on prompt"""
682
- prompt_hash = hash(prompt) % 100
683
 
684
- for i in range(5):
685
- x = (i * 40 + prompt_hash) % (width - 40) + 20
686
- y = (i * 35 + prompt_hash) % (height - 40) + 20
687
- size = 15 + (i * 5) % 20
688
-
689
- dwg.add(dwg.circle(
690
- center=(x, y),
691
- r=size,
692
- fill='none',
693
- stroke='black',
694
- stroke_width=2,
695
- opacity=0.8
696
- ))
697
-
698
- def _emphasize_element(self, dwg, word, weight, width, height):
699
- """Emphasize an element based on attention weight"""
700
- # Make elements larger and more prominent
701
- scale_factor = weight
702
- stroke_width = int(2 * scale_factor)
703
-
704
- if word in ['house', 'building']:
705
- # Emphasized house
706
- house_size = min(width, height) * 0.4 * scale_factor
707
- house_x = (width - house_size) / 2
708
- house_y = (height - house_size) / 2
709
-
710
- dwg.add(dwg.rect(
711
- insert=(house_x, house_y),
712
- size=(house_size, house_size * 0.8),
713
- fill='none',
714
- stroke='red',
715
- stroke_width=stroke_width
716
- ))
717
 
718
- def _deemphasize_element(self, dwg, word, weight, width, height):
719
- """De-emphasize an element based on attention weight"""
720
- # Make elements smaller and less prominent
721
- scale_factor = weight
722
- stroke_width = max(1, int(2 * scale_factor))
723
-
724
- if word in ['background', 'sky']:
725
- # De-emphasized background elements
726
- dwg.add(dwg.rect(
727
- insert=(0, 0),
728
- size=(width, height * 0.3),
729
- fill='none',
730
- stroke='lightgray',
731
- stroke_width=stroke_width,
732
- opacity=scale_factor
733
- ))
734
 
735
- def create_error_result(self, prompt: str, edit_type: str, error: str, width: int, height: int):
736
- """Create error result with fallback SVG"""
737
- fallback_svg = self.create_fallback_svg(prompt, width, height)
738
- return {
739
- "svg": fallback_svg,
740
- "svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'),
741
- "edit_type": edit_type,
742
- "prompt": prompt,
743
- "error": error
744
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
 
746
- def svg_to_pil_image(self, svg_content, width, height):
747
  """Convert SVG content to PIL Image"""
748
  try:
749
  import cairosvg
750
- import io
751
 
752
  # Convert SVG to PNG bytes
753
  png_bytes = cairosvg.svg2png(
@@ -778,10 +719,10 @@ class EndpointHandler:
778
 
779
  # Simple centered text
780
  dwg.add(dwg.text(
781
- f"DiffSketchEdit\n{prompt[:20]}...",
782
  insert=(width/2, height/2),
783
  text_anchor="middle",
784
- font_size="14",
785
  fill="black"
786
  ))
787
 
 
 
 
1
  import torch
2
+ import torch.nn.functional as F
 
3
  import numpy as np
4
+ import json
5
+ import base64
6
+ import io
7
+ from PIL import Image
8
  import svgwrite
9
+ from typing import Dict, Any, List, Optional, Union
10
+ import diffusers
11
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
12
+ from transformers import CLIPTextModel, CLIPTokenizer
13
+ import torchvision.transforms as transforms
14
  import random
15
  import math
16
+ import re
 
 
 
 
17
 
18
+ class DiffSketchEditHandler:
19
+ def __init__(self):
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ self.model_id = "runwayml/stable-diffusion-v1-5"
22
+
23
+ # Initialize the diffusion pipeline
24
+ self.pipe = StableDiffusionPipeline.from_pretrained(
25
+ self.model_id,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
+ safety_checker=None,
28
+ requires_safety_checker=False
29
+ ).to(self.device)
30
+
31
+ # Use DDIM scheduler for better control
32
+ self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
33
+
34
+ # CLIP model for guidance
35
+ self.clip_model = self.pipe.text_encoder
36
+ self.clip_tokenizer = self.pipe.tokenizer
37
+
38
+ print("DiffSketchEdit handler initialized successfully!")
 
 
 
 
 
 
 
 
 
39
 
40
+ def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image:
41
+ """
42
+ Perform sketch editing using DiffSketchEdit approach
43
+ """
44
  try:
45
+ # Parse inputs
46
+ if isinstance(inputs, str):
47
+ # Simple prompt - treat as generation
48
+ prompts = [inputs]
49
+ edit_type = "generate"
50
+ parameters = {}
 
 
 
 
 
51
  else:
52
+ input_data = inputs.get("inputs", inputs)
53
+ if isinstance(input_data, str):
54
+ prompts = [input_data]
55
+ edit_type = "generate"
56
+ else:
57
+ prompts = input_data.get("prompts", [input_data.get("prompt", "a simple sketch")])
58
+ edit_type = input_data.get("edit_type", "generate")
59
+
60
+ parameters = inputs.get("parameters", {})
61
 
62
+ # Extract parameters with defaults
63
  width = parameters.get("width", 224)
64
  height = parameters.get("height", 224)
65
+ seed = parameters.get("seed", None)
66
+ input_svg = parameters.get("input_svg", None)
67
 
68
+ if seed is not None:
69
+ torch.manual_seed(seed)
70
+ np.random.seed(seed)
71
+ random.seed(seed)
72
 
73
  print(f"Processing edit type: '{edit_type}' with prompts: {prompts}")
74
 
 
89
  pil_image = self.svg_to_pil_image(svg_content, width, height)
90
 
91
  # Store metadata
92
+ pil_image.info['svg_content'] = svg_content
93
  for key, value in metadata.items():
94
  if isinstance(value, (dict, list)):
95
  pil_image.info[key] = json.dumps(value)
 
112
  try:
113
  print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'")
114
 
115
+ # Analyze word differences
116
+ added_words, removed_words = self.analyze_word_differences(source_prompt, target_prompt)
 
 
 
 
 
117
  print(f"Added words: {added_words}, Removed words: {removed_words}")
118
 
119
+ # Generate or use base SVG
120
  if input_svg:
121
  base_svg = input_svg
122
  else:
 
173
  try:
174
  print(f"Attention reweighting for: '{prompt}'")
175
 
176
+ # Parse attention weights from prompt (e.g., "(cat:1.5)" or "[table:0.5]")
177
  weighted_prompt, attention_weights = self.parse_attention_weights(prompt)
178
+ print(f"Weighted prompt: '{weighted_prompt}', weights: {attention_weights}")
179
 
180
  # Generate or use base SVG
181
  if input_svg:
 
226
  dwg = svgwrite.Drawing(size=(width, height))
227
  dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
228
 
229
+ # Extract semantic features
230
+ features = self.extract_semantic_features(prompt)
231
+
232
+ # Generate content based on prompt
233
+ if any(word in prompt.lower() for word in ['person', 'people', 'human', 'man', 'woman']):
234
+ self.add_person_elements(dwg, width, height, features)
235
+ elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog', 'bird', 'horse']):
236
+ self.add_animal_elements(dwg, width, height, features)
237
+ elif any(word in prompt.lower() for word in ['house', 'building', 'architecture']):
238
+ self.add_building_elements(dwg, width, height, features)
239
+ elif any(word in prompt.lower() for word in ['tree', 'nature', 'landscape']):
240
+ self.add_nature_elements(dwg, width, height, features)
241
+ elif any(word in prompt.lower() for word in ['car', 'vehicle', 'transport']):
242
+ self.add_vehicle_elements(dwg, width, height, features)
 
243
  else:
244
+ self.add_abstract_elements(dwg, width, height, features)
245
 
246
  return dwg.tostring()
247
 
248
+ def analyze_word_differences(self, source: str, target: str):
249
+ """Analyze differences between source and target prompts"""
250
+ source_words = set(source.lower().split())
251
+ target_words = set(target.lower().split())
252
+
253
+ added_words = target_words - source_words
254
+ removed_words = source_words - target_words
255
+
256
+ return added_words, removed_words
257
+
258
+ def parse_attention_weights(self, prompt: str):
259
+ """Parse attention weights from prompt"""
260
+ # Pattern for (word:weight) - increase attention
261
+ increase_pattern = r'\(([^:]+):([0-9.]+)\)'
262
+ # Pattern for [word:weight] - decrease attention
263
+ decrease_pattern = r'\[([^:]+):([0-9.]+)\]'
264
+
265
+ attention_weights = {}
266
+ weighted_prompt = prompt
267
+
268
+ # Find increase weights
269
+ for match in re.finditer(increase_pattern, prompt):
270
+ word = match.group(1).strip()
271
+ weight = float(match.group(2))
272
+ attention_weights[word] = weight
273
+ # Remove the weight notation from prompt
274
+ weighted_prompt = weighted_prompt.replace(match.group(0), word)
275
+
276
+ # Find decrease weights
277
+ for match in re.finditer(decrease_pattern, prompt):
278
+ word = match.group(1).strip()
279
+ weight = float(match.group(2))
280
+ attention_weights[word] = weight
281
+ # Remove the weight notation from prompt
282
+ weighted_prompt = weighted_prompt.replace(match.group(0), word)
283
+
284
+ return weighted_prompt.strip(), attention_weights
285
+
286
+ def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str,
287
+ added_words: set, removed_words: set, width: int, height: int):
288
  """Apply word replacement transformations to SVG"""
289
+ # For now, regenerate with target prompt but keep some base structure
290
+ # In a full implementation, this would do more sophisticated editing
291
+
292
+ # Parse the base SVG to understand its structure
293
+ features = self.extract_semantic_features(target_prompt)
294
+
295
+ # Create new SVG with target prompt characteristics
296
  dwg = svgwrite.Drawing(size=(width, height))
297
  dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
298
 
299
+ # Apply changes based on word differences
300
+ if any(word in added_words for word in ['red', 'blue', 'green', 'yellow']):
301
+ # Color change
302
+ self.add_colored_elements(dwg, width, height, added_words)
303
+ elif any(word in added_words for word in ['big', 'large', 'huge']):
304
+ # Size change
305
+ self.add_large_elements(dwg, width, height, features)
306
+ elif any(word in added_words for word in ['small', 'tiny', 'mini']):
307
+ # Size change
308
+ self.add_small_elements(dwg, width, height, features)
309
+ else:
310
+ # General content change
311
+ self.add_content_based_on_prompt(dwg, target_prompt, width, height)
 
 
 
 
 
 
 
 
312
 
313
  return dwg.tostring()
314
 
315
  def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int):
316
  """Apply refinement to existing SVG"""
317
+ # For now, enhance the base SVG with additional details
318
+ features = self.extract_semantic_features(prompt)
319
+
320
  dwg = svgwrite.Drawing(size=(width, height))
321
  dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
322
 
323
+ # Add refined elements based on prompt
324
+ if features.get('detailed', False):
325
+ self.add_detailed_elements(dwg, width, height, features)
 
 
 
 
326
  else:
327
+ self.add_content_based_on_prompt(dwg, prompt, width, height)
 
328
 
329
  return dwg.tostring()
330
 
331
  def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int):
332
+ """Apply attention reweighting to SVG"""
333
  dwg = svgwrite.Drawing(size=(width, height))
334
  dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white'))
335
 
336
+ # Apply different emphasis based on attention weights
337
  for word, weight in attention_weights.items():
338
  if weight > 1.0:
339
  # Emphasize this element
340
+ self.add_emphasized_element(dwg, word, weight, width, height)
341
  elif weight < 1.0:
342
  # De-emphasize this element
343
+ self.add_deemphasized_element(dwg, word, weight, width, height)
344
 
345
+ # Add base content
346
+ self.add_content_based_on_prompt(dwg, prompt, width, height)
347
 
348
  return dwg.tostring()
349
 
350
+ def add_person_elements(self, dwg, width, height, features):
351
+ """Add person-like elements"""
352
+ center_x, center_y = width // 2, height // 2
353
 
354
+ # Head
355
+ head_radius = 20
356
+ dwg.add(dwg.circle(center=(center_x, center_y - 40), r=head_radius, fill='#FDBCB4', stroke='black', stroke_width=2))
357
 
358
+ # Body
359
+ body_height = 60
360
+ body_width = 30
361
+ dwg.add(dwg.rect(
362
+ insert=(center_x - body_width//2, center_y - 10),
363
+ size=(body_width, body_height),
364
+ fill='#4A90E2',
365
+ stroke='black',
366
+ stroke_width=2
367
+ ))
368
+
369
+ # Arms
370
+ dwg.add(dwg.line(start=(center_x - body_width//2, center_y), end=(center_x - 40, center_y + 20), stroke='black', stroke_width=3))
371
+ dwg.add(dwg.line(start=(center_x + body_width//2, center_y), end=(center_x + 40, center_y + 20), stroke='black', stroke_width=3))
372
+
373
+ # Legs
374
+ dwg.add(dwg.line(start=(center_x - 10, center_y + body_height - 10), end=(center_x - 20, center_y + body_height + 30), stroke='black', stroke_width=3))
375
+ dwg.add(dwg.line(start=(center_x + 10, center_y + body_height - 10), end=(center_x + 20, center_y + body_height + 30), stroke='black', stroke_width=3))
376
 
377
+ def add_animal_elements(self, dwg, width, height, features):
378
+ """Add animal-like elements"""
379
+ center_x, center_y = width // 2, height // 2
380
+
381
+ # Body (oval)
382
+ dwg.add(dwg.ellipse(center=(center_x, center_y), r=(40, 25), fill='#8B4513', stroke='black', stroke_width=2))
383
+
384
+ # Head
385
+ dwg.add(dwg.circle(center=(center_x - 30, center_y - 10), r=20, fill='#A0522D', stroke='black', stroke_width=2))
386
+
387
+ # Legs
388
+ for i, x_offset in enumerate([-20, -10, 10, 20]):
389
+ dwg.add(dwg.line(
390
+ start=(center_x + x_offset, center_y + 25),
391
+ end=(center_x + x_offset, center_y + 45),
392
+ stroke='black',
393
+ stroke_width=3
394
+ ))
395
+
396
+ # Tail
397
+ dwg.add(dwg.path(
398
+ d=f"M {center_x + 40},{center_y} Q {center_x + 60},{center_y - 20} {center_x + 50},{center_y - 35}",
399
+ stroke='black',
400
+ stroke_width=3,
401
+ fill='none'
402
+ ))
403
+
404
+ def add_building_elements(self, dwg, width, height, features):
405
+ """Add building-like elements"""
406
+ # Main building
407
+ building_width = width * 0.6
408
+ building_height = height * 0.7
409
+ x = (width - building_width) // 2
410
+ y = height - building_height - 10
411
+
412
  dwg.add(dwg.rect(
413
+ insert=(x, y),
414
+ size=(building_width, building_height),
415
+ fill='#CD853F',
416
  stroke='black',
417
  stroke_width=2
418
  ))
419
 
420
  # Roof
421
+ roof_points = [(x, y), (x + building_width//2, y - 30), (x + building_width, y)]
422
+ dwg.add(dwg.polygon(points=roof_points, fill='#8B0000', stroke='black', stroke_width=2))
423
+
424
+ # Windows
425
+ window_size = 15
426
+ for i in range(3):
427
+ for j in range(4):
428
+ wx = x + 15 + i * 30
429
+ wy = y + 15 + j * 25
430
+ if wy < y + building_height - 20:
431
+ dwg.add(dwg.rect(
432
+ insert=(wx, wy),
433
+ size=(window_size, window_size),
434
+ fill='#87CEEB',
435
+ stroke='black',
436
+ stroke_width=1
437
+ ))
438
 
439
  # Door
440
+ door_width = 20
441
+ door_height = 40
442
+ door_x = x + building_width//2 - door_width//2
443
+ door_y = y + building_height - door_height
 
444
  dwg.add(dwg.rect(
445
  insert=(door_x, door_y),
446
  size=(door_width, door_height),
447
+ fill='#8B4513',
448
  stroke='black',
449
  stroke_width=2
450
  ))
451
 
452
+ def add_nature_elements(self, dwg, width, height, features):
453
+ """Add nature-like elements"""
454
+ # Tree
455
+ center_x, center_y = width // 2, height // 2
456
 
457
  # Trunk
458
+ trunk_width = 15
459
+ trunk_height = height // 3
460
+ trunk_x = center_x - trunk_width // 2
461
+ trunk_y = height - trunk_height - 10
462
+
463
  dwg.add(dwg.rect(
464
+ insert=(trunk_x, trunk_y),
465
  size=(trunk_width, trunk_height),
466
+ fill='#8B4513',
467
  stroke='black',
468
+ stroke_width=1
469
  ))
470
 
471
+ # Crown (multiple circles for foliage)
472
+ crown_radius = 30
473
+ for i, (dx, dy) in enumerate([(-15, -20), (15, -20), (0, -35), (-10, -50), (10, -50)]):
474
+ dwg.add(dwg.circle(
475
+ center=(center_x + dx, center_y + dy),
476
+ r=crown_radius - i * 3,
477
+ fill='#228B22',
478
+ stroke='#006400',
479
+ stroke_width=1,
480
+ opacity=0.8
481
+ ))
482
 
483
+ def add_vehicle_elements(self, dwg, width, height, features):
484
+ """Add vehicle-like elements"""
485
+ center_x, center_y = width // 2, height // 2
 
 
 
486
 
487
  # Car body
488
+ car_width = width * 0.6
489
+ car_height = height * 0.3
490
+ car_x = (width - car_width) // 2
491
+ car_y = center_y + 10
492
+
493
  dwg.add(dwg.rect(
494
  insert=(car_x, car_y),
495
  size=(car_width, car_height),
496
+ fill='#FF4500',
497
  stroke='black',
498
  stroke_width=2,
499
  rx=5
500
  ))
501
 
502
+ # Windshield
503
+ windshield_width = car_width * 0.6
504
+ windshield_height = car_height * 0.4
505
+ windshield_x = car_x + (car_width - windshield_width) // 2
506
+ windshield_y = car_y - windshield_height + 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
+ dwg.add(dwg.rect(
509
+ insert=(windshield_x, windshield_y),
510
+ size=(windshield_width, windshield_height),
511
+ fill='#87CEEB',
 
512
  stroke='black',
513
+ stroke_width=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  ))
515
 
516
+ # Wheels
517
+ wheel_radius = 12
518
+ wheel_y = car_y + car_height - 5
519
+ dwg.add(dwg.circle(center=(car_x + 25, wheel_y), r=wheel_radius, fill='black'))
520
+ dwg.add(dwg.circle(center=(car_x + car_width - 25, wheel_y), r=wheel_radius, fill='black'))
 
 
 
521
 
522
+ def add_abstract_elements(self, dwg, width, height, features):
523
+ """Add abstract elements"""
524
+ colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
+ for i in range(5):
527
+ shape_type = random.choice(['circle', 'rect', 'path'])
528
+ color = random.choice(colors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
+ if shape_type == 'circle':
531
+ radius = random.randint(10, 30)
532
+ x = random.randint(radius, width - radius)
533
+ y = random.randint(radius, height - radius)
534
+ dwg.add(dwg.circle(center=(x, y), r=radius, fill=color, opacity=0.7))
535
+ elif shape_type == 'rect':
536
+ w = random.randint(20, 60)
537
+ h = random.randint(20, 60)
538
+ x = random.randint(0, width - w)
539
+ y = random.randint(0, height - h)
540
+ dwg.add(dwg.rect(insert=(x, y), size=(w, h), fill=color, opacity=0.7))
541
+ else:
542
+ # Random path
543
+ start_x = random.randint(0, width)
544
+ start_y = random.randint(0, height)
545
+ end_x = random.randint(0, width)
546
+ end_y = random.randint(0, height)
547
+ dwg.add(dwg.line(start=(start_x, start_y), end=(end_x, end_y), stroke=color, stroke_width=3))
548
 
549
+ def add_colored_elements(self, dwg, width, height, color_words):
550
+ """Add elements with specific colors"""
551
  color_map = {
552
  'red': '#FF0000',
553
  'blue': '#0000FF',
554
  'green': '#00FF00',
555
  'yellow': '#FFFF00',
556
+ 'purple': '#800080',
557
+ 'orange': '#FFA500'
558
  }
559
 
560
+ center_x, center_y = width // 2, height // 2
561
 
562
+ for word in color_words:
563
+ if word in color_map:
564
+ color = color_map[word]
565
+ # Add a colored shape
566
+ dwg.add(dwg.circle(
567
+ center=(center_x + random.randint(-50, 50), center_y + random.randint(-50, 50)),
568
+ r=random.randint(15, 35),
569
+ fill=color,
570
+ opacity=0.8
571
+ ))
572
+
573
+ def add_large_elements(self, dwg, width, height, features):
574
+ """Add large-sized elements"""
575
+ center_x, center_y = width // 2, height // 2
576
+
577
+ # Large central element
578
  dwg.add(dwg.circle(
579
+ center=(center_x, center_y),
580
+ r=min(width, height) // 3,
581
+ fill='#4A90E2',
582
  stroke='black',
583
+ stroke_width=3
584
  ))
585
 
586
+ def add_small_elements(self, dwg, width, height, features):
587
+ """Add small-sized elements"""
588
+ # Multiple small elements
589
+ for i in range(8):
590
+ x = random.randint(10, width - 10)
591
+ y = random.randint(10, height - 10)
592
+ dwg.add(dwg.circle(
593
+ center=(x, y),
594
+ r=random.randint(3, 8),
595
+ fill='#E74C3C',
596
+ opacity=0.7
 
 
 
 
 
 
 
 
 
 
597
  ))
598
 
599
+ def add_detailed_elements(self, dwg, width, height, features):
600
+ """Add detailed elements for refinement"""
601
+ # Add more complex shapes and details
602
+ self.add_abstract_elements(dwg, width, height, features)
603
+
604
+ # Add decorative elements
605
+ center_x, center_y = width // 2, height // 2
606
+ for i in range(4):
607
+ angle = i * math.pi / 2
608
+ x = center_x + 40 * math.cos(angle)
609
+ y = center_y + 40 * math.sin(angle)
610
+ dwg.add(dwg.circle(center=(x, y), r=8, fill='#9B59B6', opacity=0.6))
611
 
612
+ def add_emphasized_element(self, dwg, word: str, weight: float, width: int, height: int):
613
+ """Add emphasized element based on attention weight"""
614
+ center_x, center_y = width // 2, height // 2
615
+
616
+ # Scale size based on weight
617
+ base_size = 20
618
+ size = int(base_size * weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
 
620
  dwg.add(dwg.circle(
621
+ center=(center_x + random.randint(-30, 30), center_y + random.randint(-30, 30)),
622
+ r=size,
623
+ fill='#FF6B6B',
624
+ opacity=min(1.0, weight / 2),
625
  stroke='black',
626
  stroke_width=2
627
  ))
628
 
629
+ def add_deemphasized_element(self, dwg, word: str, weight: float, width: int, height: int):
630
+ """Add de-emphasized element based on attention weight"""
631
+ center_x, center_y = width // 2, height // 2
632
 
633
+ # Scale size based on weight
634
+ base_size = 15
635
+ size = int(base_size * weight)
 
 
 
 
 
 
 
 
 
636
 
637
+ dwg.add(dwg.circle(
638
+ center=(center_x + random.randint(-40, 40), center_y + random.randint(-40, 40)),
639
+ r=max(3, size),
640
+ fill='#CCCCCC',
641
+ opacity=weight,
642
+ stroke='gray',
643
+ stroke_width=1
644
+ ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
 
646
+ def add_content_based_on_prompt(self, dwg, prompt: str, width: int, height: int):
647
+ """Add content based on prompt analysis"""
648
+ features = self.extract_semantic_features(prompt)
649
+
650
+ if any(word in prompt.lower() for word in ['person', 'people', 'human']):
651
+ self.add_person_elements(dwg, width, height, features)
652
+ elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog']):
653
+ self.add_animal_elements(dwg, width, height, features)
654
+ elif any(word in prompt.lower() for word in ['house', 'building']):
655
+ self.add_building_elements(dwg, width, height, features)
656
+ elif any(word in prompt.lower() for word in ['tree', 'nature']):
657
+ self.add_nature_elements(dwg, width, height, features)
658
+ elif any(word in prompt.lower() for word in ['car', 'vehicle']):
659
+ self.add_vehicle_elements(dwg, width, height, features)
660
+ else:
661
+ self.add_abstract_elements(dwg, width, height, features)
662
 
663
+ def extract_semantic_features(self, prompt: str):
664
+ """Extract semantic features from prompt"""
665
+ features = {
666
+ 'detailed': False,
667
+ 'simple': False,
668
+ 'colorful': False,
669
+ 'large': False,
670
+ 'small': False
 
671
  }
672
+
673
+ prompt_lower = prompt.lower()
674
+
675
+ if any(word in prompt_lower for word in ['detailed', 'complex', 'intricate']):
676
+ features['detailed'] = True
677
+ if any(word in prompt_lower for word in ['simple', 'minimal', 'basic']):
678
+ features['simple'] = True
679
+ if any(word in prompt_lower for word in ['colorful', 'bright', 'vibrant']):
680
+ features['colorful'] = True
681
+ if any(word in prompt_lower for word in ['large', 'big', 'huge']):
682
+ features['large'] = True
683
+ if any(word in prompt_lower for word in ['small', 'tiny', 'mini']):
684
+ features['small'] = True
685
+
686
+ return features
687
 
688
+ def svg_to_pil_image(self, svg_content: str, width: int, height: int):
689
  """Convert SVG content to PIL Image"""
690
  try:
691
  import cairosvg
 
692
 
693
  # Convert SVG to PNG bytes
694
  png_bytes = cairosvg.svg2png(
 
719
 
720
  # Simple centered text
721
  dwg.add(dwg.text(
722
+ f"DiffSketchEdit\n{prompt[:30]}...",
723
  insert=(width/2, height/2),
724
  text_anchor="middle",
725
+ font_size="12px",
726
  fill="black"
727
  ))
728