|
|
|
|
|
|
|
""" |
|
Simplified DiffSketcher implementation for Hugging Face Inference API. |
|
This version doesn't rely on cloning the repository at runtime. |
|
""" |
|
|
|
import os |
|
import io |
|
import base64 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import cairosvg |
|
import random |
|
from pathlib import Path |
|
|
|
class SimplifiedDiffSketcher: |
|
def __init__(self, model_dir): |
|
"""Initialize the simplified DiffSketcher model""" |
|
self.model_dir = model_dir |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Initializing simplified DiffSketcher on device: {self.device}") |
|
|
|
|
|
try: |
|
import clip |
|
self.clip_model, _ = clip.load("ViT-B-32", device=self.device) |
|
self.clip_available = True |
|
print("CLIP model loaded successfully") |
|
except Exception as e: |
|
print(f"Error loading CLIP model: {e}") |
|
self.clip_available = False |
|
|
|
def generate_svg(self, prompt, num_paths=20, width=512, height=512): |
|
"""Generate an SVG from a text prompt""" |
|
print(f"Generating SVG for prompt: {prompt}") |
|
|
|
|
|
if self.clip_available: |
|
try: |
|
import clip |
|
with torch.no_grad(): |
|
text = clip.tokenize([prompt]).to(self.device) |
|
text_features = self.clip_model.encode_text(text) |
|
text_features = text_features.cpu().numpy()[0] |
|
|
|
text_features = text_features / np.linalg.norm(text_features) |
|
except Exception as e: |
|
print(f"Error encoding prompt with CLIP: {e}") |
|
text_features = np.random.randn(512) |
|
else: |
|
|
|
text_features = np.random.randn(512) |
|
|
|
|
|
svg_content = self._generate_car_svg(prompt, text_features, num_paths, width, height) |
|
|
|
return svg_content |
|
|
|
def _generate_car_svg(self, prompt, features, num_paths=20, width=512, height=512): |
|
"""Generate a car-like SVG based on the prompt and features""" |
|
|
|
svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"> |
|
<rect width="100%" height="100%" fill="#f8f8f8"/> |
|
""" |
|
|
|
|
|
car_color_hue = int((features[0] + 1) * 180) % 360 |
|
car_size = 0.6 + 0.2 * features[1] |
|
car_style = int(abs(features[2] * 3)) % 3 |
|
|
|
|
|
car_width = int(width * 0.7 * car_size) |
|
car_height = int(height * 0.3 * car_size) |
|
car_x = (width - car_width) // 2 |
|
car_y = height // 2 |
|
|
|
|
|
if car_style == 0: |
|
|
|
svg_content += f"""<rect x="{car_x}" y="{car_y}" width="{car_width}" height="{car_height}" |
|
rx="20" ry="20" fill="hsl({car_color_hue}, 80%, 50%)" stroke="black" stroke-width="2" />""" |
|
|
|
|
|
windshield_width = car_width * 0.7 |
|
windshield_height = car_height * 0.5 |
|
windshield_x = car_x + (car_width - windshield_width) // 2 |
|
windshield_y = car_y - windshield_height * 0.3 |
|
svg_content += f"""<rect x="{windshield_x}" y="{windshield_y}" width="{windshield_width}" height="{windshield_height}" |
|
rx="10" ry="10" fill="#a8d8ff" stroke="black" stroke-width="1" />""" |
|
|
|
|
|
wheel_radius = car_height * 0.4 |
|
wheel_y = car_y + car_height * 0.8 |
|
svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />""" |
|
svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />""" |
|
svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />""" |
|
svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />""" |
|
|
|
elif car_style == 1: |
|
|
|
svg_content += f"""<rect x="{car_x}" y="{car_y - car_height * 0.3}" width="{car_width}" height="{car_height * 1.3}" |
|
rx="15" ry="15" fill="hsl({car_color_hue}, 80%, 50%)" stroke="black" stroke-width="2" />""" |
|
|
|
|
|
windshield_width = car_width * 0.6 |
|
windshield_height = car_height * 0.6 |
|
windshield_x = car_x + (car_width - windshield_width) // 2 |
|
windshield_y = car_y - car_height * 0.2 |
|
svg_content += f"""<rect x="{windshield_x}" y="{windshield_y}" width="{windshield_width}" height="{windshield_height}" |
|
rx="8" ry="8" fill="#a8d8ff" stroke="black" stroke-width="1" />""" |
|
|
|
|
|
wheel_radius = car_height * 0.45 |
|
wheel_y = car_y + car_height * 0.7 |
|
svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />""" |
|
svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius}" fill="black" />""" |
|
svg_content += f"""<circle cx="{car_x + car_width * 0.2}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />""" |
|
svg_content += f"""<circle cx="{car_x + car_width * 0.8}" cy="{wheel_y}" r="{wheel_radius * 0.6}" fill="#444" />""" |
|
|
|
else: |
|
|
|
svg_content += f"""<path d="M {car_x} {car_y + car_height * 0.5} |
|
C {car_x + car_width * 0.1} {car_y - car_height * 0.2}, |
|
{car_x + car_width * 0.3} {car_y - car_height * 0.3}, |
|
{car_x + car_width * 0.5} {car_y - car_height * 0.2} |
|
S {car_x + car_width * 0.9} {car_y}, |
|
{car_x + car_width} {car_y + car_height * 0.3} |
|
L {car_x + car_width} {car_y + car_height * 0.7} |
|
C {car_x + car_width * 0.9} {car_y + car_height}, |
|
{car_x + car_width * 0.1} {car_y + car_height}, |
|
{car_x} {car_y + car_height * 0.7} Z" |
|
fill="hsl({car_color_hue}, 90%, 45%)" stroke="black" stroke-width="2" />""" |
|
|
|
|
|
windshield_width = car_width * 0.4 |
|
windshield_x = car_x + car_width * 0.3 |
|
windshield_y = car_y - car_height * 0.1 |
|
svg_content += f"""<path d="M {windshield_x} {windshield_y} |
|
C {windshield_x + windshield_width * 0.1} {windshield_y - car_height * 0.15}, |
|
{windshield_x + windshield_width * 0.9} {windshield_y - car_height * 0.15}, |
|
{windshield_x + windshield_width} {windshield_y} Z" |
|
fill="#a8d8ff" stroke="black" stroke-width="1" />""" |
|
|
|
|
|
wheel_radius = car_height * 0.35 |
|
wheel_y = car_y + car_height * 0.7 |
|
svg_content += f"""<ellipse cx="{car_x + car_width * 0.2}" cy="{wheel_y}" rx="{wheel_radius * 1.2}" ry="{wheel_radius}" fill="black" />""" |
|
svg_content += f"""<ellipse cx="{car_x + car_width * 0.8}" cy="{wheel_y}" rx="{wheel_radius * 1.2}" ry="{wheel_radius}" fill="black" />""" |
|
svg_content += f"""<ellipse cx="{car_x + car_width * 0.2}" cy="{wheel_y}" rx="{wheel_radius * 0.7}" ry="{wheel_radius * 0.6}" fill="#444" />""" |
|
svg_content += f"""<ellipse cx="{car_x + car_width * 0.8}" cy="{wheel_y}" rx="{wheel_radius * 0.7}" ry="{wheel_radius * 0.6}" fill="#444" />""" |
|
|
|
|
|
headlight_radius = car_width * 0.05 |
|
headlight_y = car_y + car_height * 0.3 |
|
svg_content += f"""<circle cx="{car_x + car_width * 0.1}" cy="{headlight_y}" r="{headlight_radius}" fill="yellow" stroke="black" stroke-width="1" />""" |
|
svg_content += f"""<circle cx="{car_x + car_width * 0.9}" cy="{headlight_y}" r="{headlight_radius}" fill="yellow" stroke="black" stroke-width="1" />""" |
|
|
|
|
|
for i in range(min(10, len(features))): |
|
feature_val = features[i % len(features)] |
|
x = car_x + car_width * ((i / 10) * 0.8 + 0.1) |
|
y = car_y + car_height * ((feature_val + 1) / 4) |
|
size = car_width * 0.03 * abs(feature_val) |
|
svg_content += f"""<circle cx="{x}" cy="{y}" r="{size}" fill="rgba(0,0,0,0.2)" />""" |
|
|
|
|
|
svg_content += f"""<text x="{width/2}" y="{height - 20}" font-family="Arial" font-size="12" text-anchor="middle">{prompt}</text>""" |
|
|
|
|
|
svg_content += "</svg>" |
|
|
|
return svg_content |
|
|
|
def svg_to_png(self, svg_content): |
|
"""Convert SVG content to PNG""" |
|
try: |
|
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) |
|
return png_data |
|
except Exception as e: |
|
print(f"Error converting SVG to PNG: {e}") |
|
|
|
image = Image.new("RGB", (512, 512), color="#ff0000") |
|
from PIL import ImageDraw |
|
draw = ImageDraw.Draw(image) |
|
draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") |
|
|
|
|
|
buffer = io.BytesIO() |
|
image.save(buffer, format="PNG") |
|
return buffer.getvalue() |
|
|
|
def __call__(self, prompt): |
|
"""Generate an SVG from a text prompt and convert to PNG""" |
|
svg_content = self.generate_svg(prompt) |
|
png_data = self.svg_to_png(svg_content) |
|
|
|
|
|
image = Image.open(io.BytesIO(png_data)) |
|
|
|
|
|
response = { |
|
"svg": svg_content, |
|
"svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"), |
|
"png_base64": base64.b64encode(png_data).decode("utf-8"), |
|
"image": image |
|
} |
|
|
|
return response |