Create custom_st.py
Browse files- custom_st.py +221 -0
    	
        custom_st.py
    ADDED
    
    | @@ -0,0 +1,221 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from io import BytesIO
         | 
| 2 | 
            +
            from typing import Any, Dict, Optional, List
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            from sentence_transformers.models import Transformer as BaseTransformer
         | 
| 6 | 
            +
            from transformers import AutoModelForVision2Seq, AutoProcessor
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class MultiModalTransformer(BaseTransformer):
         | 
| 10 | 
            +
                def __init__(
         | 
| 11 | 
            +
                    self,
         | 
| 12 | 
            +
                    model_name_or_path: str,
         | 
| 13 | 
            +
                    cache_dir: Optional[str] = None,
         | 
| 14 | 
            +
                    tokenizer_args: Optional[Dict[str, Any]] = None,
         | 
| 15 | 
            +
                    min_image_tokens: int = 256,
         | 
| 16 | 
            +
                    max_image_tokens: int = 1280,
         | 
| 17 | 
            +
                    max_length: int = 1800,
         | 
| 18 | 
            +
                    **kwargs,
         | 
| 19 | 
            +
                ):
         | 
| 20 | 
            +
                    super().__init__(model_name_or_path, **kwargs)
         | 
| 21 | 
            +
                    if tokenizer_args is None:
         | 
| 22 | 
            +
                        tokenizer_args = {}
         | 
| 23 | 
            +
                    tokenizer_args.pop("trust_remote_code", None)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    # Initialize processor
         | 
| 26 | 
            +
                    min_pixels = min_image_tokens * 28 * 28
         | 
| 27 | 
            +
                    max_pixels = max_image_tokens * 28 * 28
         | 
| 28 | 
            +
                    self.processor = AutoProcessor.from_pretrained(
         | 
| 29 | 
            +
                        model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
         | 
| 30 | 
            +
                    )
         | 
| 31 | 
            +
                    self.processor.tokenizer.padding_side = 'right'
         | 
| 32 | 
            +
                    self.sep = ' '
         | 
| 33 | 
            +
                    self.max_length = max_length
         | 
| 34 | 
            +
                    self.normalize = True
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def _load_model(
         | 
| 37 | 
            +
                        self,
         | 
| 38 | 
            +
                        model_name_or_path: str,
         | 
| 39 | 
            +
                        config,
         | 
| 40 | 
            +
                        cache_dir: str,
         | 
| 41 | 
            +
                        backend: str,
         | 
| 42 | 
            +
                        is_peft_model: bool,
         | 
| 43 | 
            +
                        **model_args,
         | 
| 44 | 
            +
                ) -> None:
         | 
| 45 | 
            +
                    model_args.pop("trust_remote_code", None)
         | 
| 46 | 
            +
                    self.auto_model = AutoModelForVision2Seq.from_pretrained(
         | 
| 47 | 
            +
                        model_name_or_path, torch_dtype=torch.float16, **model_args
         | 
| 48 | 
            +
                    )
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def forward(
         | 
| 51 | 
            +
                    self, features: Dict[str, torch.Tensor], **kwargs
         | 
| 52 | 
            +
                ) -> Dict[str, torch.Tensor]:       
         | 
| 53 | 
            +
                    if features.get("inputs_embeds", None) is None:
         | 
| 54 | 
            +
                        features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
         | 
| 55 | 
            +
                        if features.get("pixel_values", None) is not None:
         | 
| 56 | 
            +
                            features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
         | 
| 57 | 
            +
                            image_embeds = self.auto_model.visual(
         | 
| 58 | 
            +
                                features["pixel_values"], grid_thw=features["image_grid_thw"]
         | 
| 59 | 
            +
                            )
         | 
| 60 | 
            +
                            image_mask = features["input_ids"] == self.auto_model.config.image_token_id
         | 
| 61 | 
            +
                            features["inputs_embeds"][image_mask] = image_embeds
         | 
| 62 | 
            +
                            # features.pop("pixel_values")
         | 
| 63 | 
            +
                            # features.pop("image_grid_thw")
         | 
| 64 | 
            +
                    # features.pop("input_ids")
         | 
| 65 | 
            +
                    inputs = {k: v for k, v in features.items() if k in 'position_ids,attention_mask,inputs_embeds'}
         | 
