jree423 commited on
Commit
69b2d5c
·
verified ·
1 Parent(s): 40f5d6e

Add: diffsketcher handler.py with original implementation

Browse files
Files changed (1) hide show
  1. handler.py +44 -85
handler.py CHANGED
@@ -36,83 +36,58 @@ try:
36
  except ImportError as e:
37
  debug_log(f"Error importing DiffSketcher models: {e}")
38
  debug_log(traceback.format_exc())
 
39
 
40
  class EndpointHandler:
41
  def __init__(self, model_dir):
42
  """Initialize the handler with model directory"""
43
- try:
44
- debug_log(f"Initializing handler with model_dir: {model_dir}")
45
- self.model_dir = model_dir
46
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
- debug_log(f"Using device: {self.device}")
48
-
49
- # Initialize the model
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
- self.clip_model = ClipModel(device=self.device)
52
- self.diffusion_model = DiffusionModel(device=self.device)
53
- self.sketch_model = SketchModel(device=self.device)
54
-
55
- # Load checkpoint if available
56
- weights_path = os.path.join(model_dir, "checkpoint.pth")
57
- if os.path.exists(weights_path):
58
- debug_log(f"Loading checkpoint from {weights_path}")
59
- checkpoint = torch.load(weights_path, map_location=self.device)
60
- self.sketch_model.load_state_dict(checkpoint['sketch_model'])
61
- debug_log("Successfully loaded checkpoint")
62
- self.use_model = True
63
- else:
64
- debug_log(f"Checkpoint not found at {weights_path}, using model without pre-trained weights")
65
- self.use_model = True
66
  except Exception as e:
67
- debug_log(f"Error initializing model: {e}")
68
  debug_log(traceback.format_exc())
69
- self.use_model = False
70
- except Exception as e:
71
- debug_log(f"Error in handler initialization: {e}")
72
- debug_log(traceback.format_exc())
73
- self.use_model = False
74
 
75
  def generate_svg(self, prompt, width=512, height=512):
76
  """Generate an SVG from a text prompt"""
77
  debug_log(f"Generating SVG for prompt: {prompt}")
78
 
