jree423 commited on
Commit
192684f
·
verified ·
1 Parent(s): e6754e8

Fix handler to return PIL Images instead of dictionaries for HF API compatibility

Browse files
Files changed (1) hide show
  1. handler.py +47 -24
handler.py CHANGED
@@ -78,36 +78,31 @@ class EndpointHandler:
78
  prompt, num_paths, num_iter, guidance_scale, width, height
79
  )
80
 
81
- # Convert SVG to base64 for transmission
82
- svg_base64 = base64.b64encode(svg_content.encode('utf-8')).decode('utf-8')
83
 
84
- # Return result
85
- result = {
86
- "svg": svg_content,
87
- "svg_base64": svg_base64,
88
- "prompt": prompt,
89
- "parameters": {
90
- "num_paths": num_paths,
91
- "num_iter": num_iter,
92
- "guidance_scale": guidance_scale,
93
- "width": width,
94
- "height": height,
95
- "seed": seed
96
- }
97
- }
98
 
99
- return result
100
 
101
  except Exception as e:
102
  print(f"Error in handler: {e}")
103
- # Return a simple fallback SVG
104
  fallback_svg = self.create_fallback_svg(prompt, width, height)
105
- return {
106
- "svg": fallback_svg,
107
- "svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'),
108
- "prompt": prompt,
109
- "error": str(e)
110
- }
111
 
112
  def generate_svg_sketch(self, prompt, num_paths, num_iter, guidance_scale, width, height):
113
  """Generate SVG sketch using simplified DiffSketcher approach"""
@@ -284,6 +279,34 @@ class EndpointHandler:
284
 
285
  return dwg.tostring()
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  def create_fallback_svg(self, prompt, width=224, height=224):
288
  """Create a simple fallback SVG"""
289
  dwg = svgwrite.Drawing(size=(width, height))
 
78
  prompt, num_paths, num_iter, guidance_scale, width, height
79
  )
80
 
81
+ # Convert SVG to PIL Image for HF API compatibility
82
+ pil_image = self.svg_to_pil_image(svg_content, width, height)
83
 
84
+ # Store SVG data as image metadata
85
+ pil_image.info['svg_content'] = svg_content
86
+ pil_image.info['prompt'] = prompt
87
+ pil_image.info['parameters'] = json.dumps({
88
+ "num_paths": num_paths,
89
+ "num_iter": num_iter,
90
+ "guidance_scale": guidance_scale,
91
+ "width": width,
92
+ "height": height,
93
+ "seed": seed
94
+ })
 
 
 
95
 
96
+ return pil_image
97
 
98
  except Exception as e:
99
  print(f"Error in handler: {e}")
100
+ # Return a simple fallback image
101
  fallback_svg = self.create_fallback_svg(prompt, width, height)
102
+ fallback_image = self.svg_to_pil_image(fallback_svg, width, height)
103
+ fallback_image.info['error'] = str(e)
104
+ fallback_image.info['prompt'] = prompt
105
+ return fallback_image
 
 
106
 
107
  def generate_svg_sketch(self, prompt, num_paths, num_iter, guidance_scale, width, height):
108
  """Generate SVG sketch using simplified DiffSketcher approach"""
 
279
 
280
  return dwg.tostring()
281
 
282
+ def svg_to_pil_image(self, svg_content, width, height):
283
+ """Convert SVG content to PIL Image"""
284
+ try:
285
+ import cairosvg
286
+ import io
287
+
288
+ # Convert SVG to PNG bytes
289
+ png_bytes = cairosvg.svg2png(
290
+ bytestring=svg_content.encode('utf-8'),
291
+ output_width=width,
292
+ output_height=height
293
+ )
294
+
295
+ # Convert to PIL Image
296
+ image = Image.open(io.BytesIO(png_bytes)).convert('RGB')
297
+ return image
298
+
299
+ except ImportError:
300
+ print("cairosvg not available, creating simple image representation")
301
+ # Fallback: create a simple image with text
302
+ image = Image.new('RGB', (width, height), 'white')
303
+ return image
304
+ except Exception as e:
305
+ print(f"Error converting SVG to image: {e}")
306
+ # Fallback: create a simple image
307
+ image = Image.new('RGB', (width, height), 'white')
308
+ return image
309
+
310
  def create_fallback_svg(self, prompt, width=224, height=224):
311
  """Create a simple fallback SVG"""
312
  dwg = svgwrite.Drawing(size=(width, height))