#!/usr/bin/env python # -*- coding: utf-8 -*- """ Simplified DiffSketcher implementation for Hugging Face Inference API. This version doesn't rely on cloning the repository at runtime. """ import os import io import base64 import torch import numpy as np from PIL import Image import cairosvg import random from pathlib import Path class SimplifiedDiffSketcher: def __init__(self, model_dir): """Initialize the simplified DiffSketcher model""" self.model_dir = model_dir self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Initializing simplified DiffSketcher on device: {self.device}") # Load CLIP model if available try: import clip self.clip_model, _ = clip.load("ViT-B-32", device=self.device) self.clip_available = True print("CLIP model loaded successfully") except Exception as e: print(f"Error loading CLIP model: {e}") self.clip_available = False def generate_svg(self, prompt, num_paths=20, width=512, height=512): """Generate an SVG from a text prompt""" print(f"Generating SVG for prompt: {prompt}") # Use CLIP to encode the prompt if available if self.clip_available: try: import clip with torch.no_grad(): text = clip.tokenize([prompt]).to(self.device) text_features = self.clip_model.encode_text(text) text_features = text_features.cpu().numpy()[0] # Normalize features text_features = text_features / np.linalg.norm(text_features) except Exception as e: print(f"Error encoding prompt with CLIP: {e}") text_features = np.random.randn(512) # Random features as fallback else: # Generate random features if CLIP is not available text_features = np.random.randn(512) # Generate a car-like SVG based on the prompt svg_content = self._generate_car_svg(prompt, text_features, num_paths, width, height) return svg_content def _generate_car_svg(self, prompt, features, num_paths=20, width=512, height=512): """Generate a car-like SVG based on the prompt and features""" # Start SVG svg_content = f""" """ # Use the features to determine car properties car_color_hue = int((features[0] + 1) * 180) % 360 # Map to 0-360 hue car_size = 0.6 + 0.2 * features[1] # Size variation car_style = int(abs(features[2] * 3)) % 3 # 0: sedan, 1: SUV, 2: sports car # Calculate car dimensions car_width = int(width * 0.7 * car_size) car_height = int(height * 0.3 * car_size) car_x = (width - car_width) // 2 car_y = height // 2 # Generate car body based on style if car_style == 0: # Sedan # Car body (rounded rectangle) svg_content += f"""""" # Windshield windshield_width = car_width * 0.7 windshield_height = car_height * 0.5 windshield_x = car_x + (car_width - windshield_width) // 2 windshield_y = car_y - windshield_height * 0.3 svg_content += f"""""" # Wheels wheel_radius = car_height * 0.4 wheel_y = car_y + car_height * 0.8 svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" elif car_style == 1: # SUV # Car body (taller rectangle) svg_content += f"""""" # Windshield windshield_width = car_width * 0.6 windshield_height = car_height * 0.6 windshield_x = car_x + (car_width - windshield_width) // 2 windshield_y = car_y - car_height * 0.2 svg_content += f"""""" # Wheels (larger) wheel_radius = car_height * 0.45 wheel_y = car_y + car_height * 0.7 svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" else: # Sports car # Car body (low, sleek shape) svg_content += f"""""" # Windshield windshield_width = car_width * 0.4 windshield_x = car_x + car_width * 0.3 windshield_y = car_y - car_height * 0.1 svg_content += f"""""" # Wheels (low profile) wheel_radius = car_height * 0.35 wheel_y = car_y + car_height * 0.7 svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" svg_content += f"""""" # Add headlights headlight_radius = car_width * 0.05 headlight_y = car_y + car_height * 0.3 svg_content += f"""""" svg_content += f"""""" # Add details based on features for i in range(min(10, len(features))): feature_val = features[i % len(features)] x = car_x + car_width * ((i / 10) * 0.8 + 0.1) y = car_y + car_height * ((feature_val + 1) / 4) size = car_width * 0.03 * abs(feature_val) svg_content += f"""""" # Add prompt as text svg_content += f"""{prompt}""" # Close SVG svg_content += "" return svg_content def svg_to_png(self, svg_content): """Convert SVG content to PNG""" try: png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) return png_data except Exception as e: print(f"Error converting SVG to PNG: {e}") # Create a simple error image image = Image.new("RGB", (512, 512), color="#ff0000") from PIL import ImageDraw draw = ImageDraw.Draw(image) draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") # Convert PIL Image to PNG data buffer = io.BytesIO() image.save(buffer, format="PNG") return buffer.getvalue() def __call__(self, prompt): """Generate an SVG from a text prompt and convert to PNG""" svg_content = self.generate_svg(prompt) png_data = self.svg_to_png(svg_content) # Create a PIL Image from the PNG data image = Image.open(io.BytesIO(png_data)) # Create the response response = { "svg": svg_content, "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"), "png_base64": base64.b64encode(png_data).decode("utf-8"), "image": image } return response