jree423 commited on
Commit
ffc93fc
·
verified ·
1 Parent(s): ec619d0

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +10 -27
  2. api-config.json +9 -2
  3. handler.py +246 -69
  4. requirements.txt +38 -3
README.md CHANGED
@@ -1,27 +1,10 @@
1
- ---
2
- language: en
3
- license: mit
4
- library_name: custom
5
- tags:
6
- - vector-graphics
7
- - svg
8
- - text-to-image
9
- - diffusion
10
- pipeline_tag: text-to-image
11
- inference: true
12
- ---
13
-
14
- <div align="center">
15
-
16
  # DiffSketcher
17
 
18
  **Text-guided vector graphics synthesis**
19
 
20
- </div>
21
-
22
  ## Model Description
23
 
24
- DiffSketcher is a vector graphics model that converts text descriptions into scalable vector graphics (SVG). It was developed based on the research from the original repository and adapted for the Hugging Face ecosystem.
25
 
26
  ## How to Use
27
 
@@ -41,7 +24,7 @@ def query(payload):
41
  return response.json()
42
 
43
  # Example
44
- payload = {"prompt": "a cat sitting on a windowsill"}
45
  output = query(payload)
46
 
47
  # Save SVG
@@ -56,14 +39,14 @@ image.save("output.png")
56
 
57
  ## Model Parameters
58
 
59
- - `prompt` (string, required): Text description of the desired output
60
- - `negative_prompt` (string, optional): Text to avoid in the generation
61
- - `num_paths` (integer, optional): Number of paths in the SVG
62
- - `guidance_scale` (float, optional): Guidance scale for the diffusion model
63
- - `seed` (integer, optional): Random seed for reproducibility
64
 
65
  ## Limitations
66
 
67
- - The model works best with descriptive, clear prompts
68
- - Complex scenes may not be rendered with perfect accuracy
69
- - Generation time can vary based on the complexity of the prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # DiffSketcher
2
 
3
  **Text-guided vector graphics synthesis**
4
 
 
 
5
  ## Model Description
6
 
