OmniSVG commited on
Commit
0a475cf
·
verified ·
1 Parent(s): 2c72c48

Upload 3 files

Browse files
Files changed (2) hide show
  1. decoder.py +38 -0
  2. tokenizer.py +284 -0
decoder.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
4
+
5
+ class SketchDecoder(nn.Module):
6
+ """
7
+ Autoregressive generative model
8
+ """
9
+
10
+ def __init__(self,
11
+ **kwargs):
12
+ super().__init__()
13
+ self.vocab_size = 196042
14
+ self.bos_token_id = 151643
15
+ self.eos_token_id = 196041
16
+ self.pad_token_id = 151643
17
+
18
+ config = AutoConfig.from_pretrained(
19
+ "Qwen/Qwen2.5-VL-3B-Instruct",
20
+ #n_positions=8192,
21
+ vocab_size=self.vocab_size,
22
+ bos_token_id=self.bos_token_id,
23
+ eos_token_id=self.eos_token_id,
24
+ pad_token_id=self.pad_token_id)
25
+
26
+ self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
27
+ "Qwen/Qwen2.5-VL-3B-Instruct",
28
+ config=config,
29
+ #torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
30
+ #device_map ="cuda",
31
+ ignore_mismatched_sizes=True
32
+ )
33
+
34
+ self.transformer.resize_token_embeddings(self.vocab_size)
35
+
36
+ def forward(self, *args, **kwargs):
37
+ raise NotImplementedError("Forward pass not included in open-source version")
38
+
tokenizer.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import yaml
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+ from deepsvg.difflib.tensor import SVGTensor
6
+ from deepsvg.svglib.svg import SVG
7
+ from deepsvg.svglib.geom import Bbox
8
+
9
+
10
+ class SVGTokenizer:
11
+ """SVG tokenizer for converting between tokens and SVG representations"""
12
+
13
+ def __init__(self, config_path: str = "https://huggingface.co/OmniSVG/OmniSVG/resolve/main/config.yaml"):
14
+ with open(config_path, 'r') as f:
15
+ self.config = yaml.safe_load(f)
16
+
17
+ # Extract configuration values
18
+ self.tokens_config = self.config['tokens']
19
+ self.coordinates_config = self.config['coordinates']
20
+ self.colors_config = self.config['colors']
21
+ self.svg_commands = self.config['svg_commands']
22
+
23
+ self.pixel2xy = self._create_pixel2xy_mapping()
24
+
25
+ def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]:
26
+ """Create mapping from pixel indices to xy coordinates"""
27
+ bbox = self.coordinates_config['bbox']
28
+ coord_pad = self.coordinates_config['coord_pad_offset']
29
+ svg_end = self.tokens_config['svg_end']
30
+
31
+ pixel2xy = {}
32
+ x = np.linspace(0, bbox-1, bbox)
33
+ y = np.linspace(0, bbox-1, bbox)
34
+ xx, yy = np.meshgrid(x, y)
35
+ xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)
36
+
37
+ for pixel, xy in enumerate(xy_grid):
38
+ pixel2xy[pixel] = xy + coord_pad + svg_end
39
+
40
+ return pixel2xy
41
+
42
+ def token_to_color(self, color_token: int) -> str:
43
+ try:
44
+ color_token_start = self.colors_config['color_token_start']
45
+ max_color_tokens = self.colors_config['max_color_tokens']
46
+
47
+ # Check special color tokens
48
+ if color_token == color_token_start:
49
+ return "none" # No color
50
+ elif color_token == color_token_start + 1:
51
+ return "currentColor" # Special color
52
+
53
+ color_index = color_token - (color_token_start + 2)
54
+ if color_index < 0 or color_index >= max_color_tokens:
55
+ print(f"Warning: Color token {color_token} out of range, using default color")
56
+ return "#808080" # Gray as default
57
+
58
+ r = (color_index >> 8) & 0xF
59
+ g = (color_index >> 4) & 0xF
60
+ b = color_index & 0xF
61
+
62
+ r = (r << 4) | r
63
+ g = (g << 4) | g
64
+ b = (b << 4) | b
65
+
66
+ return f"#{r:02x}{g:02x}{b:02x}"
67
+
68
+ except Exception as e:
69
+ print(f"Error in token_to_color: {e}")
70
+ return "#808080"
71
+
72
+ def pixel_to_xy(self, pixel: int) -> np.ndarray:
73
+ """Convert pixel token to xy coordinates"""
74
+ base_offset = self.tokens_config['base_offset']
75
+ pix_pad = self.coordinates_config['pix_pad_offset']
76
+ svg_end = self.tokens_config['svg_end']
77
+
78
+ if self.tokens_config['eom'] < pixel < pix_pad + svg_end:
79
+ xy = np.array([pixel - base_offset, pixel - base_offset]).astype(int)
80
+ return xy
81
+ elif pix_pad + svg_end <= pixel < self.colors_config['cmd_fill'] + base_offset + svg_end:
82
+ pixel_index = pixel - pix_pad - svg_end
83
+ if pixel_index in self.pixel2xy:
84
+ return self.pixel2xy[pixel_index] - base_offset
85
+ else:
86
+ raise ValueError(f"Invalid pixel index: {pixel_index}")
87
+ else:
88
+ raise ValueError(f"Invalid pixel token: {pixel}")
89
+
90
+ def raster_svg(self, pixels: np.ndarray) -> List[List[torch.Tensor]]:
91
+ """Convert pixel sequence to SVG tensor representation"""
92
+ try:
93
+ adjustment = self.tokens_config['num_end_token'] + self.tokens_config['svg_end'] + 2 # 8
94
+ pixels = pixels - adjustment
95
+
96
+ svg_tensors = []
97
+ path_tensor = []
98
+ i = 0
99
+
100
+ while i < len(pixels):
101
+ try:
102
+ pix = pixels[i]
103
+
104
+ if pix[0] == self.svg_commands['move']: # Move command
105
+ cmd_tensor = np.zeros(14)
106
+ cmd_tensor[0] = 0
107
+
108
+ if i + 2 >= len(pixels):
109
+ break
110
+
111
+ cmd_tensor[12:14] = pixels[i+2]
112
+ start_pos = pixels[i+1]
113
+ end_pos = pixels[i+2]
114
+
115
+ if np.all(start_pos == end_pos) and path_tensor:
116
+ svg_tensors.append(torch.tensor(path_tensor))
117
+ path_tensor = []
118
+ path_tensor.append(cmd_tensor.tolist())
119
+ i += 3
120
+
121
+ elif pix[0] == self.svg_commands['line']: # Line command
122
+ cmd_tensor = np.zeros(14)
123
+ cmd_tensor[0] = 1
124
+
125
+ if i + 1 >= len(pixels):
126
+ break
127
+
128
+ cmd_tensor[12:14] = pixels[i+1]
129
+ path_tensor.append(cmd_tensor.tolist())
130
+ i += 2
131
+
132
+ elif pix[0] == self.svg_commands['curve']: # Curve command
133
+ cmd_tensor = np.zeros(14)
134
+ cmd_tensor[0] = 2
135
+
136
+ if i + 3 >= len(pixels):
137
+ break
138
+
139
+ cmd_tensor[8:10] = pixels[i+1]
140
+ cmd_tensor[10:12] = pixels[i+2]
141
+ cmd_tensor[12:14] = pixels[i+3]
142
+ path_tensor.append(cmd_tensor.tolist())
143
+ i += 4
144
+
145
+ elif pix[0] == self.svg_commands['arc']: # Arc command
146
+ cmd_tensor = np.zeros(14)
147
+ cmd_tensor[0] = 3
148
+
149
+ if i + 5 >= len(pixels):
150
+ break
151
+
152
+ radius = pixels[i+1]
153
+ x_axis_rot = pixels[i+2][0]
154
+ large_arc_flg = pixels[i+3][0]
155
+ sweep_flg = pixels[i+4][0]
156
+ end_pos = pixels[i+5]
157
+
158
+ cmd_tensor[1:3] = radius
159
+ cmd_tensor[3] = x_axis_rot
160
+ cmd_tensor[4] = large_arc_flg
161
+ cmd_tensor[5] = sweep_flg
162
+ cmd_tensor[12:14] = end_pos
163
+ path_tensor.append(cmd_tensor.tolist())
164
+ i += 6
165
+
166
+ elif pix[0] == self.svg_commands['close']: # Close command
167
+ cmd_tensor = np.zeros(14)
168
+ cmd_tensor[0] = 6
169
+
170
+ if i + 1 >= len(pixels):
171
+ break
172
+
173
+ cmd_tensor[12:14] = pixels[i+1]
174
+ path_tensor.append(cmd_tensor.tolist())
175
+ i += 2
176
+ else:
177
+ i += 1
178
+
179
+ except IndexError:
180
+ print(f"Index error at position {i}, stopping SVG processing")
181
+ break
182
+
183
+ if path_tensor:
184
+ svg_tensors.append(torch.tensor(path_tensor))
185
+
186
+ return [svg_tensors]
187
+
188
+ except Exception as e:
189
+ print(f"Error in raster_svg: {e}")
190
+ return []
191
+
192
+ def extract_colors_from_tokens(self, tokens: List[int]) -> List[int]:
193
+ colors = []
194
+ base_offset = self.tokens_config['base_offset']
195
+ color_start = self.colors_config['color_start_offset']
196
+ color_end = self.colors_config['color_end_offset']
197
+
198
+ for token in tokens:
199
+ if color_start <= token < color_end:
200
+ colors.append(token - 1 - base_offset)
201
+
202
+ return colors
203
+
204
+ def process_generated_tokens(self, output_ids: torch.Tensor) -> Tuple[np.ndarray, List[int]]:
205
+ # Remove <bos> and <eos> tokens
206
+ generated_pixels = output_ids[:, 1:-1].tolist()
207
+
208
+ generated_xy = []
209
+ generated_colors = []
210
+
211
+ for pixel_sequence in generated_pixels:
212
+ xy_sequence = []
213
+ colors = []
214
+
215
+ for pixel in pixel_sequence:
216
+ try:
217
+ if self.tokens_config['eom'] < pixel < self.coordinates_config['pix_pad_offset'] + self.tokens_config['svg_end']:
218
+ xy = self.pixel_to_xy(pixel)
219
+ xy_sequence.append(xy)
220
+ elif self.coordinates_config['pix_pad_offset'] + self.tokens_config['svg_end'] <= pixel < self.colors_config['cmd_fill'] + self.tokens_config['base_offset'] + self.tokens_config['svg_end']:
221
+ xy = self.pixel_to_xy(pixel)
222
+ xy_sequence.append(xy)
223
+ elif self.colors_config['color_start_offset'] <= pixel < self.colors_config['color_end_offset']:
224
+ colors.append(pixel - 1 - self.tokens_config['base_offset'])
225
+ except ValueError as e:
226
+ print(f"Error processing pixel {pixel}: {e}")
227
+ continue
228
+
229
+ if xy_sequence:
230
+ generated_xy = np.vstack(xy_sequence)
231
+ generated_colors = colors
232
+
233
+ return generated_xy, generated_colors
234
+
235
+ def apply_colors_to_svg(self, svg_tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]], colors: Optional[List[int]]) -> SVG:
236
+ paths = []
237
+ bbox = self.coordinates_config['bbox']
238
+
239
+ flat_tensors = []
240
+ if svg_tensors and isinstance(svg_tensors[0], list):
241
+ for tensor_list in svg_tensors:
242
+ flat_tensors.extend(tensor_list)
243
+ else:
244
+ flat_tensors = svg_tensors
245
+
246
+ if not flat_tensors:
247
+ raise ValueError("No valid SVG tensors provided")
248
+
249
+ if colors is None:
250
+ colors = []
251
+
252
+ for i, path_tensor in enumerate(flat_tensors):
253
+ try:
254
+ path = SVGTensor.from_data(path_tensor)
255
+ path = SVG.from_tensor(path.data, viewbox=Bbox(bbox))
256
+
257
+ if i < len(colors):
258
+ color_token = colors[i]
259
+ actual_color = self.token_to_color(color_token)
260
+ else:
261
+ actual_color = "none"
262
+
263
+ for path_group in path:
264
+ path_group.color = actual_color
265
+ path_group.stroke_color = "none"
266
+
267
+ path.fill_(True)
268
+ paths.append(path)
269
+
270
+
271
+ except Exception as e:
272
+ print(f"Error processing path {i}: {e}")
273
+ continue
274
+
275
+ if not paths:
276
+ raise ValueError("No valid paths could be generated")
277
+ path_groups = paths[0].svg_path_groups
278
+ for i in range(1, len(paths)):
279
+ if i < len(paths):
280
+ path_groups.extend(paths[i].svg_path_groups)
281
+
282
+ svg = SVG(path_groups, viewbox=Bbox(bbox))
283
+
284
+ return svg