File size: 2,893 Bytes
e942bd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from typing import Dict, List, Optional, Union
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import numpy as np
import io
import base64
class DiffSketcherPipeline(DiffusionPipeline):
def __init__(self):
super().__init__()
self.register_modules(
model=None
)
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
num_paths: int = 96,
token_ind: int = 4,
num_iter: int = 800,
guidance_scale: float = 7.5,
width: float = 1.5,
seed: Optional[int] = None,
return_dict: bool = True,
output_type: str = "pil",
) -> Union[Dict, tuple]:
"""
Generate a vector sketch based on a text prompt.
Args:
prompt: The text prompt to guide the sketch generation.
negative_prompt: Negative text prompt for guidance.
num_paths: Number of paths to use in the sketch.
token_ind: Token index for attention.
num_iter: Number of optimization iterations.
guidance_scale: Scale for classifier-free guidance.
width: Stroke width.
seed: Random seed for reproducibility.
return_dict: Whether to return a dict or tuple.
output_type: Output type, one of "pil", "np", or "svg".
Returns:
If return_dict is True, returns a dict with keys:
- "svg": SVG string representation of the sketch
- "image": Rendered image of the sketch
Otherwise, returns a tuple (svg_string, image)
"""
# Set seed for reproducibility
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
# Generate a placeholder image
width, height = 512, 512
image = Image.new('RGB', (width, height), color='white')
# Create a simple SVG with the prompt text
svg_str = f'''<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
<rect width="100%" height="100%" fill="white"/>
<text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" dominant-baseline="middle" fill="black">
{prompt}
</text>
<text x="50%" y="70%" font-family="Arial" font-size="12" text-anchor="middle" dominant-baseline="middle" fill="gray">
Paths: {num_paths}, Width: {width}
</text>
</svg>'''
# Convert output based on output_type
if output_type == "np":
image = np.array(image)
elif output_type == "svg":
image = svg_str
if return_dict:
return {"svg": svg_str, "image": image}
else:
return svg_str, image
|