MohamedRashad commited on
Commit
27fa9cc
·
1 Parent(s): 11b119e

Refactor skin weight calculations to handle division by zero and ensure valid index access in Exporter and SAMPart3DDataset classes

Browse files
src/data/exporter.py CHANGED
@@ -290,11 +290,15 @@ class Exporter():
290
  vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted]
291
  if group_per_vertex == -1:
292
  group_per_vertex = vertex_group_reweight.shape[-1]
 
 
293
  if not do_not_normalize:
294
- vertex_group_reweight = vertex_group_reweight / vertex_group_reweight[..., :group_per_vertex].sum(axis=1)[...,None]
 
 
295
 
296
  for v, w in enumerate(skin):
297
- for ii in range(group_per_vertex):
298
  i = argsorted[v, ii]
299
  if i >= J:
300
  continue
 
290
  vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted]
291
  if group_per_vertex == -1:
292
  group_per_vertex = vertex_group_reweight.shape[-1]
293
+ # Ensure we don't access more columns than available
294
+ max_groups = min(group_per_vertex, argsorted.shape[1])
295
  if not do_not_normalize:
296
+ vertex_group_sum = vertex_group_reweight[..., :max_groups].sum(axis=1)[..., None]
297
+ vertex_group_sum = np.where(vertex_group_sum == 0, 1.0, vertex_group_sum) # Avoid division by zero
298
+ vertex_group_reweight = vertex_group_reweight / vertex_group_sum
299
 
300
  for v, w in enumerate(skin):
301
+ for ii in range(max_groups):
302
  i = argsorted[v, ii]
303
  if i >= J:
304
  continue
src/inference/merge.py CHANGED
@@ -262,7 +262,11 @@ def make_armature(
262
 
263
  argsorted = np.argsort(-skin, axis=1)
264
  vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted]
265
- vertex_group_reweight = vertex_group_reweight / vertex_group_reweight[..., :group_per_vertex].sum(axis=1)[...,None]
 
 
 
 
266
  vertex_group_reweight = np.nan_to_num(vertex_group_reweight)
267
  tree = cKDTree(vertices)
268
  for ob in objects:
