Update handler to return PIL Images for Inference API compatibility
Browse files- config.json +1 -1
- handler.py +15 -3
config.json
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
"model_type": "diffsketchedit",
|
4 |
"task": "svg-editing",
|
5 |
"framework": "pytorch",
|
6 |
-
"pipeline_tag": "text-
|
7 |
"library_name": "diffusers",
|
8 |
"inference": {
|
9 |
"parameters": {
|
|
|
3 |
"model_type": "diffsketchedit",
|
4 |
"task": "svg-editing",
|
5 |
"framework": "pytorch",
|
6 |
+
"pipeline_tag": "text-to-image",
|
7 |
"library_name": "diffusers",
|
8 |
"inference": {
|
9 |
"parameters": {
|
handler.py
CHANGED
@@ -4,6 +4,9 @@ import json
|
|
4 |
import torch
|
5 |
import numpy as np
|
6 |
from typing import Dict, Any, List
|
|
|
|
|
|
|
7 |
|
8 |
class EndpointHandler:
|
9 |
def __init__(self, path=""):
|
@@ -95,11 +98,20 @@ class EndpointHandler:
|
|
95 |
# Generate edited SVG based on the sequence of prompts
|
96 |
svg_content = self._generate_edited_svg_sequence(prompts, width, height, edit_type, seed)
|
97 |
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
except Exception as e:
|
101 |
-
# Return error
|
102 |
-
|
|
|
103 |
|
104 |
def _generate_edited_svg_sequence(self, prompts: List[str], width: int, height: int, edit_type: str, seed: int) -> str:
|
105 |
"""Generate SVG showing editing progression through prompt sequence"""
|
|
|
4 |
import torch
|
5 |
import numpy as np
|
6 |
from typing import Dict, Any, List
|
7 |
+
from PIL import Image
|
8 |
+
import cairosvg
|
9 |
+
import io
|
10 |
|
11 |
class EndpointHandler:
|
12 |
def __init__(self, path=""):
|
|
|
98 |
# Generate edited SVG based on the sequence of prompts
|
99 |
svg_content = self._generate_edited_svg_sequence(prompts, width, height, edit_type, seed)
|
100 |
|
101 |
+
# Convert SVG to PIL Image
|
102 |
+
try:
|
103 |
+
png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
|
104 |
+
image = Image.open(io.BytesIO(png_data))
|
105 |
+
return image
|
106 |
+
except Exception as svg_error:
|
107 |
+
# Fallback: create a simple error image
|
108 |
+
error_image = Image.new('RGB', (width, height), color='white')
|
109 |
+
return error_image
|
110 |
|
111 |
except Exception as e:
|
112 |
+
# Return error image
|
113 |
+
error_image = Image.new('RGB', (224, 224), color='white')
|
114 |
+
return error_image
|
115 |
|
116 |
def _generate_edited_svg_sequence(self, prompts: List[str], width: int, height: int, edit_type: str, seed: int) -> str:
|
117 |
"""Generate SVG showing editing progression through prompt sequence"""
|