|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Processor class for Centurio.
|
| """
|
| import timm
|
| import torch
|
| import transformers
|
| from tokenizers import AddedToken
|
| from torchvision.transforms import InterpolationMode, Compose, Resize, ToTensor, Normalize
|
| from transformers import BaseImageProcessor, AutoTokenizer, AutoProcessor, AutoImageProcessor
|
| from typing import List, Union, Optional
|
|
|
| from transformers.feature_extraction_utils import BatchFeature
|
| from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
| from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
| from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| from transformers.utils import logging
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
| class CenturioTimmImageProcessor(BaseImageProcessor):
|
| r"""
|
|
|
| """
|
| model_input_names = ["pixel_values"]
|
|
|
| def __init__(
|
| self,
|
| timm_model="vit_so400m_patch14_siglip_384",
|
| tiling=1,
|
| **kwargs,
|
| ) -> None:
|
| config = timm.get_pretrained_cfg(timm_model)
|
| input_size = config.input_size[1]
|
| self.timm_model = timm_model
|
| self.interpolation = config.interpolation
|
| self.mean = config.mean
|
| self.std = config.std
|
| self.tiling = tiling
|
| self.input_size = (input_size, input_size)
|
|
|
|
|
| def __call__(
|
| self,
|
| images: ImageInput,
|
| **kwargs
|
| ):
|
| return self.preprocess(images, **kwargs)
|
|
|
|
|
| def preprocess(
|
| self,
|
| images: ImageInput,
|
| **kwargs
|
| ):
|
| transform = Compose([
|
| Resize(self.input_size, interpolation=InterpolationMode(self.interpolation)),
|
| ToTensor(),
|
| Normalize(mean=self.mean, std=self.std)
|
| ])
|
| if self.tiling > 1:
|
|
|
| self.input_size_large = (self.input_size[0] * self.tiling, self.input_size[0] * self.tiling)
|
| transform_large = Compose([
|
| Resize(self.input_size_large, interpolation=InterpolationMode(self.interpolation)),
|
| ToTensor(),
|
| Normalize(mean=self.mean, std=self.std)
|
| ])
|
|
|
| processed_images = []
|
| if not isinstance(images, list):
|
| images = [images]
|
| for image_pil in images:
|
| image = transform(image_pil)
|
| if self.tiling > 1:
|
| image_large = transform_large(image_pil)
|
| h, w = self.input_size
|
| img_large_split = [image_large[:, i * h:(i + 1) * h, j * w:(j + 1) * w] for i in range(self.tiling) for
|
| j in range(self.tiling)]
|
| processed_images.extend([image] + img_large_split)
|
| else:
|
| processed_images.append(image)
|
| processed_images = torch.stack(processed_images, dim=0)
|
| return BatchFeature(
|
| data={"pixel_values": processed_images}
|
| )
|
|
|
| AutoImageProcessor.register("CenturioTimmImageProcessor", CenturioTimmImageProcessor)
|
|
|
| transformers.CenturioTimmImageProcessor = CenturioTimmImageProcessor
|
|
|
| class CenturioProcessor(ProcessorMixin):
|
| attributes = ["image_processor", "tokenizer"]
|
| optional_attributes = ["chat_template"]
|
| image_processor_class = "CenturioTimmImageProcessor"
|
| tokenizer_class = ("AutoTokenizer")
|
| image_token="<image_placeholder>"
|
|
|
| def __init__(
|
| self,
|
| image_processor=None,
|
| tokenizer=None,
|
| tiling=1,
|
| **kwargs,
|
| ):
|
|
|
|
|
|
|
|
|
|
|
|
|
| self.image_processor = image_processor
|
| self.tokenizer = tokenizer
|
|
|
|
|
|
|
| def __call__(
|
| self,
|
| images: ImageInput = None,
|
| text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| **kwargs,
|
| ) -> BatchFeature:
|
| """
|
| """
|
| if images is None and text is None:
|
| raise ValueError("You have to specify at least one of `images` or `text`.")
|
|
|
|
|
| images, text = _validate_images_text_input_order(images, text)
|
|
|
| if images is not None:
|
| image_inputs = self.image_processor(images)
|
| else:
|
| image_inputs = {}
|
|
|
| if isinstance(text, str):
|
| text = [text]
|
| elif not isinstance(text, list) and not isinstance(text[0], str):
|
| raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
|
|
| prompt_strings = text
|
|
|
| text_inputs = self.tokenizer(prompt_strings, **kwargs)
|
| return BatchFeature(data={**text_inputs, **image_inputs})
|
|
|
|
|
| def batch_decode(self, *args, **kwargs):
|
| """
|
| This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| refer to the docstring of this method for more information.
|
| """
|
| return self.tokenizer.batch_decode(*args, **kwargs)
|
|
|
|
|
| def decode(self, *args, **kwargs):
|
| """
|
| This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| the docstring of this method for more information.
|
| """
|
| return self.tokenizer.decode(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| pass |