from networkx import to_numpy_array import numpy as np import torch from PIL import Image import math from functools import partial, reduce from transformers.image_transforms import ( convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, ) from transformers.processing_utils import ImagesKwargs from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from transformers.image_utils import ImageInput, ChannelDimension, PILImageResampling, to_numpy_array from einops import rearrange class LlavaUHDV3ImageProcessor(BaseImageProcessor): model_input_names = ["pixel_values", "grid_hws"] def __init__( self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(400, 400), crop_size = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST, scale_resolution=1580, patch_size=10, any_res=True, allow_upscale=True, **kwargs, ): super().__init__(**kwargs) crop_size = crop_size if crop_size is not None else {"height": 400, "width": 400} crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") self.image_mean = image_mean self.image_std = image_std self.size = size self.resample = resample self.rescale_factor = rescale_factor self.data_format = data_format self.crop_size = crop_size self.scale_resolution = scale_resolution self.patch_size = patch_size self.any_res = any_res self.allow_upscale = allow_upscale def preprocess(self, images, max_resolution=None, upscale_rate=1.4, return_tensors = 'pt', **kwargs) -> BatchFeature: if max_resolution is not None: scale_resolution = max_resolution else: scale_resolution = self.scale_resolution if images is not None: pixel_values, grid_hws = [], [] for image in images if isinstance(images, list) else [images]: image = self._preprocess(image, scale_resolution, self.patch_size, self.any_res, self.allow_upscale, upscale_rate=upscale_rate) if not torch.is_tensor(image): image = torch.tensor(image) _,H,W = image.shape grid_h = int(H // self.patch_size) grid_w = int(W // self.patch_size) grid_hw = (grid_h, grid_w) patches = rearrange(image, "c (h p1) (w p2) -> h w c p1 p2", h=grid_h, w=grid_w) patches = rearrange(patches, "h w c p1 p2 -> (h w) c p1 p2") pixel_values.append(patches) grid_hws.append(grid_hw) pixel_values = torch.concat(pixel_values, dim=0) grid_hws = torch.tensor(grid_hws) data = { "pixel_values": pixel_values, "grid_hws": grid_hws } return BatchFeature(data=data, tensor_type=return_tensors) def _preprocess(self, image, scale_resolution=1580, patch_size=10, any_res=True, allow_upscale=True, upscale_rate=1.4): original_size = image.size soft_patch_size = patch_size * 8 best_size = self.find_best_resize( original_size, scale_resolution, soft_patch_size, allow_upscale=allow_upscale, upscale_rate=upscale_rate, any_res=any_res ) source_image = image.resize(best_size, Image.Resampling.BICUBIC) source_image = [source_image] transforms = [ convert_to_rgb, to_numpy_array ] transforms.append(partial(rescale, scale=self.rescale_factor, data_format=self.data_format)) transforms.append(partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format)) transforms.append(partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format)) image = reduce(lambda x, f: [*map(f, x)], transforms, source_image) return image[0] if len(image) == 1 else image def ensure_divide(self, length, patch_size): return max(math.floor(length / patch_size) * patch_size, patch_size) def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False, upscale_rate=1.4, any_res=False): width, height = original_size max_edge = 5120 scale_resolution_low = 512 if any_res: if allow_upscale: width *= upscale_rate height *= upscale_rate scale_resolution_low = 560 r = width / height if (width * height > scale_resolution * scale_resolution): height = int(scale_resolution / math.sqrt(r)) width = int(height * r) if (width * height < scale_resolution_low * scale_resolution_low): height = int(scale_resolution_low / math.sqrt(r)) width = int(height * r) if max(width, height) > max_edge: scale = max_edge / max(width, height) width = int(width * scale) height = int(height * scale) else: if (width * height > scale_resolution * scale_resolution) or allow_upscale: r = width / height # width=672 height=448 r= 1.5 height = int(scale_resolution / math.sqrt(r)) # scale_resolution=336 / r**0.5 274.3428511917 width = int(height * r) # 411.5142767876 best_width = self.ensure_divide(width, patch_size) best_height = self.ensure_divide(height, patch_size) best_width = min(best_width, max_edge) best_height = min(best_height, max_edge) return (best_width, best_height) __all__ = ["LlavaUHDV3ImageProcessor"]