File size: 4,002 Bytes
a381a76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from transformers import PreTrainedModel, PretrainedConfig
from PIL import Image
import numpy as np
from typing import Optional, Union, List
import json

class Pix2TextConfig(PretrainedConfig):
    model_type = "pix2text"
    
    def __init__(

        self,

        vocab_size=30000,

        hidden_size=768,

        num_attention_heads=12,

        num_hidden_layers=12,

        intermediate_size=3072,

        max_position_embeddings=512,

        dropout_prob=0.1,

        **kwargs

    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings
        self.dropout_prob = dropout_prob

class Pix2TextModel(PreTrainedModel):
    config_class = Pix2TextConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # Basit bir CNN encoder
        self.image_encoder = torch.nn.Sequential(
            torch.nn.Conv2d(3, 64, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(64, 128, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(128, 256, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d((8, 8)),
            torch.nn.Flatten(),
            torch.nn.Linear(256 * 8 * 8, config.hidden_size)
        )
        
        # Text decoder
        self.text_decoder = torch.nn.Sequential(
            torch.nn.Linear(config.hidden_size, config.intermediate_size),
            torch.nn.ReLU(),
            torch.nn.Dropout(config.dropout_prob),
            torch.nn.Linear(config.intermediate_size, config.vocab_size)
        )
        
        # Basit tokenizer için vocab
        self.vocab = {str(i): i for i in range(10)}  # 0-9 rakamları
        self.vocab.update({chr(i): i+10 for i in range(ord('a'), ord('z')+1)})  # a-z
        self.vocab.update({chr(i): i+36 for i in range(ord('A'), ord('Z')+1)})  # A-Z
        self.vocab.update({' ': 62, '.': 63, ',': 64, '!': 65, '?': 66})
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        
    def forward(self, pixel_values, labels=None):
        # Görüntüyü encode et
        image_features = self.image_encoder(pixel_values)
        
        # Text'e decode et
        logits = self.text_decoder(image_features)
        
        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
            
        return {"loss": loss, "logits": logits}
    
    def predict(self, image: Union[Image.Image, np.ndarray]) -> str:
        """Görüntüden metin çıkarma fonksiyonu"""
        if isinstance(image, Image.Image):
            image = np.array(image)
            
        # Görüntüyü ön işle
        if len(image.shape) == 3:
            image = torch.tensor(image).permute(2, 0, 1).float() / 255.0
        else:
            image = torch.tensor(image).unsqueeze(0).float() / 255.0
            
        image = image.unsqueeze(0)  # Batch dimension
        
        with torch.no_grad():
            outputs = self.forward(image)
            logits = outputs["logits"]
            
            # En yüksek olasılıklı tokenları seç
            predicted_ids = torch.argmax(logits, dim=-1)
            
            # Token'ları metne çevir
            text = ""
            for token_id in predicted_ids[0][:10]:  # İlk 10 token
                if token_id.item() in self.inv_vocab:
                    text += self.inv_vocab[token_id.item()]
                    
        return text