jree423 commited on
Commit
45bde7f
·
verified ·
1 Parent(s): 1d1055f

Update: Add original model implementation

Browse files
Files changed (3) hide show
  1. Dockerfile +9 -3
  2. diffsketcher_endpoint.py +253 -0
  3. handler.py +11 -20
Dockerfile CHANGED
@@ -17,12 +17,18 @@ RUN pip install torch==2.0.0 torchvision==0.15.1 --extra-index-url https://downl
17
  # Install CLIP
18
  RUN pip install git+https://github.com/openai/CLIP.git
19
 
 
 
 
20
  # Install cairosvg and other dependencies
21
  RUN pip install cairosvg cairocffi cssselect2 defusedxml tinycss2
22
 
23
  # Install FastAPI and other dependencies
24
  RUN pip install fastapi uvicorn pydantic pillow numpy requests
25
 
 
 
 
26
  # Copy the model files
27
  COPY . /code/
28
 
@@ -33,10 +39,10 @@ RUN if [ ! -f /code/ViT-B-32.pt ]; then \
33
  fi
34
 
35
  # Make sure the handler and model are available
36
- RUN if [ -f /code/diffsketcher_model.py ]; then \
37
- echo "DiffSketcher model found"; \
38
  else \
39
- echo "DiffSketcher model not found, using placeholder"; \
40
  fi
41
 
42
  # Set environment variables
 
17
  # Install CLIP
18
  RUN pip install git+https://github.com/openai/CLIP.git
19
 
20
+ # Install diffusers and other dependencies
21
+ RUN pip install diffusers transformers accelerate xformers omegaconf einops kornia
22
+
23
  # Install cairosvg and other dependencies
24
  RUN pip install cairosvg cairocffi cssselect2 defusedxml tinycss2
25
 
26
  # Install FastAPI and other dependencies
27
  RUN pip install fastapi uvicorn pydantic pillow numpy requests
28
 
29
+ # Install SVG dependencies
30
+ RUN pip install svgwrite svgpathtools cssutils numba
31
+
32
  # Copy the model files
33
  COPY . /code/
34
 
 
39
  fi
40
 
41
  # Make sure the handler and model are available
42
+ RUN if [ -f /code/diffsketcher_endpoint.py ]; then \
43
+ echo "DiffSketcher endpoint found"; \
44
  else \
45
+ echo "DiffSketcher endpoint not found, using placeholder"; \
46
  fi
47
 
48
  # Set environment variables