@@ -286,7 +290,9 @@ def make_armature(
286
  _, index = tree.query(n_vertices)
287
 
288
  for v, co in enumerate(tqdm(n_vertices)):
289
- for ii in range(group_per_vertex):
 
 
290
  i = argsorted[index[v], ii]
291
  if i >= len(names):
292
  continue
 
262
 
263
  argsorted = np.argsort(-skin, axis=1)
264
  vertex_group_reweight = skin[np.arange(skin.shape[0])[..., None], argsorted]
265
+ # Handle division by zero and limit to group_per_vertex
266
+ max_groups = min(group_per_vertex, vertex_group_reweight.shape[1])
267
+ vertex_group_sum = vertex_group_reweight[..., :max_groups].sum(axis=1)[..., None]
268
+ vertex_group_sum = np.where(vertex_group_sum == 0, 1.0, vertex_group_sum) # Avoid division by zero
269
+ vertex_group_reweight = vertex_group_reweight / vertex_group_sum
270
  vertex_group_reweight = np.nan_to_num(vertex_group_reweight)
271
  tree = cKDTree(vertices)
272
  for ob in objects:
 
290
  _, index = tree.query(n_vertices)
291
 
292
  for v, co in enumerate(tqdm(n_vertices)):
293
+ # Ensure we don't access more columns than available in argsorted
294
+ max_groups = min(group_per_vertex, argsorted.shape[1])
295
+ for ii in range(max_groups):
296
  i = argsorted[index[v], ii]
297
  if i >= len(names):
298
  continue
src/model/michelangelo/models/modules/transformer_blocks.py CHANGED
@@ -35,13 +35,13 @@ def init_linear(l, stddev):
35
  nn.init.constant_(l.bias, 0.0)
36
 
37
  def flash_attention(q, k, v):
38
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
39
- q = q.transpose(1, 2)
40
- k = k.transpose(1, 2)
41
- v = v.transpose(1, 2)
42
- out = F.scaled_dot_product_attention(q, k, v)
43
- out = out.transpose(1, 2)
44
- # print("use flash atten 2")
45
 
46
  return out
47
 
 
35
  nn.init.constant_(l.bias, 0.0)
36
 
37
  def flash_attention(q, k, v):
38
+ # Use torch.nn.functional.scaled_dot_product_attention directly instead of deprecated context manager
39
+ q = q.transpose(1, 2)
40
+ k = k.transpose(1, 2)
41
+ v = v.transpose(1, 2)
42
+ out = F.scaled_dot_product_attention(q, k, v)
43
+ out = out.transpose(1, 2)
44
+ # print("use flash atten 2")
45
 
46
  return out
47
 
src/model/pointcept/datasets/dataset_render_16views.py CHANGED
@@ -387,13 +387,17 @@ class SAMPart3DDataset16Views(Dataset):
387
  if per_pixel_index.shape[-1] == 1:
388
  per_pixel_mask = per_pixel_index.squeeze()
389
  else:
 
 
390
  per_pixel_mask = torch.gather(
391
- per_pixel_index, 1, random_index.unsqueeze(-1)
392
  ).squeeze()
 
 
393
  per_pixel_mask_ = torch.gather(
394
  per_pixel_index,
395
  1,
396
- torch.max(random_index.unsqueeze(-1) - 1, torch.Tensor([0]).int()),
397
  ).squeeze()
398
 
399
  mask_id[i : i + npximg] = per_pixel_mask.to(self.device)
 
387
  if per_pixel_index.shape[-1] == 1:
388
  per_pixel_mask = per_pixel_index.squeeze()
389
  else:
390
+ # Clamp random_index to valid range to prevent out of bounds error
391
+ random_index_clamped = torch.clamp(random_index.unsqueeze(-1), 0, per_pixel_index.shape[1] - 1)
392
  per_pixel_mask = torch.gather(
393
+ per_pixel_index, 1, random_index_clamped
394
  ).squeeze()
395
+ # Clamp the previous index to valid range as well
396
+ prev_index_clamped = torch.clamp(random_index.unsqueeze(-1) - 1, 0, per_pixel_index.shape[1] - 1)
397
  per_pixel_mask_ = torch.gather(
398
  per_pixel_index,
399
  1,
400
+ prev_index_clamped,
401
  ).squeeze()
402
 
403
  mask_id[i : i + npximg] = per_pixel_mask.to(self.device)
src/model/unirig_skin.py CHANGED
@@ -356,7 +356,7 @@ class UniRigSkin(ModelSpec):
356
  ptv3_input = {
357
  'coord': vertices.reshape(-1, 3),
358
  'feat': feat.reshape(-1, 9),
359
- 'offset': torch.tensor(batch['offset']),
360
  'grid_size': self.grid_size,
361
  }
362
  if not self.training:
@@ -420,7 +420,7 @@ class UniRigSkin(ModelSpec):
420
  input_features = attn_weight[i, :, :num_bones[i], :].reshape(-1, attn_weight.shape[-1])
421
 
422
  pred = self.skinweight_pred(input_features).reshape(cur_N, num_bones[i])
423
- skin_pred[i, :, :num_bones[i]] = F.softmax(pred)
424
  skin_pred_list.append(skin_pred)
425
  skin_pred_list = torch.cat(skin_pred_list, dim=1)
426
  for i in range(B):
@@ -437,4 +437,5 @@ class UniRigSkin(ModelSpec):
437
  outputs = []
438
  for i in range(skin_pred.shape[0]):
439
  outputs.append(skin_pred[i, :, :num_bones[i]])
 
440
  return outputs
 
356
  ptv3_input = {
357
  'coord': vertices.reshape(-1, 3),
358
  'feat': feat.reshape(-1, 9),
359
+ 'offset': batch['offset'].detach().clone(),
360
  'grid_size': self.grid_size,
361
  }
362
  if not self.training:
 
420
  input_features = attn_weight[i, :, :num_bones[i], :].reshape(-1, attn_weight.shape[-1])
421
 
422
  pred = self.skinweight_pred(input_features).reshape(cur_N, num_bones[i])
423
+ skin_pred[i, :, :num_bones[i]] = F.softmax(pred, dim=-1)
424
  skin_pred_list.append(skin_pred)
425
  skin_pred_list = torch.cat(skin_pred_list, dim=1)
426
  for i in range(B):
 
437
  outputs = []
438
  for i in range(skin_pred.shape[0]):
439
  outputs.append(skin_pred[i, :, :num_bones[i]])
440
+
441
  return outputs
src/system/skin.py CHANGED
@@ -118,7 +118,7 @@ class SkinWriter(BasePredictionWriter):
118
  J = num_bones[id]
119
  F = num_faces[id]
120
  o_vertices = vertices[id, :N]
121
-
122
  _parents = parents_list[id]
123
  parents = []
124
  for i in range(J):
@@ -156,36 +156,34 @@ class SkinWriter(BasePredictionWriter):
156
  raw_data = RawSkin(skin=skin_pred, vertices=sampled_vertices[id], joints=joints[id, :J])
157
  if self.export_npz is not None:
158
  raw_data.save(path=make_path(self.export_npz, 'npz'))
 
159
  if self.export_fbx is not None:
160
- try:
161
- exporter = Exporter()
162
- names = RawData.load(path=os.path.join(paths[id], data_names[id])).names
163
- if names is None:
164
- names = [f"bone_{i}" for i in range(J)]
165
- if self.user_mode:
166
- if self.output_name is not None:
167
- path = self.output_name
168
- else:
169
- path = make_path(self.save_name, 'fbx', trim=True)
170
  else:
171
- path = make_path(self.export_fbx, 'fbx')
172
- exporter._export_fbx(
173
- path=path,
174
- vertices=o_vertices,
175
- joints=joints[id, :J],
176
- skin=skin_resampled,
177
- parents=parents,
178
- names=names,
179
- faces=faces[id, :F],
180
- group_per_vertex=4,
181
- tails=tails[id, :J],
182
- use_extrude_bone=False,
183
- use_connect_unique_child=False,
184
- # do_not_normalize=True,
185
- )
186
- except Exception as e:
187
- print(str(e))
188
-
189
  def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
190
  self._epoch += 1
191
 
 
118
  J = num_bones[id]
119
  F = num_faces[id]
120
  o_vertices = vertices[id, :N]
121
+
122
  _parents = parents_list[id]
123
  parents = []
124
  for i in range(J):
 
156
  raw_data = RawSkin(skin=skin_pred, vertices=sampled_vertices[id], joints=joints[id, :J])
157
  if self.export_npz is not None:
158
  raw_data.save(path=make_path(self.export_npz, 'npz'))
159
+
160
  if self.export_fbx is not None:
161
+ exporter = Exporter()
162
+ names = RawData.load(path=os.path.join(paths[id], data_names[id])).names
163
+ if names is None:
164
+ names = [f"bone_{i}" for i in range(J)]
165
+ if self.user_mode:
166
+ if self.output_name is not None:
167
+ path = self.output_name
 
 
 
168
  else:
169
+ path = make_path(self.save_name, 'fbx', trim=True)
170
+ else:
171
+ path = make_path(self.export_fbx, 'fbx')
172
+ exporter._export_fbx(
173
+ path=path,
174
+ vertices=o_vertices,
175
+ joints=joints[id, :J],
176
+ skin=skin_resampled,
177
+ parents=parents,
178
+ names=names,
179
+ faces=faces[id, :F],
180
+ group_per_vertex=4,
181
+ tails=tails[id, :J],
182
+ use_extrude_bone=False,
183
+ use_connect_unique_child=False,
184
+ # do_not_normalize=True,
185
+ )
186
+
187
  def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
188
  self._epoch += 1
189