import gradio as gr import torch from bert_gru_classifier import BERTBiGRUClassifier from bert_lstm_classifier import BERTBiLSTMClassifier from transformers import AutoTokenizer CLASS_MAP = {0: "Negative", 1: "Neutral", 2: "Positive" } # Load tokenizer (pakai tokenizer yang sama untuk semua model) tokenizer = AutoTokenizer.from_pretrained("LazarusNLP/NusaBERT-large") # Load models bigru_model = BERTBiGRUClassifier.from_pretrained("Amal17/NusaBERT-concate-BiGRU-NusaX-ace") bigru_model.eval() bigru_translate_model = BERTBiGRUClassifier.from_pretrained("Amal17/NusaBERT-concate-BiGRU-NusaTranslate-senti") bigru_translate_model.eval() bilstm_model = BERTBiLSTMClassifier.from_pretrained("Amal17/NusaBERT-concate-BiLSTM-NusaX-ace") bilstm_model.eval() bilstm_translate_model = BERTBiLSTMClassifier.from_pretrained("Amal17/NusaBERT-concate-BiLSTM-NusaTranslate-senti") bilstm_translate_model.eval() # Inference helper def predict_with_model(model, text): inputs = tokenizer( text, padding="max_length", truncation=True, max_length=128, return_tensors="pt" ) with torch.no_grad(): outputs = model(**inputs) logits = outputs["logits"] probs = torch.softmax(logits, dim=1) pred = torch.argmax(probs, dim=1).item() confidence = probs[0][pred].item() return pred, confidence # Gradio interface function def compare_models(text): pred_a, conf_a = predict_with_model(bigru_model, text) pred_b, conf_b = predict_with_model(bilstm_model, text) pred_c, conf_c = predict_with_model(bigru_translate_model, text) pred_d, conf_d = predict_with_model(bilstm_translate_model, text) return ( f"Class: {pred_a} ({CLASS_MAP[pred_a]}) with confidence: {conf_a:.4f}", f"Class: {pred_b} ({CLASS_MAP[pred_b]}) with confidence: {conf_b:.4f}", f"Class: {pred_c} ({CLASS_MAP[pred_c]}) with confidence: {conf_c:.4f}", f"Class: {pred_d} ({CLASS_MAP[pred_d]}) with confidence: {conf_d:.4f}", ) # Build Gradio UI interface = gr.Interface( fn=compare_models, inputs=gr.Textbox(label="Input Text"), outputs=[ gr.Textbox(label="NusaBERT-BiGRU-ace"), gr.Textbox(label="NusaBERT-BiLSTM-ace"), gr.Textbox(label="NusaBERT-BiGRU-translate"), gr.Textbox(label="NusaBERT-BiLSRM-translate"), ], title="Hybrid Model NusaBERT + RNN" ) interface.launch()