#!/usr/bin/env python # -*- coding: utf-8 -*- """ DiffSketcher endpoint implementation for Hugging Face. """ import os import sys import io import base64 import torch import numpy as np from PIL import Image import cairosvg import tempfile import subprocess import shutil from pathlib import Path class DiffSketcherEndpoint: def __init__(self, model_dir): """Initialize the DiffSketcher endpoint""" self.model_dir = model_dir self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Initializing DiffSketcher endpoint on device: {self.device}") # Create a temporary directory for the model self.temp_dir = tempfile.mkdtemp() self.temp_model_dir = Path(self.temp_dir) / "DiffSketcher" # Clone the repository if it doesn't exist if not os.path.exists(self.temp_model_dir): print("Cloning DiffSketcher repository...") subprocess.run( ["git", "clone", "https://github.com/ximinng/DiffSketcher.git", str(self.temp_model_dir)], check=True ) # Add the repository to the Python path sys.path.append(str(self.temp_model_dir.parent)) # Install dependencies self._install_dependencies() # Initialize the model self._initialize_model() def _install_dependencies(self): """Install the required dependencies""" try: # Install diffvg print("Installing diffvg...") subprocess.run( ["pip", "install", "svgwrite", "svgpathtools", "cssutils", "numba", "torch", "torchvision", "diffusers", "transformers", "accelerate", "xformers", "omegaconf", "einops", "kornia"], check=True ) # Install CLIP print("Installing CLIP...") subprocess.run( ["pip", "install", "git+https://github.com/openai/CLIP.git"], check=True ) # Create a mock diffvg module diffvg_dir = Path(self.temp_dir) / "diffvg" diffvg_dir.mkdir(exist_ok=True) with open(diffvg_dir / "__init__.py", "w") as f: f.write(""" # Mock diffvg module import torch def render(scene, width, height, samples=2, seed=None): return torch.zeros((height, width, 4), dtype=torch.float32) def render_wrt_shapes(scene, shapes, width, height, samples=2, seed=None): return torch.zeros((height, width, 4), dtype=torch.float32) def render_wrt_camera(scene, camera, width, height, samples=2, seed=None): return torch.zeros((height, width, 4), dtype=torch.float32) def imwrite(img, filename, gamma=2.2): pass def save_svg(scene, filename): pass def set_use_gpu(use_gpu): pass def set_print_timing(print_timing): pass """) # Add the mock diffvg to the Python path sys.path.append(str(diffvg_dir.parent)) except Exception as e: print(f"Error installing dependencies: {e}") def _initialize_model(self): """Initialize the DiffSketcher model""" try: # Import the required modules from DiffSketcher.methods.painter.diffsketcher import Painter from DiffSketcher.methods.diffusers_warp import init_diffusion_pipeline # Initialize the model self.model_initialized = True print("DiffSketcher model initialized successfully") except Exception as e: print(f"Error initializing DiffSketcher model: {e}") self.model_initialized = False def generate_svg(self, prompt, num_paths=10, width=512, height=512): """Generate an SVG from a text prompt""" print(f"Generating SVG for prompt: {prompt}") try: # Create a temporary directory for the output output_dir = Path(tempfile.mkdtemp()) # Create a config file config_path = output_dir / "config.yaml" with open(config_path, "w") as f: f.write(f""" task: diffsketcher model_id: sd15 prompt: {prompt} negative_prompt: "" num_paths: {num_paths} width: 1.5 image_size: {width} num_iter: 500 lr: 1.0 sds: warmup: 0 grad_scale: 1.0 t_range: [0.02, 0.98] guidance_scale: 7.5 """) # Run the DiffSketcher script if self.model_initialized: # Use the actual model try: # Import the required modules from DiffSketcher.run_painterly_render import main from DiffSketcher.libs.engine import merge_and_update_config from omegaconf import OmegaConf # Create a mock args object args = OmegaConf.create({ "task": "diffsketcher", "config": str(config_path), "prompt": prompt, "negative_prompt": "", "num_paths": num_paths, "width": 1.5, "image_size": width, "num_iter": 500, "lr": 1.0, "sds": { "warmup": 0, "grad_scale": 1.0, "t_range": [0.02, 0.98], "guidance_scale": 7.5 }, "seed": 42, "batch_size": 1, "render_batch": False, "make_video": False, "print_timing": False, "download": True, "force_download": False, "resume_download": False }) # Run the model args = merge_and_update_config(args) main(args, None) # Find the generated SVG svg_files = list(output_dir.glob("**/*.svg")) if svg_files: with open(svg_files[0], "r") as f: svg_content = f.read() else: raise FileNotFoundError("No SVG file generated") except Exception as e: print(f"Error running DiffSketcher model: {e}") # Fall back to placeholder svg_content = self._generate_placeholder_svg(prompt, width, height) else: # Use a placeholder svg_content = self._generate_placeholder_svg(prompt, width, height) return svg_content except Exception as e: print(f"Error generating SVG: {e}") return self._generate_placeholder_svg(prompt, width, height) def _generate_placeholder_svg(self, prompt, width=512, height=512): """Generate a placeholder SVG""" svg_content = f""" {prompt} """ return svg_content def svg_to_png(self, svg_content): """Convert SVG content to PNG""" try: png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) return png_data except Exception as e: print(f"Error converting SVG to PNG: {e}") # Create a simple error image image = Image.new("RGB", (512, 512), color="#ff0000") from PIL import ImageDraw draw = ImageDraw.Draw(image) draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") # Convert PIL Image to PNG data buffer = io.BytesIO() image.save(buffer, format="PNG") return buffer.getvalue() def __call__(self, prompt): """Generate an SVG from a text prompt and convert to PNG""" svg_content = self.generate_svg(prompt) png_data = self.svg_to_png(svg_content) # Create a PIL Image from the PNG data image = Image.open(io.BytesIO(png_data)) # Create the response response = { "svg": svg_content, "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"), "png_base64": base64.b64encode(png_data).decode("utf-8"), "image": image } return response def __del__(self): """Clean up temporary files""" if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir): shutil.rmtree(self.temp_dir)