File size: 3,609 Bytes
3af2df3
 
 
21d3d34
a5fec08
1d215a0
83f3978
1d215a0
a5fec08
313817b
a5fec08
 
 
313817b
a5fec08
 
 
313817b
a5fec08
 
1d215a0
 
83f3978
 
1d215a0
83f3978
1d215a0
 
a5fec08
 
 
 
5868b60
50a99ec
a5fec08
5868b60
 
 
 
 
 
 
 
 
 
 
 
50a99ec
a5fec08
 
 
 
 
 
 
5868b60
 
313817b
a5fec08
5868b60
a5fec08
 
 
5868b60
 
 
a5fec08
 
 
 
 
 
 
 
 
 
 
 
 
 
21d3d34
 
5868b60
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

import nltk
nltk.download("punkt")
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
from llama_index.llms import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import os

# Load DeepSeek-R1
deepseek_tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1", trust_remote_code=True)
deepseek_model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1", trust_remote_code=True)

# Load IndicBART
indicbart_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBART")
indicbart_model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/IndicBART")

# Initialize LlamaIndex components
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")

llm = HuggingFaceLLM(
    model_name="deepseek-ai/DeepSeek-R1",
    tokenizer_name="deepseek-ai/DeepSeek-R1",
    max_new_tokens=512,
    context_window=4096
)

service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)

# Build index from documents in 'data' directory
documents = SimpleDirectoryReader("data").load_data()
index = VectorStoreIndex.from_documents(documents, service_context=service_context)

# Define functions for each task
def restore_text(input_text, task_type):
    prefix_map = {
        "Restore & Correct Tamil Text": "restore: ",
        "Summarize in Tamil": "summarize: ",
        "Translate to English": "translate Tamil to English: "
    }
    prefix = prefix_map.get(task_type, "restore: ")
    input_text = prefix + input_text
    inputs = indicbart_tokenizer([input_text], return_tensors="pt", padding=True)
    outputs = indicbart_model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True)
    decoded_output = indicbart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return decoded_output[0]

def deepseek_chat(message):
    inputs = deepseek_tokenizer.encode(message + deepseek_tokenizer.eos_token, return_tensors="pt")
    outputs = deepseek_model.generate(inputs, max_length=1024, pad_token_id=deepseek_tokenizer.eos_token_id)
    return deepseek_tokenizer.decode(outputs[0], skip_special_tokens=True)

def query_documents(query):
    query_engine = index.as_query_engine()
    response = query_engine.query(query)
    return str(response)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## 🕉️ Ancient Tamil Literature Expert AI")

    with gr.Tab("IndicBART Tasks"):
        input_text = gr.Textbox(label="Input Tamil Text", lines=8, placeholder="Enter ancient Tamil text here...")
        task_type = gr.Radio(choices=["Restore & Correct Tamil Text", "Summarize in Tamil", "Translate to English"], label="Select Task")
        output_text = gr.Textbox(label="Output")
        submit_button = gr.Button("Submit")
        submit_button.click(fn=restore_text, inputs=[input_text, task_type], outputs=output_text)

    with gr.Tab("DeepSeek-R1 Chat"):
        chat_input = gr.Textbox(label="Enter your message")
        chat_output = gr.Textbox(label="DeepSeek-R1 Response")
        chat_button = gr.Button("Send")
        chat_button.click(fn=deepseek_chat, inputs=chat_input, outputs=chat_output)

    with gr.Tab("Document Query"):
        query_input = gr.Textbox(label="Enter your query")
        query_output = gr.Textbox(label="Query Response")
        query_button = gr.Button("Search")
        query_button.click(fn=query_documents, inputs=query_input, outputs=query_output)

if __name__ == "__main__":
    demo.launch()