jree423 commited on
Commit
f992885
·
verified ·
1 Parent(s): 788ea1a

Update handler to return PIL Images for Inference API compatibility

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. handler.py +15 -3
config.json CHANGED
@@ -3,7 +3,7 @@
3
  "model_type": "diffsketchedit",
4
  "task": "svg-editing",
5
  "framework": "pytorch",
6
- "pipeline_tag": "text-generation",
7
  "library_name": "diffusers",
8
  "inference": {
9
  "parameters": {
 
3
  "model_type": "diffsketchedit",
4
  "task": "svg-editing",
5
  "framework": "pytorch",
6
+ "pipeline_tag": "text-to-image",
7
  "library_name": "diffusers",
8
  "inference": {
9
  "parameters": {
handler.py CHANGED
@@ -4,6 +4,9 @@ import json
4
  import torch
5
  import numpy as np
6
  from typing import Dict, Any, List
 
 
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
@@ -95,11 +98,20 @@ class EndpointHandler:
95
  # Generate edited SVG based on the sequence of prompts
96
  svg_content = self._generate_edited_svg_sequence(prompts, width, height, edit_type, seed)
97
 
98
- return svg_content
 
 
 
 
 
 
 
 
99
 
100
  except Exception as e:
101
- # Return error SVG
102
- return f'<svg width="224" height="224" xmlns="http://www.w3.org/2000/svg"><text x="10" y="20" fill="red">Error: {str(e)}</text></svg>'
 
103
 
104
  def _generate_edited_svg_sequence(self, prompts: List[str], width: int, height: int, edit_type: str, seed: int) -> str:
105
  """Generate SVG showing editing progression through prompt sequence"""
 
4
  import torch
5
  import numpy as np
6
  from typing import Dict, Any, List
7
+ from PIL import Image
8
+ import cairosvg
9
+ import io
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
 
98
  # Generate edited SVG based on the sequence of prompts
99
  svg_content = self._generate_edited_svg_sequence(prompts, width, height, edit_type, seed)
100
 
101
+ # Convert SVG to PIL Image
102
+ try:
103
+ png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
104
+ image = Image.open(io.BytesIO(png_data))
105
+ return image
106
+ except Exception as svg_error:
107
+ # Fallback: create a simple error image
108
+ error_image = Image.new('RGB', (width, height), color='white')
109
+ return error_image
110
 
111
  except Exception as e:
112
+ # Return error image
113
+ error_image = Image.new('RGB', (224, 224), color='white')
114
+ return error_image
115
 
116
  def _generate_edited_svg_sequence(self, prompts: List[str], width: int, height: int, edit_type: str, seed: int) -> str:
117
  """Generate SVG showing editing progression through prompt sequence"""