79
- if self.use_model:
80
- try:
81
- debug_log("Using initialized model")
82
-
83
- # Generate SVG using DiffSketcher
84
- text_features = self.clip_model.encode_text(prompt)
85
- latent = self.diffusion_model.generate(text_features)
86
- svg_data = self.sketch_model.generate(latent, num_paths=20, width=width, height=height)
87
- debug_log("Generated SVG using DiffSketcher")
88
- return svg_data
89
- except Exception as e:
90
- debug_log(f"Error generating SVG with model: {e}")
91
- debug_log(traceback.format_exc())
92
- return self._generate_placeholder_svg(prompt, width, height)
93
- else:
94
- debug_log("Model not initialized, using placeholder")
95
- return self._generate_placeholder_svg(prompt, width, height)
96
-
97
- def _generate_placeholder_svg(self, prompt, width=512, height=512):
98
- """Generate a placeholder SVG"""
99
- debug_log(f"Generating placeholder SVG for prompt: {prompt}")
100
-
101
- # Create a more interesting placeholder that looks like a sketch
102
- svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
103
- <rect width="100%" height="100%" fill="#ffffff"/>
104
- <g stroke="#000000" fill="none">
105
- <!-- Draw a simple sketch based on the prompt -->
106
- <circle cx="{width/2}" cy="{height/2}" r="{min(width, height)/4}" stroke-width="2"/>
107
- <ellipse cx="{width/2}" cy="{height/2}" rx="{width/3}" ry="{height/4}" stroke-width="1.5"/>
108
- <path d="M {width/4} {height/4} Q {width/2} {height/8} {3*width/4} {height/4}" stroke-width="2"/>
109
- <path d="M {width/4} {3*height/4} Q {width/2} {7*height/8} {3*width/4} {3*height/4}" stroke-width="2"/>
110
- </g>
111
- <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" fill="#333333">{prompt}</text>
112
- </svg>"""
113
-
114
- debug_log("Generated placeholder SVG")
115
- return svg_content
116
 
117
  def __call__(self, data):
118
  """Handle a request to the model"""
@@ -136,19 +111,9 @@ class EndpointHandler:
136
  svg_content = self.generate_svg(prompt)
137
 
138
  # Convert SVG to PNG
139
- try:
140
- png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
141
- image = Image.open(io.BytesIO(png_data))
142
- debug_log("Generated image from SVG")
143
- except Exception as e:
144
- debug_log(f"Error converting SVG to PNG: {e}")
145
- debug_log(traceback.format_exc())
146
- # Create a simple placeholder image
147
- image = Image.new("RGB", (512, 512), color="#f0f0f0")
148
- from PIL import ImageDraw
149
- draw = ImageDraw.Draw(image)
150
- draw.text((256, 256), prompt, fill="black", anchor="mm")
151
- debug_log("Created placeholder image")
152
 
153
  # Return the PIL Image directly
154
  debug_log("Returning image")
@@ -156,10 +121,4 @@ class EndpointHandler:
156
  except Exception as e:
157
  debug_log(f"Error in handler: {e}")
158
  debug_log(traceback.format_exc())
159
- # Return a simple error image
160
- image = Image.new("RGB", (512, 512), color="#ff0000")
161
- from PIL import ImageDraw
162
- draw = ImageDraw.Draw(image)
163
- draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
164
- debug_log("Returning error image")
165
- return image
 
36
  except ImportError as e:
37
  debug_log(f"Error importing DiffSketcher models: {e}")
38
  debug_log(traceback.format_exc())
39
+ raise ImportError(f"Failed to import DiffSketcher models: {e}")
40
 
41
  class EndpointHandler:
42
  def __init__(self, model_dir):
43
  """Initialize the handler with model directory"""
44
+ debug_log(f"Initializing handler with model_dir: {model_dir}")
45
+ self.model_dir = model_dir
46
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ debug_log(f"Using device: {self.device}")
48
+
49
+ # Initialize the model
50
+ self.clip_model = ClipModel(device=self.device)
51
+ self.diffusion_model = DiffusionModel(device=self.device)
52
+ self.sketch_model = SketchModel(device=self.device)
53
+
54
+ # Load checkpoint if available
55
+ weights_path = os.path.join(model_dir, "checkpoint.pth")
56
+ if os.path.exists(weights_path):
57
+ debug_log(f"Loading checkpoint from {weights_path}")
58
+ checkpoint = torch.load(weights_path, map_location=self.device)
59
+ self.sketch_model.load_state_dict(checkpoint['sketch_model'])
60
+ debug_log("Successfully loaded checkpoint")
61
+ else:
62
+ debug_log(f"Checkpoint not found at {weights_path}, using model without pre-trained weights")
63
+ # Download the checkpoint if not available
64
  try:
65
+ debug_log("Attempting to download checkpoint...")
66
+ import urllib.request
67
+ os.makedirs(os.path.dirname(weights_path), exist_ok=True)
68
+ urllib.request.urlretrieve(
69
+ "https://github.com/ximinng/DiffSketcher/releases/download/v0.1-weights/diffvg_checkpoint.pth",
70
+ weights_path
71
+ )
72
+ debug_log(f"Downloaded checkpoint to {weights_path}")
73
+ checkpoint = torch.load(weights_path, map_location=self.device)
74
+ self.sketch_model.load_state_dict(checkpoint['sketch_model'])
75
+ debug_log("Successfully loaded downloaded checkpoint")
 
 
 
 
76
  except Exception as e:
77
+ debug_log(f"Error downloading checkpoint: {e}")
78
  debug_log(traceback.format_exc())
79
+ debug_log("Continuing with uninitialized weights")
 
 
 
 
80
 
81
  def generate_svg(self, prompt, width=512, height=512):
82
  """Generate an SVG from a text prompt"""
83
  debug_log(f"Generating SVG for prompt: {prompt}")
84
 
85
+ # Generate SVG using DiffSketcher
86
+ text_features = self.clip_model.encode_text(prompt)
87
+ latent = self.diffusion_model.generate(text_features)
88
+ svg_data = self.sketch_model.generate(latent, num_paths=20, width=width, height=height)
89
+ debug_log("Generated SVG using DiffSketcher")
90
+ return svg_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def __call__(self, data):
93
  """Handle a request to the model"""
 
111
  svg_content = self.generate_svg(prompt)
112
 
113
  # Convert SVG to PNG
114
+ png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
115
+ image = Image.open(io.BytesIO(png_data))
116
+ debug_log("Generated image from SVG")
 
 
 
 
 
 
 
 
 
 
117
 
118
  # Return the PIL Image directly
119
  debug_log("Returning image")
 
121
  except Exception as e:
122
  debug_log(f"Error in handler: {e}")
123
  debug_log(traceback.format_exc())
124
+ raise Exception(f"Error generating image: {str(e)}")