translit-demo / app.py
SmitaGautam's picture
Update app.py
59c62d8 verified
import os
import torch
import gradio as gr
from train import CharTokenizer, Seq2Seq, Encoder, Decoder, TransformerTransliterator
# ----------------------
# 1️⃣ Load LSTM checkpoint
# ----------------------
lstm_ckpt_path = "lstm_transliterator.pt"
lstm_ckpt = torch.load(lstm_ckpt_path, map_location='cpu')
src_vocab = lstm_ckpt['src_vocab']
tgt_vocab = lstm_ckpt['tgt_vocab']
src_tokenizer = CharTokenizer(vocab=src_vocab)
tgt_tokenizer = CharTokenizer(vocab=tgt_vocab)
# Reconstruct LSTM model architecture
EMBED_DIM = 256
ENC_HIDDEN_DIM = 256
DEC_HIDDEN_DIM = 256
NUM_LAYERS_MODEL = 2
DROPOUT = 0.3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = Encoder(len(src_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT)
decoder = Decoder(len(tgt_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, DEC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT)
lstm_model = Seq2Seq(encoder, decoder, device=device).to(device)
lstm_model.load_state_dict(lstm_ckpt['model_state_dict'])
lstm_model.eval()
print("✅ LSTM model loaded")
# ----------------------
# 2️⃣ Load Transformer checkpoint
# ----------------------
transformer_ckpt_path = "transformer_transliterator.pt"
transformer_ckpt = torch.load(transformer_ckpt_path, map_location='cpu')
transformer_model = TransformerTransliterator(
src_vocab_size=len(src_tokenizer),
tgt_vocab_size=len(tgt_tokenizer),
d_model=256,
nhead=8,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=512,
dropout=0.1,
max_len=100
).to(device)
transformer_model.load_state_dict(transformer_ckpt['model_state_dict'])
transformer_model.eval()
print("✅ Transformer model loaded")
# ----------------------
# 3️⃣ Load TinyLLaMA
# ----------------------
from transformers import AutoTokenizer, AutoModelForCausalLM
try:
llm_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name, trust_remote_code=True)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
trust_remote_code=True,
)
if device != "cuda":
llm_model = llm_model.to(device)
llm_model.eval()
print("✅ TinyLLaMA model loaded")
has_llm = True
except Exception as e:
print(f"⚠️ TinyLLaMA loading failed: {e}")
print("⚠️ Will use only LSTM and Transformer models")
has_llm = False
# ----------------------
# 4️⃣ Transliteration Function
# ----------------------
@torch.no_grad()
def transliterate(word):
word = word.strip()
if not word:
return "❌ Empty input", "❌ Empty input", "❌ Empty input"
try:
# LSTM prediction
lstm_pred = lstm_model.translate(word, src_tokenizer, tgt_tokenizer)
except Exception as e:
lstm_pred = f"Error: {str(e)[:50]}"
try:
# Transformer prediction (greedy)
transformer_pred = transformer_model.translate(
word, src_tokenizer, tgt_tokenizer,
device=device, decoding="greedy"
)
except Exception as e:
transformer_pred = f"Error: {str(e)[:50]}"
# TinyLLaMA prediction
if has_llm:
try:
prompt = f"Transliterate the following English word to Hindi (Devanagari script).\nEnglish word: {word}\nHindi transliteration:"
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output_ids = llm_model.generate(
**inputs,
max_new_tokens=30,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=llm_tokenizer.eos_token_id,
eos_token_id=llm_tokenizer.eos_token_id,
)
generated = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
llm_pred = generated.replace(prompt, "").strip()
llm_pred = ''.join(c for c in llm_pred if not (c.isascii() and c.isalpha()) and c.strip())
except Exception as e:
llm_pred = f"Error: {str(e)[:50]}"
else:
llm_pred = "TinyLLaMA model not loaded (insufficient memory)"
return lstm_pred, transformer_pred, llm_pred
# ----------------------
# 5️⃣ Gradio Interface
# ----------------------
demo = gr.Interface(
fn=transliterate,
inputs=gr.Textbox(
label="Input Hindi Roman Word",
placeholder="e.g., namaste, dhanyavaad, bharat",
lines=1
),
outputs=[
gr.Textbox(label="LSTM Prediction", interactive=False),
gr.Textbox(label="Transformer Prediction", interactive=False),
gr.Textbox(label="TinyLLaMA Prediction", interactive=False)
],
title="Hindi Roman to Devanagari Transliteration",
description="Compare three models: LSTM, Transformer, and TinyLLaMA.\nEnter a Hindi Roman word to get transliteration predictions.",
examples=[
["namaste"],
["dhanyavaad"],
["bharat"],
["mumbai"],
["hindustan"],
["pranaam"]
],
allow_flagging="never"
)
demo.launch(
share=False,
debug=False,
server_name="0.0.0.0",
server_port=7860
)