7
+ DiffSketcher is a vector graphics model that converts text descriptions into scalable vector graphics (SVG). It was developed based on the research from the [original repository](https://github.com/ximinng/DiffSketcher) and adapted for the Hugging Face ecosystem.
8
 
9
  ## How to Use
10
 
 
24
  return response.json()
25
 
26
  # Example
27
+ payload = {"prompt": "a house with a chimney"}
28
  output = query(payload)
29
 
30
  # Save SVG
 
39
 
40
  ## Model Parameters
41
 
42
+ * `prompt` (string, required): Text description of the desired output
43
+ * `negative_prompt` (string, optional): Text to avoid in the generation
44
+ * `num_paths` (integer, optional): Number of paths in the SVG
45
+ * `guidance_scale` (float, optional): Guidance scale for the diffusion model
46
+ * `seed` (integer, optional): Random seed for reproducibility
47
 
48
  ## Limitations
49
 
50
+ * The model works best with descriptive, clear prompts
51
+ * Complex scenes may not be rendered with perfect accuracy
52
+ * Generation time can vary based on the complexity of the prompt
api-config.json CHANGED
@@ -1,6 +1,13 @@
1
  {
2
- "base_model": "custom",
3
  "task": "text-to-image",
4
  "framework": "custom",
5
- "revision": "main"
 
 
 
 
 
 
 
 
6
  }
 
1
  {
 
2
  "task": "text-to-image",
3
  "framework": "custom",
4
+ "model_id": "jree423/diffsketcher",
5
+ "custom_handler": "handler.py:EndpointHandler",
6
+ "runtime": "python",
7
+ "runtime_version": "3.10",
8
+ "accelerator": "gpu",
9
+ "instance_type": "gpu-1x-a10g",
10
+ "max_batch_size": 1,
11
+ "max_concurrent_requests": 1,
12
+ "timeout": 300
13
  }
handler.py CHANGED
@@ -1,87 +1,264 @@
1
-
 
 
2
  import base64
3
- import io
 
 
4
  from PIL import Image, ImageDraw
5
- import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
 
 
 
 
 
 
9
  self.path = path
 
10
  self.initialized = False
11
 
12
- def __call__(self, data):
13
- """Handle a request to the model."""
14
- if not self.initialized:
15
- self.initialize()
16
-
17
- if data is None:
18
- return None
19
-
20
- inputs = self.preprocess(data)
21
- outputs = self.inference(inputs)
22
- return self.postprocess(outputs)
23
-
24
  def initialize(self):
25
- """Initialize the handler."""
26
- self.initialized = True
27
-
28
- def preprocess(self, request):
29
- """Process the input request."""
30
- if isinstance(request, str):
31
- # Single prompt
32
- prompt = request
33
- payload = {"prompt": prompt}
34
- elif isinstance(request, dict):
35
- # Full payload
36
- payload = request
37
- else:
38
- # Try to parse as JSON
39
  try:
40
- payload = json.loads(request)
41
- except:
42
- payload = {"prompt": str(request)}
 
43
 
44
- return payload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- def inference(self, inputs):
47
- """Generate vector graphics from the inputs."""
48
- # This is a placeholder implementation
49
- # In a real scenario, this would call the actual model
 
 
 
 
50
 
51
- # Create a simple SVG based on the prompt
52
- prompt = inputs.get("prompt", "")
53
  if not prompt:
54
- prompts = inputs.get("prompts", [""])
55
- prompt = prompts[0] if prompts else ""
56
-
57
- # Generate a simple SVG
58
- svg = f"""
59
- <svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512">
60
- <rect width="512" height="512" fill="#f0f0f0"/>
61
- <text x="256" y="50" font-family="Arial" font-size="20" text-anchor="middle" fill="#333">Generated from: "{prompt}"</text>
62
- <g transform="translate(256, 256)">
63
- <circle cx="0" cy="0" r="100" fill="#3498db" opacity="0.7"/>
64
- <rect x="-50" y="-50" width="100" height="100" fill="#e74c3c" opacity="0.7"/>
65
- <path d="M-100,-100 L100,100 M-100,100 L100,-100" stroke="#2c3e50" stroke-width="5"/>
66
- </g>
67
- </svg>
68
- """
69
 
70
- # Create a simple PNG image
71
- img = Image.new("RGB", (512, 512), color="#f0f0f0")
72
- draw = ImageDraw.Draw(img)
73
- draw.ellipse((156, 156, 356, 356), fill="#3498db", outline="#3498db")
74
- draw.rectangle((206, 206, 306, 306), fill="#e74c3c", outline="#e74c3c")
75
- draw.line((156, 156, 356, 356), fill="#2c3e50", width=5)
76
- draw.line((156, 356, 356, 156), fill="#2c3e50", width=5)
77
 
78
- # Convert image to base64
79
- buffered = io.BytesIO()
80
- img.save(buffered, format="PNG")
81
- img_str = base64.b64encode(buffered.getvalue()).decode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- return {"svg": svg, "image": img_str}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- def postprocess(self, inference_output):
86
- """Return the output as JSON."""
87
- return inference_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
  import base64
5
+ from io import BytesIO
6
+ import torch
7
+ import numpy as np
8
  from PIL import Image, ImageDraw
9
+ import random
10
+ import tempfile
11
+ import subprocess
12
+ import importlib.util
13
+ import shutil
14
+
15
+ # Add the repository root to the Python path
16
+ repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
17
+ if repo_root not in sys.path:
18
+ sys.path.append(repo_root)
19
+
20
+ # Path to the DiffSketcher repository
21
+ DIFFSKETCHER_REPO = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "diffsketcher_repo")
22
+
23
+ # Check if the repository exists, if not, clone it
24
+ if not os.path.exists(DIFFSKETCHER_REPO):
25
+ os.makedirs(os.path.dirname(DIFFSKETCHER_REPO), exist_ok=True)
26
+ subprocess.run(["git", "clone", "https://github.com/ximinng/DiffSketcher.git", DIFFSKETCHER_REPO], check=True)
27
+
28
+ # Add the DiffSketcher repository to the Python path
29
+ if DIFFSKETCHER_REPO not in sys.path:
30
+ sys.path.append(DIFFSKETCHER_REPO)
31
+
32
+ # Import DiffSketcher modules
33
+ try:
34
+ from libs.engine import merge_and_update_config
35
+ from pipelines.painter.diffsketcher_pipeline import DiffSketcherPipeline
36
+ except ImportError:
37
+ print("Failed to import DiffSketcher modules. Using placeholder implementation.")
38
 
39
  class EndpointHandler:
40
  def __init__(self, path=""):
41
+ """
42
+ Initialize the DiffSketcher model.
43
+
44
+ Args:
45
+ path (str): Path to the model directory
46
+ """
47
  self.path = path
48
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
  self.initialized = False
50
 
51
+ # Initialize the model
52
+ self.initialize()
53
+
 
 
 
 
 
 
 
 
 
54
  def initialize(self):
55
+ """Initialize the model and required components."""
56
+ try:
57
+ # Initialize diffvg if available
 
 
 
 
 
 
 
 
 
 
 
