File size: 5,070 Bytes
ffc93fc f8b22af 8d52022 ffc93fc c198636 6d9283b 8d52022 55dc40f 8d52022 55dc40f 8d52022 6d9283b 55dc40f 8d52022 1d1055f 6d9283b 69b2d5c 6d9283b 54f01ed 69b2d5c 6d9283b 69b2d5c 6d9283b 69b2d5c 6d9283b 69b2d5c 1d1055f 8d52022 69b2d5c 54f01ed 8d52022 54f01ed f8b22af d87b721 54f01ed 8d52022 69b2d5c 54f01ed 7739491 8d52022 7739491 55dc40f 8d52022 69b2d5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import io
import sys
import torch
import numpy as np
from PIL import Image
import traceback
# Add debug logging
def debug_log(message):
print(f"DEBUG: {message}")
sys.stdout.flush()
debug_log("Starting handler initialization")
# Safely import cairosvg with fallback
try:
import cairosvg
debug_log("Successfully imported cairosvg")
except ImportError:
debug_log("cairosvg not found. Installing...")
import subprocess
subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
import cairosvg
debug_log("Installed and imported cairosvg")
# Add the model directory to the path
sys.path.append('/code/diffsketcher')
# Try to import the model
try:
from models.clip_model import ClipModel
from models.diffusion_model import DiffusionModel
from models.sketch_model import SketchModel
debug_log("Successfully imported DiffSketcher models")
except ImportError as e:
debug_log(f"Error importing DiffSketcher models: {e}")
debug_log(traceback.format_exc())
raise ImportError(f"Failed to import DiffSketcher models: {e}")
class EndpointHandler:
def __init__(self, model_dir):
"""Initialize the handler with model directory"""
debug_log(f"Initializing handler with model_dir: {model_dir}")
self.model_dir = model_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
debug_log(f"Using device: {self.device}")
# Initialize the model
self.clip_model = ClipModel(device=self.device)
self.diffusion_model = DiffusionModel(device=self.device)
self.sketch_model = SketchModel(device=self.device)
# Load checkpoint if available
weights_path = os.path.join(model_dir, "checkpoint.pth")
if os.path.exists(weights_path):
debug_log(f"Loading checkpoint from {weights_path}")
checkpoint = torch.load(weights_path, map_location=self.device)
self.sketch_model.load_state_dict(checkpoint['sketch_model'])
debug_log("Successfully loaded checkpoint")
else:
debug_log(f"Checkpoint not found at {weights_path}, using model without pre-trained weights")
# Download the checkpoint if not available
try:
debug_log("Attempting to download checkpoint...")
import urllib.request
os.makedirs(os.path.dirname(weights_path), exist_ok=True)
urllib.request.urlretrieve(
"https://github.com/ximinng/DiffSketcher/releases/download/v0.1-weights/diffvg_checkpoint.pth",
weights_path
)
debug_log(f"Downloaded checkpoint to {weights_path}")
checkpoint = torch.load(weights_path, map_location=self.device)
self.sketch_model.load_state_dict(checkpoint['sketch_model'])
debug_log("Successfully loaded downloaded checkpoint")
except Exception as e:
debug_log(f"Error downloading checkpoint: {e}")
debug_log(traceback.format_exc())
debug_log("Continuing with uninitialized weights")
def generate_svg(self, prompt, width=512, height=512):
"""Generate an SVG from a text prompt"""
debug_log(f"Generating SVG for prompt: {prompt}")
# Generate SVG using DiffSketcher
text_features = self.clip_model.encode_text(prompt)
latent = self.diffusion_model.generate(text_features)
svg_data = self.sketch_model.generate(latent, num_paths=20, width=width, height=height)
debug_log("Generated SVG using DiffSketcher")
return svg_data
def __call__(self, data):
"""Handle a request to the model"""
try:
debug_log(f"Handling request: {data}")
# Extract the prompt
if isinstance(data, dict) and "inputs" in data:
if isinstance(data["inputs"], str):
prompt = data["inputs"]
elif isinstance(data["inputs"], dict) and "text" in data["inputs"]:
prompt = data["inputs"]["text"]
else:
prompt = "No prompt provided"
else:
prompt = "No prompt provided"
debug_log(f"Extracted prompt: {prompt}")
# Generate SVG
svg_content = self.generate_svg(prompt)
# Convert SVG to PNG
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
image = Image.open(io.BytesIO(png_data))
debug_log("Generated image from SVG")
# Return the PIL Image directly
debug_log("Returning image")
return image
except Exception as e:
debug_log(f"Error in handler: {e}")
debug_log(traceback.format_exc())
raise Exception(f"Error generating image: {str(e)}") |