fix
Browse files- modeling_sa2va_chat.py +2 -2
 
    	
        modeling_sa2va_chat.py
    CHANGED
    
    | 
         @@ -545,7 +545,7 @@ class Sa2VAChatModel(PreTrainedModel): 
     | 
|
| 545 | 
         
             
                    self.gen_config = GenerationConfig(**default_generation_kwargs)
         
     | 
| 546 | 
         
             
                    self.init_prediction_config = True
         
     | 
| 547 | 
         
             
                    self.torch_dtype = torch_dtype
         
     | 
| 548 | 
         
            -
                    self.to(torch_dtype)
         
     | 
| 549 | 
         
             
                    self.extra_image_processor = DirectResize(target_length=1024, )
         
     | 
| 550 | 
         
             
                    # for multi image process
         
     | 
| 551 | 
         
             
                    self.min_dynamic_patch = 1
         
     | 
| 
         @@ -623,7 +623,7 @@ class Sa2VAChatModel(PreTrainedModel): 
     | 
|
| 623 | 
         
             
                            extra_pixel_values = []
         
     | 
| 624 | 
         
             
                            ori_image_size = video[0].size
         
     | 
| 625 | 
         
             
                            for frame_idx, frame_image in enumerate(video):
         
     | 
| 626 | 
         
            -
                                assert ori_image_size == frame_image.size
         
     | 
| 627 | 
         
             
                                g_image = np.array(frame_image)  # for grounding
         
     | 
| 628 | 
         
             
                                g_image = self.extra_image_processor.apply_image(g_image)
         
     | 
| 629 | 
         
             
                                g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         
     | 
| 
         | 
|
| 545 | 
         
             
                    self.gen_config = GenerationConfig(**default_generation_kwargs)
         
     | 
| 546 | 
         
             
                    self.init_prediction_config = True
         
     | 
| 547 | 
         
             
                    self.torch_dtype = torch_dtype
         
     | 
| 548 | 
         
            +
                    # self.to(torch_dtype)
         
     | 
| 549 | 
         
             
                    self.extra_image_processor = DirectResize(target_length=1024, )
         
     | 
| 550 | 
         
             
                    # for multi image process
         
     | 
| 551 | 
         
             
                    self.min_dynamic_patch = 1
         
     | 
| 
         | 
|
| 623 | 
         
             
                            extra_pixel_values = []
         
     | 
| 624 | 
         
             
                            ori_image_size = video[0].size
         
     | 
| 625 | 
         
             
                            for frame_idx, frame_image in enumerate(video):
         
     | 
| 626 | 
         
            +
                                # assert ori_image_size == frame_image.size
         
     | 
| 627 | 
         
             
                                g_image = np.array(frame_image)  # for grounding
         
     | 
| 628 | 
         
             
                                g_image = self.extra_image_processor.apply_image(g_image)
         
     | 
| 629 | 
         
             
                                g_image = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         
     |