Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import os | |
import joblib | |
from sklearn.preprocessing import LabelEncoder | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer | |
from datasets import Dataset | |
import torch | |
MODEL_NAME = "distilbert-base-multilingual-cased" | |
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) | |
LABEL_ENCODER_PATH = "label_encoder.pkl" | |
model = None | |
label_encoder = None | |
def train_model(file): | |
global model, label_encoder | |
df = pd.read_csv(file.name) | |
if 'text' not in df.columns or 'label' not in df.columns: | |
return "CSV phải có 2 cột: 'text' và 'label'" | |
# Encode labels | |
label_encoder = LabelEncoder() | |
df['label_encoded'] = label_encoder.fit_transform(df['label']) | |
joblib.dump(label_encoder, LABEL_ENCODER_PATH) | |
# Tokenize data | |
dataset = Dataset.from_pandas(df[['text', 'label_encoded']]) | |
def tokenize(batch): | |
return TOKENIZER(batch["text"], padding=True, truncation=True) | |
tokenized_dataset = dataset.map(tokenize, batched=True) | |
tokenized_dataset = tokenized_dataset.rename_column("label_encoded", "labels") | |
tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) | |
# Training args | |
args = TrainingArguments( | |
output_dir="output", | |
evaluation_strategy="no", | |
per_device_train_batch_size=4, | |
num_train_epochs=3, | |
logging_dir="logs", | |
logging_steps=10, | |
save_strategy="no", | |
report_to="none" | |
) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(label_encoder.classes_)) | |
trainer = Trainer( | |
model=model, | |
args=args, | |
train_dataset=tokenized_dataset, | |
) | |
trainer.train() | |
model.save_pretrained("finetuned_model") | |
return "✅ Huấn luyện thành công!" | |
def predict(text): | |
global model, label_encoder | |
if model is None or label_encoder is None: | |
try: | |
model = AutoModelForSequenceClassification.from_pretrained("finetuned_model") | |
label_encoder = joblib.load(LABEL_ENCODER_PATH) | |
except: | |
return "⚠️ Mô hình chưa được huấn luyện hoặc không thể load." | |
inputs = TOKENIZER(text, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted = torch.argmax(outputs.logits, dim=1).item() | |
label = label_encoder.inverse_transform([predicted])[0] | |
return f"🔍 Dự đoán: {label}" | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🚀 Fine-tune mô hình văn bản tiếng Việt") | |
with gr.Row(): | |
file_input = gr.File(label="📂 Upload file CSV huấn luyện") | |
train_btn = gr.Button("🔧 Huấn luyện mô hình") | |
train_output = gr.Textbox(label="Trạng thái huấn luyện") | |
train_btn.click(fn=train_model, inputs=file_input, outputs=train_output) | |
gr.Markdown("---") | |
with gr.Row(): | |
input_text = gr.Textbox(label="✍️ Nhập văn bản để phân loại") | |
output_label = gr.Textbox(label="📘 Kết quả dự đoán") | |
generate_btn = gr.Button("🔍 Dự đoán") | |
generate_btn.click(fn=predict, inputs=input_text, outputs=output_label) | |
demo.launch() | |