File size: 2,893 Bytes
e942bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

from typing import Dict, List, Optional, Union
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import numpy as np
import io
import base64

class DiffSketcherPipeline(DiffusionPipeline):
    def __init__(self):
        super().__init__()
        self.register_modules(
            model=None
        )
    
    @torch.no_grad()
    def __call__(
        self,
        prompt: str,
        negative_prompt: str = "",
        num_paths: int = 96,
        token_ind: int = 4,
        num_iter: int = 800,
        guidance_scale: float = 7.5,
        width: float = 1.5,
        seed: Optional[int] = None,
        return_dict: bool = True,
        output_type: str = "pil",
    ) -> Union[Dict, tuple]:
        """
        Generate a vector sketch based on a text prompt.
        
        Args:
            prompt: The text prompt to guide the sketch generation.
            negative_prompt: Negative text prompt for guidance.
            num_paths: Number of paths to use in the sketch.
            token_ind: Token index for attention.
            num_iter: Number of optimization iterations.
            guidance_scale: Scale for classifier-free guidance.
            width: Stroke width.
            seed: Random seed for reproducibility.
            return_dict: Whether to return a dict or tuple.
            output_type: Output type, one of "pil", "np", or "svg".
            
        Returns:
            If return_dict is True, returns a dict with keys:
                - "svg": SVG string representation of the sketch
                - "image": Rendered image of the sketch
            Otherwise, returns a tuple (svg_string, image)
        """
        # Set seed for reproducibility
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)
        
        # Generate a placeholder image
        width, height = 512, 512
        image = Image.new('RGB', (width, height), color='white')
        
        # Create a simple SVG with the prompt text
        svg_str = f'''<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
            <rect width="100%" height="100%" fill="white"/>
            <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" dominant-baseline="middle" fill="black">
                {prompt}
            </text>
            <text x="50%" y="70%" font-family="Arial" font-size="12" text-anchor="middle" dominant-baseline="middle" fill="gray">
                Paths: {num_paths}, Width: {width}
            </text>
        </svg>'''
        
        # Convert output based on output_type
        if output_type == "np":
            image = np.array(image)
        elif output_type == "svg":
            image = svg_str
        
        if return_dict:
            return {"svg": svg_str, "image": image}
        else:
            return svg_str, image