File size: 5,948 Bytes
21a6d06
 
4984d4a
21a6d06
b3797cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21a6d06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3797cd
 
21a6d06
b3797cd
 
 
 
 
 
 
 
 
 
 
21a6d06
b3797cd
21a6d06
 
 
 
 
b3797cd
 
 
 
 
21a6d06
 
 
 
 
b3797cd
 
 
 
21a6d06
b3797cd
4984d4a
21a6d06
b3797cd
 
 
 
 
4984d4a
 
b3797cd
 
21a6d06
b3797cd
 
 
 
 
 
 
 
 
 
21a6d06
b3797cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4984d4a
b3797cd
 
 
 
 
 
4984d4a
b3797cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
from typing import List, Dict, Tuple
import numpy as np

def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

# Read the Telugu text file and train BPE
def train_bpe(vocab_size: int = 350):
    # Read the preprocessed Telugu text
    with open('telugu_preprocessed_file.txt', 'r', encoding='utf-8') as f:
        text = f.read()
    
    # Convert initial text to bytes
    tokens = list(text.encode('utf-8'))
    
    # Train merges
    num_merges = vocab_size - 256
    ids = list(tokens)
    merges = {}
    
    for i in range(num_merges):
        stats = get_stats(ids)
        if not stats:  # If no more pairs to merge
            break
        pair = max(stats, key=stats.get)
        idx = 256 + i
        print(f"merging {pair} into a new token {idx}")  # Optional: for monitoring training
        ids = merge(ids, pair, idx)
        merges[pair] = idx
    
    return merges

# Train the tokenizer
merges = train_bpe()

class OptimizedBPETokenizer:
    def __init__(self, merges: Dict[Tuple[int, int], int]):
        self.merges = merges
        self.idx_to_pair = {idx: pair for pair, idx in merges.items()}
        
        # Create lookup table for faster encoding
        self.merge_lookup = {}
        for (first, second), idx in merges.items():
            if first not in self.merge_lookup:
                self.merge_lookup[first] = {}
            self.merge_lookup[first][second] = idx
    
    def encode(self, text: str, chunk_size: int = 1000000) -> List[int]:
        if not isinstance(text, str):
            return []
        
        # Convert to regular integers instead of numpy types
        ids = [int(x) for x in text.encode('utf-8')]
        
        # Apply merges
        while True:
            stats = get_stats(ids)
            if not stats:
                break
            pair = max(stats, key=stats.get)
            if pair not in self.merges:
                break
            ids = merge(ids, pair, self.merges[pair])
        
        return ids
    
    def decode(self, ids: List[int]) -> str:
        result = []
        for token in ids:
            if token < 256:
                result.append(token)
            else:
                # Expand merged tokens
                pair = self.idx_to_pair[token]
                result.extend(self._expand_token(pair[0]))
                result.extend(self._expand_token(pair[1]))
        return bytes(result).decode('utf-8')
    
    def _expand_token(self, token: int) -> List[int]:
        if token < 256:
            return [token]
        pair = self.idx_to_pair[token]
        result = []
        result.extend(self._expand_token(pair[0]))
        result.extend(self._expand_token(pair[1]))
        return result

# Initialize tokenizer
tokenizer = OptimizedBPETokenizer(merges)

def encode_text(text: str) -> str:
    """Function to handle encoding"""
    if not text:
        return "Please enter text to encode"
    try:
        tokens = tokenizer.encode(text)
        return f"Encoded tokens: {tokens}\nToken count: {len(tokens)}"
    except Exception as e:
        return f"Encoding error: {str(e)}"

def decode_tokens(text: str) -> str:
    """Function to handle decoding"""
    if not text:
        return "Please enter tokens to decode"
    try:
        tokens = [int(x) for x in text.strip('[]').split(',')]
        decoded_text = tokenizer.decode(tokens)
        return f"Decoded text: {decoded_text}"
    except Exception as e:
        return f"Error: Please provide valid integers for decoding. Details: {str(e)}"

# Create the Gradio interface
with gr.Blocks(title="Telugu BPE Tokenizer") as iface:
    gr.Markdown("# Telugu BPE Tokenizer")
    gr.Markdown("A byte-pair encoding tokenizer trained on Telugu text.")
    
    with gr.Row():
        # Encoding Section
        with gr.Column():
            gr.Markdown("### Encode Text")
            input_text = gr.Textbox(
                label="Input Text", 
                placeholder="Enter Telugu text to encode..."
            )
            encode_button = gr.Button("Encode")
            encode_output = gr.Textbox(label="Encoding Result")
            
        # Decoding Section
        with gr.Column():
            gr.Markdown("### Decode Tokens")
            input_tokens = gr.Textbox(
                label="Input Tokens", 
                placeholder="Enter comma-separated tokens (e.g., 256,257,258)"
            )
            decode_button = gr.Button("Decode")
            decode_output = gr.Textbox(label="Decoding Result")
    
    # Set up the button click events
    encode_button.click(
        fn=encode_text,
        inputs=input_text,
        outputs=encode_output
    )
    
    decode_button.click(
        fn=decode_tokens,
        inputs=input_tokens,
        outputs=decode_output
    )
    
    # Add examples
    with gr.Row():
        with gr.Column():
            gr.Examples(
                examples=[
                    ["నమస్కారం"],
                    ["తెలుగు భాష"],
                ],
                inputs=input_text,
                outputs=encode_output,
                fn=encode_text,
                label="Encoding Examples"
            )
        
        with gr.Column():
            gr.Examples(
                examples=[
                    ["256,257,258"],  # Example tokens
                ],
                inputs=input_tokens,
                outputs=decode_output,
                fn=decode_tokens,
                label="Decoding Examples"
            )

if __name__ == "__main__":
    iface.launch()