| 66 | 
            +
                    outputs = self.auto_model.model(
         | 
| 67 | 
            +
                        **inputs,
         | 
| 68 | 
            +
                        return_dict=True,
         | 
| 69 | 
            +
                        output_hidden_states=True,
         | 
| 70 | 
            +
                        # **kwargs
         | 
| 71 | 
            +
                    )
         | 
| 72 | 
            +
                    # pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"]
         | 
| 73 | 
            +
                    # left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0])  # TODO
         | 
| 74 | 
            +
                    # if left_padding:
         | 
| 75 | 
            +
                    #     embeddings = outputs.last_hidden_state
         | 
| 76 | 
            +
                    # else:
         | 
| 77 | 
            +
                    #     sequence_lengths = pooling_mask.sum(dim=1) - 1
         | 
| 78 | 
            +
                    #     embeddings = outputs.last_hidden_state[torch.arange(
         | 
| 79 | 
            +
                    #         outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device
         | 
| 80 | 
            +
                    #     ), sequence_lengths]
         | 
| 81 | 
            +
                    features.update({"token_embeddings": outputs.last_hidden_state})
         | 
| 82 | 
            +
                    return features 
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def tokenize(self, texts: List[List[Dict[str, Any]]] | List[str]) -> Dict[str, torch.Tensor]:
         | 
| 85 | 
            +
                    default_instruction = 'You are a helpful assistant.'
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    all_texts, all_images = list(), list()
         | 
| 88 | 
            +
                    for item in texts:
         | 
| 89 | 
            +
                        if isinstance(item, str):
         | 
| 90 | 
            +
                            txt, img, inst = item, None, default_instruction
         | 
| 91 | 
            +
                        elif isinstance(item, dict):
         | 
| 92 | 
            +
                            txt = item.get('text', None)
         | 
| 93 | 
            +
                            img = item.get('image', None)
         | 
| 94 | 
            +
                            inst = item.get('prompt', default_instruction)
         | 
| 95 | 
            +
                        else:
         | 
| 96 | 
            +
                            raise RuntimeError(f'Input format not supported! {item=}')
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                        input_str = ''
         | 
| 99 | 
            +
                        if img is None:
         | 
| 100 | 
            +
                            all_images = None  # All examples in the same batch are consistent
         | 
| 101 | 
            +
                            # or will have ValueError: Could not make a flat list of images from xxxx
         | 
| 102 | 
            +
                        else:
         | 
| 103 | 
            +
                            input_str += '<|vision_start|><|image_pad|><|vision_end|>'
         | 
| 104 | 
            +
                            img = fetch_image(img)
         | 
| 105 | 
            +
                            all_images.append(img)
         | 
| 106 | 
            +
                        if txt is not None:
         | 
| 107 | 
            +
                            input_str += txt
         | 
| 108 | 
            +
                        msg = f'<|im_start|>system\n{inst}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
         | 
| 109 | 
            +
                        all_texts.append(msg)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    inputs = self.processor(
         | 
| 112 | 
            +
                        text=all_texts,
         | 
| 113 | 
            +
                        images=all_images,
         | 
| 114 | 
            +
                        padding="longest",
         | 
| 115 | 
            +
                        truncation=True,
         | 
| 116 | 
            +
                        max_length=self.max_seq_length,
         | 
| 117 | 
            +
                        return_tensors='pt'
         | 
| 118 | 
            +
                    )
         | 
| 119 | 
            +
                    return inputs
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            ### Copied from qwen_vl_utils.vision_process.py
         | 
| 123 | 
            +
            import base64
         | 
| 124 | 
            +
            from io import BytesIO
         | 
| 125 | 
            +
            import requests
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            IMAGE_FACTOR = 28
         | 
| 128 | 
            +
            MIN_PIXELS = 4 * 28 * 28
         | 
| 129 | 
            +
            MAX_PIXELS = 16384 * 28 * 28
         | 
| 130 | 
            +
            MAX_RATIO = 200
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            def round_by_factor(number: int, factor: int) -> int:
         | 
| 134 | 
            +
                """Returns the closest integer to 'number' that is divisible by 'factor'."""
         | 
| 135 | 
            +
                return round(number / factor) * factor
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            def ceil_by_factor(number: int, factor: int) -> int:
         | 
| 139 | 
            +
                """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
         | 
| 140 | 
            +
                return math.ceil(number / factor) * factor
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def floor_by_factor(number: int, factor: int) -> int:
         | 
| 144 | 
            +
                """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
         | 
