diffsketcher / pipeline.py
jree423's picture
Upload diffsketcher model
13ea232 verified
from typing import Dict, Any, List, Union
import torch
import base64
import io
from PIL import Image
class Pipeline:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Initializing diffsketcher pipeline on {self.device}")
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
# Extract prompt from the input data
prompt = inputs.get("prompt", "")
if not prompt and "prompts" in inputs:
prompts = inputs.get("prompts", [""])
prompt = prompts[0] if prompts else ""
# Generate a placeholder SVG
svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">diffsketcher: {prompt}</text></svg>'
# Create a placeholder image
image = Image.new('RGB', (512, 512), color = (100, 100, 100))
# Convert the image to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Return the results
return {
"svg": svg,
"image": img_str
}