Spaces:
Runtime error
Runtime error
import gradio as gr | |
from typing import List, Dict, Tuple | |
import numpy as np | |
def get_stats(ids): | |
counts = {} | |
for pair in zip(ids, ids[1:]): | |
counts[pair] = counts.get(pair, 0) + 1 | |
return counts | |
def merge(ids, pair, idx): | |
newids = [] | |
i = 0 | |
while i < len(ids): | |
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: | |
newids.append(idx) | |
i += 2 | |
else: | |
newids.append(ids[i]) | |
i += 1 | |
return newids | |
# Read the Telugu text file and train BPE | |
def train_bpe(vocab_size: int = 350): | |
# Read the preprocessed Telugu text | |
with open('telugu_preprocessed_file.txt', 'r', encoding='utf-8') as f: | |
text = f.read() | |
# Convert initial text to bytes | |
tokens = list(text.encode('utf-8')) | |
# Train merges | |
num_merges = vocab_size - 256 | |
ids = list(tokens) | |
merges = {} | |
for i in range(num_merges): | |
stats = get_stats(ids) | |
if not stats: # If no more pairs to merge | |
break | |
pair = max(stats, key=stats.get) | |
idx = 256 + i | |
print(f"merging {pair} into a new token {idx}") # Optional: for monitoring training | |
ids = merge(ids, pair, idx) | |
merges[pair] = idx | |
return merges | |
# Train the tokenizer | |
merges = train_bpe() | |
class OptimizedBPETokenizer: | |
def __init__(self, merges: Dict[Tuple[int, int], int]): | |
self.merges = merges | |
self.idx_to_pair = {idx: pair for pair, idx in merges.items()} | |
# Create lookup table for faster encoding | |
self.merge_lookup = {} | |
for (first, second), idx in merges.items(): | |
if first not in self.merge_lookup: | |
self.merge_lookup[first] = {} | |
self.merge_lookup[first][second] = idx | |
def encode(self, text: str, chunk_size: int = 1000000) -> List[int]: | |
if not isinstance(text, str): | |
return [] | |
# Convert to regular integers instead of numpy types | |
ids = [int(x) for x in text.encode('utf-8')] | |
# Apply merges | |
while True: | |
stats = get_stats(ids) | |
if not stats: | |
break | |
pair = max(stats, key=stats.get) | |
if pair not in self.merges: | |
break | |
ids = merge(ids, pair, self.merges[pair]) | |
return ids | |
def decode(self, ids: List[int]) -> str: | |
result = [] | |
for token in ids: | |
if token < 256: | |
result.append(token) | |
else: | |
# Expand merged tokens | |
pair = self.idx_to_pair[token] | |
result.extend(self._expand_token(pair[0])) | |
result.extend(self._expand_token(pair[1])) | |
return bytes(result).decode('utf-8') | |
def _expand_token(self, token: int) -> List[int]: | |
if token < 256: | |
return [token] | |
pair = self.idx_to_pair[token] | |
result = [] | |
result.extend(self._expand_token(pair[0])) | |
result.extend(self._expand_token(pair[1])) | |
return result | |
# Initialize tokenizer | |
tokenizer = OptimizedBPETokenizer(merges) | |
def encode_text(text: str) -> str: | |
"""Function to handle encoding""" | |
if not text: | |
return "Please enter text to encode" | |
try: | |
tokens = tokenizer.encode(text) | |
return f"Encoded tokens: {tokens}\nToken count: {len(tokens)}" | |
except Exception as e: | |
return f"Encoding error: {str(e)}" | |
def decode_tokens(text: str) -> str: | |
"""Function to handle decoding""" | |
if not text: | |
return "Please enter tokens to decode" | |
try: | |
tokens = [int(x) for x in text.strip('[]').split(',')] | |
decoded_text = tokenizer.decode(tokens) | |
return f"Decoded text: {decoded_text}" | |
except Exception as e: | |
return f"Error: Please provide valid integers for decoding. Details: {str(e)}" | |
# Create the Gradio interface | |
with gr.Blocks(title="Telugu BPE Tokenizer") as iface: | |
gr.Markdown("# Telugu BPE Tokenizer") | |
gr.Markdown("A byte-pair encoding tokenizer trained on Telugu text.") | |
with gr.Row(): | |
# Encoding Section | |
with gr.Column(): | |
gr.Markdown("### Encode Text") | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Enter Telugu text to encode..." | |
) | |
encode_button = gr.Button("Encode") | |
encode_output = gr.Textbox(label="Encoding Result") | |
# Decoding Section | |
with gr.Column(): | |
gr.Markdown("### Decode Tokens") | |
input_tokens = gr.Textbox( | |
label="Input Tokens", | |
placeholder="Enter comma-separated tokens (e.g., 256,257,258)" | |
) | |
decode_button = gr.Button("Decode") | |
decode_output = gr.Textbox(label="Decoding Result") | |
# Set up the button click events | |
encode_button.click( | |
fn=encode_text, | |
inputs=input_text, | |
outputs=encode_output | |
) | |
decode_button.click( | |
fn=decode_tokens, | |
inputs=input_tokens, | |
outputs=decode_output | |
) | |
# Add examples | |
with gr.Row(): | |
with gr.Column(): | |
gr.Examples( | |
examples=[ | |
["నమస్కారం"], | |
["తెలుగు భాష"], | |
], | |
inputs=input_text, | |
outputs=encode_output, | |
fn=encode_text, | |
label="Encoding Examples" | |
) | |
with gr.Column(): | |
gr.Examples( | |
examples=[ | |
["256,257,258"], # Example tokens | |
], | |
inputs=input_tokens, | |
outputs=decode_output, | |
fn=decode_tokens, | |
label="Decoding Examples" | |
) | |
if __name__ == "__main__": | |
iface.launch() |