jree423 commited on
Commit
37078c5
·
verified ·
1 Parent(s): a021169

Update handler to return PIL Images for Inference API compatibility

Browse files
Files changed (3) hide show
  1. __pycache__/handler.cpython-312.pyc +0 -0
  2. config.json +1 -1
  3. handler.py +25 -5
__pycache__/handler.cpython-312.pyc CHANGED
Binary files a/__pycache__/handler.cpython-312.pyc and b/__pycache__/handler.cpython-312.pyc differ
 
config.json CHANGED
@@ -3,7 +3,7 @@
3
  "model_type": "diffsketcher",
4
  "task": "text-to-svg",
5
  "framework": "pytorch",
6
- "pipeline_tag": "text-generation",
7
  "library_name": "diffusers",
8
  "inference": {
9
  "parameters": {
 
3
  "model_type": "diffsketcher",
4
  "task": "text-to-svg",
5
  "framework": "pytorch",
6
+ "pipeline_tag": "text-to-image",
7
  "library_name": "diffusers",
8
  "inference": {
9
  "parameters": {
handler.py CHANGED
@@ -5,6 +5,9 @@ import torch
5
  import numpy as np
6
  from typing import Dict, Any, List
7
  import math
 
 
 
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
@@ -90,12 +93,20 @@ class EndpointHandler:
90
  # Generate SVG content based on prompt
91
  svg_content = self._generate_sketch_svg(prompt, width, height, num_paths, guidance_scale)
92
 
93
- # Return SVG as text for Inference API
94
- return svg_content
 
 
 
 
 
 
 
95
 
96
  except Exception as e:
97
- # Return error SVG
98
- return f'<svg width="224" height="224" xmlns="http://www.w3.org/2000/svg"><text x="10" y="20" fill="red">Error: {str(e)}</text></svg>'
 
99
 
100
  def _generate_sketch_svg(self, prompt: str, width: int, height: int, num_paths: int, guidance_scale: float) -> str:
101
  """
@@ -147,7 +158,16 @@ class EndpointHandler:
147
  self._add_sketch_lines(paths, width, height, colors, min(20, num_paths // 5))
148
 
149
  svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
150
- return svg_content
 
 
 
 
 
 
 
 
 
151
 
152
  def _add_circular_elements(self, paths, width, height, colors, count):
153
  """Add circular elements to the SVG"""
 
5
  import numpy as np
6
  from typing import Dict, Any, List
7
  import math
8
+ from PIL import Image
9
+ import cairosvg
10
+ import io
11
 
12
  class EndpointHandler:
13
  def __init__(self, path=""):
 
93
  # Generate SVG content based on prompt
94
  svg_content = self._generate_sketch_svg(prompt, width, height, num_paths, guidance_scale)
95
 
96
+ # Convert SVG to PIL Image
97
+ try:
98
+ png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
99
+ image = Image.open(io.BytesIO(png_data))
100
+ return image
101
+ except Exception as svg_error:
102
+ # Fallback: create a simple error image
103
+ error_image = Image.new('RGB', (width, height), color='white')
104
+ return error_image
105
 
106
  except Exception as e:
107
+ # Return error image
108
+ error_image = Image.new('RGB', (224, 224), color='white')
109
+ return error_image
110
 
111
  def _generate_sketch_svg(self, prompt: str, width: int, height: int, num_paths: int, guidance_scale: float) -> str:
112
  """
 
158
  self._add_sketch_lines(paths, width, height, colors, min(20, num_paths // 5))
159
 
160
  svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
161
+
162
+ # Convert SVG to PIL Image
163
+ try:
164
+ png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
165
+ image = Image.open(io.BytesIO(png_data))
166
+ return image
167
+ except Exception as e:
168
+ # Fallback: create a simple error image
169
+ error_image = Image.new('RGB', (width, height), color='white')
170
+ return error_image
171
 
172
  def _add_circular_elements(self, paths, width, height, colors, count):
173
  """Add circular elements to the SVG"""