Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| from train import CharTokenizer, Seq2Seq, Encoder, Decoder, TransformerTransliterator | |
| # ---------------------- | |
| # 1️⃣ Load LSTM checkpoint | |
| # ---------------------- | |
| lstm_ckpt_path = "lstm_transliterator.pt" | |
| lstm_ckpt = torch.load(lstm_ckpt_path, map_location='cpu') | |
| src_vocab = lstm_ckpt['src_vocab'] | |
| tgt_vocab = lstm_ckpt['tgt_vocab'] | |
| src_tokenizer = CharTokenizer(vocab=src_vocab) | |
| tgt_tokenizer = CharTokenizer(vocab=tgt_vocab) | |
| # Reconstruct LSTM model architecture | |
| EMBED_DIM = 256 | |
| ENC_HIDDEN_DIM = 256 | |
| DEC_HIDDEN_DIM = 256 | |
| NUM_LAYERS_MODEL = 2 | |
| DROPOUT = 0.3 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| encoder = Encoder(len(src_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT) | |
| decoder = Decoder(len(tgt_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, DEC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT) | |
| lstm_model = Seq2Seq(encoder, decoder, device=device).to(device) | |
| lstm_model.load_state_dict(lstm_ckpt['model_state_dict']) | |
| lstm_model.eval() | |
| print("✅ LSTM model loaded") | |
| # ---------------------- | |
| # 2️⃣ Load Transformer checkpoint | |
| # ---------------------- | |
| transformer_ckpt_path = "transformer_transliterator.pt" | |
| transformer_ckpt = torch.load(transformer_ckpt_path, map_location='cpu') | |
| transformer_model = TransformerTransliterator( | |
| src_vocab_size=len(src_tokenizer), | |
| tgt_vocab_size=len(tgt_tokenizer), | |
| d_model=256, | |
| nhead=8, | |
| num_encoder_layers=2, | |
| num_decoder_layers=2, | |
| dim_feedforward=512, | |
| dropout=0.1, | |
| max_len=100 | |
| ).to(device) | |
| transformer_model.load_state_dict(transformer_ckpt['model_state_dict']) | |
| transformer_model.eval() | |
| print("✅ Transformer model loaded") | |
| # ---------------------- | |
| # 3️⃣ Load TinyLLaMA | |
| # ---------------------- | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| try: | |
| llm_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name, trust_remote_code=True) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| llm_model_name, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| device_map="auto" if device == "cuda" else None, | |
| trust_remote_code=True, | |
| ) | |
| if device != "cuda": | |
| llm_model = llm_model.to(device) | |
| llm_model.eval() | |
| print("✅ TinyLLaMA model loaded") | |
| has_llm = True | |
| except Exception as e: | |
| print(f"⚠️ TinyLLaMA loading failed: {e}") | |
| print("⚠️ Will use only LSTM and Transformer models") | |
| has_llm = False | |
| # ---------------------- | |
| # 4️⃣ Transliteration Function | |
| # ---------------------- | |
| def transliterate(word): | |
| word = word.strip() | |
| if not word: | |
| return "❌ Empty input", "❌ Empty input", "❌ Empty input" | |
| try: | |
| # LSTM prediction | |
| lstm_pred = lstm_model.translate(word, src_tokenizer, tgt_tokenizer) | |
| except Exception as e: | |
| lstm_pred = f"Error: {str(e)[:50]}" | |
| try: | |
| # Transformer prediction (greedy) | |
| transformer_pred = transformer_model.translate( | |
| word, src_tokenizer, tgt_tokenizer, | |
| device=device, decoding="greedy" | |
| ) | |
| except Exception as e: | |
| transformer_pred = f"Error: {str(e)[:50]}" | |
| # TinyLLaMA prediction | |
| if has_llm: | |
| try: | |
| prompt = f"Transliterate the following English word to Hindi (Devanagari script).\nEnglish word: {word}\nHindi transliteration:" | |
| inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output_ids = llm_model.generate( | |
| **inputs, | |
| max_new_tokens=30, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=llm_tokenizer.eos_token_id, | |
| eos_token_id=llm_tokenizer.eos_token_id, | |
| ) | |
| generated = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| llm_pred = generated.replace(prompt, "").strip() | |
| llm_pred = ''.join(c for c in llm_pred if not (c.isascii() and c.isalpha()) and c.strip()) | |
| except Exception as e: | |
| llm_pred = f"Error: {str(e)[:50]}" | |
| else: | |
| llm_pred = "TinyLLaMA model not loaded (insufficient memory)" | |
| return lstm_pred, transformer_pred, llm_pred | |
| # ---------------------- | |
| # 5️⃣ Gradio Interface | |
| # ---------------------- | |
| demo = gr.Interface( | |
| fn=transliterate, | |
| inputs=gr.Textbox( | |
| label="Input Hindi Roman Word", | |
| placeholder="e.g., namaste, dhanyavaad, bharat", | |
| lines=1 | |
| ), | |
| outputs=[ | |
| gr.Textbox(label="LSTM Prediction", interactive=False), | |
| gr.Textbox(label="Transformer Prediction", interactive=False), | |
| gr.Textbox(label="TinyLLaMA Prediction", interactive=False) | |
| ], | |
| title="Hindi Roman to Devanagari Transliteration", | |
| description="Compare three models: LSTM, Transformer, and TinyLLaMA.\nEnter a Hindi Roman word to get transliteration predictions.", | |
| examples=[ | |
| ["namaste"], | |
| ["dhanyavaad"], | |
| ["bharat"], | |
| ["mumbai"], | |
| ["hindustan"], | |
| ["pranaam"] | |
| ], | |
| allow_flagging="never" | |
| ) | |
| demo.launch( | |
| share=False, | |
| debug=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |