Fix handler to return PIL Images instead of dictionaries for HF API compatibility
Browse files- handler.py +73 -29
handler.py
CHANGED
@@ -81,29 +81,37 @@ class EndpointHandler:
|
|
81 |
|
82 |
# Process based on edit type
|
83 |
if edit_type == "replace" and len(prompts) >= 2:
|
84 |
-
|
85 |
elif edit_type == "refine":
|
86 |
-
|
87 |
elif edit_type == "reweight":
|
88 |
-
|
89 |
elif edit_type == "generate":
|
90 |
-
|
91 |
else:
|
92 |
# Default to refinement
|
93 |
-
|
94 |
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
except Exception as e:
|
98 |
print(f"Error in handler: {e}")
|
99 |
-
# Return fallback
|
100 |
fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height)
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
"error": str(e)
|
106 |
-
}
|
107 |
|
108 |
def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None):
|
109 |
"""Perform word replacement editing"""
|
@@ -128,9 +136,7 @@ class EndpointHandler:
|
|
128 |
# Apply word replacement transformations
|
129 |
edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height)
|
130 |
|
131 |
-
|
132 |
-
"svg": edited_svg,
|
133 |
-
"svg_base64": base64.b64encode(edited_svg.encode('utf-8')).decode('utf-8'),
|
134 |
"edit_type": "replace",
|
135 |
"source_prompt": source_prompt,
|
136 |
"target_prompt": target_prompt,
|
@@ -138,9 +144,13 @@ class EndpointHandler:
|
|
138 |
"removed_words": list(removed_words)
|
139 |
}
|
140 |
|
|
|
|
|
141 |
except Exception as e:
|
142 |
print(f"Error in word_replacement_edit: {e}")
|
143 |
-
|
|
|
|
|
144 |
|
145 |
def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
|
146 |
"""Perform prompt refinement editing"""
|
@@ -156,16 +166,18 @@ class EndpointHandler:
|
|
156 |
# Apply refinement based on prompt analysis
|
157 |
refined_svg = self.apply_refinement(base_svg, prompt, width, height)
|
158 |
|
159 |
-
|
160 |
-
"svg": refined_svg,
|
161 |
-
"svg_base64": base64.b64encode(refined_svg.encode('utf-8')).decode('utf-8'),
|
162 |
"edit_type": "refine",
|
163 |
"prompt": prompt
|
164 |
}
|
165 |
|
|
|
|
|
166 |
except Exception as e:
|
167 |
print(f"Error in prompt_refinement_edit: {e}")
|
168 |
-
|
|
|
|
|
169 |
|
170 |
def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
|
171 |
"""Perform attention reweighting editing"""
|
@@ -184,18 +196,20 @@ class EndpointHandler:
|
|
184 |
# Apply attention reweighting
|
185 |
reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height)
|
186 |
|
187 |
-
|
188 |
-
"svg": reweighted_svg,
|
189 |
-
"svg_base64": base64.b64encode(reweighted_svg.encode('utf-8')).decode('utf-8'),
|
190 |
"edit_type": "reweight",
|
191 |
"prompt": prompt,
|
192 |
"weighted_prompt": weighted_prompt,
|
193 |
"attention_weights": attention_weights
|
194 |
}
|
195 |
|
|
|
|
|
196 |
except Exception as e:
|
197 |
print(f"Error in attention_reweighting_edit: {e}")
|
198 |
-
|
|
|
|
|
199 |
|
200 |
def simple_generation(self, prompt: str, width: int, height: int):
|
201 |
"""Perform simple SVG generation"""
|
@@ -204,16 +218,18 @@ class EndpointHandler:
|
|
204 |
|
205 |
svg_content = self.generate_base_svg(prompt, width, height)
|
206 |
|
207 |
-
|
208 |
-
"svg": svg_content,
|
209 |
-
"svg_base64": base64.b64encode(svg_content.encode('utf-8')).decode('utf-8'),
|
210 |
"edit_type": "generate",
|
211 |
"prompt": prompt
|
212 |
}
|
213 |
|
|
|
|
|
214 |
except Exception as e:
|
215 |
print(f"Error in simple_generation: {e}")
|
216 |
-
|
|
|
|
|
217 |
|
218 |
def generate_base_svg(self, prompt: str, width: int, height: int):
|
219 |
"""Generate base SVG from prompt"""
|
@@ -727,6 +743,34 @@ class EndpointHandler:
|
|
727 |
"error": error
|
728 |
}
|
729 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
730 |
def create_fallback_svg(self, prompt: str, width: int, height: int):
|
731 |
"""Create simple fallback SVG"""
|
732 |
dwg = svgwrite.Drawing(size=(width, height))
|
|
|
81 |
|
82 |
# Process based on edit type
|
83 |
if edit_type == "replace" and len(prompts) >= 2:
|
84 |
+
svg_content, metadata = self.word_replacement_edit(prompts[0], prompts[1], width, height, input_svg)
|
85 |
elif edit_type == "refine":
|
86 |
+
svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
|
87 |
elif edit_type == "reweight":
|
88 |
+
svg_content, metadata = self.attention_reweighting_edit(prompts[0], width, height, input_svg)
|
89 |
elif edit_type == "generate":
|
90 |
+
svg_content, metadata = self.simple_generation(prompts[0], width, height)
|
91 |
else:
|
92 |
# Default to refinement
|
93 |
+
svg_content, metadata = self.prompt_refinement_edit(prompts[0], width, height, input_svg)
|
94 |
|
95 |
+
# Convert SVG to PIL Image for HF API compatibility
|
96 |
+
pil_image = self.svg_to_pil_image(svg_content, width, height)
|
97 |
+
|
98 |
+
# Store metadata
|
99 |
+
for key, value in metadata.items():
|
100 |
+
if isinstance(value, (dict, list)):
|
101 |
+
pil_image.info[key] = json.dumps(value)
|
102 |
+
else:
|
103 |
+
pil_image.info[key] = str(value)
|
104 |
+
|
105 |
+
return pil_image
|
106 |
|
107 |
except Exception as e:
|
108 |
print(f"Error in handler: {e}")
|
109 |
+
# Return fallback image
|
110 |
fallback_svg = self.create_fallback_svg(prompts[0] if prompts else "error", width, height)
|
111 |
+
fallback_image = self.svg_to_pil_image(fallback_svg, width, height)
|
112 |
+
fallback_image.info['error'] = str(e)
|
113 |
+
fallback_image.info['edit_type'] = edit_type
|
114 |
+
return fallback_image
|
|
|
|
|
115 |
|
116 |
def word_replacement_edit(self, source_prompt: str, target_prompt: str, width: int, height: int, input_svg: str = None):
|
117 |
"""Perform word replacement editing"""
|
|
|
136 |
# Apply word replacement transformations
|
137 |
edited_svg = self.apply_word_replacement(base_svg, source_prompt, target_prompt, added_words, removed_words, width, height)
|
138 |
|
139 |
+
metadata = {
|
|
|
|
|
140 |
"edit_type": "replace",
|
141 |
"source_prompt": source_prompt,
|
142 |
"target_prompt": target_prompt,
|
|
|
144 |
"removed_words": list(removed_words)
|
145 |
}
|
146 |
|
147 |
+
return edited_svg, metadata
|
148 |
+
|
149 |
except Exception as e:
|
150 |
print(f"Error in word_replacement_edit: {e}")
|
151 |
+
fallback_svg = self.create_fallback_svg(source_prompt, width, height)
|
152 |
+
metadata = {"edit_type": "replace", "error": str(e)}
|
153 |
+
return fallback_svg, metadata
|
154 |
|
155 |
def prompt_refinement_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
|
156 |
"""Perform prompt refinement editing"""
|
|
|
166 |
# Apply refinement based on prompt analysis
|
167 |
refined_svg = self.apply_refinement(base_svg, prompt, width, height)
|
168 |
|
169 |
+
metadata = {
|
|
|
|
|
170 |
"edit_type": "refine",
|
171 |
"prompt": prompt
|
172 |
}
|
173 |
|
174 |
+
return refined_svg, metadata
|
175 |
+
|
176 |
except Exception as e:
|
177 |
print(f"Error in prompt_refinement_edit: {e}")
|
178 |
+
fallback_svg = self.create_fallback_svg(prompt, width, height)
|
179 |
+
metadata = {"edit_type": "refine", "error": str(e)}
|
180 |
+
return fallback_svg, metadata
|
181 |
|
182 |
def attention_reweighting_edit(self, prompt: str, width: int, height: int, input_svg: str = None):
|
183 |
"""Perform attention reweighting editing"""
|
|
|
196 |
# Apply attention reweighting
|
197 |
reweighted_svg = self.apply_attention_reweighting(base_svg, weighted_prompt, attention_weights, width, height)
|
198 |
|
199 |
+
metadata = {
|
|
|
|
|
200 |
"edit_type": "reweight",
|
201 |
"prompt": prompt,
|
202 |
"weighted_prompt": weighted_prompt,
|
203 |
"attention_weights": attention_weights
|
204 |
}
|
205 |
|
206 |
+
return reweighted_svg, metadata
|
207 |
+
|
208 |
except Exception as e:
|
209 |
print(f"Error in attention_reweighting_edit: {e}")
|
210 |
+
fallback_svg = self.create_fallback_svg(prompt, width, height)
|
211 |
+
metadata = {"edit_type": "reweight", "error": str(e)}
|
212 |
+
return fallback_svg, metadata
|
213 |
|
214 |
def simple_generation(self, prompt: str, width: int, height: int):
|
215 |
"""Perform simple SVG generation"""
|
|
|
218 |
|
219 |
svg_content = self.generate_base_svg(prompt, width, height)
|
220 |
|
221 |
+
metadata = {
|
|
|
|
|
222 |
"edit_type": "generate",
|
223 |
"prompt": prompt
|
224 |
}
|
225 |
|
226 |
+
return svg_content, metadata
|
227 |
+
|
228 |
except Exception as e:
|
229 |
print(f"Error in simple_generation: {e}")
|
230 |
+
fallback_svg = self.create_fallback_svg(prompt, width, height)
|
231 |
+
metadata = {"edit_type": "generate", "error": str(e)}
|
232 |
+
return fallback_svg, metadata
|
233 |
|
234 |
def generate_base_svg(self, prompt: str, width: int, height: int):
|
235 |
"""Generate base SVG from prompt"""
|
|
|
743 |
"error": error
|
744 |
}
|
745 |
|
746 |
+
def svg_to_pil_image(self, svg_content, width, height):
|
747 |
+
"""Convert SVG content to PIL Image"""
|
748 |
+
try:
|
749 |
+
import cairosvg
|
750 |
+
import io
|
751 |
+
|
752 |
+
# Convert SVG to PNG bytes
|
753 |
+
png_bytes = cairosvg.svg2png(
|
754 |
+
bytestring=svg_content.encode('utf-8'),
|
755 |
+
output_width=width,
|
756 |
+
output_height=height
|
757 |
+
)
|
758 |
+
|
759 |
+
# Convert to PIL Image
|
760 |
+
image = Image.open(io.BytesIO(png_bytes)).convert('RGB')
|
761 |
+
return image
|
762 |
+
|
763 |
+
except ImportError:
|
764 |
+
print("cairosvg not available, creating simple image representation")
|
765 |
+
# Fallback: create a simple image with text
|
766 |
+
image = Image.new('RGB', (width, height), 'white')
|
767 |
+
return image
|
768 |
+
except Exception as e:
|
769 |
+
print(f"Error converting SVG to image: {e}")
|
770 |
+
# Fallback: create a simple image
|
771 |
+
image = Image.new('RGB', (width, height), 'white')
|
772 |
+
return image
|
773 |
+
|
774 |
def create_fallback_svg(self, prompt: str, width: int, height: int):
|
775 |
"""Create simple fallback SVG"""
|
776 |
dwg = svgwrite.Drawing(size=(width, height))
|