Update handler to return PIL Images for Inference API compatibility
Browse files- __pycache__/handler.cpython-312.pyc +0 -0
- config.json +1 -1
- handler.py +25 -5
__pycache__/handler.cpython-312.pyc
CHANGED
Binary files a/__pycache__/handler.cpython-312.pyc and b/__pycache__/handler.cpython-312.pyc differ
|
|
config.json
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
"model_type": "diffsketcher",
|
4 |
"task": "text-to-svg",
|
5 |
"framework": "pytorch",
|
6 |
-
"pipeline_tag": "text-
|
7 |
"library_name": "diffusers",
|
8 |
"inference": {
|
9 |
"parameters": {
|
|
|
3 |
"model_type": "diffsketcher",
|
4 |
"task": "text-to-svg",
|
5 |
"framework": "pytorch",
|
6 |
+
"pipeline_tag": "text-to-image",
|
7 |
"library_name": "diffusers",
|
8 |
"inference": {
|
9 |
"parameters": {
|
handler.py
CHANGED
@@ -5,6 +5,9 @@ import torch
|
|
5 |
import numpy as np
|
6 |
from typing import Dict, Any, List
|
7 |
import math
|
|
|
|
|
|
|
8 |
|
9 |
class EndpointHandler:
|
10 |
def __init__(self, path=""):
|
@@ -90,12 +93,20 @@ class EndpointHandler:
|
|
90 |
# Generate SVG content based on prompt
|
91 |
svg_content = self._generate_sketch_svg(prompt, width, height, num_paths, guidance_scale)
|
92 |
|
93 |
-
#
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
except Exception as e:
|
97 |
-
# Return error
|
98 |
-
|
|
|
99 |
|
100 |
def _generate_sketch_svg(self, prompt: str, width: int, height: int, num_paths: int, guidance_scale: float) -> str:
|
101 |
"""
|
@@ -147,7 +158,16 @@ class EndpointHandler:
|
|
147 |
self._add_sketch_lines(paths, width, height, colors, min(20, num_paths // 5))
|
148 |
|
149 |
svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
def _add_circular_elements(self, paths, width, height, colors, count):
|
153 |
"""Add circular elements to the SVG"""
|
|
|
5 |
import numpy as np
|
6 |
from typing import Dict, Any, List
|
7 |
import math
|
8 |
+
from PIL import Image
|
9 |
+
import cairosvg
|
10 |
+
import io
|
11 |
|
12 |
class EndpointHandler:
|
13 |
def __init__(self, path=""):
|
|
|
93 |
# Generate SVG content based on prompt
|
94 |
svg_content = self._generate_sketch_svg(prompt, width, height, num_paths, guidance_scale)
|
95 |
|
96 |
+
# Convert SVG to PIL Image
|
97 |
+
try:
|
98 |
+
png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
|
99 |
+
image = Image.open(io.BytesIO(png_data))
|
100 |
+
return image
|
101 |
+
except Exception as svg_error:
|
102 |
+
# Fallback: create a simple error image
|
103 |
+
error_image = Image.new('RGB', (width, height), color='white')
|
104 |
+
return error_image
|
105 |
|
106 |
except Exception as e:
|
107 |
+
# Return error image
|
108 |
+
error_image = Image.new('RGB', (224, 224), color='white')
|
109 |
+
return error_image
|
110 |
|
111 |
def _generate_sketch_svg(self, prompt: str, width: int, height: int, num_paths: int, guidance_scale: float) -> str:
|
112 |
"""
|
|
|
158 |
self._add_sketch_lines(paths, width, height, colors, min(20, num_paths // 5))
|
159 |
|
160 |
svg_content = svg_header + '\n' + '\n'.join(paths) + '\n' + svg_footer
|
161 |
+
|
162 |
+
# Convert SVG to PIL Image
|
163 |
+
try:
|
164 |
+
png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
|
165 |
+
image = Image.open(io.BytesIO(png_data))
|
166 |
+
return image
|
167 |
+
except Exception as e:
|
168 |
+
# Fallback: create a simple error image
|
169 |
+
error_image = Image.new('RGB', (width, height), color='white')
|
170 |
+
return error_image
|
171 |
|
172 |
def _add_circular_elements(self, paths, width, height, colors, count):
|
173 |
"""Add circular elements to the SVG"""
|