File size: 5,358 Bytes
6b35b80
4e6f28b
 
 
 
6b35b80
bd9cf85
6b35b80
 
 
4e6f28b
6b35b80
 
4e6f28b
6b35b80
 
4e6f28b
6b35b80
 
 
 
4e6f28b
 
 
eecacbe
 
4e6f28b
 
eecacbe
4e6f28b
 
 
eecacbe
 
4e6f28b
bd9cf85
4e6f28b
6b35b80
 
 
 
 
 
 
 
 
 
4e6f28b
 
 
eecacbe
4e6f28b
 
 
eecacbe
 
4e6f28b
bd9cf85
4e6f28b
bd9cf85
6e6d27b
 
bd9cf85
 
 
 
 
 
 
 
 
 
6e6d27b
bd9cf85
6e6d27b
 
bd9cf85
 
6e6d27b
4e6f28b
 
bd9cf85
4e6f28b
eecacbe
6e6d27b
4e6f28b
eecacbe
 
6e6d27b
eecacbe
 
6e6d27b
 
eecacbe
6e6d27b
eecacbe
 
6e6d27b
 
 
 
 
 
 
 
bd9cf85
6e6d27b
 
59c62d8
6e6d27b
bd9cf85
 
 
 
 
 
 
 
 
 
 
 
 
 
59c62d8
6e6d27b
 
 
bd9cf85
6e6d27b
 
4e6f28b
 
bd9cf85
4e6f28b
6e6d27b
 
 
 
 
 
 
 
 
 
bd9cf85
6e6d27b
 
bd9cf85
6e6d27b
 
 
 
 
 
 
 
 
 
4e6f28b
216f4a0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import torch
import gradio as gr
from train import CharTokenizer, Seq2Seq, Encoder, Decoder, TransformerTransliterator

# ----------------------
# 1️⃣ Load LSTM checkpoint
# ----------------------
lstm_ckpt_path = "lstm_transliterator.pt"
lstm_ckpt = torch.load(lstm_ckpt_path, map_location='cpu')

src_vocab = lstm_ckpt['src_vocab']
tgt_vocab = lstm_ckpt['tgt_vocab']

src_tokenizer = CharTokenizer(vocab=src_vocab)
tgt_tokenizer = CharTokenizer(vocab=tgt_vocab)

# Reconstruct LSTM model architecture
EMBED_DIM = 256
ENC_HIDDEN_DIM = 256
DEC_HIDDEN_DIM = 256
NUM_LAYERS_MODEL = 2
DROPOUT = 0.3

device = 'cuda' if torch.cuda.is_available() else 'cpu'

encoder = Encoder(len(src_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT)
decoder = Decoder(len(tgt_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, DEC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT)
lstm_model = Seq2Seq(encoder, decoder, device=device).to(device)
lstm_model.load_state_dict(lstm_ckpt['model_state_dict'])
lstm_model.eval()

print("✅ LSTM model loaded")

# ----------------------
# 2️⃣ Load Transformer checkpoint
# ----------------------
transformer_ckpt_path = "transformer_transliterator.pt"
transformer_ckpt = torch.load(transformer_ckpt_path, map_location='cpu')

transformer_model = TransformerTransliterator(
    src_vocab_size=len(src_tokenizer),
    tgt_vocab_size=len(tgt_tokenizer),
    d_model=256,
    nhead=8,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=512,
    dropout=0.1,
    max_len=100
).to(device)
transformer_model.load_state_dict(transformer_ckpt['model_state_dict'])
transformer_model.eval()

print("✅ Transformer model loaded")

# ----------------------
# 3️⃣ Load TinyLLaMA
# ----------------------
from transformers import AutoTokenizer, AutoModelForCausalLM

try:
    llm_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name, trust_remote_code=True)
    llm_model = AutoModelForCausalLM.from_pretrained(
        llm_model_name,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        device_map="auto" if device == "cuda" else None,
        trust_remote_code=True,
    )
    if device != "cuda":
        llm_model = llm_model.to(device)
    llm_model.eval()
    print("✅ TinyLLaMA model loaded")
    has_llm = True
except Exception as e:
    print(f"⚠️ TinyLLaMA loading failed: {e}")
    print("⚠️ Will use only LSTM and Transformer models")
    has_llm = False

# ----------------------
# 4️⃣ Transliteration Function
# ----------------------
@torch.no_grad()
def transliterate(word):
    word = word.strip()
    
    if not word:
        return "❌ Empty input", "❌ Empty input", "❌ Empty input"
    
    try:
        # LSTM prediction
        lstm_pred = lstm_model.translate(word, src_tokenizer, tgt_tokenizer)
    except Exception as e:
        lstm_pred = f"Error: {str(e)[:50]}"
    
    try:
        # Transformer prediction (greedy)
        transformer_pred = transformer_model.translate(
            word, src_tokenizer, tgt_tokenizer, 
            device=device, decoding="greedy"
        )
    except Exception as e:
        transformer_pred = f"Error: {str(e)[:50]}"
    
    # TinyLLaMA prediction
    if has_llm:
        try:
            prompt = f"Transliterate the following English word to Hindi (Devanagari script).\nEnglish word: {word}\nHindi transliteration:"
            inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
            
            with torch.no_grad():
                output_ids = llm_model.generate(
                    **inputs,
                    max_new_tokens=30,
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True,
                    pad_token_id=llm_tokenizer.eos_token_id,
                    eos_token_id=llm_tokenizer.eos_token_id,
                )
            
            generated = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
            llm_pred = generated.replace(prompt, "").strip()
            llm_pred = ''.join(c for c in llm_pred if not (c.isascii() and c.isalpha()) and c.strip())
        except Exception as e:
            llm_pred = f"Error: {str(e)[:50]}"
    else:
        llm_pred = "TinyLLaMA model not loaded (insufficient memory)"
    
    return lstm_pred, transformer_pred, llm_pred

# ----------------------
# 5️⃣ Gradio Interface
# ----------------------
demo = gr.Interface(
    fn=transliterate,
    inputs=gr.Textbox(
        label="Input Hindi Roman Word",
        placeholder="e.g., namaste, dhanyavaad, bharat",
        lines=1
    ),
    outputs=[
        gr.Textbox(label="LSTM Prediction", interactive=False),
        gr.Textbox(label="Transformer Prediction", interactive=False),
        gr.Textbox(label="TinyLLaMA Prediction", interactive=False)
    ],
    title="Hindi Roman to Devanagari Transliteration",
    description="Compare three models: LSTM, Transformer, and TinyLLaMA.\nEnter a Hindi Roman word to get transliteration predictions.",
    examples=[
        ["namaste"],
        ["dhanyavaad"],
        ["bharat"],
        ["mumbai"],
        ["hindustan"],
        ["pranaam"]
    ],
    allow_flagging="never"
)

demo.launch(
    share=False,
    debug=False,
    server_name="0.0.0.0",
    server_port=7860
)