import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch # Load models and tokenizers tokenizer_sentiment = AutoTokenizer.from_pretrained("tmt3103/VSFC-sentiment-classify-phoBERT") model_sentiment = AutoModelForSequenceClassification.from_pretrained("tmt3103/VSFC-sentiment-classify-phoBERT") tokenizer_topic = AutoTokenizer.from_pretrained("tmt3103/VSFC-topic-classify-phoBERT") model_topic = AutoModelForSequenceClassification.from_pretrained("tmt3103/VSFC-topic-classify-phoBERT") # Define prediction functions def predict_sentiment(text): model_sentiment.eval() with torch.no_grad(): encoding = tokenizer_sentiment(text, return_tensors='pt', truncation=True, padding=True, max_length=128) output = model_sentiment(**encoding) pred = torch.argmax(output.logits, dim=1).item() sentiment_labels = ["Tiêu cực", "Trung lập", "Tích cực"] # Cập nhật nhãn theo yêu cầu return sentiment_labels[pred] # Trả về nhãn cảm xúc def predict_topic(text): model_topic.eval() with torch.no_grad(): encoding = tokenizer_topic(text, return_tensors='pt', truncation=True, padding=True, max_length=128) output = model_topic(**encoding) pred = torch.argmax(output.logits, dim=1).item() topic_labels = ["Giảng viên", "Chương trình đào tạo", "Cơ sở vật chất", "Khác"] # Cập nhật nhãn theo yêu cầu return topic_labels[pred] # Trả về nhãn chủ đề # Define a function to handle both predictions at once def classify_sentiment_and_topic(text): sentiment = predict_sentiment(text) topic = predict_topic(text) return sentiment, topic # Create Gradio Interface iface = gr.Interface( fn=classify_sentiment_and_topic, inputs=gr.Textbox(label="Nhập văn bản để phân loại", placeholder="Gõ văn bản của bạn ở đây..."), outputs=[gr.Textbox(label="Kết quả phân loại cảm xúc"), gr.Textbox(label="Kết quả phân loại chủ đề")], title="Phân loại cảm xúc và chủ đề", description="Phân loại cảm xúc (Tiêu cực, Trung lập, Tích cực) và chủ đề (Giảng viên, Chương trình đào tạo, Cơ sở vật chất, Khác) của văn bản.", ) # Launch the app iface.launch()