Spaces:
Running
Running
| # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team | |
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | |
| # Based on: | |
| # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Optional, Tuple | |
| import torch | |
| from .flash_attention_utils import flash_attention_forward | |
| try: | |
| from transformers.models.qwen2_vl.modeling_qwen2_vl import ( | |
| Qwen2VLAttention, | |
| apply_multimodal_rotary_pos_emb, | |
| repeat_kv, | |
| ) | |
| from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor | |
| except ImportError: | |
| pass | |
| def get_rope_index( | |
| processor: "Qwen2VLProcessor", | |
| input_ids: torch.Tensor, | |
| image_grid_thw: Optional[torch.Tensor] = None, | |
| video_grid_thw: Optional[torch.Tensor] = None, | |
| second_per_grid_ts: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence. | |
| The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. | |
| https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546 | |
| """ | |
| spatial_merge_size = processor.image_processor.merge_size | |
| tokens_per_second = 2 | |
| image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") | |
| video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") | |
| vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") | |
| if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) | |
| image_index, video_index = 0, 0 | |
| input_ids = input_ids[attention_mask == 1] | |
| image_nums, video_nums = 0, 0 | |
| vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) | |
| vision_tokens = input_ids[vision_start_indices + 1] | |
| image_nums = (vision_tokens == image_token_id).sum() | |
| video_nums = (vision_tokens == video_token_id).sum() | |
| input_tokens = input_ids.tolist() | |
| llm_pos_ids_list: list = [] | |
| st = 0 | |
| remain_images, remain_videos = image_nums, video_nums | |
| for _ in range(image_nums + video_nums): | |
| if image_token_id in input_tokens and remain_images > 0: | |
| ed_image = input_tokens.index(image_token_id, st) | |
| else: | |
| ed_image = len(input_tokens) + 1 | |
| if video_token_id in input_tokens and remain_videos > 0: | |
| ed_video = input_tokens.index(video_token_id, st) | |
| else: | |
| ed_video = len(input_tokens) + 1 | |
| if ed_image < ed_video: | |
| t, h, w = ( | |
| image_grid_thw[image_index][0], | |
| image_grid_thw[image_index][1], | |
| image_grid_thw[image_index][2], | |
| ) | |
| second_per_grid_t = 0 | |
| image_index += 1 | |
| remain_images -= 1 | |
| ed = ed_image | |
| else: | |
| t, h, w = ( | |
| video_grid_thw[video_index][0], | |
| video_grid_thw[video_index][1], | |
| video_grid_thw[video_index][2], | |
| ) | |
| if second_per_grid_ts is not None: | |
| second_per_grid_t = second_per_grid_ts[video_index] | |
| else: | |
| second_per_grid_t = 1.0 | |
| video_index += 1 | |
| remain_videos -= 1 | |
| ed = ed_video | |
| llm_grid_t, llm_grid_h, llm_grid_w = ( | |
| t.item(), | |
| h.item() // spatial_merge_size, | |
| w.item() // spatial_merge_size, | |
| ) | |
| text_len = ed - st | |
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | |
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | |
| t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) | |
| t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() | |
| h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() | |
| w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() | |
| llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) | |
| st = ed + llm_grid_t * llm_grid_h * llm_grid_w | |
| if st < len(input_tokens): | |
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | |
| text_len = len(input_tokens) - st | |
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | |
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | |
| position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) | |
| else: | |
| if attention_mask is not None: | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) | |
| else: | |
| position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) | |
| return position_ids | |
| def qwen2_vl_attn_forward( | |
| self: "Qwen2VLAttention", | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 | |
| **kwargs, | |
| ) -> Tuple[torch.Tensor, None, None]: | |
| bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size | |
| query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) | |
| key_states = self.k_proj(hidden_states) | |
| value_states = self.v_proj(hidden_states) | |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| # Because the input can be padded, the absolute sequence length depends on the max position id. | |
| if position_embeddings is None: | |
| cos, sin = self.rotary_emb(value_states, position_ids) | |
| else: | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_multimodal_rotary_pos_emb( | |
| query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] | |
| ) | |
| key_states = repeat_kv(key_states, self.num_key_value_groups) | |
| value_states = repeat_kv(value_states, self.num_key_value_groups) | |
| dropout_rate = 0.0 if not self.training else self.attention_dropout | |
| sliding_window = None | |
| if ( | |
| self.config.use_sliding_window | |
| and getattr(self.config, "sliding_window", None) is not None | |
| and self.layer_idx >= self.config.max_window_layers | |
| ): | |
| sliding_window = self.config.sliding_window | |
| attn_output, _ = flash_attention_forward( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| dropout=dropout_rate, | |
| sliding_window=sliding_window, | |
| position_ids=position_ids, # important: pass position ids | |
| ) # (batch_size, seq_length, num_head / sp_size, head_size) | |
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output, None, None | |