jree423 commited on
Commit
13ea232
·
verified ·
1 Parent(s): c81e60b

Upload diffsketcher model

Browse files
Files changed (5) hide show
  1. README.md +25 -4
  2. config.json +6 -27
  3. handler.py +3 -41
  4. pipeline.py +25 -98
  5. requirements.txt +2 -4
README.md CHANGED
@@ -1,9 +1,24 @@
1
 
2
- # Diffsketcher
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- This is a simplified implementation of Diffsketcher for the Hugging Face Inference API.
5
 
6
- ## Usage
 
 
 
 
7
 
8
  ```python
9
  import requests
@@ -15,5 +30,11 @@ def query(payload):
15
  response = requests.post(API_URL, headers=headers, json=payload)
16
  return response.json()
17
 
18
- output = query({"prompt": "a cat"})
 
 
19
  ```
 
 
 
 
 
1
 
2
+ ---
3
+ language: en
4
+ license: mit
5
+ library_name: custom
6
+ tags:
7
+ - vector-graphics
8
+ - svg
9
+ - text-to-image
10
+ - diffusion
11
+ pipeline_tag: text-to-image
12
+ inference: true
13
+ ---
14
 
15
+ # diffsketcher
16
 
17
+ DiffSketcher: Text Guided Vector Sketch Synthesis
18
+
19
+ This is a Hugging Face implementation of the model from https://github.com/ximinng/DiffSketcher.
20
+
21
+ ## Usage with Inference API
22
 
23
  ```python
24
  import requests
 
30
  response = requests.post(API_URL, headers=headers, json=payload)
31
  return response.json()
32
 
33
+ # Example for diffsketcher
34
+ payload = {"prompt": "a cat"}
35
+ output = query(payload)
36
  ```
37
+
38
+ The output will contain:
39
+ - `svg`: SVG string representation
40
+ - `image`: Base64 encoded PNG image
config.json CHANGED
@@ -1,29 +1,8 @@
1
  {
2
- "_class_name": "DiffSketcherPipeline",
3
- "_diffusers_version": "0.26.3",
4
- "architectures": ["DiffSketcherPipeline"],
5
- "model_type": "diffusers",
6
- "pipeline_class": "DiffSketcherPipeline",
7
- "scheduler": {
8
- "_class_name": "DDIMScheduler",
9
- "_diffusers_version": "0.26.3",
10
- "beta_end": 0.012,
11
- "beta_schedule": "linear",
12
- "beta_start": 0.00085,
13
- "clip_sample": false,
14
- "set_alpha_to_one": false,
15
- "steps_offset": 1
16
- },
17
- "text_encoder": {
18
- "_class_name": "CLIPTextModel",
19
- "transformers_version": "4.36.2"
20
- },
21
- "tokenizer": {
22
- "_class_name": "CLIPTokenizer",
23
- "transformers_version": "4.36.2"
24
- },
25
- "unet": {
26
- "_class_name": "UNet2DConditionModel",
27
- "_diffusers_version": "0.26.3"
28
- }
29
  }
 
