jree423 commited on
Commit
ad55c44
·
verified ·
1 Parent(s): 1802ee2

Fix handler to return PIL Images instead of dictionaries for HF API compatibility

Browse files
Files changed (1) hide show
  1. handler.py +73 -29
handler.py CHANGED
@@ -81,29 +81,37 @@ class EndpointHandler:
81
 
82
  # Process based on edit type
83
  if edit_type == "replace" and len(prompts) >= 2:
84
- result = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg)
85
  elif edit_type == "refine":
86
- result = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
87
  elif edit_type == "reweight":
88
- result = self.attention_reweighting_edit(prompts[0], width, height, input_svg)
89
  elif edit_type == "generate":
90
- result = self.simple_generation(prompts[0], width, height)
91
  else:
92
  # Default to refinement
93
- result = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
94
 
95
- return result
 
 
 
 
 
 
 
 
 
 
96
 
97
  except Exception as e:
98
  print(f"Error in handler: {e}")
99
- # Return fallback result
100
  fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height)
101
- return {
102
- "svg": fallback_svg,
103
- "svg_base64": base64.b64encode(fallback_svg.encode('utf-8')).decode('utf-8'),
104
- "edit_type": edit_type,
105
- "error": str(e)
106
- }
107
 
108
  def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None):
109
  """Perform word replacement editing"""
@@ -128,9 +136,7 @@ class EndpointHandler:
128
  # Apply word replacement transformations
129
  edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height)
130
 
131
- return {
132
- "svg": edited_svg,
133
- "svg_base64": base64.b64encode(edited_svg.encode('utf-8')).decode('utf-8'),
134
  "edit_type": "replace",
135
  "source_prompt": source_prompt,
136
  "target_prompt": target_prompt,
@@ -138,9 +144,13 @@ class EndpointHandler:
138
  "removed_words": list(removed_words)
139
  }
140
 
 
 
141
  except Exception as e:
142
  print(f"Error in word_replacement_edit: {e}")
143
- return self.create_error_result(source_prompt, "replace", str(e), width, height)
 
 
144
 
145
  def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
146
  """Perform prompt refinement editing"""
@@ -156,16 +166,18 @@ class EndpointHandler:
156
  # Apply refinement based on prompt analysis
157
  refined_svg = self.apply_refinement(base_svg, prompt, width, height)
158
 
159
- return {
160
- "svg": refined_svg,
161
- "svg_base64": base64.b64encode(refined_svg.encode('utf-8')).decode('utf-8'),
162
  "edit_type": "refine",
163
  "prompt": prompt
164
  }
165
 
 
 
166
  except Exception as e:
167
  print(f"Error in prompt_refinement_edit: {e}")
168
- return self.create_error_result(prompt, "refine", str(e), width, height)
 
 
169
 
170
  def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
171
  """Perform attention reweighting editing"""
@@ -184,18 +196,20 @@ class EndpointHandler:
184
  # Apply attention reweighting
185
  reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height)
186
 
187
- return {
188
- "svg": reweighted_svg,
189
- "svg_base64": base64.b64encode(reweighted_svg.encode('utf-8')).decode('utf-8'),
190
  "edit_type": "reweight",
191
  "prompt": prompt,
192
  "weighted_prompt": weighted_prompt,
193
  "attention_weights": attention_weights
194
  }
195
 
 
 
196
  except Exception as e:
197
  print(f"Error in attention_reweighting_edit: {e}")
198
- return self.create_error_result(prompt, "reweight", str(e), width, height)
 
 
199
 
200
  def simple_generation(self, prompt: str, width: int, height: int):
201
  """Perform simple SVG generation"""
@@ -204,16 +218,18 @@ class EndpointHandler:
204
 
205
  svg_content = self.generate_base_svg(prompt, width, height)
206
 
207
- return {
208
- "svg": svg_content,
209
- "svg_base64": base64.b64encode(svg_content.encode('utf-8')).decode('utf-8'),
210
  "edit_type": "generate",
211
  "prompt": prompt
212
  }
213
 
 
 
214
  except Exception as e:
215
  print(f"Error in simple_generation: {e}")
216
- return self.create_error_result(prompt, "generate", str(e), width, height)
 
 
217
 
218
  def generate_base_svg(self, prompt: str, width: int, height: int):
219
  """Generate base SVG from prompt"""
@@ -727,6 +743,34 @@ class EndpointHandler:
727
  "error": error
728
  }
729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
  def create_fallback_svg(self, prompt: str, width: int, height: int):
731
  """Create simple fallback SVG"""
732
  dwg = svgwrite.Drawing(size=(width, height))
 
81
 
82
  # Process based on edit type
83
  if edit_type == "replace" and len(prompts) >= 2:
84
+ svg_content, metadata = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg)
85
  elif edit_type == "refine":
86
+ svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
87
  elif edit_type == "reweight":
88
+ svg_content, metadata = self.attention_reweighting_edit(prompts[0], width, height, input_svg)
89
  elif edit_type == "generate":
