#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Versatile SVG Generator that creates different types of objects based on the prompt.
"""
import os
import io
import base64
import torch
import numpy as np
from PIL import Image
import cairosvg
import random
from pathlib import Path
import re
class VersatileSVGGenerator:
def __init__(self, model_dir):
"""Initialize the versatile SVG generator"""
self.model_dir = model_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Initializing versatile SVG generator on device: {self.device}")
# Load CLIP model if available
try:
import clip
self.clip_model, _ = clip.load("ViT-B-32", device=self.device)
self.clip_available = True
print("CLIP model loaded successfully")
except Exception as e:
print(f"Error loading CLIP model: {e}")
self.clip_available = False
def generate_svg(self, prompt, num_paths=20, width=512, height=512):
"""Generate an SVG from a text prompt"""
print(f"Generating SVG for prompt: {prompt}")
# Use CLIP to encode the prompt if available
if self.clip_available:
try:
import clip
with torch.no_grad():
text = clip.tokenize([prompt]).to(self.device)
text_features = self.clip_model.encode_text(text)
text_features = text_features.cpu().numpy()[0]
# Normalize features
text_features = text_features / np.linalg.norm(text_features)
except Exception as e:
print(f"Error encoding prompt with CLIP: {e}")
text_features = np.random.randn(512) # Random features as fallback
else:
# Generate random features if CLIP is not available
text_features = np.random.randn(512)
# Determine what type of object to generate based on the prompt
object_type = self._determine_object_type(prompt)
# Generate SVG based on the object type
if object_type == "car":
svg_content = self._generate_car_svg(prompt, text_features, num_paths, width, height)
elif object_type == "landscape":
svg_content = self._generate_landscape_svg(prompt, text_features, num_paths, width, height)
elif object_type == "animal":
svg_content = self._generate_animal_svg(prompt, text_features, num_paths, width, height)
elif object_type == "building":
svg_content = self._generate_building_svg(prompt, text_features, num_paths, width, height)
elif object_type == "face":
svg_content = self._generate_face_svg(prompt, text_features, num_paths, width, height)
else:
svg_content = self._generate_abstract_svg(prompt, text_features, num_paths, width, height)
return svg_content
def _determine_object_type(self, prompt):
"""Determine what type of object to generate based on the prompt"""
prompt = prompt.lower()
# Check for car-related terms
car_terms = ["car", "vehicle", "truck", "suv", "sedan", "convertible", "sports car", "automobile"]
for term in car_terms:
if term in prompt:
return "car"
# Check for landscape-related terms
landscape_terms = ["landscape", "mountain", "forest", "beach", "ocean", "sea", "lake", "river", "sunset", "sunrise", "sky"]
for term in landscape_terms:
if term in prompt:
return "landscape"
# Check for animal-related terms
animal_terms = ["animal", "dog", "cat", "bird", "horse", "lion", "tiger", "elephant", "bear", "fish", "pet"]
for term in animal_terms:
if term in prompt:
return "animal"
# Check for building-related terms
building_terms = ["building", "house", "skyscraper", "tower", "castle", "mansion", "apartment", "office", "structure"]
for term in building_terms:
if term in prompt:
return "building"
# Check for face-related terms
face_terms = ["face", "portrait", "person", "man", "woman", "boy", "girl", "human", "head", "smile"]
for term in face_terms:
if term in prompt:
return "face"
# Default to abstract
return "abstract"
def _generate_car_svg(self, prompt, features, num_paths=20, width=512, height=512):
"""Generate a car-like SVG based on the prompt and features"""
# Start SVG
svg_content = f""""
return svg_content
def _generate_landscape_svg(self, prompt, features, num_paths=20, width=512, height=512):
"""Generate a landscape SVG based on the prompt and features"""
# Start SVG
svg_content = f""""
return svg_content
def _generate_animal_svg(self, prompt, features, num_paths=20, width=512, height=512):
"""Generate an animal SVG based on the prompt and features"""
# Start SVG
svg_content = f""""
return svg_content
def _generate_building_svg(self, prompt, features, num_paths=20, width=512, height=512):
"""Generate a building SVG based on the prompt and features"""
# Start SVG
svg_content = f""""
return svg_content
def _generate_face_svg(self, prompt, features, num_paths=20, width=512, height=512):
"""Generate a face SVG based on the prompt and features"""
# Start SVG
svg_content = f""""
return svg_content
def _generate_abstract_svg(self, prompt, features, num_paths=20, width=512, height=512):
"""Generate an abstract SVG based on the prompt and features"""
# Start SVG
svg_content = f""""
return svg_content
def svg_to_png(self, svg_content):
"""Convert SVG content to PNG"""
try:
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
return png_data
except Exception as e:
print(f"Error converting SVG to PNG: {e}")
# Create a simple error image
image = Image.new("RGB", (512, 512), color="#ff0000")
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
# Convert PIL Image to PNG data
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return buffer.getvalue()
def __call__(self, prompt):
"""Generate an SVG from a text prompt and convert to PNG"""
svg_content = self.generate_svg(prompt)
png_data = self.svg_to_png(svg_content)
# Create a PIL Image from the PNG data
image = Image.open(io.BytesIO(png_data))
# Create the response
response = {
"svg": svg_content,
"svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"),
"png_base64": base64.b64encode(png_data).decode("utf-8"),
"image": image
}
return response