File size: 1,795 Bytes
0cf3ada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import json
import os
# Define the function that will be called when the widget is used
def infer(text):
# Load the model and tokenizer
model_path = os.path.dirname(os.path.abspath(__file__))
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
# Load the categories
try:
with open(os.path.join(model_path, "categories.json"), "r") as f:
categories = json.load(f)
except Exception as e:
print(f"Error loading categories: {str(e)}")
categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
# Prepare the input
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
# Get the model prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.sigmoid(outputs.logits)
# Get the top categories
top_categories = []
for i, score in enumerate(predictions[0]):
if score > 0.5: # Threshold for multi-label classification
top_categories.append((categories[i], score.item()))
# Sort by score
top_categories.sort(key=lambda x: x[1], reverse=True)
# Format the results
if top_categories:
result = f"Top categories for '{text}':\n\n"
for category, score in top_categories:
result += f"- {category}: {score:.4f}\n"
result += "\nBased on your query, I would recommend looking for deals in the "
result += f"**{top_categories[0][0]}** category."
else:
result = f"No categories found for '{text}'. Please try a different query."
return result
|