#!/usr/bin/env python # -*- coding: utf-8 -*- """ Simplified DiffSketcher model for text-to-SVG generation. """ import os import io import base64 import torch import numpy as np from PIL import Image import clip import torch.nn.functional as F import xml.etree.ElementTree as ET import cairosvg class DiffSketcherModel: def __init__(self, model_dir): """Initialize the DiffSketcher model""" self.model_dir = model_dir self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load CLIP model self.clip_model_path = os.path.join(model_dir, "ViT-B-32.pt") if os.path.exists(self.clip_model_path): print(f"Loading CLIP model from {self.clip_model_path}") self.clip_model, _ = clip.load(self.clip_model_path, device=self.device) else: print(f"CLIP model not found at {self.clip_model_path}, downloading...") self.clip_model, _ = clip.load("ViT-B-32", device=self.device) # Set model to evaluation mode self.clip_model.eval() print(f"DiffSketcher model initialized on device: {self.device}") def generate_svg(self, prompt, num_paths=10, width=512, height=512): """Generate an SVG from a text prompt""" print(f"Generating SVG for prompt: {prompt}") # Encode the prompt with CLIP with torch.no_grad(): text_features = self.clip_model.encode_text(clip.tokenize([prompt]).to(self.device)) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # Generate a simple SVG based on the prompt # In a real implementation, this would use the full DiffSketcher model svg_content = f""" Generated by DiffSketcher {prompt} """ # Add some random paths based on the text features for i in range(min(num_paths, text_features.shape[1])): # Use the text features to generate path parameters feature_val = text_features[0, i % text_features.shape[1]].item() x = (feature_val + 1) * width / 2 y = ((i / num_paths) * 0.8 + 0.1) * height radius = abs(feature_val) * 50 + 10 hue = (feature_val + 1) * 180 # Add a circle with color based on the feature svg_content += f"""""" # Close the SVG svg_content += "" return svg_content def svg_to_png(self, svg_content): """Convert SVG content to PNG""" try: png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) return png_data except Exception as e: print(f"Error converting SVG to PNG: {e}") # Create a simple error image image = Image.new("RGB", (512, 512), color="#ff0000") from PIL import ImageDraw draw = ImageDraw.Draw(image) draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") # Convert PIL Image to PNG data buffer = io.BytesIO() image.save(buffer, format="PNG") return buffer.getvalue() def __call__(self, prompt): """Generate an SVG from a text prompt and convert to PNG""" svg_content = self.generate_svg(prompt) png_data = self.svg_to_png(svg_content) # Create a PIL Image from the PNG data image = Image.open(io.BytesIO(png_data)) # Create the response response = { "svg": svg_content, "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"), "png_base64": base64.b64encode(png_data).decode("utf-8"), "image": image } return response