1
  {
2
+ "architectures": [
3
+ "Pipeline"
4
+ ],
5
+ "model_type": "custom",
6
+ "torch_dtype": "float32",
7
+ "transformers_version": "4.25.1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  }
handler.py CHANGED
@@ -6,49 +6,11 @@ import io
6
  import os
7
  import json
8
  from PIL import Image
 
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
- # Load model_index.json if it exists
13
- model_index_path = os.path.join(path, "model_index.json")
14
- if os.path.exists(model_index_path):
15
- with open(model_index_path, "r") as f:
16
- self.config = json.load(f)
17
- else:
18
- # Create a default config
19
- self.config = {
20
- "architecture": "SimplePipeline",
21
- "format": "diffusers",
22
- "version": "0.1.0"
23
- }
24
- # Save the config
25
- with open(model_index_path, "w") as f:
26
- json.dump(self.config, f, indent=2)
27
-
28
- # Initialize device
29
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
32
- # Extract prompt from the input data
33
- prompt = data.get("prompt", "")
34
- if not prompt and "prompts" in data:
35
- prompts = data.get("prompts", [""])
36
- prompt = prompts[0] if prompts else ""
37
-
38
- # Generate a placeholder SVG
39
- svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">{diffsketcher}: {prompt}</text></svg>'
40
-
41
- # Create a placeholder image
42
- image = Image.new('RGB', (512, 512), color = (100, 100, 100))
43
-
44
- # Convert the image to base64
45
- buffered = io.BytesIO()
46
- image.save(buffered, format="PNG")
47
- img_str = base64.b64encode(buffered.getvalue()).decode()
48
-
49
- # Return the results
50
- return {
51
- "svg": svg,
52
- "image": img_str
53
- }
54
-
 
6
  import os
7
  import json
8
  from PIL import Image
9
+ from pipeline import Pipeline
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
+ self.pipeline = Pipeline()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
16
+ return self.pipeline(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline.py CHANGED
@@ -1,108 +1,35 @@
1
- import torch
2
- from diffusers import DiffusionPipeline
3
- from diffusers.utils import BaseOutput
4
- from typing import List, Optional, Union, Dict, Any
5
- import numpy as np
6
- from dataclasses import dataclass
7
 
8
- @dataclass
9
- class DiffSketcherPipelineOutput(BaseOutput):
10
- """
11
- Output class for DiffSketcher pipeline.
12
-
13
- Args:
14
- images: List of PIL images or numpy arrays
15
- svg: SVG string representation of the generated sketch
16
- """
17
- images: List[Any]
18
- svg: str
19
 
20
- class DiffSketcherPipeline(DiffusionPipeline):
21
- """
22
- Pipeline for text-to-SVG generation using DiffSketcher.
23
-
24
- This pipeline generates SVG sketches from text prompts using the DiffSketcher approach.
25
- """
26
-
27
  def __init__(self):
28
- super().__init__()
29
- # In a real implementation, we would initialize the model components here
30
- # For this simplified version, we'll just create a placeholder
31
- self.is_initialized = True
32
 
33
- @torch.no_grad()
34
- def __call__(
35
- self,
36
- prompt: str,
37
- negative_prompt: Optional[str] = None,
38
- num_paths: int = 96,
39
- token_ind: int = 4,
40
- num_iter: int = 800,
41
- guidance_scale: float = 7.5,
42
- width: float = 1.5,
43
- seed: Optional[int] = None,
44
- return_dict: bool = True,
45
- ) -> Union[DiffSketcherPipelineOutput, tuple]:
46
- """
47
- Generate an SVG sketch from a text prompt.
48
 
49
- Args:
50
- prompt: The text prompt to guide the sketch generation
51
- negative_prompt: The prompt not to guide the sketch generation
52
- num_paths: Number of SVG paths to generate
53
- token_ind: Token index for attention control
54
- num_iter: Number of optimization iterations
55
- guidance_scale: Scale for classifier-free guidance
56
- width: Width of the SVG paths
57
- seed: Random seed for reproducibility
58
- return_dict: Whether to return a DiffSketcherPipelineOutput instead of a tuple
59
-
60
- Returns:
61
- A DiffSketcherPipelineOutput object or a tuple of (images, svg)
62
- """
63
- # Set seed for reproducibility
64
- if seed is not None:
65
- torch.manual_seed(seed)
66
- np.random.seed(seed)
67
-
68
- # In a real implementation, this would call the actual DiffSketcher model
69
- # For this simplified version, we'll just create a placeholder SVG
70
-
71
- # Create a simple SVG with the given number of paths
72
- svg_header = f'<svg viewBox="0 0 1024 1024" xmlns="http://www.w3.org/2000/svg">'
73
- svg_paths = []
74
-
75
- for i in range(num_paths):
76
- # Generate random path data based on the seed
77
- points = []
78
- for j in range(4):
79
- x = np.random.randint(0, 1024)
80
- y = np.random.randint(0, 1024)
81
- points.append(f"{x},{y}")
82
-
83
- path_data = f"M {points[0]} C {points[1]} {points[2]} {points[3]}"
84
- stroke_width = width
85
-
86
- # Create the path element
87
- path = f'<path d="{path_data}" fill="none" stroke="black" stroke-width="{stroke_width}"/>'
88
- svg_paths.append(path)
89
-
90
- svg_footer = '</svg>'
91
- svg = svg_header + ''.join(svg_paths) + svg_footer
92
 
93
  # Create a placeholder image
94
- # In a real implementation, this would be a rendered version of the SVG
95
- image = np.zeros((1024, 1024, 3), dtype=np.uint8)
96
 
97
- # Add some text to the image to indicate it's a placeholder
98
- prompt_text = f"Prompt: {prompt}"
99
- params_text = f"Paths: {num_paths}, Iterations: {num_iter}"
 
100
 
101
  # Return the results
102
- if not return_dict:
103
- return ([image], svg)
104
-
105
- return DiffSketcherPipelineOutput(
106
- images=[image],
107
- svg=svg
108
- )
 
 
 
 
 
 
 
1
 
2
+ from typing import Dict, Any, List, Union
3
+ import torch
4
+ import base64
5
+ import io
6
+ from PIL import Image
 
 
 
 
 
 
7
 
8
+ class Pipeline:
 
 
 
 
 
 
9
  def __init__(self):
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ print(f"Initializing diffsketcher pipeline on {self.device}")
 
 
12
 
13
+ def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
14
+ # Extract prompt from the input data
15
+ prompt = inputs.get("prompt", "")
16
+ if not prompt and "prompts" in inputs:
17
+ prompts = inputs.get("prompts", [""])
18
+ prompt = prompts[0] if prompts else ""
 
 
 
 
 
 
 
 
 
19
 
20
+ # Generate a placeholder SVG
21
+ svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">diffsketcher: {prompt}</text></svg>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Create a placeholder image
24
+ image = Image.new('RGB', (512, 512), color = (100, 100, 100))
 
25
 
26
+ # Convert the image to base64
27
+ buffered = io.BytesIO()
28
+ image.save(buffered, format="PNG")
29
+ img_str = base64.b64encode(buffered.getvalue()).decode()
30
 
31
  # Return the results
32
+ return {
33
+ "svg": svg,
34
+ "image": img_str
35
+ }
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,3 @@
1
 
2
- fastapi
3
- uvicorn
4
- pillow
5
- torch
 
1
 
2
+ torch>=1.7.0
3
+ pillow>=8.0.0