import torch import torch.nn.functional as F import numpy as np import json import base64 import io from PIL import Image import svgwrite from typing import Dict, Any, List, Optional, Union import diffusers from diffusers import StableDiffusionPipeline, DDIMScheduler from transformers import CLIPTextModel, CLIPTokenizer import torchvision.transforms as transforms import random import math class EndpointHandler: def __init__(self, path=""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model_id = "runwayml/stable-diffusion-v1-5" try: # Initialize the diffusion pipeline self.pipe = StableDiffusionPipeline.from_pretrained( self.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, safety_checker=None, requires_safety_checker=False ).to(self.device) # Use DDIM scheduler for better control self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) # CLIP model for guidance self.clip_model = self.pipe.text_encoder self.clip_tokenizer = self.pipe.tokenizer print("SVGDreamer handler initialized successfully!") except Exception as e: print(f"Warning: Could not load diffusion model: {e}") self.pipe = None self.clip_model = None self.clip_tokenizer = None def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image: """ Generate SVG using SVGDreamer approach with multiple particles """ try: # Parse inputs if isinstance(inputs, str): prompt = inputs parameters = {} else: prompt = inputs.get("inputs", inputs.get("prompt", "a simple icon")) parameters = inputs.get("parameters", {}) # Extract parameters with defaults n_particle = parameters.get("n_particle", 4) style = parameters.get("style", "iconography") # iconography, pixel_art, sketch, painting width = parameters.get("width", 256) height = parameters.get("height", 256) seed = parameters.get("seed", None) if seed is not None: torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) print(f"Generating SVGDreamer for: '{prompt}' with {n_particle} particles, style: {style}") # Generate multiple particles using SVGDreamer approach particles = self.generate_svgdreamer_particles( prompt, width, height, n_particle, style ) # Select best particle or combine them best_particle = self.select_best_particle(particles, prompt) # Convert SVG to PIL Image pil_image = self.svg_to_pil_image(best_particle['svg'], width, height) # Store metadata in image pil_image.info['svg_content'] = best_particle['svg'] pil_image.info['prompt'] = prompt pil_image.info['style'] = style pil_image.info['n_particle'] = str(n_particle) pil_image.info['particles'] = json.dumps(particles) pil_image.info['method'] = 'svgdreamer' return pil_image except Exception as e: print(f"Error in SVGDreamer handler: {e}") # Return fallback image fallback_svg = self.create_fallback_svg(prompt if 'prompt' in locals() else "error", 256, 256, "iconography") fallback_image = self.svg_to_pil_image(fallback_svg, 256, 256) fallback_image.info['error'] = str(e) return fallback_image def generate_svgdreamer_particles(self, prompt: str, width: int, height: int, n_particle: int, style: str): """ Generate multiple SVG particles using SVGDreamer approach """ particles = [] # Get text embeddings for guidance text_embeddings = self.get_text_embeddings(prompt) # Generate multiple particles with different initializations for particle_id in range(n_particle): print(f"Generating particle {particle_id + 1}/{n_particle}") # Set different seed for each particle particle_seed = hash(f"{prompt}_{particle_id}_{style}") % 1000000 torch.manual_seed(particle_seed) np.random.seed(particle_seed) random.seed(particle_seed) # Generate particle based on style svg_content = self.generate_particle_by_style( prompt, width, height, style, text_embeddings, particle_id ) particle = { 'particle_id': particle_id, 'svg': svg_content, 'svg_base64': base64.b64encode(svg_content.encode('utf-8')).decode('utf-8'), 'prompt': prompt, 'style': style, 'parameters': { 'width': width, 'height': height, 'seed': particle_seed } } particles.append(particle) return particles def generate_particle_by_style(self, prompt: str, width: int, height: int, style: str, text_embeddings: torch.Tensor, particle_id: int): """ Generate SVG particle based on specified style """ if style == "iconography": return self.generate_iconography_svg(prompt, width, height, text_embeddings) elif style == "pixel_art": return self.generate_pixel_art_svg(prompt, width, height, text_embeddings) elif style == "sketch": return self.generate_sketch_svg(prompt, width, height, text_embeddings) elif style == "painting": return self.generate_painting_svg(prompt, width, height, text_embeddings) else: return self.generate_iconography_svg(prompt, width, height, text_embeddings) def generate_iconography_svg(self, prompt: str, width: int, height: int, text_embeddings: torch.Tensor): """Generate icon-style SVG with simple geometric shapes""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Extract semantic features for icon design features = self.extract_semantic_features(prompt) # Generate icon elements based on prompt if any(word in prompt.lower() for word in ['animal', 'cat', 'dog', 'bird', 'lion']): self.add_animal_icon_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['house', 'building', 'home', 'castle']): self.add_building_icon_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['tree', 'flower', 'plant', 'nature']): self.add_nature_icon_elements(dwg, width, height, features) elif any(word in prompt.lower() for word in ['car', 'vehicle', 'transport']): self.add_vehicle_icon_elements(dwg, width, height, features) else: self.add_abstract_icon_elements(dwg, width, height, features) return dwg.tostring() def generate_pixel_art_svg(self, prompt: str, width: int, height: int, text_embeddings: torch.Tensor): """Generate pixel art style SVG""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Pixel art uses small squares pixel_size = 8 cols = width // pixel_size rows = height // pixel_size # Generate pixel pattern based on prompt features = self.extract_semantic_features(prompt) colors = self.get_style_colors("pixel_art", features) for row in range(rows): for col in range(cols): # Create pattern based on position and prompt if self.should_place_pixel(row, col, rows, cols, prompt, features): color = random.choice(colors) x = col * pixel_size y = row * pixel_size dwg.add(dwg.rect( insert=(x, y), size=(pixel_size, pixel_size), fill=color, stroke='none' )) return dwg.tostring() def generate_sketch_svg(self, prompt: str, width: int, height: int, text_embeddings: torch.Tensor): """Generate sketch-style SVG with loose strokes""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) features = self.extract_semantic_features(prompt) # Generate sketch strokes num_strokes = random.randint(15, 40) for i in range(num_strokes): # Create loose, sketchy paths path_data = self.generate_sketchy_path(width, height, features) stroke_color = f"rgb({random.randint(20, 80)},{random.randint(20, 80)},{random.randint(20, 80)})" stroke_width = random.uniform(0.5, 2.5) opacity = random.uniform(0.3, 0.8) dwg.add(dwg.path( d=path_data, stroke=stroke_color, stroke_width=stroke_width, stroke_opacity=opacity, fill='none', stroke_linecap='round', stroke_linejoin='round' )) return dwg.tostring() def generate_painting_svg(self, prompt: str, width: int, height: int, text_embeddings: torch.Tensor): """Generate painting-style SVG with brush strokes""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) features = self.extract_semantic_features(prompt) colors = self.get_style_colors("painting", features) # Generate brush strokes num_strokes = random.randint(20, 60) for i in range(num_strokes): # Create painterly brush strokes path_data = self.generate_brush_stroke(width, height, features) color = random.choice(colors) stroke_width = random.uniform(2.0, 8.0) opacity = random.uniform(0.4, 0.9) dwg.add(dwg.path( d=path_data, stroke=color, stroke_width=stroke_width, stroke_opacity=opacity, fill='none', stroke_linecap='round', stroke_linejoin='round' )) return dwg.tostring() def add_animal_icon_elements(self, dwg, width, height, features): """Add animal-like icon elements""" center_x, center_y = width // 2, height // 2 # Main body (circle) body_radius = min(width, height) // 4 dwg.add(dwg.circle( center=(center_x, center_y + 10), r=body_radius, fill='#4A90E2', stroke='#2E5C8A', stroke_width=2 )) # Head (smaller circle) head_radius = body_radius * 0.7 dwg.add(dwg.circle( center=(center_x, center_y - 20), r=head_radius, fill='#5BA0F2', stroke='#2E5C8A', stroke_width=2 )) # Eyes eye_size = 4 dwg.add(dwg.circle(center=(center_x - 15, center_y - 25), r=eye_size, fill='black')) dwg.add(dwg.circle(center=(center_x + 15, center_y - 25), r=eye_size, fill='black')) def add_building_icon_elements(self, dwg, width, height, features): """Add building-like icon elements""" # Main building rectangle building_width = width * 0.6 building_height = height * 0.7 x = (width - building_width) // 2 y = height - building_height - 20 dwg.add(dwg.rect( insert=(x, y), size=(building_width, building_height), fill='#E74C3C', stroke='#C0392B', stroke_width=2 )) # Roof (triangle) roof_points = [ (x, y), (x + building_width // 2, y - 30), (x + building_width, y) ] dwg.add(dwg.polygon( points=roof_points, fill='#8B4513', stroke='#654321', stroke_width=2 )) # Windows window_size = 20 for i in range(2): for j in range(3): wx = x + 20 + i * 40 wy = y + 20 + j * 30 dwg.add(dwg.rect( insert=(wx, wy), size=(window_size, window_size), fill='#3498DB', stroke='#2980B9', stroke_width=1 )) def add_nature_icon_elements(self, dwg, width, height, features): """Add nature-like icon elements""" center_x, center_y = width // 2, height // 2 # Tree trunk trunk_width = 20 trunk_height = height // 3 trunk_x = center_x - trunk_width // 2 trunk_y = height - trunk_height - 10 dwg.add(dwg.rect( insert=(trunk_x, trunk_y), size=(trunk_width, trunk_height), fill='#8B4513', stroke='#654321', stroke_width=1 )) # Tree crown (circle) crown_radius = min(width, height) // 3 dwg.add(dwg.circle( center=(center_x, center_y - 20), r=crown_radius, fill='#27AE60', stroke='#1E8449', stroke_width=2 )) def add_vehicle_icon_elements(self, dwg, width, height, features): """Add vehicle-like icon elements""" center_x, center_y = width // 2, height // 2 # Car body car_width = width * 0.7 car_height = height * 0.4 car_x = (width - car_width) // 2 car_y = center_y dwg.add(dwg.rect( insert=(car_x, car_y), size=(car_width, car_height), fill='#E74C3C', stroke='#C0392B', stroke_width=2, rx=10 )) # Wheels wheel_radius = 15 wheel_y = car_y + car_height - 5 dwg.add(dwg.circle(center=(car_x + 30, wheel_y), r=wheel_radius, fill='#2C3E50')) dwg.add(dwg.circle(center=(car_x + car_width - 30, wheel_y), r=wheel_radius, fill='#2C3E50')) def add_abstract_icon_elements(self, dwg, width, height, features): """Add abstract icon elements""" center_x, center_y = width // 2, height // 2 # Generate abstract geometric shapes colors = ['#3498DB', '#E74C3C', '#F39C12', '#27AE60', '#9B59B6'] for i in range(3): shape_type = random.choice(['circle', 'rect', 'polygon']) color = random.choice(colors) if shape_type == 'circle': radius = random.randint(20, 50) x = random.randint(radius, width - radius) y = random.randint(radius, height - radius) dwg.add(dwg.circle(center=(x, y), r=radius, fill=color, opacity=0.7)) elif shape_type == 'rect': w = random.randint(30, 80) h = random.randint(30, 80) x = random.randint(0, width - w) y = random.randint(0, height - h) dwg.add(dwg.rect(insert=(x, y), size=(w, h), fill=color, opacity=0.7)) def get_style_colors(self, style: str, features: Dict): """Get color palette for specific style""" if style == "pixel_art": return ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD'] elif style == "painting": return ['#8B4513', '#228B22', '#4169E1', '#DC143C', '#FFD700', '#9370DB'] elif style == "iconography": return ['#3498DB', '#E74C3C', '#F39C12', '#27AE60', '#9B59B6', '#34495E'] else: return ['#333333', '#666666', '#999999', '#CCCCCC'] def should_place_pixel(self, row: int, col: int, rows: int, cols: int, prompt: str, features: Dict): """Determine if a pixel should be placed at given position""" center_row, center_col = rows // 2, cols // 2 distance_from_center = math.sqrt((row - center_row)**2 + (col - center_col)**2) max_distance = math.sqrt(center_row**2 + center_col**2) # Create patterns based on prompt content if any(word in prompt.lower() for word in ['circle', 'round', 'ball']): return distance_from_center < max_distance * 0.4 elif any(word in prompt.lower() for word in ['square', 'box', 'cube']): return abs(row - center_row) < rows * 0.3 and abs(col - center_col) < cols * 0.3 else: # Random pattern with center bias probability = 1.0 - (distance_from_center / max_distance) * 0.7 return random.random() < probability def generate_sketchy_path(self, width: int, height: int, features: Dict): """Generate a sketchy path with natural variations""" # Start point start_x = random.uniform(width * 0.1, width * 0.9) start_y = random.uniform(height * 0.1, height * 0.9) # Create a path with multiple segments path_data = f"M {start_x},{start_y}" current_x, current_y = start_x, start_y num_segments = random.randint(2, 5) for i in range(num_segments): # Add some randomness for sketchy feel dx = random.uniform(-width * 0.3, width * 0.3) dy = random.uniform(-height * 0.3, height * 0.3) end_x = max(0, min(width, current_x + dx)) end_y = max(0, min(height, current_y + dy)) # Use quadratic curves for more natural feel cp_x = current_x + dx * 0.5 + random.uniform(-20, 20) cp_y = current_y + dy * 0.5 + random.uniform(-20, 20) path_data += f" Q {cp_x},{cp_y} {end_x},{end_y}" current_x, current_y = end_x, end_y return path_data def generate_brush_stroke(self, width: int, height: int, features: Dict): """Generate a painterly brush stroke""" # Start point start_x = random.uniform(width * 0.1, width * 0.9) start_y = random.uniform(height * 0.1, height * 0.9) # End point length = random.uniform(30, 100) angle = random.uniform(0, 2 * math.pi) end_x = start_x + length * math.cos(angle) end_y = start_y + length * math.sin(angle) # Clamp to bounds end_x = max(0, min(width, end_x)) end_y = max(0, min(height, end_y)) # Control point for curve mid_x = (start_x + end_x) / 2 + random.uniform(-20, 20) mid_y = (start_y + end_y) / 2 + random.uniform(-20, 20) return f"M {start_x},{start_y} Q {mid_x},{mid_y} {end_x},{end_y}" def get_text_embeddings(self, prompt: str): """Get CLIP text embeddings for the prompt""" if self.clip_model is None or self.clip_tokenizer is None: # Return dummy embeddings if model not loaded return torch.zeros((1, 77, 768)) try: with torch.no_grad(): text_inputs = self.clip_tokenizer( prompt, padding="max_length", max_length=self.clip_tokenizer.model_max_length, truncation=True, return_tensors="pt" ).to(self.device) text_embeddings = self.clip_model(text_inputs.input_ids)[0] return text_embeddings except Exception as e: print(f"Error getting text embeddings: {e}") return torch.zeros((1, 77, 768)) def extract_semantic_features(self, prompt: str): """Extract semantic features from prompt""" features = { 'complexity': 'medium', 'organic': False, 'geometric': False, 'colorful': False, 'minimal': False } prompt_lower = prompt.lower() # Analyze features if any(word in prompt_lower for word in ['simple', 'minimal', 'clean']): features['minimal'] = True features['complexity'] = 'low' elif any(word in prompt_lower for word in ['detailed', 'complex', 'intricate']): features['complexity'] = 'high' if any(word in prompt_lower for word in ['colorful', 'bright', 'vibrant']): features['colorful'] = True if any(word in prompt_lower for word in ['organic', 'natural', 'flowing']): features['organic'] = True if any(word in prompt_lower for word in ['geometric', 'angular', 'structured']): features['geometric'] = True return features def select_best_particle(self, particles: List[Dict], prompt: str): """Select the best particle from generated options""" # For now, return the first particle # In a full implementation, this would use quality metrics return particles[0] if particles else self.create_fallback_particle(prompt) def create_fallback_particle(self, prompt: str): """Create a fallback particle""" fallback_svg = self.create_fallback_svg(prompt, 256, 256, "iconography") return { 'particle_id': 0, 'svg': fallback_svg, 'svg_base64': base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'), 'prompt': prompt, 'style': 'iconography', 'parameters': {'width': 256, 'height': 256, 'seed': 0} } def svg_to_pil_image(self, svg_content: str, width: int, height: int): """Convert SVG content to PIL Image""" try: import cairosvg # Convert SVG to PNG bytes png_bytes = cairosvg.svg2png( bytestring=svg_content.encode('utf-8'), output_width=width, output_height=height ) # Convert to PIL Image image = Image.open(io.BytesIO(png_bytes)).convert('RGB') return image except ImportError: print("cairosvg not available, creating simple image representation") # Fallback: create a simple image with text image = Image.new('RGB', (width, height), 'white') return image except Exception as e: print(f"Error converting SVG to image: {e}") # Fallback: create a simple image image = Image.new('RGB', (width, height), 'white') return image def create_fallback_svg(self, prompt: str, width: int, height: int, style: str): """Create simple fallback SVG""" dwg = svgwrite.Drawing(size=(width, height)) dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) # Simple centered text dwg.add(dwg.text( f"SVGDreamer\n{style}\n{prompt[:20]}...", insert=(width/2, height/2), text_anchor="middle", font_size="12px", fill="black" )) return dwg.tostring()