diffsketcher / handler.py
jree423's picture
Add: diffsketcher handler.py with original implementation
69b2d5c verified
raw
history blame
5.07 kB
import os
import io
import sys
import torch
import numpy as np
from PIL import Image
import traceback
# Add debug logging
def debug_log(message):
print(f"DEBUG: {message}")
sys.stdout.flush()
debug_log("Starting handler initialization")
# Safely import cairosvg with fallback
try:
import cairosvg
debug_log("Successfully imported cairosvg")
except ImportError:
debug_log("cairosvg not found. Installing...")
import subprocess
subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
import cairosvg
debug_log("Installed and imported cairosvg")
# Add the model directory to the path
sys.path.append('/code/diffsketcher')
# Try to import the model
try:
from models.clip_model import ClipModel
from models.diffusion_model import DiffusionModel
from models.sketch_model import SketchModel
debug_log("Successfully imported DiffSketcher models")
except ImportError as e:
debug_log(f"Error importing DiffSketcher models: {e}")
debug_log(traceback.format_exc())
raise ImportError(f"Failed to import DiffSketcher models: {e}")
class EndpointHandler:
def __init__(self, model_dir):
"""Initialize the handler with model directory"""
debug_log(f"Initializing handler with model_dir: {model_dir}")
self.model_dir = model_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
debug_log(f"Using device: {self.device}")
# Initialize the model
self.clip_model = ClipModel(device=self.device)
self.diffusion_model = DiffusionModel(device=self.device)
self.sketch_model = SketchModel(device=self.device)
# Load checkpoint if available
weights_path = os.path.join(model_dir, "checkpoint.pth")
if os.path.exists(weights_path):
debug_log(f"Loading checkpoint from {weights_path}")
checkpoint = torch.load(weights_path, map_location=self.device)
self.sketch_model.load_state_dict(checkpoint['sketch_model'])
debug_log("Successfully loaded checkpoint")
else:
debug_log(f"Checkpoint not found at {weights_path}, using model without pre-trained weights")
# Download the checkpoint if not available
try:
debug_log("Attempting to download checkpoint...")
import urllib.request
os.makedirs(os.path.dirname(weights_path), exist_ok=True)
urllib.request.urlretrieve(
"https://github.com/ximinng/DiffSketcher/releases/download/v0.1-weights/diffvg_checkpoint.pth",
weights_path
)
debug_log(f"Downloaded checkpoint to {weights_path}")
checkpoint = torch.load(weights_path, map_location=self.device)
self.sketch_model.load_state_dict(checkpoint['sketch_model'])
debug_log("Successfully loaded downloaded checkpoint")
except Exception as e:
debug_log(f"Error downloading checkpoint: {e}")
debug_log(traceback.format_exc())
debug_log("Continuing with uninitialized weights")
def generate_svg(self, prompt, width=512, height=512):
"""Generate an SVG from a text prompt"""
debug_log(f"Generating SVG for prompt: {prompt}")
# Generate SVG using DiffSketcher
text_features = self.clip_model.encode_text(prompt)
latent = self.diffusion_model.generate(text_features)
svg_data = self.sketch_model.generate(latent, num_paths=20, width=width, height=height)
debug_log("Generated SVG using DiffSketcher")
return svg_data
def __call__(self, data):
"""Handle a request to the model"""
try:
debug_log(f"Handling request: {data}")
# Extract the prompt
if isinstance(data, dict) and "inputs" in data:
if isinstance(data["inputs"], str):
prompt = data["inputs"]
elif isinstance(data["inputs"], dict) and "text" in data["inputs"]:
prompt = data["inputs"]["text"]
else:
prompt = "No prompt provided"
else:
prompt = "No prompt provided"
debug_log(f"Extracted prompt: {prompt}")
# Generate SVG
svg_content = self.generate_svg(prompt)
# Convert SVG to PNG
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
image = Image.open(io.BytesIO(png_data))
debug_log("Generated image from SVG")
# Return the PIL Image directly
debug_log("Returning image")
return image
except Exception as e:
debug_log(f"Error in handler: {e}")
debug_log(traceback.format_exc())
raise Exception(f"Error generating image: {str(e)}")