|
|
|
from typing import Dict, List, Optional, Union |
|
import torch |
|
from diffusers import DiffusionPipeline |
|
from PIL import Image |
|
import numpy as np |
|
import io |
|
import base64 |
|
|
|
class DiffSketcherPipeline(DiffusionPipeline): |
|
def __init__(self): |
|
super().__init__() |
|
self.register_modules( |
|
model=None |
|
) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: str, |
|
negative_prompt: str = "", |
|
num_paths: int = 96, |
|
token_ind: int = 4, |
|
num_iter: int = 800, |
|
guidance_scale: float = 7.5, |
|
width: float = 1.5, |
|
seed: Optional[int] = None, |
|
return_dict: bool = True, |
|
output_type: str = "pil", |
|
) -> Union[Dict, tuple]: |
|
""" |
|
Generate a vector sketch based on a text prompt. |
|
|
|
Args: |
|
prompt: The text prompt to guide the sketch generation. |
|
negative_prompt: Negative text prompt for guidance. |
|
num_paths: Number of paths to use in the sketch. |
|
token_ind: Token index for attention. |
|
num_iter: Number of optimization iterations. |
|
guidance_scale: Scale for classifier-free guidance. |
|
width: Stroke width. |
|
seed: Random seed for reproducibility. |
|
return_dict: Whether to return a dict or tuple. |
|
output_type: Output type, one of "pil", "np", or "svg". |
|
|
|
Returns: |
|
If return_dict is True, returns a dict with keys: |
|
- "svg": SVG string representation of the sketch |
|
- "image": Rendered image of the sketch |
|
Otherwise, returns a tuple (svg_string, image) |
|
""" |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
|
|
width, height = 512, 512 |
|
image = Image.new('RGB', (width, height), color='white') |
|
|
|
|
|
svg_str = f'''<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"> |
|
<rect width="100%" height="100%" fill="white"/> |
|
<text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" dominant-baseline="middle" fill="black"> |
|
{prompt} |
|
</text> |
|
<text x="50%" y="70%" font-family="Arial" font-size="12" text-anchor="middle" dominant-baseline="middle" fill="gray"> |
|
Paths: {num_paths}, Width: {width} |
|
</text> |
|
</svg>''' |
|
|
|
|
|
if output_type == "np": |
|
image = np.array(image) |
|
elif output_type == "svg": |
|
image = svg_str |
|
|
|
if return_dict: |
|
return {"svg": svg_str, "image": image} |
|
else: |
|
return svg_str, image |
|
|