Commit
·
d45e0b2
1
Parent(s):
1ea9622
add video
Browse files- modeling_sa2va_chat.py +56 -23
- sam2.py +2 -2
modeling_sa2va_chat.py
CHANGED
|
@@ -485,6 +485,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 485 |
objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1)
|
| 486 |
vp_embeds.append(tile_vit_embeds[objects_prompt_masks])
|
| 487 |
i_vp_img += 1
|
|
|
|
| 488 |
vp_embeds = torch.cat(vp_embeds, dim=0)
|
| 489 |
else:
|
| 490 |
vp_embeds = None
|
|
@@ -583,6 +584,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 583 |
def predict_forward(
|
| 584 |
self,
|
| 585 |
image=None,
|
|
|
|
| 586 |
text=None,
|
| 587 |
past_text='',
|
| 588 |
mask_prompts=None,
|
|
@@ -593,29 +595,57 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 593 |
self.preparing_for_generation(tokenizer=tokenizer)
|
| 594 |
|
| 595 |
input_dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
|
| 597 |
-
|
|
|
|
|
|
|
| 598 |
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
|
|
|
|
|
|
|
|
|
| 604 |
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
|
| 615 |
-
|
| 616 |
-
|
|
|
|
|
|
|
|
|
|
| 617 |
input_dict['pixel_values'] = pixel_values
|
| 618 |
-
|
| 619 |
|
| 620 |
if mask_prompts is not None:
|
| 621 |
# reshape mask prompts to feature size
|
|
@@ -627,7 +657,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 627 |
mode='nearest').squeeze(0) for item in mask_prompts]
|
| 628 |
region_pixels = []
|
| 629 |
for mask_prompt in mask_prompts[0]:
|
| 630 |
-
region_pixels.append(mask_prompt.to(torch.int64).sum())
|
| 631 |
|
| 632 |
vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
|
| 633 |
for i in range(len(mask_prompts[0])):
|
|
@@ -645,6 +675,9 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 645 |
image_token_str = f'{self.IMG_START_TOKEN}' \
|
| 646 |
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
| 647 |
f'{self.IMG_END_TOKEN}'
|
|
|
|
|
|
|
|
|
|
| 648 |
|
| 649 |
ret_masks = []
|
| 650 |
|
|
@@ -695,16 +728,14 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 695 |
|
| 696 |
for seg_hidden_states in all_seg_hidden_states:
|
| 697 |
seg_hidden_states = seg_hidden_states.unsqueeze(0)
|
| 698 |
-
g_pixel_values =
|
| 699 |
-
self.grounding_encoder.preprocess_image(pixel, dtype=self.torch_dtype)
|
| 700 |
-
for pixel in [input_dict['g_pixel_values']]])
|
| 701 |
sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values)
|
| 702 |
-
pred_masks = self.grounding_encoder.
|
| 703 |
w, h = ori_image_size
|
| 704 |
masks = F.interpolate(pred_masks, size=(h, w), mode='bilinear', align_corners=False)
|
| 705 |
masks = masks[:, 0]
|
| 706 |
masks = masks.sigmoid() > 0.5
|
| 707 |
-
masks = masks.
|
| 708 |
ret_masks.append(masks)
|
| 709 |
|
| 710 |
return {'prediction': predict, 'prediction_masks': ret_masks,}
|
|
@@ -712,6 +743,8 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 712 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
| 713 |
seg_mask = output_ids == seg_id
|
| 714 |
n_out = len(seg_mask)
|
|
|
|
|
|
|
| 715 |
return hidden_states[-n_out:][seg_mask]
|
| 716 |
|
| 717 |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
|
|
|
| 485 |
objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1)
|
| 486 |
vp_embeds.append(tile_vit_embeds[objects_prompt_masks])
|
| 487 |
i_vp_img += 1
|
| 488 |
+
|
| 489 |
vp_embeds = torch.cat(vp_embeds, dim=0)
|
| 490 |
else:
|
| 491 |
vp_embeds = None
|
|
|
|
| 584 |
def predict_forward(
|
| 585 |
self,
|
| 586 |
image=None,
|
| 587 |
+
video=None,
|
| 588 |
text=None,
|
| 589 |
past_text='',
|
| 590 |
mask_prompts=None,
|
|
|
|
| 595 |
self.preparing_for_generation(tokenizer=tokenizer)
|
| 596 |
|
| 597 |
input_dict = {}
|
| 598 |
+
if video is not None:
|
| 599 |
+
pixel_values = []
|
| 600 |
+
extra_pixel_values = []
|
| 601 |
+
ori_image_size = video[0].size
|
| 602 |
+
for frame_idx, frame_image in enumerate(video):
|
| 603 |
+
assert ori_image_size == frame_image.size
|
| 604 |
+
g_image = np.array(frame_image) # for grounding
|
| 605 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
| 606 |
+
g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
| 607 |
+
extra_pixel_values.append(g_image)
|
| 608 |
+
if frame_idx < 5:
|
| 609 |
+
img = self.transformer(frame_image)
|
| 610 |
+
pixel_values.append(img)
|
| 611 |
+
|
| 612 |
+
pixel_values = torch.stack(pixel_values, dim=0).to(self.torch_dtype) # (n_f, 3, h, w)
|
| 613 |
+
g_pixel_values = torch.stack([
|
| 614 |
+
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
| 615 |
+
]).to(self.torch_dtype)
|
| 616 |
+
num_image_tokens = self.patch_token
|
| 617 |
+
num_frames = 5
|
| 618 |
|
| 619 |
+
input_dict['vp_overall_mask'] = None
|
| 620 |
+
else:
|
| 621 |
+
ori_image_size = image.size
|
| 622 |
|
| 623 |
+
# prepare grounding images
|
| 624 |
+
g_image = np.array(image) # for grounding
|
| 625 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
| 626 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous().to(self.torch_dtype)
|
| 627 |
+
extra_pixel_values = [g_pixel_values]
|
| 628 |
+
g_pixel_values = torch.stack([
|
| 629 |
+
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
| 630 |
+
]).to(self.torch_dtype)
|
| 631 |
|
| 632 |
+
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
| 633 |
+
self.max_dynamic_patch,
|
| 634 |
+
self.image_size, self.use_thumbnail)
|
| 635 |
|
| 636 |
+
if mask_prompts is not None:
|
| 637 |
+
vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
|
| 638 |
+
input_dict['vp_overall_mask'] = vp_overall_mask
|
| 639 |
+
else:
|
| 640 |
+
input_dict['vp_overall_mask'] = None
|
| 641 |
|
| 642 |
+
pixel_values = [self.transformer(image) for image in images]
|
| 643 |
+
pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
|
| 644 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
| 645 |
+
num_frames = 1
|
| 646 |
+
input_dict['g_pixel_values'] = g_pixel_values
|
| 647 |
input_dict['pixel_values'] = pixel_values
|
| 648 |
+
|
| 649 |
|
| 650 |
if mask_prompts is not None:
|
| 651 |
# reshape mask prompts to feature size
|
|
|
|
| 657 |
mode='nearest').squeeze(0) for item in mask_prompts]
|
| 658 |
region_pixels = []
|
| 659 |
for mask_prompt in mask_prompts[0]:
|
| 660 |
+
region_pixels.append(mask_prompt.bool().to(torch.int64).sum())
|
| 661 |
|
| 662 |
vp_token_str = '\nThere are {} part regions in the picture: '.format(len(mask_prompts[0]))
|
| 663 |
for i in range(len(mask_prompts[0])):
|
|
|
|
| 675 |
image_token_str = f'{self.IMG_START_TOKEN}' \
|
| 676 |
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
| 677 |
f'{self.IMG_END_TOKEN}'
|
| 678 |
+
image_token_str = image_token_str + '\n'
|
| 679 |
+
image_token_str = image_token_str * num_frames
|
| 680 |
+
image_token_str = image_token_str.strip()
|
| 681 |
|
| 682 |
ret_masks = []
|
| 683 |
|
|
|
|
| 728 |
|
| 729 |
for seg_hidden_states in all_seg_hidden_states:
|
| 730 |
seg_hidden_states = seg_hidden_states.unsqueeze(0)
|
| 731 |
+
g_pixel_values = input_dict['g_pixel_values']
|
|
|
|
|
|
|
| 732 |
sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values)
|
| 733 |
+
pred_masks = self.grounding_encoder.language_embd_inference(sam_states, [seg_hidden_states] * num_frames)
|
| 734 |
w, h = ori_image_size
|
| 735 |
masks = F.interpolate(pred_masks, size=(h, w), mode='bilinear', align_corners=False)
|
| 736 |
masks = masks[:, 0]
|
| 737 |
masks = masks.sigmoid() > 0.5
|
| 738 |
+
masks = masks.cpu().numpy()
|
| 739 |
ret_masks.append(masks)
|
| 740 |
|
| 741 |
return {'prediction': predict, 'prediction_masks': ret_masks,}
|
|
|
|
| 743 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
| 744 |
seg_mask = output_ids == seg_id
|
| 745 |
n_out = len(seg_mask)
|
| 746 |
+
if n_out == 0:
|
| 747 |
+
return hidden_states[0:0]
|
| 748 |
return hidden_states[-n_out:][seg_mask]
|
| 749 |
|
| 750 |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
sam2.py
CHANGED
|
@@ -623,8 +623,8 @@ class CXBlock(nn.Module):
|
|
| 623 |
x = self.pwconv1(x)
|
| 624 |
x = self.act(x)
|
| 625 |
x = self.pwconv2(x)
|
| 626 |
-
if self.
|
| 627 |
-
x = self.
|
| 628 |
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 629 |
|
| 630 |
x = input + self.drop_path(x)
|
|
|
|
| 623 |
x = self.pwconv1(x)
|
| 624 |
x = self.act(x)
|
| 625 |
x = self.pwconv2(x)
|
| 626 |
+
if self.g_weight is not None:
|
| 627 |
+
x = self.g_weight * x
|
| 628 |
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 629 |
|
| 630 |
x = input + self.drop_path(x)
|