File size: 1,889 Bytes
8de5d7e a96d3ab bebb135 a96d3ab bebb135 a96d3ab bebb135 8de5d7e bebb135 8de5d7e bebb135 a96d3ab 8de5d7e a96d3ab 8de5d7e a96d3ab bebb135 77cfdee a96d3ab bebb135 8de5d7e a96d3ab 8de5d7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from typing import Dict, Any
import torch
import base64
import io
import os
import json
from PIL import Image
class EndpointHandler:
def __init__(self, path=""):
# Load model_index.json if it exists
model_index_path = os.path.join(path, "model_index.json")
if os.path.exists(model_index_path):
with open(model_index_path, "r") as f:
self.config = json.load(f)
else:
# Create a default config
self.config = {
"architecture": "SimplePipeline",
"format": "diffusers",
"version": "0.1.0"
}
# Save the config
with open(model_index_path, "w") as f:
json.dump(self.config, f, indent=2)
# Initialize device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
# Extract prompt from the input data
prompt = data.get("prompt", "")
if not prompt and "prompts" in data:
prompts = data.get("prompts", [""])
prompt = prompts[0] if prompts else ""
# Generate a placeholder SVG
svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">{diffsketcher}: {prompt}</text></svg>'
# Create a placeholder image
image = Image.new('RGB', (512, 512), color = (100, 100, 100))
# Convert the image to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Return the results
return {
"svg": svg,
"image": img_str
}
|