Fix image embedding logic to be mps-compatible
#45
by
DefOs9
- opened
- modeling_phi4mm.py +70 -40
modeling_phi4mm.py
CHANGED
|
@@ -325,7 +325,7 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
| 325 |
bs = img_embeds.shape[0]
|
| 326 |
# Nx(HW)xC
|
| 327 |
if image_attention_mask is not None and len(image_attention_mask) > 0:
|
| 328 |
-
img_features = self.get_img_features(img_embeds.flatten(0, 1), attention_mask=image_attention_mask.
|
| 329 |
else:
|
| 330 |
img_features = self.get_img_features(img_embeds.flatten(0, 1))
|
| 331 |
|
|
@@ -337,13 +337,12 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
| 337 |
|
| 338 |
assert base_feat_height == base_feat_height_target and base_feat_width == base_feat_height_target, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {base_feat_height_target} features for hd transform'
|
| 339 |
|
| 340 |
-
# bs x max_num_crops x (
|
| 341 |
img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
|
| 342 |
C = self.image_dim_out
|
| 343 |
H = base_feat_height
|
| 344 |
|
| 345 |
output_imgs = []
|
| 346 |
-
output_len = []
|
| 347 |
# training is tensor, inference is list
|
| 348 |
if isinstance(img_sizes, torch.Tensor):
|
| 349 |
img_sizes = img_sizes.view(-1, 2)
|
|
@@ -353,39 +352,71 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
| 353 |
w = w // base_resolution
|
| 354 |
B_ = h * w
|
| 355 |
|
| 356 |
-
# 1 x (
|
| 357 |
global_img_feature = img_features[_bs, :1]
|
| 358 |
|
| 359 |
-
# 1 x
|
| 360 |
-
glb_img =
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
-
# 1 x
|
| 364 |
-
glb_img =
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
-
# (max_num_crops-1) x (
|
| 367 |
sub_img = img_features[_bs, 1:]
|
| 368 |
-
# 16x574x1024
|
| 369 |
# get rid of padding sub_img
|
| 370 |
sub_img = sub_img[:B_]
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
if image_attention_mask is not None and len(image_attention_mask) > 0:
|
| 377 |
-
reshaped_image_attention_mask =
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1)
|
| 382 |
-
temp_len = int(image_attention_mask[_bs,:B_+1,0::2,0::2].sum().item()) + (useful_height+1) + base_feat_height//base_feat_height_reduction
|
| 383 |
else:
|
| 384 |
-
temp_sub_GN = self.sub_GN.repeat(1, h*base_feat_height//base_feat_height_reduction, 1, 1)
|
| 385 |
-
temp_len = int((h*w+1)*self.num_img_tokens+ 1 + (h+1)*base_feat_height//base_feat_height_reduction)
|
| 386 |
|
| 387 |
-
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1
|
| 388 |
-
# (1, num_img_tokens, 1024*4)
|
| 389 |
|
| 390 |
# glb + sub
|
| 391 |
if self.hd_transform_order == 'glb_sub':
|
|
@@ -395,17 +426,11 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
| 395 |
else:
|
| 396 |
raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented')
|
| 397 |
|
| 398 |
-
#temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
|
| 399 |
-
assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}'
|
| 400 |
-
output_len.append(temp_len)
|
| 401 |
-
|
| 402 |
-
num_img_tokens = output_len
|
| 403 |
img_set_tensor = []
|
| 404 |
for _output_img in output_imgs:
|
| 405 |
img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype))
|
| 406 |
img_set_tensor.append(img_feature_proj)
|
| 407 |
-
#logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}')
|
| 408 |
-
#assert sum(num_img_tokens) == len(g_values), f'(branch 1) sum(num_img_tokens): {sum(num_img_tokens)}, g_values size: {len(g_values)}, g_values {g_values}'
|
| 409 |
|
| 410 |
else:
|
| 411 |
raise NotImplementedError
|
|
@@ -420,7 +445,7 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
| 420 |
self.get_img_features(img_embeds)
|
| 421 |
.to(target_device)
|
| 422 |
.to(target_dtype)
|
| 423 |
-
.reshape(-1,
|
| 424 |
)
|
| 425 |
if self.use_hd_transform:
|
| 426 |
img_set_tensor = self.img_projection(tt.reshape(-1, self.image_dim_out*self.base_feat_height_reduction**2) * self.glb_GN[0] * self.sub_GN[0, 0])
|
|
@@ -442,14 +467,19 @@ class Phi4MMImageEmbedding(nn.Module):
|
|
| 442 |
# Shape: (merged_N_tokens, C)
|
| 443 |
merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0)
|
| 444 |
merged_img_set_tensor = merged_img_set_tensor.to(hidden_states.dtype).to(hidden_states.device)
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
new_hidden_states =
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
hidden_states = new_hidden_states
|
| 454 |
else:
|
| 455 |
raise NotImplementedError
|
|
@@ -2096,7 +2126,7 @@ class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
|
|
| 2096 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2097 |
|
| 2098 |
if isinstance(input_mode, torch.Tensor):
|
| 2099 |
-
|
| 2100 |
input_mode = input_mode[0].item()
|
| 2101 |
input_mode = InputMode(input_mode)
|
| 2102 |
|
|
|
|
| 325 |
bs = img_embeds.shape[0]
|
| 326 |
# Nx(HW)xC
|
| 327 |
if image_attention_mask is not None and len(image_attention_mask) > 0:
|
| 328 |
+
img_features = self.get_img_features(img_embeds.flatten(0, 1), attention_mask=image_attention_mask.bool().flatten(0,1).to(target_device))
|
| 329 |
else:
|
| 330 |
img_features = self.get_img_features(img_embeds.flatten(0, 1))
|
| 331 |
|
|
|
|
| 337 |
|
| 338 |
assert base_feat_height == base_feat_height_target and base_feat_width == base_feat_height_target, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {base_feat_height_target} features for hd transform'
|
| 339 |
|
| 340 |
+
# bs x max_num_crops x (HxH) x C
|
| 341 |
img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
|
| 342 |
C = self.image_dim_out
|
| 343 |
H = base_feat_height
|
| 344 |
|
| 345 |
output_imgs = []
|
|
|
|
| 346 |
# training is tensor, inference is list
|
| 347 |
if isinstance(img_sizes, torch.Tensor):
|
| 348 |
img_sizes = img_sizes.view(-1, 2)
|
|
|
|
| 352 |
w = w // base_resolution
|
| 353 |
B_ = h * w
|
| 354 |
|
| 355 |
+
# 1 x (HxH) x C
|
| 356 |
global_img_feature = img_features[_bs, :1]
|
| 357 |
|
| 358 |
+
# 1 x H x H x C
|
| 359 |
+
glb_img = (
|
| 360 |
+
global_img_feature
|
| 361 |
+
.reshape(1, H, H, C)
|
| 362 |
+
.reshape(1, H // base_feat_height_reduction, base_feat_height_reduction,
|
| 363 |
+
H // base_feat_height_reduction, base_feat_height_reduction, C)
|
| 364 |
+
.contiguous()
|
| 365 |
+
.permute(0, 1, 3, 2, 4, 5)
|
| 366 |
+
.reshape(1, H // base_feat_height_reduction, H // base_feat_height_reduction,
|
| 367 |
+
base_feat_height_reduction * base_feat_height_reduction * C)
|
| 368 |
+
.contiguous()
|
| 369 |
+
)
|
| 370 |
+
temp_glb_GN = self.sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1)
|
| 371 |
|
| 372 |
+
# 1 x (HxH+H) x C
|
| 373 |
+
glb_img = (
|
| 374 |
+
torch.cat([glb_img, temp_glb_GN], dim=2)
|
| 375 |
+
.reshape(1, -1, base_feat_height_reduction * base_feat_height_reduction * C)
|
| 376 |
+
)
|
| 377 |
|
| 378 |
+
# (max_num_crops-1) x (HxH) x C
|
| 379 |
sub_img = img_features[_bs, 1:]
|
|
|
|
| 380 |
# get rid of padding sub_img
|
| 381 |
sub_img = sub_img[:B_]
|
| 382 |
|
| 383 |
+
sub_img = (
|
| 384 |
+
sub_img
|
| 385 |
+
.reshape(B_, H, H, C)
|
| 386 |
+
.reshape(B_, H // base_feat_height_reduction, base_feat_height_reduction,
|
| 387 |
+
H // base_feat_height_reduction, base_feat_height_reduction, C)
|
| 388 |
+
.contiguous()
|
| 389 |
+
.permute(0, 1, 3, 2, 4, 5)
|
| 390 |
+
.reshape(B_, -1, base_feat_height_reduction * base_feat_height_reduction * C)
|
| 391 |
+
.contiguous()
|
| 392 |
+
)
|
| 393 |
+
sub_img = (
|
| 394 |
+
sub_img
|
| 395 |
+
.reshape(1, h, w, base_feat_height // base_feat_height_reduction,
|
| 396 |
+
base_feat_width // base_feat_height_reduction, -1)
|
| 397 |
+
.permute(0, 1, 3, 2, 4, 5)
|
| 398 |
+
.reshape(1, h * base_feat_height // base_feat_height_reduction,
|
| 399 |
+
w * base_feat_width // base_feat_height_reduction,
|
| 400 |
+
base_feat_height_reduction * base_feat_height_reduction * C)
|
| 401 |
+
)
|
| 402 |
|
| 403 |
if image_attention_mask is not None and len(image_attention_mask) > 0:
|
| 404 |
+
reshaped_image_attention_mask = (
|
| 405 |
+
image_attention_mask[_bs, 1:B_ + 1, 0::2, 0::2]
|
| 406 |
+
.reshape(1, h, w, base_feat_height // base_feat_height_reduction,
|
| 407 |
+
base_feat_width // base_feat_height_reduction)
|
| 408 |
+
.permute(0, 1, 3, 2, 4)
|
| 409 |
+
.reshape(1, h * base_feat_height // base_feat_height_reduction,
|
| 410 |
+
w * base_feat_width // base_feat_height_reduction)
|
| 411 |
+
)
|
| 412 |
+
useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item())
|
| 413 |
+
useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item())
|
| 414 |
+
sub_img = sub_img[:, :useful_height, :useful_width]
|
| 415 |
temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1)
|
|
|
|
| 416 |
else:
|
| 417 |
+
temp_sub_GN = self.sub_GN.repeat(1, h * base_feat_height // base_feat_height_reduction, 1, 1)
|
|
|
|
| 418 |
|
| 419 |
+
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1, -1, base_feat_height_reduction * base_feat_height_reduction * C)
|
|
|
|
| 420 |
|
| 421 |
# glb + sub
|
| 422 |
if self.hd_transform_order == 'glb_sub':
|
|
|
|
| 426 |
else:
|
| 427 |
raise NotImplementedError(f'hd_transform_order = {self.hd_transform_order}, not implemented')
|
| 428 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
img_set_tensor = []
|
| 430 |
for _output_img in output_imgs:
|
| 431 |
img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype))
|
| 432 |
img_set_tensor.append(img_feature_proj)
|
| 433 |
+
# logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}')
|
|
|
|
| 434 |
|
| 435 |
else:
|
| 436 |
raise NotImplementedError
|
|
|
|
| 445 |
self.get_img_features(img_embeds)
|
| 446 |
.to(target_device)
|
| 447 |
.to(target_dtype)
|
| 448 |
+
.reshape(-1, self.image_dim_out)
|
| 449 |
)
|
| 450 |
if self.use_hd_transform:
|
| 451 |
img_set_tensor = self.img_projection(tt.reshape(-1, self.image_dim_out*self.base_feat_height_reduction**2) * self.glb_GN[0] * self.sub_GN[0, 0])
|
|
|
|
| 467 |
# Shape: (merged_N_tokens, C)
|
| 468 |
merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0)
|
| 469 |
merged_img_set_tensor = merged_img_set_tensor.to(hidden_states.dtype).to(hidden_states.device)
|
| 470 |
+
if hidden_states.device.type == "mps":
|
| 471 |
+
# For MPS, assign using direct indexing to avoid index_put issues.
|
| 472 |
+
new_hidden_states = hidden_states.clone()
|
| 473 |
+
new_hidden_states[positions_tuple] = merged_img_set_tensor
|
| 474 |
+
else:
|
| 475 |
+
# Temporarily disable autocast to avoid issue on bf16 tensors
|
| 476 |
+
# Ref: https://github.com/pytorch/pytorch/issues/132715
|
| 477 |
+
with torch.autocast(device_type=hidden_states.device.type, enabled=False):
|
| 478 |
+
new_hidden_states = hidden_states.index_put(
|
| 479 |
+
indices=positions_tuple,
|
| 480 |
+
values=merged_img_set_tensor,
|
| 481 |
+
accumulate=False
|
| 482 |
+
)
|
| 483 |
hidden_states = new_hidden_states
|
| 484 |
else:
|
| 485 |
raise NotImplementedError
|
|
|
|
| 2126 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2127 |
|
| 2128 |
if isinstance(input_mode, torch.Tensor):
|
| 2129 |
+
assert len(input_mode) == 1
|
| 2130 |
input_mode = input_mode[0].item()
|
| 2131 |
input_mode = InputMode(input_mode)
|
| 2132 |
|