vnanhtuan's picture
Update app.py
add87ad verified
import gradio as gr
import pandas as pd
import os
import joblib
from sklearn.preprocessing import LabelEncoder
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import Dataset
import torch
MODEL_NAME = "distilbert-base-multilingual-cased"
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
LABEL_ENCODER_PATH = "label_encoder.pkl"
model = None
label_encoder = None
def train_model(file):
global model, label_encoder
df = pd.read_csv(file.name)
if 'text' not in df.columns or 'label' not in df.columns:
return "CSV phải có 2 cột: 'text' và 'label'"
# Encode labels
label_encoder = LabelEncoder()
df['label_encoded'] = label_encoder.fit_transform(df['label'])
joblib.dump(label_encoder, LABEL_ENCODER_PATH)
# Tokenize data
dataset = Dataset.from_pandas(df[['text', 'label_encoded']])
def tokenize(batch):
return TOKENIZER(batch["text"], padding=True, truncation=True)
tokenized_dataset = dataset.map(tokenize, batched=True)
tokenized_dataset = tokenized_dataset.rename_column("label_encoded", "labels")
tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
# Training args
args = TrainingArguments(
output_dir="output",
evaluation_strategy="no",
per_device_train_batch_size=4,
num_train_epochs=3,
logging_dir="logs",
logging_steps=10,
save_strategy="no",
report_to="none"
)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(label_encoder.classes_))
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized_dataset,
)
trainer.train()
model.save_pretrained("finetuned_model")
return "✅ Huấn luyện thành công!"
def predict(text):
global model, label_encoder
if model is None or label_encoder is None:
try:
model = AutoModelForSequenceClassification.from_pretrained("finetuned_model")
label_encoder = joblib.load(LABEL_ENCODER_PATH)
except:
return "⚠️ Mô hình chưa được huấn luyện hoặc không thể load."
inputs = TOKENIZER(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
predicted = torch.argmax(outputs.logits, dim=1).item()
label = label_encoder.inverse_transform([predicted])[0]
return f"🔍 Dự đoán: {label}"
with gr.Blocks() as demo:
gr.Markdown("## 🚀 Fine-tune mô hình văn bản tiếng Việt")
with gr.Row():
file_input = gr.File(label="📂 Upload file CSV huấn luyện")
train_btn = gr.Button("🔧 Huấn luyện mô hình")
train_output = gr.Textbox(label="Trạng thái huấn luyện")
train_btn.click(fn=train_model, inputs=file_input, outputs=train_output)
gr.Markdown("---")
with gr.Row():
input_text = gr.Textbox(label="✍️ Nhập văn bản để phân loại")
output_label = gr.Textbox(label="📘 Kết quả dự đoán")
generate_btn = gr.Button("🔍 Dự đoán")
generate_btn.click(fn=predict, inputs=input_text, outputs=output_label)
demo.launch()