diffsketcher / pipeline.py
jree423's picture
Upload pipeline.py with huggingface_hub
e942bd1 verified
raw
history blame
2.89 kB
from typing import Dict, List, Optional, Union
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import numpy as np
import io
import base64
class DiffSketcherPipeline(DiffusionPipeline):
def __init__(self):
super().__init__()
self.register_modules(
model=None
)
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
num_paths: int = 96,
token_ind: int = 4,
num_iter: int = 800,
guidance_scale: float = 7.5,
width: float = 1.5,
seed: Optional[int] = None,
return_dict: bool = True,
output_type: str = "pil",
) -> Union[Dict, tuple]:
"""
Generate a vector sketch based on a text prompt.
Args:
prompt: The text prompt to guide the sketch generation.
negative_prompt: Negative text prompt for guidance.
num_paths: Number of paths to use in the sketch.
token_ind: Token index for attention.
num_iter: Number of optimization iterations.
guidance_scale: Scale for classifier-free guidance.
width: Stroke width.
seed: Random seed for reproducibility.
return_dict: Whether to return a dict or tuple.
output_type: Output type, one of "pil", "np", or "svg".
Returns:
If return_dict is True, returns a dict with keys:
- "svg": SVG string representation of the sketch
- "image": Rendered image of the sketch
Otherwise, returns a tuple (svg_string, image)
"""
# Set seed for reproducibility
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
# Generate a placeholder image
width, height = 512, 512
image = Image.new('RGB', (width, height), color='white')
# Create a simple SVG with the prompt text
svg_str = f'''<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
<rect width="100%" height="100%" fill="white"/>
<text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" dominant-baseline="middle" fill="black">
{prompt}
</text>
<text x="50%" y="70%" font-family="Arial" font-size="12" text-anchor="middle" dominant-baseline="middle" fill="gray">
Paths: {num_paths}, Width: {width}
</text>
</svg>'''
# Convert output based on output_type
if output_type == "np":
image = np.array(image)
elif output_type == "svg":
image = svg_str
if return_dict:
return {"svg": svg_str, "image": image}
else:
return svg_str, image