90
+ svg_content, metadata = self.simple_generation(prompts[0], width, height)
91
  else:
92
  # Default to refinement
93
+ svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
94
 
95
+ # Convert SVG to PIL Image for HF API compatibility
96
+ pil_image = self.svg_to_pil_image(svg_content, width, height)
97
+
98
+ # Store metadata
99
+ for key, value in metadata.items():
100
+ if isinstance(value, (dict, list)):
101
+ pil_image.info[key] = json.dumps(value)
102
+ else:
103
+ pil_image.info[key] = str(value)
104
+
105
+ return pil_image
106
 
107
  except Exception as e:
108
  print(f"Error in handler: {e}")
109
+ # Return fallback image
110
  fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height)
111
+ fallback_image = self.svg_to_pil_image(fallback_svg, width, height)
112
+ fallback_image.info['error'] = str(e)
113
+ fallback_image.info['edit_type'] = edit_type
114
+ return fallback_image
 
 
115
 
116
  def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None):
117
  """Perform word replacement editing"""
 
136
  # Apply word replacement transformations
137
  edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height)
138
 
139
+ metadata = {
 
 
140
  "edit_type": "replace",
141
  "source_prompt": source_prompt,
142
  "target_prompt": target_prompt,
 
144
  "removed_words": list(removed_words)
145
  }
146
 
147
+ return edited_svg, metadata
148
+
149
  except Exception as e:
150
  print(f"Error in word_replacement_edit: {e}")
151
+ fallback_svg = self.create_fallback_svg(source_prompt, width, height)
152
+ metadata = {"edit_type": "replace", "error": str(e)}
153
+ return fallback_svg, metadata
154
 
155
  def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
156
  """Perform prompt refinement editing"""
 
166
  # Apply refinement based on prompt analysis
167
  refined_svg = self.apply_refinement(base_svg, prompt, width, height)
168
 
169
+ metadata = {
 
 
170
  "edit_type": "refine",
171
  "prompt": prompt
172
  }
173
 
174
+ return refined_svg, metadata
175
+
176
  except Exception as e:
177
  print(f"Error in prompt_refinement_edit: {e}")
178
+ fallback_svg = self.create_fallback_svg(prompt, width, height)
179
+ metadata = {"edit_type": "refine", "error": str(e)}
180
+ return fallback_svg, metadata
181
 
182
  def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
183
  """Perform attention reweighting editing"""
 
196
  # Apply attention reweighting
197
  reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height)
198
 
199
+ metadata = {
 
 
200
  "edit_type": "reweight",
201
  "prompt": prompt,
202
  "weighted_prompt": weighted_prompt,
203
  "attention_weights": attention_weights
204
  }
205
 
206
+ return reweighted_svg, metadata
207
+
208
  except Exception as e:
209
  print(f"Error in attention_reweighting_edit: {e}")
210
+ fallback_svg = self.create_fallback_svg(prompt, width, height)
211
+ metadata = {"edit_type": "reweight", "error": str(e)}
212
+ return fallback_svg, metadata
213
 
214
  def simple_generation(self, prompt: str, width: int, height: int):
215
  """Perform simple SVG generation"""
 
218
 
219
  svg_content = self.generate_base_svg(prompt, width, height)
220
 
221
+ metadata = {
 
 
222
  "edit_type": "generate",
223
  "prompt": prompt
224
  }
225
 
226
+ return svg_content, metadata
227
+
228
  except Exception as e:
229
  print(f"Error in simple_generation: {e}")
230
+ fallback_svg = self.create_fallback_svg(prompt, width, height)
231
+ metadata = {"edit_type": "generate", "error": str(e)}
232
+ return fallback_svg, metadata
233
 
234
  def generate_base_svg(self, prompt: str, width: int, height: int):
235
  """Generate base SVG from prompt"""
 
743
  "error": error
744
  }
745
 
746
+ def svg_to_pil_image(self, svg_content, width, height):
747
+ """Convert SVG content to PIL Image"""
748
+ try:
749
+ import cairosvg
750
+ import io
751
+
752
+ # Convert SVG to PNG bytes
753
+ png_bytes = cairosvg.svg2png(
754
+ bytestring=svg_content.encode('utf-8'),
755
+ output_width=width,
756
+ output_height=height
757
+ )
758
+
759
+ # Convert to PIL Image
760
+ image = Image.open(io.BytesIO(png_bytes)).convert('RGB')
761
+ return image
762
+
763
+ except ImportError:
764
+ print("cairosvg not available, creating simple image representation")
765
+ # Fallback: create a simple image with text
766
+ image = Image.new('RGB', (width, height), 'white')
767
+ return image
768
+ except Exception as e:
769
+ print(f"Error converting SVG to image: {e}")
770
+ # Fallback: create a simple image
771
+ image = Image.new('RGB', (width, height), 'white')
772
+ return image
773
+
774
  def create_fallback_svg(self, prompt: str, width: int, height: int):
775
  """Create simple fallback SVG"""
776
  dwg = svgwrite.Drawing(size=(width, height))