selvaonline commited on
Commit
1732d93
·
verified ·
1 Parent(s): 9a3c339

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import json
5
+ import os
6
+
7
+ # Load the model and tokenizer
8
+ model_path = os.path.dirname(os.path.abspath(__file__))
9
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
10
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
11
+
12
+ # Load the categories
13
+ try:
14
+ with open(os.path.join(model_path, "categories.json"), "r") as f:
15
+ categories = json.load(f)
16
+ except Exception as e:
17
+ print(f"Error loading categories: {str(e)}")
18
+ categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
19
+
20
+ def classify_text(text):
21
+ """
22
+ Classify the text using the model
23
+ """
24
+ # Prepare the input
25
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
26
+
27
+ # Get the model prediction
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+ predictions = torch.sigmoid(outputs.logits)
31
+
32
+ # Get the top categories
33
+ top_categories = []
34
+ for i, score in enumerate(predictions[0]):
35
+ if score > 0.5: # Threshold for multi-label classification
36
+ top_categories.append((categories[i], score.item()))
37
+
38
+ # Sort by score
39
+ top_categories.sort(key=lambda x: x[1], reverse=True)
40
+
41
+ # Format the results
42
+ if top_categories:
43
+ result = f"Top categories for '{text}':\n\n"
44
+ for category, score in top_categories:
45
+ result += f"- {category}: {score:.4f}\n"
46
+
47
+ result += f"\nBased on your query, I would recommend looking for deals in the **{top_categories[0][0]}** category."
48
+ else:
49
+ result = f"No categories found for '{text}'. Please try a different query."
50
+
51
+ return result
52
+
53
+ # Create the Gradio interface
54
+ demo = gr.Interface(
55
+ fn=classify_text,
56
+ inputs=gr.Textbox(
57
+ lines=2,
58
+ placeholder="Enter your shopping query here...",
59
+ label="Shopping Query"
60
+ ),
61
+ outputs=gr.Markdown(label="Results"),
62
+ title="Shopping Assistant",
63
+ description="""
64
+ This demo shows how to use the Shopping Assistant model to classify shopping queries into categories.
65
+ Enter a shopping query below to see which categories it belongs to.
66
+
67
+ Examples:
68
+ - "I'm looking for headphones"
69
+ - "Do you have any kitchen appliance deals?"
70
+ - "Show me the best laptop deals"
71
+ - "I need a new smart TV"
72
+ """,
73
+ examples=[
74
+ ["I'm looking for headphones"],
75
+ ["Do you have any kitchen appliance deals?"],
76
+ ["Show me the best laptop deals"],
77
+ ["I need a new smart TV"]
78
+ ],
79
+ theme=gr.themes.Soft()
80
+ )
81
+
82
+ # Launch the app
83
+ if __name__ == "__main__":
84
+ demo.launch()