File size: 5,070 Bytes
ffc93fc
f8b22af
8d52022
ffc93fc
 
c198636
6d9283b
8d52022
 
 
 
 
 
 
55dc40f
 
 
 
8d52022
55dc40f
8d52022
6d9283b
55dc40f
 
8d52022
1d1055f
6d9283b
 
 
 
 
 
 
 
 
 
 
 
69b2d5c
6d9283b
54f01ed
 
 
69b2d5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d9283b
69b2d5c
 
 
 
 
 
 
 
 
 
 
6d9283b
69b2d5c
6d9283b
69b2d5c
1d1055f
8d52022
 
 
 
69b2d5c
 
 
 
 
 
54f01ed
 
 
 
8d52022
 
54f01ed
 
 
 
 
 
 
 
f8b22af
d87b721
54f01ed
8d52022
 
 
 
 
 
69b2d5c
 
 
54f01ed
7739491
8d52022
7739491
55dc40f
8d52022
 
69b2d5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import io
import sys
import torch
import numpy as np
from PIL import Image
import traceback

# Add debug logging
def debug_log(message):
    print(f"DEBUG: {message}")
    sys.stdout.flush()

debug_log("Starting handler initialization")

# Safely import cairosvg with fallback
try:
    import cairosvg
    debug_log("Successfully imported cairosvg")
except ImportError:
    debug_log("cairosvg not found. Installing...")
    import subprocess
    subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
    import cairosvg
    debug_log("Installed and imported cairosvg")

# Add the model directory to the path
sys.path.append('/code/diffsketcher')

# Try to import the model
try:
    from models.clip_model import ClipModel
    from models.diffusion_model import DiffusionModel
    from models.sketch_model import SketchModel
    debug_log("Successfully imported DiffSketcher models")
except ImportError as e:
    debug_log(f"Error importing DiffSketcher models: {e}")
    debug_log(traceback.format_exc())
    raise ImportError(f"Failed to import DiffSketcher models: {e}")

class EndpointHandler:
    def __init__(self, model_dir):
        """Initialize the handler with model directory"""
        debug_log(f"Initializing handler with model_dir: {model_dir}")
        self.model_dir = model_dir
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        debug_log(f"Using device: {self.device}")
        
        # Initialize the model
        self.clip_model = ClipModel(device=self.device)
        self.diffusion_model = DiffusionModel(device=self.device)
        self.sketch_model = SketchModel(device=self.device)
        
        # Load checkpoint if available
        weights_path = os.path.join(model_dir, "checkpoint.pth")
        if os.path.exists(weights_path):
            debug_log(f"Loading checkpoint from {weights_path}")
            checkpoint = torch.load(weights_path, map_location=self.device)
            self.sketch_model.load_state_dict(checkpoint['sketch_model'])
            debug_log("Successfully loaded checkpoint")
        else:
            debug_log(f"Checkpoint not found at {weights_path}, using model without pre-trained weights")
            # Download the checkpoint if not available
            try:
                debug_log("Attempting to download checkpoint...")
                import urllib.request
                os.makedirs(os.path.dirname(weights_path), exist_ok=True)
                urllib.request.urlretrieve(
                    "https://github.com/ximinng/DiffSketcher/releases/download/v0.1-weights/diffvg_checkpoint.pth",
                    weights_path
                )
                debug_log(f"Downloaded checkpoint to {weights_path}")
                checkpoint = torch.load(weights_path, map_location=self.device)
                self.sketch_model.load_state_dict(checkpoint['sketch_model'])
                debug_log("Successfully loaded downloaded checkpoint")
            except Exception as e:
                debug_log(f"Error downloading checkpoint: {e}")
                debug_log(traceback.format_exc())
                debug_log("Continuing with uninitialized weights")
    
    def generate_svg(self, prompt, width=512, height=512):
        """Generate an SVG from a text prompt"""
        debug_log(f"Generating SVG for prompt: {prompt}")
        
        # Generate SVG using DiffSketcher
        text_features = self.clip_model.encode_text(prompt)
        latent = self.diffusion_model.generate(text_features)
        svg_data = self.sketch_model.generate(latent, num_paths=20, width=width, height=height)
        debug_log("Generated SVG using DiffSketcher")
        return svg_data
    
    def __call__(self, data):
        """Handle a request to the model"""
        try:
            debug_log(f"Handling request: {data}")
            
            # Extract the prompt
            if isinstance(data, dict) and "inputs" in data:
                if isinstance(data["inputs"], str):
                    prompt = data["inputs"]
                elif isinstance(data["inputs"], dict) and "text" in data["inputs"]:
                    prompt = data["inputs"]["text"]
                else:
                    prompt = "No prompt provided"
            else:
                prompt = "No prompt provided"
            
            debug_log(f"Extracted prompt: {prompt}")
            
            # Generate SVG
            svg_content = self.generate_svg(prompt)
            
            # Convert SVG to PNG
            png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
            image = Image.open(io.BytesIO(png_data))
            debug_log("Generated image from SVG")
            
            # Return the PIL Image directly
            debug_log("Returning image")
            return image
        except Exception as e:
            debug_log(f"Error in handler: {e}")
            debug_log(traceback.format_exc())
            raise Exception(f"Error generating image: {str(e)}")