|
import os |
|
import io |
|
import sys |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import traceback |
|
|
|
|
|
def debug_log(message): |
|
print(f"DEBUG: {message}") |
|
sys.stdout.flush() |
|
|
|
debug_log("Starting handler initialization") |
|
|
|
|
|
try: |
|
import cairosvg |
|
debug_log("Successfully imported cairosvg") |
|
except ImportError: |
|
debug_log("cairosvg not found. Installing...") |
|
import subprocess |
|
subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"]) |
|
import cairosvg |
|
debug_log("Installed and imported cairosvg") |
|
|
|
|
|
sys.path.append('/code/diffsketcher') |
|
|
|
|
|
try: |
|
from models.clip_model import ClipModel |
|
from models.diffusion_model import DiffusionModel |
|
from models.sketch_model import SketchModel |
|
debug_log("Successfully imported DiffSketcher models") |
|
except ImportError as e: |
|
debug_log(f"Error importing DiffSketcher models: {e}") |
|
debug_log(traceback.format_exc()) |
|
raise ImportError(f"Failed to import DiffSketcher models: {e}") |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
"""Initialize the handler with model directory""" |
|
debug_log(f"Initializing handler with model_dir: {model_dir}") |
|
self.model_dir = model_dir |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
debug_log(f"Using device: {self.device}") |
|
|
|
|
|
self.clip_model = ClipModel(device=self.device) |
|
self.diffusion_model = DiffusionModel(device=self.device) |
|
self.sketch_model = SketchModel(device=self.device) |
|
|
|
|
|
weights_path = os.path.join(model_dir, "checkpoint.pth") |
|
if os.path.exists(weights_path): |
|
debug_log(f"Loading checkpoint from {weights_path}") |
|
checkpoint = torch.load(weights_path, map_location=self.device) |
|
self.sketch_model.load_state_dict(checkpoint['sketch_model']) |
|
debug_log("Successfully loaded checkpoint") |
|
else: |
|
debug_log(f"Checkpoint not found at {weights_path}, using model without pre-trained weights") |
|
|
|
try: |
|
debug_log("Attempting to download checkpoint...") |
|
import urllib.request |
|
os.makedirs(os.path.dirname(weights_path), exist_ok=True) |
|
urllib.request.urlretrieve( |
|
"https://github.com/ximinng/DiffSketcher/releases/download/v0.1-weights/diffvg_checkpoint.pth", |
|
weights_path |
|
) |
|
debug_log(f"Downloaded checkpoint to {weights_path}") |
|
checkpoint = torch.load(weights_path, map_location=self.device) |
|
self.sketch_model.load_state_dict(checkpoint['sketch_model']) |
|
debug_log("Successfully loaded downloaded checkpoint") |
|
except Exception as e: |
|
debug_log(f"Error downloading checkpoint: {e}") |
|
debug_log(traceback.format_exc()) |
|
debug_log("Continuing with uninitialized weights") |
|
|
|
def generate_svg(self, prompt, width=512, height=512): |
|
"""Generate an SVG from a text prompt""" |
|
debug_log(f"Generating SVG for prompt: {prompt}") |
|
|
|
|
|
text_features = self.clip_model.encode_text(prompt) |
|
latent = self.diffusion_model.generate(text_features) |
|
svg_data = self.sketch_model.generate(latent, num_paths=20, width=width, height=height) |
|
debug_log("Generated SVG using DiffSketcher") |
|
return svg_data |
|
|
|
def __call__(self, data): |
|
"""Handle a request to the model""" |
|
try: |
|
debug_log(f"Handling request: {data}") |
|
|
|
|
|
if isinstance(data, dict) and "inputs" in data: |
|
if isinstance(data["inputs"], str): |
|
prompt = data["inputs"] |
|
elif isinstance(data["inputs"], dict) and "text" in data["inputs"]: |
|
prompt = data["inputs"]["text"] |
|
else: |
|
prompt = "No prompt provided" |
|
else: |
|
prompt = "No prompt provided" |
|
|
|
debug_log(f"Extracted prompt: {prompt}") |
|
|
|
|
|
svg_content = self.generate_svg(prompt) |
|
|
|
|
|
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) |
|
image = Image.open(io.BytesIO(png_data)) |
|
debug_log("Generated image from SVG") |
|
|
|
|
|
debug_log("Returning image") |
|
return image |
|
except Exception as e: |
|
debug_log(f"Error in handler: {e}") |
|
debug_log(traceback.format_exc()) |
|
raise Exception(f"Error generating image: {str(e)}") |