|
|
|
|
|
from typing import Dict, Any |
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
import json |
|
|
from PIL import Image |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
|
|
|
model_index_path = os.path.join(path, "model_index.json") |
|
|
if os.path.exists(model_index_path): |
|
|
with open(model_index_path, "r") as f: |
|
|
self.config = json.load(f) |
|
|
else: |
|
|
|
|
|
self.config = { |
|
|
"architecture": "SimplePipeline", |
|
|
"format": "diffusers", |
|
|
"version": "0.1.0" |
|
|
} |
|
|
|
|
|
with open(model_index_path, "w") as f: |
|
|
json.dump(self.config, f, indent=2) |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
|
|
|
|
|
prompt = data.get("prompt", "") |
|
|
if not prompt and "prompts" in data: |
|
|
prompts = data.get("prompts", [""]) |
|
|
prompt = prompts[0] if prompts else "" |
|
|
|
|
|
|
|
|
svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">{diffsketcher}: {prompt}</text></svg>' |
|
|
|
|
|
|
|
|
image = Image.new('RGB', (512, 512), color = (100, 100, 100)) |
|
|
|
|
|
|
|
|
buffered = io.BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
|
|
|
return { |
|
|
"svg": svg, |
|
|
"image": img_str |
|
|
} |
|
|
|