SmitaGautam commited on
Commit
eecacbe
·
verified ·
1 Parent(s): bf9fb64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -51
app.py CHANGED
@@ -2,9 +2,6 @@ import os
2
  import torch
3
  import gradio as gr
4
  from train import CharTokenizer, Seq2Seq, Encoder, Decoder, TransformerTransliterator
5
- from huggingface_hub import login
6
- hf_token = os.getenv('HF_TOKEN')
7
- login(token=hf_token)
8
 
9
  # ----------------------
10
  # 1️⃣ Load LSTM checkpoint
@@ -25,12 +22,16 @@ DEC_HIDDEN_DIM = 256
25
  NUM_LAYERS_MODEL = 2
26
  DROPOUT = 0.3
27
 
 
 
28
  encoder = Encoder(len(src_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT)
29
  decoder = Decoder(len(tgt_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, DEC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT)
30
- lstm_model = Seq2Seq(encoder, decoder, device='cpu')
31
  lstm_model.load_state_dict(lstm_ckpt['model_state_dict'])
32
  lstm_model.eval()
33
 
 
 
34
  # ----------------------
35
  # 2️⃣ Load Transformer checkpoint
36
  # ----------------------
@@ -47,71 +48,109 @@ transformer_model = TransformerTransliterator(
47
  dim_feedforward=512,
48
  dropout=0.1,
49
  max_len=100
50
- )
51
  transformer_model.load_state_dict(transformer_ckpt['model_state_dict'])
52
  transformer_model.eval()
53
 
 
 
54
  # ----------------------
55
- # 3️⃣ Load LLaMA 7B (Hugging Face)
56
  # ----------------------
57
- # from transformers import LlamaForCausalLM, LlamaTokenizer
58
-
59
- # llama_model_name = "meta-llama/Llama-2-7b-hf" # adjust if using local
60
- # llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_name)
61
- # llama_model = LlamaForCausalLM.from_pretrained(llama_model_name, device_map="auto")
62
- # llama_model.eval()
63
-
64
- from transformers import AutoTokenizer, AutoModelForCausalLM
65
-
66
- indic_model_name = "mistralai/Mistral-7B-Instruct-v0.3"
67
- indic_tokenizer = AutoTokenizer.from_pretrained(indic_model_name)
68
- indic_model = AutoModelForCausalLM.from_pretrained(indic_model_name)
 
 
69
 
70
  # ----------------------
71
  # 4️⃣ Transliteration Function
72
  # ----------------------
 
73
  def transliterate(word):
74
  word = word.strip()
75
-
76
- # LSTM prediction
77
- lstm_pred = lstm_model.translate(word, src_tokenizer, tgt_tokenizer)
78
-
79
- # Transformer prediction (greedy)
80
- transformer_pred = transformer_model.translate(word, src_tokenizer, tgt_tokenizer, decoding="greedy")
81
-
82
- # LLaMA prediction
83
- # prompt = f"Transliterate this Hindi Roman word to Devanagari: {word}"
84
- # inputs = llama_tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
85
- # output_ids = llama_model.generate(**inputs, max_new_tokens=50)
86
- # llama_pred = llama_tokenizer.decode(output_ids[0], skip_special_tokens=True).replace(prompt, "").strip()
87
-
88
- inputs = indic_tokenizer(f"transliterate to Devanagari: {word}", return_tensors="pt")
89
- output = indic_model.generate(**inputs, max_new_tokens=50)
90
- llama_pred = indic_tokenizer.decode(output[0], skip_special_tokens=True)
91
-
92
- return lstm_pred, transformer_pred, llama_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # ----------------------
95
  # 5️⃣ Gradio Interface
96
  # ----------------------
97
- iface = gr.Interface(
98
  fn=transliterate,
99
- inputs=gr.Textbox(label="Input Hindi Roman Word"),
 
 
 
 
100
  outputs=[
101
- gr.Textbox(label="LSTM Prediction"),
102
- gr.Textbox(label="Transformer Prediction"),
103
- gr.Textbox(label="Mistral 7B Prediction")
104
  ],
105
  title="Hindi Roman to Devanagari Transliteration",
106
- description="Enter a Hindi Roman word and get predictions from LSTM, Transformer, and Mistral 7B models."
 
 
 
 
 
 
 
 
 
107
  )
108
 
109
- print('Hello')
110
- iface.launch(
111
- share=True,
112
- debug=True,
113
- inbrowser=False,
114
- server_name="0.0.0.0",
115
- server_port=7860,
116
- block=True
117
- )
 
2
  import torch
3
  import gradio as gr
4
  from train import CharTokenizer, Seq2Seq, Encoder, Decoder, TransformerTransliterator
 
 
 
5
 
6
  # ----------------------
7
  # 1️⃣ Load LSTM checkpoint
 
22
  NUM_LAYERS_MODEL = 2
23
  DROPOUT = 0.3
24
 
25
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
+
27
  encoder = Encoder(len(src_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT)
28
  decoder = Decoder(len(tgt_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, DEC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT)
29
+ lstm_model = Seq2Seq(encoder, decoder, device=device).to(device)
30
  lstm_model.load_state_dict(lstm_ckpt['model_state_dict'])
31
  lstm_model.eval()
32
 
33
+ print("✅ LSTM model loaded")
34
+
35
  # ----------------------
36
  # 2️⃣ Load Transformer checkpoint
37
  # ----------------------
 
48
  dim_feedforward=512,
49
  dropout=0.1,
50
  max_len=100
51
+ ).to(device)
52
  transformer_model.load_state_dict(transformer_ckpt['model_state_dict'])
53
  transformer_model.eval()
54
 
55
+ print("✅ Transformer model loaded")
56
+
57
  # ----------------------
58
+ # 3️⃣ Load lightweight LLM (DistilBERT-based or small model)
59
  # ----------------------
60
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
61
+
62
+ # Use a lightweight T5 model instead of Mistral 7B
63
+ try:
64
+ llm_model_name = "google/flan-t5-small" # 60M params, ~240MB
65
+ llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
66
+ llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name).to(device)
67
+ llm_model.eval()
68
+ print("✅ LLM model loaded (Flan-T5 Small)")
69
+ has_llm = True
70
+ except Exception as e:
71
+ print(f"⚠️ LLM loading failed: {e}")
72
+ print("⚠️ Will use only LSTM and Transformer models")
73
+ has_llm = False
74
 
75
  # ----------------------
76
  # 4️⃣ Transliteration Function
77
  # ----------------------
78
+ @torch.no_grad()
79
  def transliterate(word):
80
  word = word.strip()
81
+
82
+ if not word:
83
+ return "❌ Empty input", "❌ Empty input", "❌ Empty input"
84
+
85
+ try:
86
+ # LSTM prediction
87
+ lstm_pred = lstm_model.translate(word, src_tokenizer, tgt_tokenizer)
88
+ except Exception as e:
89
+ lstm_pred = f"Error: {str(e)[:50]}"
90
+
91
+ try:
92
+ # Transformer prediction (greedy)
93
+ transformer_pred = transformer_model.translate(
94
+ word, src_tokenizer, tgt_tokenizer,
95
+ device=device, decoding="greedy"
96
+ )
97
+ except Exception as e:
98
+ transformer_pred = f"Error: {str(e)[:50]}"
99
+
100
+ # LLM prediction (lightweight T5)
101
+ if has_llm:
102
+ try:
103
+ prompt = f"Transliterate to Devanagari: {word}"
104
+ inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
105
+ output_ids = llm_model.generate(
106
+ **inputs,
107
+ max_length=20,
108
+ num_beams=2,
109
+ early_stopping=True
110
+ )
111
+ llm_pred = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
112
+ # Clean up: remove the input prompt if it appears in output
113
+ llm_pred = llm_pred.replace(prompt, "").strip()
114
+ except Exception as e:
115
+ llm_pred = f"Error: {str(e)[:50]}"
116
+ else:
117
+ llm_pred = "LLM model not loaded (insufficient memory)"
118
+
119
+ return lstm_pred, transformer_pred, llm_pred
120
 
121
  # ----------------------
122
  # 5️⃣ Gradio Interface
123
  # ----------------------
124
+ demo = gr.Interface(
125
  fn=transliterate,
126
+ inputs=gr.Textbox(
127
+ label="Input Hindi Roman Word",
128
+ placeholder="e.g., namaste, dhanyavaad, bharat",
129
+ lines=1
130
+ ),
131
  outputs=[
132
+ gr.Textbox(label="LSTM Prediction", interactive=False),
133
+ gr.Textbox(label="Transformer Prediction", interactive=False),
134
+ gr.Textbox(label="Flan-T5 Small Prediction", interactive=False)
135
  ],
136
  title="Hindi Roman to Devanagari Transliteration",
137
+ description="Compare three models: LSTM, Transformer, and Flan-T5.\nEnter a Hindi Roman word to get transliteration predictions.",
138
+ examples=[
139
+ ["namaste"],
140
+ ["dhanyavaad"],
141
+ ["bharat"],
142
+ ["mumbai"],
143
+ ["hindustan"],
144
+ ["pranaam"]
145
+ ],
146
+ allow_flagging="never"
147
  )
148
 
149
+ if __name__ == "__main__":
150
+ print("🚀 Starting Gradio interface...")
151
+ demo.launch(
152
+ share=False,
153
+ debug=False,
154
+ server_name="0.0.0.0",
155
+ server_port=7860
156
+ )