import torch import torch.nn.functional as F import numpy as np import json import base64 import io from PIL import Image import svgwrite from typing import Dict, Any, List, Optional, Union import diffusers from diffusers import StableDiffusionPipeline, DDIMScheduler from transformers import CLIPTextModel, CLIPTokenizer import torchvision.transforms as transforms import random import math import re class EndpointHandler: def __init__(self, path=""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model_id = "runwayml/stable-diffusion-v1-5" try: # Initialize the diffusion pipeline self.pipe = StableDiffusionPipeline.from_pretrained( self.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, safety_checker=None, requires_safety_checker=False ).to(self.device) # Use DDIM scheduler for better control self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) # CLIP model for guidance self.clip_model = self.pipe.text_encoder self.clip_tokenizer = self.pipe.tokenizer print("DiffSketchEdit handler initialized successfully!") except Exception as e: print(f"Warning: Could not load diffusion model: {e}") self.pipe = None self.clip_model = None self.clip_tokenizer = None def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image: """ Perform sketch editing using DiffSketchEdit approach """ try: # Parse inputs if isinstance(inputs, str): # Check if it's a JSON string try: parsed_inputs = json.loads(inputs) if isinstance(parsed_inputs, dict): inputs = parsed_inputs else: # Simple prompt - treat as generation prompts = [inputs] edit_type = "generate" parameters = {} except: # Simple prompt - treat as generation prompts = [inputs] edit_type = "generate" parameters = {} if isinstance(inputs, dict): input_data = inputs.get("inputs", inputs) if isinstance(input_data, str): prompts = [input_data] edit_type = "generate" elif isinstance(input_data, dict): prompts = input_data.get("prompts", [input_data.get("prompt", "a simple sketch")]) edit_type = input_data.get("edit_type", "generate") else: prompts = ["a simple sketch"] edit_type = "generate" parameters = inputs.get("parameters", {}) # Extract parameters with defaults width = parameters.get("width", 224) height = parameters.get("height", 224) seed = parameters.get("seed", None) input_svg = parameters.get("input_svg", None) if seed is not None: torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) print(f"Processing edit type: '{edit_type}' with prompts: {prompts}") # Process based on edit type if edit_type == "replace" and len(prompts) >= 2: svg_content, metadata = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg) elif edit_type == "refine": svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg) elif edit_type == "reweight": svg_content, metadata = self.attention_reweighting_edit(prompts[0], width, height, input_svg) elif edit_type == "generate": svg_content, metadata = self.simple_generation(prompts[0], width, height) else: # Default to refinement svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg) # Convert SVG to PIL Image for HF API compatibility pil_image = self.svg_to_pil_image(svg_content, width, height) # Store metadata pil_image.info['svg_content'] = svg_content for key, value in metadata.items(): if isinstance(value, (dict, list)): pil_image.info[key] = json.dumps(value) else: pil_image.info[key] = str(value) return pil_image except Exception as e: print(f"Error in handler: {e}") # Return fallback image fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height) fallback_image = self.svg_to_pil_image(fallback_svg, width, height) fallback_image.info['error'] = str(e) fallback_image.info['edit_type'] = edit_type return fallback_image def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None): """Perform word replacement editing""" try: print(f"Word replacement: '{source_prompt}' -> '{target_prompt}'") # Analyze word differences added_words, removed_words = self.analyze_word_differences(source_prompt, target_prompt) print(f"Added words: {added_words}, Removed words: {removed_words}") # Generate or use base SVG if input_svg: base_svg = input_svg else: base_svg = self.generate_base_svg(source_prompt, width, height) # Apply word replacement transformations edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height) metadata = { "edit_type": "replace", "source_prompt": source_prompt, "target_prompt": target_prompt, "added_words": list(added_words), "removed_words": list(removed_words) } return edited_svg, metadata except Exception as e: print(f"Error in word_replacement_edit: {e}") fallback_svg = self.create_fallback_svg(source_prompt, width, height) metadata = {"edit_type": "replace", "error": str(e)} return fallback_svg, metadata def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None): """Perform prompt refinement editing""" try: print(f"Prompt refinement for: '{prompt}'") # Generate or use base SVG if input_svg: base_svg = input_svg else: base_svg = self.generate_base_svg(prompt, width, height) # Apply refinement based on prompt analysis refined_svg = self.apply_refinement(base_svg, prompt, width, height) metadata = { "edit_type": "refine", "prompt": prompt } return refined_svg, metadata except Exception as e: print(f"Error in prompt_refinement_edit: {e}") fallback_svg = self.create_fallback_svg(prompt, width, height) metadata = {"edit_type": "refine", "error": str(e)} return fallback_svg, metadata def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None): """Perform attention reweighting editing""" try: print(f"Attention reweighting for: '{prompt}'") # Parse attention weights from prompt (e.g., "(cat:1.5)" or "[table:0.5]") weighted_prompt, attention_weights = self.parse_attention_weights(prompt) print(f"Weighted prompt: '{weighted_prompt}', weights: {attention_weights}") # Generate or use base SVG if input_svg: base_svg = input_svg else: base_svg = self.generate_base_svg(weighted_prompt, width, height) # Apply attention reweighting reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height) metadata = { "edit_type": "reweight", "prompt": prompt, "weighted_prompt": weighted_prompt, "attention_weights": attention_weights } return reweighted_svg, metadata except Exception as e: print(f"Error in attention_reweighting_edit: {e}") fallback_svg = self.create_fallback_svg(prompt, width, height) metadata = {"edit_type": "reweight", "error": str(e)} return fallback_svg, metadata def simple_generation(self, prompt: str, width: int, height: int): """Perform simple SVG generation""" try: print(f"Simple generation for: '{prompt}'") svg_content = self.generate_base_svg(prompt, width, height) metadata = { "edit_type": "generate", "prompt": prompt } return svg_content, metadata except Exception as e: print(f"Error in simple_generation: {e}") fallback_svg = self.create_fallback_svg(prompt, width, height) metadata = {"edit_type": "generate", "error": str(e)} return fallback_svg, metadata def generate_base_svg(self, prompt: str, width: int, height: int): """Generate base SVG from prompt""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Extract semantic features features = self.extract_semantic_features(prompt) # Generate content based on prompt if any(word in prompt.lower() for word in ['person', 'people', 'human', 'man', 'woman']): self.add_person_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog', 'bird', 'horse']): self.add_animal_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['house', 'building', 'architecture']): self.add_building_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['tree', 'nature', 'landscape']): self.add_nature_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['car', 'vehicle', 'transport']): self.add_vehicle_elements(dwg, width, height, features) else: self.add_abstract_elements(dwg, width, height, features) return dwg.tostring() def analyze_word_differences(self, source: str, target: str): """Analyze differences between source and target prompts""" source_words = set(source.lower().split()) target_words = set(target.lower().split()) added_words = target_words - source_words removed_words = source_words - target_words return added_words, removed_words def parse_attention_weights(self, prompt: str): """Parse attention weights from prompt""" # Pattern for (word:weight) - increase attention increase_pattern = r'\(([^:]+):([0-9.]+)\)' # Pattern for [word:weight] - decrease attention decrease_pattern = r'\[([^:]+):([0-9.]+)\]' attention_weights = {} weighted_prompt = prompt # Find increase weights for match in re.finditer(increase_pattern, prompt): word = match.group(1).strip() weight = float(match.group(2)) attention_weights[word] = weight # Remove the weight notation from prompt weighted_prompt = weighted_prompt.replace(match.group(0), word) # Find decrease weights for match in re.finditer(decrease_pattern, prompt): word = match.group(1).strip() weight = float(match.group(2)) attention_weights[word] = weight # Remove the weight notation from prompt weighted_prompt = weighted_prompt.replace(match.group(0), word) return weighted_prompt.strip(), attention_weights def apply_word_replacement(self, base_svg: str, source_prompt: str, target_prompt: str, added_words: set, removed_words: set, width: int, height: int): """Apply word replacement transformations to SVG""" # For now, regenerate with target prompt but keep some base structure # In a full implementation, this would do more sophisticated editing # Parse the base SVG to understand its structure features = self.extract_semantic_features(target_prompt) # Create new SVG with target prompt characteristics dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Apply changes based on word differences if any(word in added_words for word in ['red', 'blue', 'green', 'yellow']): # Color change self.add_colored_elements(dwg, width, height, added_words) elif any(word in added_words for word in ['big', 'large', 'huge']): # Size change self.add_large_elements(dwg, width, height, features) elif any(word in added_words for word in ['small', 'tiny', 'mini']): # Size change self.add_small_elements(dwg, width, height, features) else: # General content change self.add_content_based_on_prompt(dwg, target_prompt, width, height) return dwg.tostring() def apply_refinement(self, base_svg: str, prompt: str, width: int, height: int): """Apply refinement to existing SVG""" # For now, enhance the base SVG with additional details features = self.extract_semantic_features(prompt) dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Add refined elements based on prompt if features.get('detailed', False): self.add_detailed_elements(dwg, width, height, features) else: self.add_content_based_on_prompt(dwg, prompt, width, height) return dwg.tostring() def apply_attention_reweighting(self, base_svg: str, prompt: str, attention_weights: dict, width: int, height: int): """Apply attention reweighting to SVG""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Apply different emphasis based on attention weights for word, weight in attention_weights.items(): if weight > 1.0: # Emphasize this element self.add_emphasized_element(dwg, word, weight, width, height) elif weight < 1.0: # De-emphasize this element self.add_deemphasized_element(dwg, word, weight, width, height) # Add base content self.add_content_based_on_prompt(dwg, prompt, width, height) return dwg.tostring() def add_person_elements(self, dwg, width, height, features): """Add person-like elements""" center_x, center_y = width // 2, height // 2 # Head head_radius = 20 dwg.add(dwg.circle(center=(center_x, center_y - 40), r=head_radius, fill='#FDBCB4', stroke='black', stroke_width=2)) # Body body_height = 60 body_width = 30 dwg.add(dwg.rect( insert=(center_x - body_width//2, center_y - 10), size=(body_width, body_height), fill='#4A90E2', stroke='black', stroke_width=2 )) # Arms dwg.add(dwg.line(start=(center_x - body_width//2, center_y), end=(center_x - 40, center_y + 20), stroke='black', stroke_width=3)) dwg.add(dwg.line(start=(center_x + body_width//2, center_y), end=(center_x + 40, center_y + 20), stroke='black', stroke_width=3)) # Legs dwg.add(dwg.line(start=(center_x - 10, center_y + body_height - 10), end=(center_x - 20, center_y + body_height + 30), stroke='black', stroke_width=3)) dwg.add(dwg.line(start=(center_x + 10, center_y + body_height - 10), end=(center_x + 20, center_y + body_height + 30), stroke='black', stroke_width=3)) def add_animal_elements(self, dwg, width, height, features): """Add animal-like elements""" center_x, center_y = width // 2, height // 2 # Body (oval) dwg.add(dwg.ellipse(center=(center_x, center_y), r=(40, 25), fill='#8B4513', stroke='black', stroke_width=2)) # Head dwg.add(dwg.circle(center=(center_x - 30, center_y - 10), r=20, fill='#A0522D', stroke='black', stroke_width=2)) # Legs for i, x_offset in enumerate([-20, -10, 10, 20]): dwg.add(dwg.line( start=(center_x + x_offset, center_y + 25), end=(center_x + x_offset, center_y + 45), stroke='black', stroke_width=3 )) # Tail dwg.add(dwg.path( d=f"M {center_x + 40},{center_y} Q {center_x + 60},{center_y - 20} {center_x + 50},{center_y - 35}", stroke='black', stroke_width=3, fill='none' )) def add_building_elements(self, dwg, width, height, features): """Add building-like elements""" # Main building building_width = width * 0.6 building_height = height * 0.7 x = (width - building_width) // 2 y = height - building_height - 10 dwg.add(dwg.rect( insert=(x, y), size=(building_width, building_height), fill='#CD853F', stroke='black', stroke_width=2 )) # Roof roof_points = [(x, y), (x + building_width//2, y - 30), (x + building_width, y)] dwg.add(dwg.polygon(points=roof_points, fill='#8B0000', stroke='black', stroke_width=2)) # Windows window_size = 15 for i in range(3): for j in range(4): wx = x + 15 + i * 30 wy = y + 15 + j * 25 if wy < y + building_height - 20: dwg.add(dwg.rect( insert=(wx, wy), size=(window_size, window_size), fill='#87CEEB', stroke='black', stroke_width=1 )) # Door door_width = 20 door_height = 40 door_x = x + building_width//2 - door_width//2 door_y = y + building_height - door_height dwg.add(dwg.rect( insert=(door_x, door_y), size=(door_width, door_height), fill='#8B4513', stroke='black', stroke_width=2 )) def add_nature_elements(self, dwg, width, height, features): """Add nature-like elements""" # Tree center_x, center_y = width // 2, height // 2 # Trunk trunk_width = 15 trunk_height = height // 3 trunk_x = center_x - trunk_width // 2 trunk_y = height - trunk_height - 10 dwg.add(dwg.rect( insert=(trunk_x, trunk_y), size=(trunk_width, trunk_height), fill='#8B4513', stroke='black', stroke_width=1 )) # Crown (multiple circles for foliage) crown_radius = 30 for i, (dx, dy) in enumerate([(-15, -20), (15, -20), (0, -35), (-10, -50), (10, -50)]): dwg.add(dwg.circle( center=(center_x + dx, center_y + dy), r=crown_radius - i * 3, fill='#228B22', stroke='#006400', stroke_width=1, opacity=0.8 )) def add_vehicle_elements(self, dwg, width, height, features): """Add vehicle-like elements""" center_x, center_y = width // 2, height // 2 # Car body car_width = width * 0.6 car_height = height * 0.3 car_x = (width - car_width) // 2 car_y = center_y + 10 dwg.add(dwg.rect( insert=(car_x, car_y), size=(car_width, car_height), fill='#FF4500', stroke='black', stroke_width=2, rx=5 )) # Windshield windshield_width = car_width * 0.6 windshield_height = car_height * 0.4 windshield_x = car_x + (car_width - windshield_width) // 2 windshield_y = car_y - windshield_height + 5 dwg.add(dwg.rect( insert=(windshield_x, windshield_y), size=(windshield_width, windshield_height), fill='#87CEEB', stroke='black', stroke_width=1 )) # Wheels wheel_radius = 12 wheel_y = car_y + car_height - 5 dwg.add(dwg.circle(center=(car_x + 25, wheel_y), r=wheel_radius, fill='black')) dwg.add(dwg.circle(center=(car_x + car_width - 25, wheel_y), r=wheel_radius, fill='black')) def add_abstract_elements(self, dwg, width, height, features): """Add abstract elements""" colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7'] for i in range(5): shape_type = random.choice(['circle', 'rect', 'path']) color = random.choice(colors) if shape_type == 'circle': radius = random.randint(10, 30) x = random.randint(radius, width - radius) y = random.randint(radius, height - radius) dwg.add(dwg.circle(center=(x, y), r=radius, fill=color, opacity=0.7)) elif shape_type == 'rect': w = random.randint(20, 60) h = random.randint(20, 60) x = random.randint(0, width - w) y = random.randint(0, height - h) dwg.add(dwg.rect(insert=(x, y), size=(w, h), fill=color, opacity=0.7)) else: # Random path start_x = random.randint(0, width) start_y = random.randint(0, height) end_x = random.randint(0, width) end_y = random.randint(0, height) dwg.add(dwg.line(start=(start_x, start_y), end=(end_x, end_y), stroke=color, stroke_width=3)) def add_colored_elements(self, dwg, width, height, color_words): """Add elements with specific colors""" color_map = { 'red': '#FF0000', 'blue': '#0000FF', 'green': '#00FF00', 'yellow': '#FFFF00', 'purple': '#800080', 'orange': '#FFA500' } center_x, center_y = width // 2, height // 2 for word in color_words: if word in color_map: color = color_map[word] # Add a colored shape dwg.add(dwg.circle( center=(center_x + random.randint(-50, 50), center_y + random.randint(-50, 50)), r=random.randint(15, 35), fill=color, opacity=0.8 )) def add_large_elements(self, dwg, width, height, features): """Add large-sized elements""" center_x, center_y = width // 2, height // 2 # Large central element dwg.add(dwg.circle( center=(center_x, center_y), r=min(width, height) // 3, fill='#4A90E2', stroke='black', stroke_width=3 )) def add_small_elements(self, dwg, width, height, features): """Add small-sized elements""" # Multiple small elements for i in range(8): x = random.randint(10, width - 10) y = random.randint(10, height - 10) dwg.add(dwg.circle( center=(x, y), r=random.randint(3, 8), fill='#E74C3C', opacity=0.7 )) def add_detailed_elements(self, dwg, width, height, features): """Add detailed elements for refinement""" # Add more complex shapes and details self.add_abstract_elements(dwg, width, height, features) # Add decorative elements center_x, center_y = width // 2, height // 2 for i in range(4): angle = i * math.pi / 2 x = center_x + 40 * math.cos(angle) y = center_y + 40 * math.sin(angle) dwg.add(dwg.circle(center=(x, y), r=8, fill='#9B59B6', opacity=0.6)) def add_emphasized_element(self, dwg, word: str, weight: float, width: int, height: int): """Add emphasized element based on attention weight""" center_x, center_y = width // 2, height // 2 # Scale size based on weight base_size = 20 size = int(base_size * weight) dwg.add(dwg.circle( center=(center_x + random.randint(-30, 30), center_y + random.randint(-30, 30)), r=size, fill='#FF6B6B', opacity=min(1.0, weight / 2), stroke='black', stroke_width=2 )) def add_deemphasized_element(self, dwg, word: str, weight: float, width: int, height: int): """Add de-emphasized element based on attention weight""" center_x, center_y = width // 2, height // 2 # Scale size based on weight base_size = 15 size = int(base_size * weight) dwg.add(dwg.circle( center=(center_x + random.randint(-40, 40), center_y + random.randint(-40, 40)), r=max(3, size), fill='#CCCCCC', opacity=weight, stroke='gray', stroke_width=1 )) def add_content_based_on_prompt(self, dwg, prompt: str, width: int, height: int): """Add content based on prompt analysis""" features = self.extract_semantic_features(prompt) if any(word in prompt.lower() for word in ['person', 'people', 'human']): self.add_person_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['animal', 'cat', 'dog']): self.add_animal_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['house', 'building']): self.add_building_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['tree', 'nature']): self.add_nature_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['car', 'vehicle']): self.add_vehicle_elements(dwg, width, height, features) else: self.add_abstract_elements(dwg, width, height, features) def extract_semantic_features(self, prompt: str): """Extract semantic features from prompt""" features = { 'detailed': False, 'simple': False, 'colorful': False, 'large': False, 'small': False } prompt_lower = prompt.lower() if any(word in prompt_lower for word in ['detailed', 'complex', 'intricate']): features['detailed'] = True if any(word in prompt_lower for word in ['simple', 'minimal', 'basic']): features['simple'] = True if any(word in prompt_lower for word in ['colorful', 'bright', 'vibrant']): features['colorful'] = True if any(word in prompt_lower for word in ['large', 'big', 'huge']): features['large'] = True if any(word in prompt_lower for word in ['small', 'tiny', 'mini']): features['small'] = True return features def svg_to_pil_image(self, svg_content: str, width: int, height: int): """Convert SVG content to PIL Image""" try: import cairosvg # Convert SVG to PNG bytes png_bytes = cairosvg.svg2png( bytestring=svg_content.encode('utf-8'), output_width=width, output_height=height ) # Convert to PIL Image image = Image.open(io.BytesIO(png_bytes)).convert('RGB') return image except ImportError: print("cairosvg not available, creating simple image representation") # Fallback: create a simple image with text image = Image.new('RGB', (width, height), 'white') return image except Exception as e: print(f"Error converting SVG to image: {e}") # Fallback: create a simple image image = Image.new('RGB', (width, height), 'white') return image def create_fallback_svg(self, prompt: str, width: int, height: int): """Create simple fallback SVG""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Simple centered text prompt_str = str(prompt)[:30] if prompt else "error" dwg.add(dwg.text( f"DiffSketchEdit\n{prompt_str}...", insert=(width/2, height/2), text_anchor="middle", font_size="12px", fill="black" )) return dwg.tostring()