#!/usr/bin/env python # -*- coding: utf-8 -*- """ Versatile SVG Generator that creates different types of objects based on the prompt. """ import os import io import base64 import torch import numpy as np from PIL import Image import cairosvg import random from pathlib import Path import re class VersatileSVGGenerator: def __init__(self, model_dir): """Initialize the versatile SVG generator""" self.model_dir = model_dir self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Initializing versatile SVG generator on device: {self.device}") # Load CLIP model if available try: import clip self.clip_model, _ = clip.load("ViT-B-32", device=self.device) self.clip_available = True print("CLIP model loaded successfully") except Exception as e: print(f"Error loading CLIP model: {e}") self.clip_available = False def generate_svg(self, prompt, num_paths=20, width=512, height=512): """Generate an SVG from a text prompt""" print(f"Generating SVG for prompt: {prompt}") # Use CLIP to encode the prompt if available if self.clip_available: try: import clip with torch.no_grad(): text = clip.tokenize([prompt]).to(self.device) text_features = self.clip_model.encode_text(text) text_features = text_features.cpu().numpy()[0] # Normalize features text_features = text_features / np.linalg.norm(text_features) except Exception as e: print(f"Error encoding prompt with CLIP: {e}") text_features = np.random.randn(512) # Random features as fallback else: # Generate random features if CLIP is not available text_features = np.random.randn(512) # Determine what type of object to generate based on the prompt object_type = self._determine_object_type(prompt) # Generate SVG based on the object type if object_type == "car": svg_content = self._generate_car_svg(prompt, text_features, num_paths, width, height) elif object_type == "landscape": svg_content = self._generate_landscape_svg(prompt, text_features, num_paths, width, height) elif object_type == "animal": svg_content = self._generate_animal_svg(prompt, text_features, num_paths, width, height) elif object_type == "building": svg_content = self._generate_building_svg(prompt, text_features, num_paths, width, height) elif object_type == "face": svg_content = self._generate_face_svg(prompt, text_features, num_paths, width, height) else: svg_content = self._generate_abstract_svg(prompt, text_features, num_paths, width, height) return svg_content def _determine_object_type(self, prompt): """Determine what type of object to generate based on the prompt""" prompt = prompt.lower() # Check for car-related terms car_terms = ["car", "vehicle", "truck", "suv", "sedan", "convertible", "sports car", "automobile"] for term in car_terms: if term in prompt: return "car" # Check for landscape-related terms landscape_terms = ["landscape", "mountain", "forest", "beach", "ocean", "sea", "lake", "river", "sunset", "sunrise", "sky"] for term in landscape_terms: if term in prompt: return "landscape" # Check for animal-related terms animal_terms = ["animal", "dog", "cat", "bird", "horse", "lion", "tiger", "elephant", "bear", "fish", "pet"] for term in animal_terms: if term in prompt: return "animal" # Check for building-related terms building_terms = ["building", "house", "skyscraper", "tower", "castle", "mansion", "apartment", "office", "structure"] for term in building_terms: if term in prompt: return "building" # Check for face-related terms face_terms = ["face", "portrait", "person", "man", "woman", "boy", "girl", "human", "head", "smile"] for term in face_terms: if term in prompt: return "face" # Default to abstract return "abstract" def _generate_car_svg(self, prompt, features, num_paths=20, width=512, height=512): """Generate a car-like SVG based on the prompt and features""" # Start SVG svg_content = f""" """ # Use the features to determine car properties car_color_hue = int((features[0] + 1) * 180) % 360 # Map to 0-360 hue car_size = 0.6 + 0.2 * features[1] # Size variation car_style = int(abs(features[2] * 3)) % 3 # 0: sedan, 1: SUV, 2: sports car # Calculate car dimensions car_width = int(width * 0.7 * car_size) car_height = int(height * 0.3 * car_size) car_x = (width - car_width) // 2 car_y = height // 2 # Generate car body based on style if car_style == 0: # Sedan # Car body (rounded rectangle) svg_content += f"""""" # Windshield windshield_width = car_width * 0.7 windshield_height = car_height * 0.5 windshield_x = car_x + (car_width - windshield_width) // 2 windshield_y = car_y - windshield_height * 0.3 svg_content += f"""""" # Wheels wheel_radius = car_height * 0.4 wheel_y = car_y + car_height * 0.8 svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" elif car_style == 1: # SUV # Car body (taller rectangle) svg_content += f"""""" # Windshield windshield_width = car_width * 0.6 windshield_height = car_height * 0.6 windshield_x = car_x + (car_width - windshield_width) // 2 windshield_y = car_y - car_height * 0.2 svg_content += f"""""" # Wheels (larger) wheel_radius = car_height * 0.45 wheel_y = car_y + car_height * 0.7 svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" else: # Sports car # Car body (low, sleek shape) svg_content += f"""""" # Windshield windshield_width = car_width * 0.4 windshield_x = car_x + car_width * 0.3 windshield_y = car_y - car_height * 0.1 svg_content += f"""""" # Wheels (low profile) wheel_radius = car_height * 0.35 wheel_y = car_y + car_height * 0.7 svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" # Add headlights headlight_radius = car_width * 0.05 headlight_y = car_y + car_height * 0.3 svg_content += f"""""" svg_content += f"""""" # Add prompt as text svg_content += f"""{prompt}""" # Close SVG svg_content += "" return svg_content def _generate_landscape_svg(self, prompt, features, num_paths=20, width=512, height=512): """Generate a landscape SVG based on the prompt and features""" # Start SVG svg_content = f""" """ # Use features to determine landscape properties mountain_count = int(abs(features[0] * 5)) + 3 tree_count = int(abs(features[1] * 20)) + 5 has_sun = features[2] > 0 has_water = features[3] > 0 # Draw mountains for i in range(mountain_count): mountain_height = height * (0.3 + 0.2 * abs(features[i % len(features)])) mountain_width = width * (0.2 + 0.1 * abs(features[(i+1) % len(features)])) mountain_x = width * (i / mountain_count) mountain_color = f"hsl({int(120 + features[i % len(features)] * 20)}, 30%, {30 + int(abs(features[i % len(features)] * 20))}%)" svg_content += f"""""" # Draw sun if present if has_sun: sun_x = width * (0.1 + 0.8 * abs(features[4])) sun_y = height * 0.2 sun_radius = width * 0.08 svg_content += f""" """ # Draw water if present if has_water: water_height = height * 0.3 water_y = height - water_height svg_content += f""" """ # Add waves for i in range(5): wave_y = water_y + i * water_height / 5 svg_content += f"""""" # Draw trees for i in range(tree_count): tree_x = width * (0.1 + 0.8 * (i / tree_count)) tree_y = height * 0.8 tree_height = height * (0.1 + 0.1 * abs(features[i % len(features)])) tree_width = tree_height * 0.6 # Tree trunk svg_content += f"""""" # Tree foliage svg_content += f""" """ # Add prompt as text svg_content += f"""{prompt}""" # Close SVG svg_content += "" return svg_content def _generate_animal_svg(self, prompt, features, num_paths=20, width=512, height=512): """Generate an animal SVG based on the prompt and features""" # Start SVG svg_content = f""" """ # Determine animal type from prompt animal_type = "generic" if "dog" in prompt.lower() or "puppy" in prompt.lower(): animal_type = "dog" elif "cat" in prompt.lower() or "kitten" in prompt.lower(): animal_type = "cat" elif "bird" in prompt.lower(): animal_type = "bird" elif "fish" in prompt.lower(): animal_type = "fish" # Use features to determine animal properties animal_color_hue = int((features[0] + 1) * 180) % 360 # Map to 0-360 hue animal_size = 0.5 + 0.3 * features[1] # Size variation # Calculate animal dimensions animal_width = int(width * 0.6 * animal_size) animal_height = int(height * 0.4 * animal_size) animal_x = (width - animal_width) // 2 animal_y = height // 2 if animal_type == "dog": # Dog body (oval) svg_content += f"""""" # Dog head (circle) head_radius = animal_width * 0.2 svg_content += f"""""" # Dog ears svg_content += f"""""" svg_content += f"""""" # Dog eyes svg_content += f"""""" svg_content += f"""""" # Dog nose svg_content += f"""""" # Dog legs leg_width = animal_width * 0.1 leg_height = animal_height * 0.4 svg_content += f"""""" svg_content += f"""""" # Dog tail svg_content += f"""""" elif animal_type == "cat": # Cat body (oval) svg_content += f"""""" # Cat head (circle) head_radius = animal_width * 0.18 svg_content += f"""""" # Cat ears (triangles) svg_content += f"""""" svg_content += f"""""" # Cat eyes (almond shaped) svg_content += f"""""" svg_content += f"""""" # Cat nose svg_content += f"""""" # Cat whiskers svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" # Cat legs leg_width = animal_width * 0.08 leg_height = animal_height * 0.3 svg_content += f"""""" svg_content += f"""""" # Cat tail svg_content += f"""""" elif animal_type == "bird": # Bird body (oval) svg_content += f"""""" # Bird head head_radius = animal_width * 0.15 svg_content += f"""""" # Bird beak svg_content += f"""""" # Bird eye svg_content += f"""""" # Bird wings svg_content += f"""""" # Bird tail svg_content += f"""""" # Bird legs leg_width = animal_width * 0.02 leg_height = animal_height * 0.2 svg_content += f"""""" svg_content += f"""""" elif animal_type == "fish": # Fish body (oval) svg_content += f"""""" # Fish tail svg_content += f"""""" # Fish eye svg_content += f"""""" svg_content += f"""""" # Fish fins svg_content += f"""""" svg_content += f"""""" # Fish scales (simplified) for i in range(5): for j in range(3): scale_x = animal_x + animal_width * (0.3 + i * 0.1) scale_y = animal_y + animal_height * (0.4 + (j-1) * 0.1) scale_radius = animal_width * 0.03 svg_content += f"""""" # Water bubbles for i in range(3): bubble_x = animal_x + animal_width * (0.8 + i * 0.1) bubble_y = animal_y + animal_height * (0.3 - i * 0.1) bubble_radius = animal_width * (0.02 + i * 0.01) svg_content += f"""""" else: # Generic animal # Body (oval) svg_content += f"""""" # Head (circle) head_radius = animal_width * 0.2 svg_content += f"""""" # Eyes svg_content += f"""""" svg_content += f"""""" # Legs leg_width = animal_width * 0.08 leg_height = animal_height * 0.3 svg_content += f"""""" svg_content += f"""""" # Tail svg_content += f"""""" # Add prompt as text svg_content += f"""{prompt}""" # Close SVG svg_content += "" return svg_content def _generate_building_svg(self, prompt, features, num_paths=20, width=512, height=512): """Generate a building SVG based on the prompt and features""" # Start SVG svg_content = f""" """ # Determine building type from prompt building_type = "generic" if "house" in prompt.lower(): building_type = "house" elif "skyscraper" in prompt.lower() or "tower" in prompt.lower(): building_type = "skyscraper" elif "castle" in prompt.lower(): building_type = "castle" # Use features to determine building properties building_color_hue = int((features[0] + 1) * 180) % 360 # Map to 0-360 hue building_size = 0.5 + 0.3 * features[1] # Size variation # Calculate building dimensions building_width = int(width * 0.6 * building_size) building_height = int(height * 0.7 * building_size) building_x = (width - building_width) // 2 building_y = height - building_height if building_type == "house": # House base svg_content += f"""""" # House roof svg_content += f"""""" # House door door_width = building_width * 0.2 door_height = building_height * 0.4 door_x = building_x + (building_width - door_width) / 2 door_y = building_y + building_height - door_height svg_content += f"""""" # Door knob svg_content += f"""""" # Windows window_width = building_width * 0.15 window_height = building_height * 0.15 window_spacing = building_width * 0.25 for i in range(2): for j in range(2): window_x = building_x + window_spacing + i * window_spacing window_y = building_y + building_height * 0.4 + j * window_spacing svg_content += f"""""" # Window crossbars svg_content += f"""""" svg_content += f"""""" # Chimney chimney_width = building_width * 0.1 chimney_height = building_height * 0.3 chimney_x = building_x + building_width * 0.7 chimney_y = building_y + building_height * 0.1 - chimney_height svg_content += f"""""" elif building_type == "skyscraper": # Skyscraper base svg_content += f"""""" # Skyscraper top top_width = building_width * 0.7 top_height = building_height * 0.1 top_x = building_x + (building_width - top_width) / 2 svg_content += f"""""" # Antenna antenna_width = building_width * 0.02 antenna_height = building_height * 0.15 antenna_x = building_x + building_width / 2 - antenna_width / 2 antenna_y = building_y - top_height - antenna_height svg_content += f"""""" # Windows (grid pattern) window_width = building_width * 0.08 window_height = building_height * 0.05 window_spacing_x = building_width * 0.12 window_spacing_y = building_height * 0.08 for i in range(int(building_width / window_spacing_x) - 1): for j in range(int(building_height / window_spacing_y) - 1): window_x = building_x + window_spacing_x * (i + 0.5) window_y = building_y + window_spacing_y * (j + 0.5) # Randomize window lighting window_color = "#a8d8ff" if random.random() < 0.3: # 30% chance of lit window window_color = "#ffff88" svg_content += f"""""" # Entrance entrance_width = building_width * 0.3 entrance_height = building_height * 0.1 entrance_x = building_x + (building_width - entrance_width) / 2 entrance_y = building_y + building_height - entrance_height svg_content += f"""""" elif building_type == "castle": # Castle base svg_content += f"""""" # Castle towers tower_width = building_width * 0.2 tower_height = building_height * 1.0 # Left tower svg_content += f"""""" # Right tower svg_content += f"""""" # Crenellations (castle top) crenel_width = building_width * 0.05 crenel_height = building_height * 0.05 crenel_count = int(building_width / crenel_width) for i in range(crenel_count): if i % 2 == 0: crenel_x = building_x + i * crenel_width svg_content += f"""""" # Tower crenellations tower_crenel_count = int(tower_width / crenel_width) # Left tower crenellations for i in range(tower_crenel_count): if i % 2 == 0: crenel_x = building_x - tower_width * 0.5 + i * crenel_width svg_content += f"""""" # Right tower crenellations for i in range(tower_crenel_count): if i % 2 == 0: crenel_x = building_x + building_width - tower_width * 0.5 + i * crenel_width svg_content += f"""""" # Castle door (gate) door_width = building_width * 0.25 door_height = building_height * 0.4 door_x = building_x + (building_width - door_width) / 2 door_y = building_y + building_height - door_height # Gate arch svg_content += f"""""" # Windows window_width = building_width * 0.1 window_height = building_height * 0.15 window_spacing = building_width * 0.25 for i in range(3): window_x = building_x + window_spacing * (i + 0.5) window_y = building_y + building_height * 0.4 # Arched window svg_content += f"""""" # Tower windows (slits) slit_width = tower_width * 0.1 slit_height = tower_height * 0.1 # Left tower slits for i in range(3): slit_x = building_x - tower_width * 0.5 + tower_width * 0.45 slit_y = building_y + tower_height * (0.2 + i * 0.2) svg_content += f"""""" # Right tower slits for i in range(3): slit_x = building_x + building_width - tower_width * 0.5 + tower_width * 0.45 slit_y = building_y + tower_height * (0.2 + i * 0.2) svg_content += f"""""" else: # Generic building # Building base svg_content += f"""""" # Building roof roof_height = building_height * 0.2 svg_content += f"""""" # Building door door_width = building_width * 0.2 door_height = building_height * 0.3 door_x = building_x + (building_width - door_width) / 2 door_y = building_y + building_height - door_height svg_content += f"""""" # Windows window_width = building_width * 0.15 window_height = building_height * 0.15 window_spacing_x = building_width * 0.25 window_spacing_y = building_height * 0.25 for i in range(3): for j in range(2): window_x = building_x + window_spacing_x * (i + 0.5) window_y = building_y + window_spacing_y * (j + 0.5) svg_content += f"""""" # Add prompt as text svg_content += f"""{prompt}""" # Close SVG svg_content += "" return svg_content def _generate_face_svg(self, prompt, features, num_paths=20, width=512, height=512): """Generate a face SVG based on the prompt and features""" # Start SVG svg_content = f""" """ # Use features to determine face properties face_color_hue = int((features[0] + 1) * 20) % 40 + 10 # Map to 10-50 hue (skin tones) face_size = 0.5 + 0.2 * features[1] # Size variation face_shape = int(abs(features[2] * 3)) % 3 # 0: round, 1: oval, 2: square # Calculate face dimensions face_width = int(width * 0.6 * face_size) face_height = int(height * 0.7 * face_size) face_x = (width - face_width) // 2 face_y = (height - face_height) // 2 # Draw face shape if face_shape == 0: # Round svg_content += f"""""" elif face_shape == 1: # Oval svg_content += f"""""" else: # Square with rounded corners svg_content += f"""""" # Determine gender from prompt is_female = any(term in prompt.lower() for term in ["woman", "girl", "female", "lady"]) # Draw eyes eye_width = face_width * 0.15 eye_height = face_height * 0.08 eye_y = face_y + face_height * 0.35 left_eye_x = face_x + face_width * 0.3 - eye_width / 2 right_eye_x = face_x + face_width * 0.7 - eye_width / 2 # Eye whites svg_content += f"""""" svg_content += f"""""" # Pupils pupil_size = eye_width * 0.3 svg_content += f"""""" svg_content += f"""""" # Eyebrows brow_width = eye_width * 1.2 brow_height = eye_height * 0.5 brow_y = eye_y - eye_height * 0.8 svg_content += f"""""" svg_content += f"""""" # Nose nose_width = face_width * 0.1 nose_height = face_height * 0.15 nose_x = face_x + face_width / 2 - nose_width / 2 nose_y = face_y + face_height * 0.5 - nose_height / 2 svg_content += f"""""" # Mouth mouth_width = face_width * 0.4 mouth_height = face_height * 0.05 mouth_x = face_x + face_width / 2 - mouth_width / 2 mouth_y = face_y + face_height * 0.7 # Smiling mouth svg_content += f"""""" # Hair hair_color_hue = int((features[3] + 1) * 180) % 360 # Map to 0-360 hue if is_female: # Long hair for female svg_content += f"""""" # Hair on top of head svg_content += f"""""" else: # Short hair for male svg_content += f"""""" # Hair sides svg_content += f"""""" svg_content += f"""""" # Add ears ear_width = face_width * 0.1 ear_height = face_height * 0.2 left_ear_x = face_x - ear_width / 2 right_ear_x = face_x + face_width - ear_width / 2 ear_y = face_y + face_height * 0.4 svg_content += f"""""" svg_content += f"""""" # Add prompt as text svg_content += f"""{prompt}""" # Close SVG svg_content += "" return svg_content def _generate_abstract_svg(self, prompt, features, num_paths=20, width=512, height=512): """Generate an abstract SVG based on the prompt and features""" # Start SVG svg_content = f""" """ # Use features to determine abstract properties color_scheme = int(abs(features[0] * 5)) % 5 # 0-4 color schemes shape_complexity = int(abs(features[1] * 10)) + 5 # 5-15 shapes use_gradients = features[2] > 0 # Define color schemes color_schemes = [ # Warm colors [f"hsl({h}, 80%, 60%)" for h in range(0, 61, 15)], # Cool colors [f"hsl({h}, 80%, 60%)" for h in range(180, 241, 15)], # Complementary [f"hsl({h}, 80%, 60%)" for h in range(0, 361, 180)], # Monochromatic [f"hsl(210, 80%, {l}%)" for l in range(30, 91, 15)], # Rainbow [f"hsl({h}, 80%, 60%)" for h in range(0, 361, 60)] ] colors = color_schemes[color_scheme] # Add gradients if needed if use_gradients: svg_content += """""" for i, color in enumerate(colors[:-1]): svg_content += f""" """ svg_content += """""" # Generate shapes based on prompt hash prompt_hash = sum(ord(c) for c in prompt) random.seed(prompt_hash) for i in range(shape_complexity): shape_type = i % 4 # 0: circle, 1: rectangle, 2: polygon, 3: path x = random.randint(0, width) y = random.randint(0, height) size = random.randint(20, 150) color_idx = i % len(colors) fill = f"url(#gradient{color_idx})" if use_gradients and color_idx < len(colors) - 1 else colors[color_idx] opacity = 0.3 + random.random() * 0.7 if shape_type == 0: # Circle svg_content += f"""""" elif shape_type == 1: # Rectangle svg_content += f"""""" elif shape_type == 2: # Polygon points = [] sides = random.randint(3, 8) for j in range(sides): angle = j * 2 * 3.14159 / sides px = x + size/2 * np.cos(angle) py = y + size/2 * np.sin(angle) points.append(f"{px},{py}") svg_content += f"""""" else: # Path (curved) path = f"M {x} {y} " control_points = random.randint(2, 5) for j in range(control_points): cx1 = x + random.randint(-size, size) cy1 = y + random.randint(-size, size) cx2 = x + random.randint(-size, size) cy2 = y + random.randint(-size, size) ex = x + random.randint(-size, size) ey = y + random.randint(-size, size) path += f"C {cx1} {cy1}, {cx2} {cy2}, {ex} {ey} " svg_content += f"""""" # Add text elements based on the prompt words = re.findall(r'\b\w+\b', prompt) for i, word in enumerate(words[:5]): # Use up to 5 words from the prompt text_x = random.randint(width // 4, width * 3 // 4) text_y = random.randint(height // 4, height * 3 // 4) text_size = random.randint(10, 40) text_color = colors[i % len(colors)] text_opacity = 0.7 + random.random() * 0.3 text_rotation = random.randint(-45, 45) svg_content += f"""{word}""" # Add prompt as text svg_content += f"""{prompt}""" # Close SVG svg_content += "" return svg_content def svg_to_png(self, svg_content): """Convert SVG content to PNG""" try: png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) return png_data except Exception as e: print(f"Error converting SVG to PNG: {e}") # Create a simple error image image = Image.new("RGB", (512, 512), color="#ff0000") from PIL import ImageDraw draw = ImageDraw.Draw(image) draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") # Convert PIL Image to PNG data buffer = io.BytesIO() image.save(buffer, format="PNG") return buffer.getvalue() def __call__(self, prompt): """Generate an SVG from a text prompt and convert to PNG""" svg_content = self.generate_svg(prompt) png_data = self.svg_to_png(svg_content) # Create a PIL Image from the PNG data image = Image.open(io.BytesIO(png_data)) # Create the response response = { "svg": svg_content, "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"), "png_base64": base64.b64encode(png_data).decode("utf-8"), "image": image } return response