Added ONNX code, and fixed examples for saving to avoid confusion. (#20)
Browse files- Added ONNX code, and fixed examples for saving to avoid confusion. (d69aea178cf98e849c0ebf38a6c224997f15b120)
    	
        README.md
    CHANGED
    
    | @@ -59,6 +59,7 @@ from docling_core.types.doc import DoclingDocument | |
| 59 | 
             
            from docling_core.types.doc.document import DocTagsDocument
         | 
| 60 | 
             
            from transformers import AutoProcessor, AutoModelForVision2Seq
         | 
| 61 | 
             
            from transformers.image_utils import load_image
         | 
|  | |
| 62 |  | 
| 63 | 
             
            DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 64 |  | 
| @@ -107,7 +108,8 @@ doc.load_from_doctags(doctags_doc) | |
| 107 |  | 
| 108 | 
             
            # export as any format
         | 
| 109 | 
             
            # HTML
         | 
| 110 | 
            -
            #  | 
|  | |
| 111 | 
             
            # MD
         | 
| 112 | 
             
            print(doc.export_to_markdown())
         | 
| 113 | 
             
            ```
         | 
| @@ -129,6 +131,7 @@ from vllm import LLM, SamplingParams | |
| 129 | 
             
            from PIL import Image
         | 
| 130 | 
             
            from docling_core.types.doc import DoclingDocument
         | 
| 131 | 
             
            from docling_core.types.doc.document import DocTagsDocument
         | 
|  | |
| 132 |  | 
| 133 | 
             
            # Configuration
         | 
| 134 | 
             
            MODEL_PATH = "ds4sd/SmolDocling-256M-preview"
         | 
| @@ -175,15 +178,145 @@ for idx, img_file in enumerate(image_files, 1): | |
| 175 | 
             
                doc.load_from_doctags(doctags_doc)
         | 
| 176 | 
             
                # export as any format
         | 
| 177 | 
             
                # HTML
         | 
| 178 | 
            -
                #  | 
|  | |
| 179 | 
             
                # MD
         | 
| 180 | 
            -
                 | 
| 181 | 
            -
                output_path_md = os.path.join(OUTPUT_DIR, output_filename_md)
         | 
| 182 | 
             
                doc.save_as_markdown(output_path_md)
         | 
| 183 | 
            -
             | 
| 184 | 
             
            print(f"Total time: {time.time() - start_time:.2f} sec")
         | 
| 185 | 
             
            ```
         | 
| 186 | 
             
            </details>
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 187 |  | 
| 188 | 
             
            💻 Local inference on Apple Silicon with MLX: [see here](https://huggingface.co/ds4sd/SmolDocling-256M-preview-mlx-bf16)
         | 
| 189 |  | 
|  | |
| 59 | 
             
            from docling_core.types.doc.document import DocTagsDocument
         | 
| 60 | 
             
            from transformers import AutoProcessor, AutoModelForVision2Seq
         | 
| 61 | 
             
            from transformers.image_utils import load_image
         | 
| 62 | 
            +
            from pathlib import Path
         | 
| 63 |  | 
| 64 | 
             
            DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 65 |  | 
|  | |
| 108 |  | 
| 109 | 
             
            # export as any format
         | 
| 110 | 
             
            # HTML
         | 
| 111 | 
            +
            # output_path_html = Path("Out/") / "example.html"
         | 
| 112 | 
            +
            # doc.save_as_html(output_filoutput_path_htmle_path)
         | 
| 113 | 
             
            # MD
         | 
| 114 | 
             
            print(doc.export_to_markdown())
         | 
| 115 | 
             
            ```
         | 
|  | |
| 131 | 
             
            from PIL import Image
         | 
| 132 | 
             
            from docling_core.types.doc import DoclingDocument
         | 
| 133 | 
             
            from docling_core.types.doc.document import DocTagsDocument
         | 
| 134 | 
            +
            from pathlib import Path
         | 
| 135 |  | 
| 136 | 
             
            # Configuration
         | 
| 137 | 
             
            MODEL_PATH = "ds4sd/SmolDocling-256M-preview"
         | 
|  | |
| 178 | 
             
                doc.load_from_doctags(doctags_doc)
         | 
| 179 | 
             
                # export as any format
         | 
| 180 | 
             
                # HTML
         | 
| 181 | 
            +
                # output_path_html = Path(OUTPUT_DIR) / f"{img_fn}.html"
         | 
| 182 | 
            +
                # doc.save_as_html(output_path_html)
         | 
| 183 | 
             
                # MD
         | 
| 184 | 
            +
                output_path_md = Path(OUTPUT_DIR) / f"{img_fn}.md"
         | 
|  | |
| 185 | 
             
                doc.save_as_markdown(output_path_md)
         | 
|  | |
| 186 | 
             
            print(f"Total time: {time.time() - start_time:.2f} sec")
         | 
| 187 | 
             
            ```
         | 
| 188 | 
             
            </details>
         | 
| 189 | 
            +
            <details>
         | 
| 190 | 
            +
            <summary> ONNX Inference</summary>
         | 
| 191 | 
            +
             | 
| 192 | 
            +
            ```python
         | 
| 193 | 
            +
            # Prerequisites:
         | 
| 194 | 
            +
            # pip install onnxruntime
         | 
| 195 | 
            +
            # pip install onnxruntime-gpu
         | 
| 196 | 
            +
            from transformers import AutoConfig, AutoProcessor
         | 
| 197 | 
            +
            from transformers.image_utils import load_image
         | 
| 198 | 
            +
            import onnxruntime
         | 
| 199 | 
            +
            import numpy as np
         | 
| 200 | 
            +
            import os
         | 
| 201 | 
            +
            from docling_core.types.doc import DoclingDocument
         | 
| 202 | 
            +
            from docling_core.types.doc.document import DocTagsDocument
         | 
| 203 | 
            +
             | 
| 204 | 
            +
            os.environ["OMP_NUM_THREADS"] = "1"
         | 
| 205 | 
            +
            # cuda
         | 
| 206 | 
            +
            os.environ["ORT_CUDA_USE_MAX_WORKSPACE"] = "1"
         | 
| 207 | 
            +
             | 
| 208 | 
            +
            # 1. Load models
         | 
| 209 | 
            +
            ## Load config and processor
         | 
| 210 | 
            +
            model_id = "ds4sd/SmolDocling-256M-preview"
         | 
| 211 | 
            +
            config = AutoConfig.from_pretrained(model_id)
         | 
| 212 | 
            +
            processor = AutoProcessor.from_pretrained(model_id)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            ## Load sessions
         | 
| 215 | 
            +
            # !wget https://huggingface.co/ds4sd/SmolDocling-256M-preview/resolve/main/onnx/vision_encoder.onnx
         | 
| 216 | 
            +
            # !wget https://huggingface.co/ds4sd/SmolDocling-256M-preview/resolve/main/onnx/embed_tokens.onnx
         | 
| 217 | 
            +
            # !wget https://huggingface.co/ds4sd/SmolDocling-256M-preview/resolve/main/onnx/decoder_model_merged.onnx
         | 
| 218 | 
            +
            # cpu
         | 
| 219 | 
            +
            # vision_session = onnxruntime.InferenceSession("vision_encoder.onnx")
         | 
| 220 | 
            +
            # embed_session = onnxruntime.InferenceSession("embed_tokens.onnx")
         | 
| 221 | 
            +
            # decoder_session = onnxruntime.InferenceSession("decoder_model_merged.onnx"
         | 
| 222 | 
            +
             | 
| 223 | 
            +
            # cuda
         | 
| 224 | 
            +
            vision_session = onnxruntime.InferenceSession("vision_encoder.onnx", providers=["CUDAExecutionProvider"])
         | 
| 225 | 
            +
            embed_session = onnxruntime.InferenceSession("embed_tokens.onnx", providers=["CUDAExecutionProvider"])
         | 
| 226 | 
            +
            decoder_session = onnxruntime.InferenceSession("decoder_model_merged.onnx", providers=["CUDAExecutionProvider"])
         | 
| 227 | 
            +
             | 
| 228 | 
            +
            ## Set config values
         | 
| 229 | 
            +
            num_key_value_heads = config.text_config.num_key_value_heads
         | 
| 230 | 
            +
            head_dim = config.text_config.head_dim
         | 
| 231 | 
            +
            num_hidden_layers = config.text_config.num_hidden_layers
         | 
| 232 | 
            +
            eos_token_id = config.text_config.eos_token_id
         | 
| 233 | 
            +
            image_token_id = config.image_token_id
         | 
| 234 | 
            +
            end_of_utterance_id = processor.tokenizer.convert_tokens_to_ids("<end_of_utterance>")
         | 
| 235 | 
            +
             | 
| 236 | 
            +
            # 2. Prepare inputs
         | 
| 237 | 
            +
            ## Create input messages
         | 
| 238 | 
            +
            messages = [
         | 
| 239 | 
            +
                {
         | 
| 240 | 
            +
                    "role": "user",
         | 
| 241 | 
            +
                    "content": [
         | 
| 242 | 
            +
                        {"type": "image"},
         | 
| 243 | 
            +
                        {"type": "text", "text": "Convert this page to docling."}
         | 
| 244 | 
            +
                    ]
         | 
| 245 | 
            +
                },
         | 
| 246 | 
            +
            ]
         | 
| 247 | 
            +
             | 
| 248 | 
            +
            ## Load image and apply processor
         | 
| 249 | 
            +
            image = load_image("https://ibm.biz/docling-page-with-table")
         | 
| 250 | 
            +
            prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
         | 
| 251 | 
            +
            inputs = processor(text=prompt, images=[image], return_tensors="np")
         | 
| 252 | 
            +
             | 
| 253 | 
            +
            ## Prepare decoder inputs
         | 
| 254 | 
            +
            batch_size = inputs['input_ids'].shape[0]
         | 
| 255 | 
            +
            past_key_values = {
         | 
| 256 | 
            +
                f'past_key_values.{layer}.{kv}': np.zeros([batch_size, num_key_value_heads, 0, head_dim], dtype=np.float32)
         | 
| 257 | 
            +
                for layer in range(num_hidden_layers)
         | 
| 258 | 
            +
                for kv in ('key', 'value')
         | 
| 259 | 
            +
            }
         | 
| 260 | 
            +
            image_features = None
         | 
| 261 | 
            +
            input_ids = inputs['input_ids']
         | 
| 262 | 
            +
            attention_mask = inputs['attention_mask']
         | 
| 263 | 
            +
            position_ids = np.cumsum(inputs['attention_mask'], axis=-1)
         | 
| 264 | 
            +
             | 
| 265 | 
            +
             | 
| 266 | 
            +
            # 3. Generation loop
         | 
| 267 | 
            +
            max_new_tokens = 8192
         | 
| 268 | 
            +
            generated_tokens = np.array([[]], dtype=np.int64)
         | 
| 269 | 
            +
            for i in range(max_new_tokens):
         | 
| 270 | 
            +
              inputs_embeds = embed_session.run(None, {'input_ids': input_ids})[0]
         | 
| 271 | 
            +
             | 
| 272 | 
            +
              if image_features is None:
         | 
| 273 | 
            +
                ## Only compute vision features if not already computed
         | 
| 274 | 
            +
                image_features = vision_session.run(
         | 
| 275 | 
            +
                    ['image_features'],  # List of output names or indices
         | 
| 276 | 
            +
                    {
         | 
| 277 | 
            +
                        'pixel_values': inputs['pixel_values'],
         | 
| 278 | 
            +
                        'pixel_attention_mask': inputs['pixel_attention_mask'].astype(np.bool_)
         | 
| 279 | 
            +
                    }
         | 
| 280 | 
            +
                )[0]
         | 
| 281 | 
            +
                
         | 
| 282 | 
            +
                ## Merge text and vision embeddings
         | 
| 283 | 
            +
                inputs_embeds[inputs['input_ids'] == image_token_id] = image_features.reshape(-1, image_features.shape[-1])
         | 
| 284 | 
            +
             | 
| 285 | 
            +
              logits, *present_key_values = decoder_session.run(None, dict(
         | 
| 286 | 
            +
                  inputs_embeds=inputs_embeds,
         | 
| 287 | 
            +
                  attention_mask=attention_mask,
         | 
| 288 | 
            +
                  position_ids=position_ids,
         | 
| 289 | 
            +
                  **past_key_values,
         | 
| 290 | 
            +
              ))
         | 
| 291 | 
            +
             | 
| 292 | 
            +
              ## Update values for next generation loop
         | 
| 293 | 
            +
              input_ids = logits[:, -1].argmax(-1, keepdims=True)
         | 
| 294 | 
            +
              attention_mask = np.ones_like(input_ids)
         | 
| 295 | 
            +
              position_ids = position_ids[:, -1:] + 1
         | 
| 296 | 
            +
              for j, key in enumerate(past_key_values):
         | 
| 297 | 
            +
                past_key_values[key] = present_key_values[j]
         | 
| 298 | 
            +
             | 
| 299 | 
            +
              generated_tokens = np.concatenate([generated_tokens, input_ids], axis=-1)
         | 
| 300 | 
            +
              if (input_ids == eos_token_id).all() or (input_ids == end_of_utterance_id).all():
         | 
| 301 | 
            +
                break  # Stop predicting
         | 
| 302 | 
            +
             | 
| 303 | 
            +
            doctags = processor.batch_decode(
         | 
| 304 | 
            +
                generated_tokens,
         | 
| 305 | 
            +
                skip_special_tokens=False,
         | 
| 306 | 
            +
            )[0].lstrip()
         | 
| 307 | 
            +
             | 
| 308 | 
            +
            print(doctags)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
            doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
         | 
| 311 | 
            +
            print(doctags)
         | 
| 312 | 
            +
            # create a docling document
         | 
| 313 | 
            +
            doc = DoclingDocument(name="Document")
         | 
| 314 | 
            +
            doc.load_from_doctags(doctags_doc)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
            print(doc.export_to_markdown())
         | 
| 317 | 
            +
            ```
         | 
| 318 | 
            +
            </details>
         | 
| 319 | 
            +
             | 
| 320 |  | 
| 321 | 
             
            💻 Local inference on Apple Silicon with MLX: [see here](https://huggingface.co/ds4sd/SmolDocling-256M-preview-mlx-bf16)
         | 
| 322 |  | 

