AIKROOZ's picture
add chinese
ead8913 verified
from transformers import pipeline
import numpy as np
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
labels = ["complaint", "praise", "suggestion", "neutral"]
def classify_chinese_batch_with_avg(texts):
translations = translator(texts, max_length=512)
english_texts = [t["translation_text"] for t in translations]
results = classifier(english_texts, candidate_labels=labels)
output = []
all_scores = []
for orig, trans, res in zip(texts, english_texts, results):
scores_dict = dict(zip(res["labels"], res["scores"]))
all_scores.append([scores_dict[l] for l in labels])
output.append({
"original": orig,
"translated": trans,
"classification": scores_dict,
"predicted": res["labels"][0]
})
# Step 4: compute average scores
all_scores = np.array(all_scores)
avg_scores = dict(zip(labels, all_scores.mean(axis=0)))
return output, avg_scores
batch = [
"客服态度非常差,让我很不满意。",
"服务非常好,我非常满意!",
"请考虑增加支付宝或微信支付选项。",
"产品还行,和描述的一样。"
]
results, averages = classify_chinese_batch_with_avg(batch)
for r in results:
print(r, "\n")
print("Average classification scores:", averages)