Spaces:
Sleeping
Sleeping
File size: 4,042 Bytes
5904988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# Load the model and tokenizer
model_name = "jbochi/madlad400-3b-mt"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
def translate_text(text, source_lang, target_lang):
"""
Translate text between English and Persian using MADLAD-400-3B
"""
# Define language codes for the model
lang_codes = {
"English": "en",
"Persian": "fa"
}
source_code = lang_codes[source_lang]
target_code = lang_codes[target_lang]
# Create the translation prompt in the format the model expects
prompt = f"<2{target_code}> {text}"
try:
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
# Move inputs to the same device as model
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate translation
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=512,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=3,
length_penalty=1.0
)
# Decode the output
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return translated_text
except Exception as e:
return f"Error during translation: {str(e)}"
# Create the Gradio interface
with gr.Blocks(title="English-Persian Translator") as demo:
gr.Markdown(
"""
# 🌍 English-Persian Translator
**Powered by MADLAD-400-3B Model**
Translate text between English and Persian using the state-of-the-art MADLAD-400 model.
"""
)
with gr.Row():
with gr.Column():
source_lang = gr.Dropdown(
choices=["English", "Persian"],
value="English",
label="Source Language"
)
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to translate...",
label="Input Text"
)
translate_btn = gr.Button("Translate", variant="primary")
with gr.Column():
target_lang = gr.Dropdown(
choices=["Persian", "English"],
value="Persian",
label="Target Language"
)
output_text = gr.Textbox(
lines=5,
label="Translated Text",
interactive=False
)
# Examples
gr.Examples(
examples=[
["Hello, how are you today?", "English", "Persian"],
["What is your name?", "English", "Persian"],
["سلام، حالتون چطوره؟", "Persian", "English"],
["امروز هوا خوب است", "Persian", "English"]
],
inputs=[input_text, source_lang, target_lang],
outputs=output_text,
fn=translate_text,
cache_examples=False
)
# Connect the button
translate_btn.click(
fn=translate_text,
inputs=[input_text, source_lang, target_lang],
outputs=output_text
)
# Auto-update target language based on source selection
def update_target_lang(source_lang):
return "Persian" if source_lang == "English" else "English"
source_lang.change(
fn=update_target_lang,
inputs=source_lang,
outputs=target_lang
)
if __name__ == "__main__":
# Launch the app
demo.launch(
server_name="0.0.0.0", # Allow external access
share=False, # Set to True to get a public URL
debug=True
) |