File size: 1,795 Bytes
0cf3ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import json
import os

# Define the function that will be called when the widget is used
def infer(text):
    # Load the model and tokenizer
    model_path = os.path.dirname(os.path.abspath(__file__))
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    
    # Load the categories
    try:
        with open(os.path.join(model_path, "categories.json"), "r") as f:
            categories = json.load(f)
    except Exception as e:
        print(f"Error loading categories: {str(e)}")
        categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
    
    # Prepare the input
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    
    # Get the model prediction
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.sigmoid(outputs.logits)
    
    # Get the top categories
    top_categories = []
    for i, score in enumerate(predictions[0]):
        if score > 0.5:  # Threshold for multi-label classification
            top_categories.append((categories[i], score.item()))
    
    # Sort by score
    top_categories.sort(key=lambda x: x[1], reverse=True)
    
    # Format the results
    if top_categories:
        result = f"Top categories for '{text}':\n\n"
        for category, score in top_categories:
            result += f"- {category}: {score:.4f}\n"
        
        result += "\nBased on your query, I would recommend looking for deals in the "
        result += f"**{top_categories[0][0]}** category."
    else:
        result = f"No categories found for '{text}'. Please try a different query."
    
    return result