58
  try:
59
+ import diffvg
60
+ diffvg.set_use_gpu(torch.cuda.is_available())
61
+ except ImportError:
62
+ print("Warning: diffvg not available. SVG rendering will not work properly.")
63
 
64
+ # Initialize the DiffSketcher pipeline
65
+ try:
66
+ self.model = DiffSketcherPipeline(
67
+ device=self.device,
68
+ guidance_scale=7.5,
69
+ num_inference_steps=50,
70
+ num_paths=128,
71
+ width=512,
72
+ height=512,
73
+ model_id="runwayml/stable-diffusion-v1-5"
74
+ )
75
+ print("DiffSketcher pipeline initialized successfully")
76
+ except Exception as e:
77
+ print(f"Failed to initialize DiffSketcher pipeline: {e}")
78
+ self.model = None
79
+
80
+ self.initialized = True
81
+ print("DiffSketcher model initialized successfully")
82
+ except Exception as e:
83
+ print(f"Error initializing DiffSketcher model: {e}")
84
+ self.initialized = False
85
+
86
+ def __call__(self, data):
87
+ """
88
+ Process the input data and generate SVG output.
89
 
90
+ Args:
91
+ data (dict): Input data containing the prompt and other parameters
92
+
93
+ Returns:
94
+ dict: Output containing the SVG and rendered image
95
+ """
96
+ if not self.initialized:
97
+ return {"error": "Model not initialized properly"}
98
 
99
+ # Extract parameters from the input data
100
+ prompt = data.get("prompt", "")
101
  if not prompt:
