jree423 commited on
Commit
3791258
·
verified ·
1 Parent(s): 97a0903

Upload custom_handler_diffsketcher.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. custom_handler_diffsketcher.py +38 -0
custom_handler_diffsketcher.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, Any
3
+ import torch
4
+ import base64
5
+ import io
6
+ import os
7
+ import json
8
+ from PIL import Image
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ # Initialize device
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"Initializing diffsketcher handler on {self.device}")
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
17
+ # Extract prompt from the input data
18
+ prompt = data.get("prompt", "")
19
+ if not prompt and "prompts" in data:
20
+ prompts = data.get("prompts", [""])
21
+ prompt = prompts[0] if prompts else ""
22
+
23
+ # Generate a placeholder SVG
24
+ svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">diffsketcher: {prompt}</text></svg>'
25
+
26
+ # Create a placeholder image
27
+ image = Image.new('RGB', (512, 512), color = (100, 100, 100))
28
+
29
+ # Convert the image to base64
30
+ buffered = io.BytesIO()
31
+ image.save(buffered, format="PNG")
32
+ img_str = base64.b64encode(buffered.getvalue()).decode()
33
+
34
+ # Return the results
35
+ return {
36
+ "svg": svg,
37
+ "image": img_str
38
+ }