from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import json import os # Define the function that will be called when the widget is used def infer(text): # Load the model and tokenizer model_path = os.path.dirname(os.path.abspath(__file__)) tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSequenceClassification.from_pretrained(model_path) # Load the categories try: with open(os.path.join(model_path, "categories.json"), "r") as f: categories = json.load(f) except Exception as e: print(f"Error loading categories: {str(e)}") categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"] # Prepare the input inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) # Get the model prediction with torch.no_grad(): outputs = model(**inputs) predictions = torch.sigmoid(outputs.logits) # Get the top categories top_categories = [] for i, score in enumerate(predictions[0]): if score > 0.5: # Threshold for multi-label classification top_categories.append((categories[i], score.item())) # Sort by score top_categories.sort(key=lambda x: x[1], reverse=True) # Format the results if top_categories: result = f"Top categories for '{text}':\n\n" for category, score in top_categories: result += f"- {category}: {score:.4f}\n" result += "\nBased on your query, I would recommend looking for deals in the " result += f"**{top_categories[0][0]}** category." else: result = f"No categories found for '{text}'. Please try a different query." return result