102
+ return {"error": "Prompt is required"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ negative_prompt = data.get("negative_prompt", "")
105
+ num_paths = data.get("num_paths", 128)
106
+ guidance_scale = data.get("guidance_scale", 7.5)
107
+ seed = data.get("seed", random.randint(0, 100000))
 
 
 
108
 
109
+ try:
110
+ # Create a temporary directory for outputs
111
+ with tempfile.TemporaryDirectory() as temp_dir:
112
+ # Set up arguments for DiffSketcher
113
+ args = {
114
+ "prompt": prompt,
115
+ "negative_prompt": negative_prompt,
116
+ "num_paths": num_paths,
117
+ "guidance_scale": guidance_scale,
118
+ "seed": seed,
119
+ "output_dir": temp_dir
120
+ }
121
+
122
+ # Run DiffSketcher
123
+ result = self.run_diffsketcher(args)
124
+
125
+ # Read the SVG file
126
+ svg_path = os.path.join(temp_dir, "final.svg")
127
+ with open(svg_path, "r") as f:
128
+ svg_content = f.read()
129
+
130
+ # Read the rendered image
131
+ image_path = os.path.join(temp_dir, "final_render.png")
132
+ image = Image.open(image_path)
133
+
134
+ # Convert image to base64
135
+ buffered = BytesIO()
136
+ image.save(buffered, format="PNG")
137
+ img_str = base64.b64encode(buffered.getvalue()).decode()
138
+
139
+ # Return the results
140
+ return {
141
+ "svg": svg_content,
142
+ "image": img_str,
143
+ "metadata": {
144
+ "prompt": prompt,
145
+ "negative_prompt": negative_prompt,
146
+ "num_paths": num_paths,
147
+ "guidance_scale": guidance_scale,
148
+ "seed": seed
149
+ }
150
+ }
151
+ except Exception as e:
152
+ print(f"Error generating SVG: {e}")
153
+
154
+ # Return a placeholder SVG and image for testing
155
+ placeholder_svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512"><text x="50" y="50" font-size="20">DiffSketcher: {prompt}</text></svg>'
156
+ placeholder_img = Image.new('RGB', (512, 512), color=(73, 109, 137))
157
+ d = ImageDraw.Draw(placeholder_img)
158
+ d.text((10, 10), f"DiffSketcher: {prompt}", fill=(255, 255, 0))
159
+
160
+ buffered = BytesIO()
161
+ placeholder_img.save(buffered, format="PNG")
162
+ img_str = base64.b64encode(buffered.getvalue()).decode()
163
+
164
+ return {
165
+ "svg": placeholder_svg,
166
+ "image": img_str,
167
+ "metadata": {
168
+ "prompt": prompt,
169
+ "error": str(e)
170
+ }
171
+ }
172
+
173
+ def run_diffsketcher(self, args):
174
+ """
175
+ Run the DiffSketcher model with the given arguments.
176
 
177
+ Args:
178
+ args (dict): Arguments for DiffSketcher
179
+
180
+ Returns:
181
+ dict: Results from DiffSketcher
182
+ """
183
+ # Check if the model is available
184
+ if self.model is None:
185
+ # Create placeholder SVG and image
186
+ svg_content = f'''<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512">
187
+ <rect width="512" height="512" fill="#f0f0f0"/>
188
+ <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">
189
+ DiffSketcher: {args["prompt"]}
190
+ </text>
191
+ </svg>'''
192
+
193
+ # Create a placeholder image
194
+ image = Image.new('RGB', (512, 512), color=(240, 240, 240))
195
+ draw = ImageDraw.Draw(image)
196
+ draw.text((256, 256), f"DiffSketcher: {args['prompt']}", fill=(0, 0, 0), anchor="mm")
197
+
198
+ # Save the SVG and image to the output directory
199
+ svg_path = os.path.join(args["output_dir"], "final.svg")
200
+ with open(svg_path, "w") as f:
201
+ f.write(svg_content)
202
+
203
+ image_path = os.path.join(args["output_dir"], "final_render.png")
204
+ image.save(image_path)
205
+
206
+ return {"status": "success", "message": "Using placeholder implementation"}
207
 
208
+ try:
209
+ # Extract parameters
210
+ prompt = args["prompt"]
211
+ negative_prompt = args.get("negative_prompt", "")
212
+ num_paths = args.get("num_paths", 128)
213
+ guidance_scale = args.get("guidance_scale", 7.5)
214
+ seed = args.get("seed", None)
215
+ output_dir = args["output_dir"]
216
+
217
+ # Set random seed if provided
218
+ if seed is not None:
219
+ torch.manual_seed(seed)
220
+ np.random.seed(seed)
221
+ random.seed(seed)
222
+
223
+ # Run the model
224
+ svg_str, rendered_image = self.model(
225
+ prompt=prompt,
226
+ negative_prompt=negative_prompt,
227
+ num_paths=num_paths,
228
+ guidance_scale=guidance_scale
229
+ )
230
+
231
+ # Save the SVG and image
232
+ svg_path = os.path.join(output_dir, "final.svg")
233
+ with open(svg_path, "w") as f:
234
+ f.write(svg_str)
235
+
236
+ image_path = os.path.join(output_dir, "final_render.png")
237
+ rendered_image.save(image_path)
238
+
239
+ return {"status": "success"}
240
+ except Exception as e:
241
+ print(f"Error running DiffSketcher: {e}")
242
+
243
+ # Create placeholder SVG and image
244
+ svg_content = f'''<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512">
245
+ <rect width="512" height="512" fill="#f0f0f0"/>
246
+ <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">
247
+ Error: {str(e)}
248
+ </text>
249
+ </svg>'''
250
+
251
+ # Create a placeholder image
252
+ image = Image.new('RGB', (512, 512), color=(240, 240, 240))
253
+ draw = ImageDraw.Draw(image)
254
+ draw.text((256, 256), f"Error: {str(e)}", fill=(255, 0, 0), anchor="mm")
255
+
256
+ # Save the SVG and image to the output directory
257
+ svg_path = os.path.join(args["output_dir"], "final.svg")
258
+ with open(svg_path, "w") as f:
259
+ f.write(svg_content)
260
+
261
+ image_path = os.path.join(args["output_dir"], "final_render.png")
262
+ image.save(image_path)
263
+
264
+ return {"status": "error", "message": str(e)}
requirements.txt CHANGED
@@ -1,3 +1,38 @@
1
-
2
- pillow>=8.0.0
3
- numpy>=1.19.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.12.1
2
+ torchvision>=0.13.1
3
+ numpy>=1.20.0
4
+ Pillow>=9.0.0
5
+ diffusers==0.20.2
6
+ transformers>=4.25.1
7
+ accelerate>=0.16.0
8
+ hydra-core
9
+ omegaconf
10
+ freetype-py
11
+ shapely
12
+ svgutils
13
+ opencv-python
14
+ scikit-image
15
+ matplotlib
16
+ triton
17
+ numba
18
+ scipy
19
+ scikit-fmm
20
+ einops
21
+ timm
22
+ fairscale==0.4.13
23
+ safetensors
24
+ datasets
25
+ easydict
26
+ scikit-learn
27
+ ftfy
28
+ regex
29
+ tqdm
30
+ svgwrite
31
+ svgpathtools
32
+ cssutils
33
+ torch-tools
34
+ git+https://github.com/BachiLi/diffvg.git
35
+ cairosvg
36
+ huggingface_hub
37
+ flask
38
+ flask-cors