| 145 | 
            +
                return math.floor(number / factor) * factor
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def smart_resize(
         | 
| 149 | 
            +
                height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
         | 
| 150 | 
            +
            ) -> tuple[int, int]:
         | 
| 151 | 
            +
                """
         | 
| 152 | 
            +
                Rescales the image so that the following conditions are met:
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                1. Both dimensions (height and width) are divisible by 'factor'.
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                3. The aspect ratio of the image is maintained as closely as possible.
         | 
| 159 | 
            +
                """
         | 
| 160 | 
            +
                h_bar = max(factor, round_by_factor(height, factor))
         | 
| 161 | 
            +
                w_bar = max(factor, round_by_factor(width, factor))
         | 
| 162 | 
            +
                if h_bar * w_bar > max_pixels:
         | 
| 163 | 
            +
                    beta = math.sqrt((height * width) / max_pixels)
         | 
| 164 | 
            +
                    h_bar = floor_by_factor(height / beta, factor)
         | 
| 165 | 
            +
                    w_bar = floor_by_factor(width / beta, factor)
         | 
| 166 | 
            +
                elif h_bar * w_bar < min_pixels:
         | 
| 167 | 
            +
                    beta = math.sqrt(min_pixels / (height * width))
         | 
| 168 | 
            +
                    h_bar = ceil_by_factor(height * beta, factor)
         | 
| 169 | 
            +
                    w_bar = ceil_by_factor(width * beta, factor)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
         | 
| 172 | 
            +
                    logging.warning(
         | 
| 173 | 
            +
                        f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
         | 
| 174 | 
            +
                    )
         | 
| 175 | 
            +
                    if h_bar > w_bar:
         | 
| 176 | 
            +
                        h_bar = w_bar * MAX_RATIO
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        w_bar = h_bar * MAX_RATIO
         | 
| 179 | 
            +
                return h_bar, w_bar
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
         | 
| 183 | 
            +
                image_obj = None
         | 
| 184 | 
            +
                if isinstance(image, Image.Image):
         | 
| 185 | 
            +
                    image_obj = image
         | 
| 186 | 
            +
                elif image.startswith("http://") or image.startswith("https://"):
         | 
| 187 | 
            +
                    image_obj = Image.open(requests.get(image, stream=True).raw)
         | 
| 188 | 
            +
                elif image.startswith("file://"):
         | 
| 189 | 
            +
                    image_obj = Image.open(image[7:])
         | 
| 190 | 
            +
                elif image.startswith("data:image"):
         | 
| 191 | 
            +
                    if "base64," in image:
         | 
| 192 | 
            +
                        _, base64_data = image.split("base64,", 1)
         | 
| 193 | 
            +
                        data = base64.b64decode(base64_data)
         | 
| 194 | 
            +
                        image_obj = Image.open(BytesIO(data))
         | 
| 195 | 
            +
                else:
         | 
| 196 | 
            +
                    image_obj = Image.open(image)
         | 
| 197 | 
            +
                if image_obj is None:
         | 
| 198 | 
            +
                    raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
         | 
| 199 | 
            +
                image = image_obj.convert("RGB")
         | 
| 200 | 
            +
                ## resize
         | 
| 201 | 
            +
                # if "resized_height" in ele and "resized_width" in ele:
         | 
| 202 | 
            +
                #     resized_height, resized_width = smart_resize(
         | 
| 203 | 
            +
                #         ele["resized_height"],
         | 
| 204 | 
            +
                #         ele["resized_width"],
         | 
| 205 | 
            +
                #         factor=size_factor,
         | 
| 206 | 
            +
                #     )
         | 
| 207 | 
            +
                # else:
         | 
| 208 | 
            +
                width, height = image.size
         | 
| 209 | 
            +
                # min_pixels = ele.get("min_pixels", MIN_PIXELS)
         | 
| 210 | 
            +
                # max_pixels = ele.get("max_pixels", MAX_PIXELS)
         | 
| 211 | 
            +
                resized_height, resized_width = smart_resize(
         | 
| 212 | 
            +
                    height,
         | 
| 213 | 
            +
                    width,
         | 
| 214 | 
            +
                    factor=size_factor,
         | 
| 215 | 
            +
                    min_pixels=MIN_PIXELS,
         | 
| 216 | 
            +
                    max_pixels=MAX_PIXELS,
         | 
| 217 | 
            +
                )
         | 
| 218 | 
            +
                image = image.resize((resized_width, resized_height))
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                return image
         | 
| 221 | 
            +
            ###
         | 

