File size: 6,285 Bytes
34e4cc0
 
 
 
 
 
 
7d13edc
34e4cc0
 
 
7d13edc
fb5f2c8
34e4cc0
 
 
 
 
 
7d13edc
34e4cc0
 
 
 
 
 
 
 
7d13edc
 
34e4cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d13edc
34e4cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d13edc
34e4cc0
 
7d13edc
34e4cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d13edc
34e4cc0
7d13edc
34e4cc0
 
 
 
 
 
 
 
 
 
 
7d13edc
34e4cc0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F
import faiss
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
from sacrebleu import corpus_bleu
import os
import tempfile


# Load Models
lang_detect_model = AutoModelForSequenceClassification.from_pretrained("papluca/xlm-roberta-base-language-detection")
lang_detect_tokenizer = AutoTokenizer.from_pretrained("papluca/xlm-roberta-base-language-detection")
trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
embed_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

# Language Mappings
id2lang = lang_detect_model.config.id2label

nllb_langs = {
    "eng_Latn": "English", "fra_Latn": "French", "hin_Deva": "Hindi",
    "spa_Latn": "Spanish", "deu_Latn": "German", "tam_Taml": "Tamil",
    "tel_Telu": "Telugu", "jpn_Jpan": "Japanese", "zho_Hans": "Chinese",
    "arb_Arab": "Arabic", "san_Deva": "Sanskrit"
}

xlm_to_nllb = {
    "en": "eng_Latn", "fr": "fra_Latn", "hi": "hin_Deva", "es": "spa_Latn", "de": "deu_Latn",
    "ta": "tam_Taml", "te": "tel_Telu", "ja": "jpn_Jpan", "zh": "zho_Hans", "ar": "arb_Arab",
    "sa": "san_Deva"
}

# Static Corpus
corpus = [
    "धर्म एव हतो हन्ति धर्मो रक्षति रक्षितः",
    "Dharma when destroyed, destroys; when protected, protects.",
    "The moon affects tides and mood, according to Jyotisha",
    "One should eat according to the season – Rituacharya",
    "Balance of Tridosha is health – Ayurveda principle",
    "Ethics in Mahabharata reflect situational dharma",
    "Meditation improves memory and mental clarity",
    "Jyotisha links planetary motion with life patterns"
]
corpus_embeddings = embed_model.encode(corpus, convert_to_numpy=True)
dimension = corpus_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(corpus_embeddings)

# Detect Language
def detect_language(text):
    inputs = lang_detect_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = lang_detect_model(**inputs)
        probs = F.softmax(outputs.logits, dim=1)
        pred = torch.argmax(probs, dim=1).item()
    return id2lang[pred]

# Translate
def translate(text, src_code, tgt_code):
    trans_tokenizer.src_lang = src_code
    encoded = trans_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    try:
        target_lang_id = trans_tokenizer.convert_tokens_to_ids([tgt_code])[0]
        generated = trans_model.generate(**encoded, forced_bos_token_id=target_lang_id)
        return trans_tokenizer.decode(generated[0], skip_special_tokens=True)
    except:
        return ""

# Semantic Search
def search_semantic(query, top_k=3):
    query_embedding = embed_model.encode([query])
    distances, indices = index.search(query_embedding, top_k)
    return [(corpus[i], float(distances[0][idx])) for idx, i in enumerate(indices[0])]
    
# Create downloadable output file
def save_output_to_file(detected_lang, translated, sem_results, bleu_score):
    with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f:
        f.write(f"Detected Language: {detected_lang}\n")
        f.write(f"Translated Text: {translated}\n\n")
        f.write("Top Semantic Matches:\n")
        for i, (text, score) in enumerate(sem_results):
            f.write(f"{i+1}. {text} (Score: {score:.2f})\n")
        if bleu_score:
            f.write(f"\nBLEU Score: {bleu_score}")
        return f.name

def full_pipeline(user_input_text, target_lang_code, human_ref=""):
    if not user_input_text.strip():
        return "Empty input", "", [], "", "", None

    if len(user_input_text) > 2048:
        return " Input too long", "Please enter shorter text (under 2000 characters).", [], "", "", None

    detected_lang = detect_language(user_input_text)
    src_nllb = xlm_to_nllb.get(detected_lang, "eng_Latn")

    translated = translate(user_input_text, src_nllb, target_lang_code)
    if not translated:
        return detected_lang, " Translation failed", [], "", "", None

    sem_results = search_semantic(translated)
    result_list = [f"{i+1}. {txt} (Score: {score:.2f})" for i, (txt, score) in enumerate(sem_results)]

    # Plot
    labels = [f"{i+1}" for i in range(len(sem_results))]
    scores = [score for _, score in sem_results]
    plt.figure(figsize=(6, 4))
    bars = plt.barh(labels, scores, color="lightgreen")
    plt.xlabel("Similarity Score")
    plt.title("Top Semantic Matches")
    plt.gca().invert_yaxis()
    for bar in bars:
        plt.text(bar.get_width() + 0.01, bar.get_y() + 0.1, f"{bar.get_width():.2f}", fontsize=8)
    plt.tight_layout()
    plot_path = "/tmp/sem_plot.png"
    plt.savefig(plot_path)
    plt.close()

    bleu_score = ""
    if human_ref.strip():
        bleu = corpus_bleu([translated], [[human_ref]])
        bleu_score = f"{bleu.score:.2f}"

    download_file_path = save_output_to_file(detected_lang, translated, sem_results, bleu_score)
    return detected_lang, translated, "\n".join(result_list), plot_path, bleu_score, download_file_path


# Gradio Interface
gr.Interface(
    fn=full_pipeline,
    inputs=[
        gr.Textbox(label="Input Text", lines=4, placeholder="Enter text to translate..."),
        gr.Dropdown(label="Target Language", choices=list(nllb_langs.keys()), value="eng_Latn"),
        gr.Textbox(label="(Optional) Human Reference Translation", lines=2, placeholder="Paste human translation here (for BLEU)...")
    ],
    outputs=[
        gr.Textbox(label="Detected Language"),
        gr.Textbox(label="Translated Text"),
        gr.Textbox(label="Top Semantic Matches"),
        gr.Image(label="Semantic Similarity Plot"),
        gr.Textbox(label="BLEU Score"),
        gr.File(label="Download Translation Report")  # NEW OUTPUT
    ],
    title=" Multilingual Translator + Semantic Search",
    description="Detects language → Translates → Finds related Sanskrit concepts → BLEU optional → Downloadable report."
).launch()