|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
import json |
|
import os |
|
|
|
|
|
def infer(text): |
|
|
|
model_path = os.path.dirname(os.path.abspath(__file__)) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
|
|
|
|
try: |
|
with open(os.path.join(model_path, "categories.json"), "r") as f: |
|
categories = json.load(f) |
|
except Exception as e: |
|
print(f"Error loading categories: {str(e)}") |
|
categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"] |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predictions = torch.sigmoid(outputs.logits) |
|
|
|
|
|
top_categories = [] |
|
for i, score in enumerate(predictions[0]): |
|
if score > 0.5: |
|
top_categories.append((categories[i], score.item())) |
|
|
|
|
|
top_categories.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
if top_categories: |
|
result = f"Top categories for '{text}':\n\n" |
|
for category, score in top_categories: |
|
result += f"- {category}: {score:.4f}\n" |
|
|
|
result += "\nBased on your query, I would recommend looking for deals in the " |
|
result += f"**{top_categories[0][0]}** category." |
|
else: |
|
result = f"No categories found for '{text}'. Please try a different query." |
|
|
|
return result |
|
|