Spaces:
Sleeping
Sleeping
File size: 5,358 Bytes
6b35b80 4e6f28b 6b35b80 bd9cf85 6b35b80 4e6f28b 6b35b80 4e6f28b 6b35b80 4e6f28b 6b35b80 4e6f28b eecacbe 4e6f28b eecacbe 4e6f28b eecacbe 4e6f28b bd9cf85 4e6f28b 6b35b80 4e6f28b eecacbe 4e6f28b eecacbe 4e6f28b bd9cf85 4e6f28b bd9cf85 6e6d27b bd9cf85 6e6d27b bd9cf85 6e6d27b bd9cf85 6e6d27b 4e6f28b bd9cf85 4e6f28b eecacbe 6e6d27b 4e6f28b eecacbe 6e6d27b eecacbe 6e6d27b eecacbe 6e6d27b eecacbe 6e6d27b bd9cf85 6e6d27b 59c62d8 6e6d27b bd9cf85 59c62d8 6e6d27b bd9cf85 6e6d27b 4e6f28b bd9cf85 4e6f28b 6e6d27b bd9cf85 6e6d27b bd9cf85 6e6d27b 4e6f28b 216f4a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os
import torch
import gradio as gr
from train import CharTokenizer, Seq2Seq, Encoder, Decoder, TransformerTransliterator
# ----------------------
# 1️⃣ Load LSTM checkpoint
# ----------------------
lstm_ckpt_path = "lstm_transliterator.pt"
lstm_ckpt = torch.load(lstm_ckpt_path, map_location='cpu')
src_vocab = lstm_ckpt['src_vocab']
tgt_vocab = lstm_ckpt['tgt_vocab']
src_tokenizer = CharTokenizer(vocab=src_vocab)
tgt_tokenizer = CharTokenizer(vocab=tgt_vocab)
# Reconstruct LSTM model architecture
EMBED_DIM = 256
ENC_HIDDEN_DIM = 256
DEC_HIDDEN_DIM = 256
NUM_LAYERS_MODEL = 2
DROPOUT = 0.3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = Encoder(len(src_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT)
decoder = Decoder(len(tgt_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, DEC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT)
lstm_model = Seq2Seq(encoder, decoder, device=device).to(device)
lstm_model.load_state_dict(lstm_ckpt['model_state_dict'])
lstm_model.eval()
print("✅ LSTM model loaded")
# ----------------------
# 2️⃣ Load Transformer checkpoint
# ----------------------
transformer_ckpt_path = "transformer_transliterator.pt"
transformer_ckpt = torch.load(transformer_ckpt_path, map_location='cpu')
transformer_model = TransformerTransliterator(
src_vocab_size=len(src_tokenizer),
tgt_vocab_size=len(tgt_tokenizer),
d_model=256,
nhead=8,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=512,
dropout=0.1,
max_len=100
).to(device)
transformer_model.load_state_dict(transformer_ckpt['model_state_dict'])
transformer_model.eval()
print("✅ Transformer model loaded")
# ----------------------
# 3️⃣ Load TinyLLaMA
# ----------------------
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
llm_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name, trust_remote_code=True)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
trust_remote_code=True,
)
if device != "cuda":
llm_model = llm_model.to(device)
llm_model.eval()
print("✅ TinyLLaMA model loaded")
has_llm = True
except Exception as e:
print(f"⚠️ TinyLLaMA loading failed: {e}")
print("⚠️ Will use only LSTM and Transformer models")
has_llm = False
# ----------------------
# 4️⃣ Transliteration Function
# ----------------------
@torch.no_grad()
def transliterate(word):
word = word.strip()
if not word:
return "❌ Empty input", "❌ Empty input", "❌ Empty input"
try:
# LSTM prediction
lstm_pred = lstm_model.translate(word, src_tokenizer, tgt_tokenizer)
except Exception as e:
lstm_pred = f"Error: {str(e)[:50]}"
try:
# Transformer prediction (greedy)
transformer_pred = transformer_model.translate(
word, src_tokenizer, tgt_tokenizer,
device=device, decoding="greedy"
)
except Exception as e:
transformer_pred = f"Error: {str(e)[:50]}"
# TinyLLaMA prediction
if has_llm:
try:
prompt = f"Transliterate the following English word to Hindi (Devanagari script).\nEnglish word: {word}\nHindi transliteration:"
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output_ids = llm_model.generate(
**inputs,
max_new_tokens=30,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=llm_tokenizer.eos_token_id,
eos_token_id=llm_tokenizer.eos_token_id,
)
generated = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
llm_pred = generated.replace(prompt, "").strip()
llm_pred = ''.join(c for c in llm_pred if not (c.isascii() and c.isalpha()) and c.strip())
except Exception as e:
llm_pred = f"Error: {str(e)[:50]}"
else:
llm_pred = "TinyLLaMA model not loaded (insufficient memory)"
return lstm_pred, transformer_pred, llm_pred
# ----------------------
# 5️⃣ Gradio Interface
# ----------------------
demo = gr.Interface(
fn=transliterate,
inputs=gr.Textbox(
label="Input Hindi Roman Word",
placeholder="e.g., namaste, dhanyavaad, bharat",
lines=1
),
outputs=[
gr.Textbox(label="LSTM Prediction", interactive=False),
gr.Textbox(label="Transformer Prediction", interactive=False),
gr.Textbox(label="TinyLLaMA Prediction", interactive=False)
],
title="Hindi Roman to Devanagari Transliteration",
description="Compare three models: LSTM, Transformer, and TinyLLaMA.\nEnter a Hindi Roman word to get transliteration predictions.",
examples=[
["namaste"],
["dhanyavaad"],
["bharat"],
["mumbai"],
["hindustan"],
["pranaam"]
],
allow_flagging="never"
)
demo.launch(
share=False,
debug=False,
server_name="0.0.0.0",
server_port=7860
) |