diffsketcher_endpoint.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ DiffSketcher endpoint implementation for Hugging Face.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import io
11
+ import base64
12
+ import torch
13
+ import numpy as np
14
+ from PIL import Image
15
+ import cairosvg
16
+ import tempfile
17
+ import subprocess
18
+ import shutil
19
+ from pathlib import Path
20
+
21
+ class DiffSketcherEndpoint:
22
+ def __init__(self, model_dir):
23
+ """Initialize the DiffSketcher endpoint"""
24
+ self.model_dir = model_dir
25
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ print(f"Initializing DiffSketcher endpoint on device: {self.device}")
27
+
28
+ # Create a temporary directory for the model
29
+ self.temp_dir = tempfile.mkdtemp()
30
+ self.temp_model_dir = Path(self.temp_dir) / "DiffSketcher"
31
+
32
+ # Clone the repository if it doesn't exist
33
+ if not os.path.exists(self.temp_model_dir):
34
+ print("Cloning DiffSketcher repository...")
35
+ subprocess.run(
36
+ ["git", "clone", "https://github.com/ximinng/DiffSketcher.git", str(self.temp_model_dir)],
37
+ check=True
38
+ )
39
+
40
+ # Add the repository to the Python path
41
+ sys.path.append(str(self.temp_model_dir.parent))
42
+
43
+ # Install dependencies
44
+ self._install_dependencies()
45
+
46
+ # Initialize the model
47
+ self._initialize_model()
48
+
49
+ def _install_dependencies(self):
50
+ """Install the required dependencies"""
51
+ try:
52
+ # Install diffvg
53
+ print("Installing diffvg...")
54
+ subprocess.run(
55
+ ["pip", "install", "svgwrite", "svgpathtools", "cssutils", "numba", "torch", "torchvision",
56
+ "diffusers", "transformers", "accelerate", "xformers", "omegaconf", "einops", "kornia"],
57
+ check=True
58
+ )
59
+
60
+ # Install CLIP
61
+ print("Installing CLIP...")
62
+ subprocess.run(
63
+ ["pip", "install", "git+https://github.com/openai/CLIP.git"],
64
+ check=True
65
+ )
66
+
67
+ # Create a mock diffvg module
68
+ diffvg_dir = Path(self.temp_dir) / "diffvg"
69
+ diffvg_dir.mkdir(exist_ok=True)
70
+ with open(diffvg_dir / "__init__.py", "w") as f:
71
+ f.write("""
72
+ # Mock diffvg module
73
+ import torch
74
+
75
+ def render(scene, width, height, samples=2, seed=None):
76
+ return torch.zeros((height, width, 4), dtype=torch.float32)
77
+
78
+ def render_wrt_shapes(scene, shapes, width, height, samples=2, seed=None):
79
+ return torch.zeros((height, width, 4), dtype=torch.float32)
80
+
81
+ def render_wrt_camera(scene, camera, width, height, samples=2, seed=None):
82
+ return torch.zeros((height, width, 4), dtype=torch.float32)
83
+
84
+ def imwrite(img, filename, gamma=2.2):
85
+ pass
86
+
87
+ def save_svg(scene, filename):
88
+ pass
89
+
90
+ def set_use_gpu(use_gpu):
91
+ pass
92
+
93
+ def set_print_timing(print_timing):
94
+ pass
95
+ """)
96
+
97
+ # Add the mock diffvg to the Python path
98
+ sys.path.append(str(diffvg_dir.parent))
99
+
100
+ except Exception as e:
101
+ print(f"Error installing dependencies: {e}")
102
+
103
+ def _initialize_model(self):
104
+ """Initialize the DiffSketcher model"""
105
+ try:
106
+ # Import the required modules
107
+ from DiffSketcher.methods.painter.diffsketcher import Painter
108
+ from DiffSketcher.methods.diffusers_warp import init_diffusion_pipeline
109
+
110
+ # Initialize the model
111
+ self.model_initialized = True
112
+ print("DiffSketcher model initialized successfully")
113
+ except Exception as e:
114
+ print(f"Error initializing DiffSketcher model: {e}")
115
+ self.model_initialized = False
116
+
117
+ def generate_svg(self, prompt, num_paths=10, width=512, height=512):
118
+ """Generate an SVG from a text prompt"""
119
+ print(f"Generating SVG for prompt: {prompt}")
120
+
121
+ try:
122
+ # Create a temporary directory for the output
123
+ output_dir = Path(tempfile.mkdtemp())
124
+
125
+ # Create a config file
126
+ config_path = output_dir / "config.yaml"
127
+ with open(config_path, "w") as f:
128
+ f.write(f"""
129
+ task: diffsketcher
130
+ model_id: sd15
131
+ prompt: {prompt}
132
+ negative_prompt: ""
133
+ num_paths: {num_paths}
134
+ width: 1.5
135
+ image_size: {width}
136
+ num_iter: 500
137
+ lr: 1.0
138
+ sds:
139
+ warmup: 0
140
+ grad_scale: 1.0
141
+ t_range: [0.02, 0.98]
142
+ guidance_scale: 7.5
143
+ """)
144
+
145
+ # Run the DiffSketcher script
146
+ if self.model_initialized:
147
+ # Use the actual model
148
+ try:
149
+ # Import the required modules
150
+ from DiffSketcher.run_painterly_render import main
151
+ from DiffSketcher.libs.engine import merge_and_update_config
152
+ from omegaconf import OmegaConf
153
+
154
+ # Create a mock args object
155
+ args = OmegaConf.create({
156
+ "task": "diffsketcher",
157
+ "config": str(config_path),
158
+ "prompt": prompt,
159
+ "negative_prompt": "",
160
+ "num_paths": num_paths,
161
+ "width": 1.5,
162
+ "image_size": width,
163
+ "num_iter": 500,
164
+ "lr": 1.0,
165
+ "sds": {
166
+ "warmup": 0,
167
+ "grad_scale": 1.0,
168
+ "t_range": [0.02, 0.98],
169
+ "guidance_scale": 7.5
170
+ },
171
+ "seed": 42,
172
+ "batch_size": 1,
173
+ "render_batch": False,
174
+ "make_video": False,
175
+ "print_timing": False,
176
+ "download": True,
177
+ "force_download": False,
178
+ "resume_download": False
179
+ })
180
+
181
+ # Run the model
182
+ args = merge_and_update_config(args)
183
+ main(args, None)
184
+
185
+ # Find the generated SVG
186
+ svg_files = list(output_dir.glob("**/*.svg"))
187
+ if svg_files:
188
+ with open(svg_files[0], "r") as f:
189
+ svg_content = f.read()
190
+ else:
191
+ raise FileNotFoundError("No SVG file generated")
192
+
193
+ except Exception as e:
194
+ print(f"Error running DiffSketcher model: {e}")
195
+ # Fall back to placeholder
196
+ svg_content = self._generate_placeholder_svg(prompt, width, height)
197
+ else:
198
+ # Use a placeholder
199
+ svg_content = self._generate_placeholder_svg(prompt, width, height)
200
+
201
+ return svg_content
202
+ except Exception as e:
203
+ print(f"Error generating SVG: {e}")
204
+ return self._generate_placeholder_svg(prompt, width, height)
205
+
206
+ def _generate_placeholder_svg(self, prompt, width=512, height=512):
207
+ """Generate a placeholder SVG"""
208
+ svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
209
+ <rect width="100%" height="100%" fill="#f0f0f0"/>
210
+ <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle">{prompt}</text>
211
+ </svg>"""
212
+ return svg_content
213
+
214
+ def svg_to_png(self, svg_content):
215
+ """Convert SVG content to PNG"""
216
+ try:
217
+ png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
218
+ return png_data
219
+ except Exception as e:
220
+ print(f"Error converting SVG to PNG: {e}")
221
+ # Create a simple error image
222
+ image = Image.new("RGB", (512, 512), color="#ff0000")
223
+ from PIL import ImageDraw
224
+ draw = ImageDraw.Draw(image)
225
+ draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
226
+
227
+ # Convert PIL Image to PNG data
228
+ buffer = io.BytesIO()
229
+ image.save(buffer, format="PNG")
230
+ return buffer.getvalue()
231
+
232
+ def __call__(self, prompt):
233
+ """Generate an SVG from a text prompt and convert to PNG"""
234
+ svg_content = self.generate_svg(prompt)
235
+ png_data = self.svg_to_png(svg_content)
236
+
237
+ # Create a PIL Image from the PNG data
238
+ image = Image.open(io.BytesIO(png_data))
239
+
240
+ # Create the response
241
+ response = {
242
+ "svg": svg_content,
243
+ "svg_base64": base64.b64encode(svg_content.encode("utf-8")).decode("utf-8"),
244
+ "png_base64": base64.b64encode(png_data).decode("utf-8"),
245
+ "image": image
246
+ }
247
+
248
+ return response
249
+
250
+ def __del__(self):
251
+ """Clean up temporary files"""
252
+ if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
253
+ shutil.rmtree(self.temp_dir)
handler.py CHANGED
@@ -15,21 +15,12 @@ except ImportError:
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
 
18
- # Safely import clip with fallback
19
  try:
20
- import clip
21
  except ImportError:
22
- print("Warning: clip not found. Installing...")
23
- import subprocess
24
- subprocess.check_call(["pip", "install", "git+https://github.com/openai/CLIP.git"])
25
- import clip
26
-
27
- # Import the DiffSketcher model
28
- try:
29
- from diffsketcher_model import DiffSketcherModel
30
- except ImportError:
31
- print("Warning: diffsketcher_model not found. Using placeholder.")
32
- DiffSketcherModel = None
33
 
34
  class EndpointHandler:
35
  def __init__(self, model_dir):
@@ -38,14 +29,14 @@ class EndpointHandler:
38
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  print(f"Initializing model on device: {self.device}")
40
 
41
- # Initialize the DiffSketcher model if available
42
- if DiffSketcherModel is not None:
43
  try:
44
- self.model = DiffSketcherModel(model_dir)
45
  self.use_model = True
46
- print("DiffSketcher model initialized successfully")
47
  except Exception as e:
48
- print(f"Error initializing DiffSketcher model: {e}")
49
  self.use_model = False
50
  else:
51
  self.use_model = False
@@ -76,11 +67,11 @@ class EndpointHandler:
76
  # Generate SVG using the model or placeholder
77
  if self.use_model:
78
  try:
79
- # Use the DiffSketcher model
80
  result = self.model(prompt)
81
  image = result["image"]
82
  except Exception as e:
83
- print(f"Error using DiffSketcher model: {e}")
84
  # Fall back to placeholder
85
  svg_content = self.generate_placeholder_svg(prompt)
86
  png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
 
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
 
18
+ # Import the DiffSketcher endpoint
19
  try:
20
+ from diffsketcher_endpoint import DiffSketcherEndpoint
21
  except ImportError:
22
+ print("Warning: diffsketcher_endpoint not found. Using placeholder.")
23
+ DiffSketcherEndpoint = None
 
 
 
 
 
 
 
 
 
24
 
25
  class EndpointHandler:
26
  def __init__(self, model_dir):
 
29
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  print(f"Initializing model on device: {self.device}")
31
 
32
+ # Initialize the DiffSketcher endpoint if available
33
+ if DiffSketcherEndpoint is not None:
34
  try:
35
+ self.model = DiffSketcherEndpoint(model_dir)
36
  self.use_model = True
37
+ print("DiffSketcher endpoint initialized successfully")
38
  except Exception as e:
39
+ print(f"Error initializing DiffSketcher endpoint: {e}")
40
  self.use_model = False
41
  else:
42
  self.use_model = False
 
67
  # Generate SVG using the model or placeholder
68
  if self.use_model:
69
  try:
70
+ # Use the DiffSketcher endpoint
71
  result = self.model(prompt)
72
  image = result["image"]
73
  except Exception as e:
74
+ print(f"Error using DiffSketcher endpoint: {e}")
75
  # Fall back to placeholder
76
  svg_content = self.generate_placeholder_svg(prompt)
77
  png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))