from typing import Dict, Any, List, Union | |
import torch | |
import base64 | |
import io | |
from PIL import Image | |
class Pipeline: | |
def __init__(self): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Initializing diffsketcher pipeline on {self.device}") | |
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |
# Extract prompt from the input data | |
prompt = inputs.get("prompt", "") | |
if not prompt and "prompts" in inputs: | |
prompts = inputs.get("prompts", [""]) | |
prompt = prompts[0] if prompts else "" | |
# Generate a placeholder SVG | |
svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">diffsketcher: {prompt}</text></svg>' | |
# Create a placeholder image | |
image = Image.new('RGB', (512, 512), color = (100, 100, 100)) | |
# Convert the image to base64 | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
# Return the results | |
return { | |
"svg": svg, | |
"image": img_str | |
} | |