diffsketcher / simplified_diffsketcher.py
jree423's picture
Update: Add simplified model implementation
4eb5b6e verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Simplified DiffSketcher implementation for Hugging Face Inference API.
This version doesn't rely on cloning the repository at runtime.
"""
import os
import io
import base64
import torch
import numpy as np
from PIL import Image
import cairosvg
import random
from pathlib import Path
class SimplifiedDiffSketcher:
def __init__(self, model_dir):
"""Initialize the simplified DiffSketcher model"""
self.model_dir = model_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Initializing simplified DiffSketcher on device: {self.device}")
# Load CLIP model if available
try:
import clip
self.clip_model, _ = clip.load("ViT-B-32", device=self.device)
self.clip_available = True
print("CLIP model loaded successfully")
except Exception as e:
print(f"Error loading CLIP model: {e}")
self.clip_available = False
def generate_svg(self, prompt, num_paths=20, width=512, height=512):
"""Generate an SVG from a text prompt"""
print(f"Generating SVG for prompt: {prompt}")
# Use CLIP to encode the prompt if available
if self.clip_available:
try:
import clip
with torch.no_grad():
text = clip.tokenize([prompt]).to(self.device)
text_features = self.clip_model.encode_text(text)
text_features = text_features.cpu().numpy()[0]
# Normalize features
text_features = text_features / np.linalg.norm(text_features)
except Exception as e:
print(f"Error encoding prompt with CLIP: {e}")
text_features = np.random.randn(512) # Random features as fallback
else:
# Generate random features if CLIP is not available
text_features = np.random.randn(512)
# Generate a car-like SVG based on the prompt
svg_content = self._generate_car_svg(prompt, text_features, num_paths, width, height)
return svg_content
def _generate_car_svg(self, prompt, features, num_paths=20, width=512, height=512):
"""Generate a car-like SVG based on the prompt and features"""
# Start SVG
svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
<rect width="100%" height="100%" fill="#f8f8f8"/>
"""
# Use the features to determine car properties
car_color_hue = int((features[0] + 1) * 180) % 360 # Map to 0-360 hue
car_size = 0.6 + 0.2 * features[1] # Size variation
car_style = int(abs(features[2] * 3)) % 3 # 0: sedan, 1: SUV, 2: sports car
# Calculate car dimensions
car_width = int(width * 0.7 * car_size)
car_height = int(height * 0.3 * car_size)
car_x = (width - car_width) // 2
car_y = height // 2
# Generate car body based on style
if car_style == 0: # Sedan
# Car body (rounded rectangle)
svg_content += f"""<rect x="{car_x}" y="{car_y}" width="{car_width}" height="{car_height}"
rx="20" ry="20" fill="hsl({car_color_hue}, 80%, 50%)" stroke="black" stroke-width="2" />"""
# Windshield
windshield_width = car_width * 0.7
windshield_height = car_height * 0.5
windshield_x = car_x + (car_width - windshield_width) // 2
windshield_y = car_y - windshield_height * 0.3
svg_content += f"""<rect x="{windshield_x}" y="{windshield_y}" width="{windshield_width}" height="{windshield_height}"
rx="10" ry="10" fill="#a8d8ff" stroke="black" stroke-width="1" />"""
# Wheels
wheel_radius = car_height * 0.4
wheel_y = car_y + car_height * 0.8
svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />"""
svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />"""
svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />"""
svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />"""
elif car_style == 1: # SUV
# Car body (taller rectangle)
svg_content += f"""<rect x="{car_x}" y="{car_y - car_height * 0.3}" width="{car_width}" height="{car_height * 1.3}"
rx="15" ry="15" fill="hsl({car_color_hue}, 80%, 50%)" stroke="black" stroke-width="2" />"""
# Windshield
windshield_width = car_width * 0.6
windshield_height = car_height * 0.6
windshield_x = car_x + (car_width - windshield_width) // 2
windshield_y = car_y - car_height * 0.2
svg_content += f"""<rect x="{windshield_x}" y="{windshield_y}" width="{windshield_width}" height="{windshield_height}"
rx="8" ry="8" fill="#a8d8ff" stroke="black" stroke-width="1" />"""
# Wheels (larger)
wheel_radius = car_height * 0.45
wheel_y = car_y + car_height * 0.7
svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />"""
svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />"""
svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />"""
svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />"""
else: # Sports car
# Car body (low, sleek shape)
svg_content += f"""<path d="M {car_x} {car_y + car_height * 0.5}
C {car_x + car_width * 0.1} {car_y - car_height * 0.2},
{car_x + car_width * 0.3} {car_y - car_height * 0.3},
{car_x + car_width * 0.5} {car_y - car_height * 0.2}
S {car_x + car_width * 0.9} {car_y},
{car_x + car_width} {car_y + car_height * 0.3}
L {car_x + car_width} {car_y + car_height * 0.7}
C {car_x + car_width * 0.9} {car_y + car_height},
{car_x + car_width * 0.1} {car_y + car_height},
{car_x} {car_y + car_height * 0.7} Z"
fill="hsl({car_color_hue}, 90%, 45%)" stroke="black" stroke-width="2" />"""
# Windshield
windshield_width = car_width * 0.4
windshield_x = car_x + car_width * 0.3
windshield_y = car_y - car_height * 0.1
svg_content += f"""<path d="M {windshield_x} {windshield_y}
C {windshield_x + windshield_width * 0.1} {windshield_y - car_height * 0.15},
{windshield_x + windshield_width * 0.9} {windshield_y - car_height * 0.15},
{windshield_x + windshield_width} {windshield_y} Z"
fill="#a8d8ff" stroke="black" stroke-width="1" />"""
# Wheels (low profile)
wheel_radius = car_height * 0.35
wheel_y = car_y + car_height * 0.7
svg_content += f"""<ellipse cx="{car_x + car_width * 0.2}" cy="{wheel_y}" rx="{wheel_radius * 1.2}" ry="{wheel_radius}" fill="black" />"""
svg_content += f"""<ellipse cx="{car_x + car_width * 0.8}" cy="{wheel_y}" rx="{wheel_radius * 1.2}" ry="{wheel_radius}" fill="black" />"""
svg_content += f"""<ellipse cx="{car_x + car_width * 0.2}" cy="{wheel_y}" rx="{wheel_radius * 0.7}" ry="{wheel_radius * 0.6}" fill="#444" />"""
svg_content += f"""<ellipse cx="{car_x + car_width * 0.8}" cy="{wheel_y}" rx="{wheel_radius * 0.7}" ry="{wheel_radius * 0.6}" fill="#444" />"""
# Add headlights
headlight_radius = car_width * 0.05
headlight_y = car_y + car_height * 0.3
svg_content += f"""<circle cx="{car_x + car_width * 0.1}" cy="{headlight_y}" r="{headlight_radius}" fill="yellow" stroke="black" stroke-width="1" />"""
svg_content += f"""<circle cx="{car_x + car_width * 0.9}" cy="{headlight_y}" r="{headlight_radius}" fill="yellow" stroke="black" stroke-width="1" />"""
# Add details based on features
for i in range(min(10, len(features))):
feature_val = features[i % len(features)]
x = car_x + car_width * ((i / 10) * 0.8 + 0.1)
y = car_y + car_height * ((feature_val + 1) / 4)
size = car_width * 0.03 * abs(feature_val)
svg_content += f"""<circle cx="{x}" cy="{y}" r="{size}" fill="rgba(0,0,0,0.2)" />"""
# Add prompt as text
svg_content += f"""<text x="{width/2}" y="{height - 20}" font-family="Arial" font-size="12" text-anchor="middle">{prompt}</text>"""
# Close SVG
svg_content += "</svg>"
return svg_content
def svg_to_png(self, svg_content):
"""Convert SVG content to PNG"""
try:
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
return png_data
except Exception as e:
print(f"Error converting SVG to PNG: {e}")
# Create a simple error image
image = Image.new("RGB", (512, 512), color="#ff0000")
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
# Convert PIL Image to PNG data
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return buffer.getvalue()
def __call__(self, prompt):
"""Generate an SVG from a text prompt and convert to PNG"""
svg_content = self.generate_svg(prompt)
png_data = self.svg_to_png(svg_content)
# Create a PIL Image from the PNG data
image = Image.open(io.BytesIO(png_data))
# Create the response
response = {
"svg": svg_content,
"svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"),
"png_base64": base64.b64encode(png_data).decode("utf-8"),
"image": image
}
return response