selvaonline commited on
Commit
995060b
·
verified ·
1 Parent(s): 0e09fb7

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +61 -17
app.py CHANGED
@@ -83,18 +83,38 @@ def process_deals_data(deals_data):
83
 
84
  return processed_deals
85
 
86
- # Load the model and tokenizer
87
- model_path = os.path.dirname(os.path.abspath(__file__))
88
- tokenizer = AutoTokenizer.from_pretrained(model_path)
89
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
90
-
91
- # Load the categories
92
  try:
93
- with open(os.path.join(model_path, "categories.json"), "r") as f:
94
- categories = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
95
  except Exception as e:
96
- print(f"Error loading categories: {str(e)}")
97
- categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # Global variable to store deals data
100
  deals_cache = None
@@ -111,13 +131,37 @@ def classify_text(text, fetch_deals=True):
111
  # Get the model prediction
112
  with torch.no_grad():
113
  outputs = model(**inputs)
114
- predictions = torch.sigmoid(outputs.logits)
115
-
116
- # Get the top categories
117
- top_categories = []
118
- for i, score in enumerate(predictions[0]):
119
- if score > 0.5: # Threshold for multi-label classification
120
- top_categories.append((categories[i], score.item()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  # Sort by score
123
  top_categories.sort(key=lambda x: x[1], reverse=True)
 
83
 
84
  return processed_deals
85
 
86
+ # Load the e-commerce specific model and tokenizer
 
 
 
 
 
87
  try:
88
+ # Try to load the e-commerce BERT model
89
+ tokenizer = AutoTokenizer.from_pretrained("prithivida/ecommerce-bert-base-uncased")
90
+ model = AutoModelForSequenceClassification.from_pretrained("prithivida/ecommerce-bert-base-uncased")
91
+
92
+ # E-commerce BERT categories
93
+ categories = [
94
+ "electronics", "computers", "mobile_phones", "accessories",
95
+ "clothing", "footwear", "watches", "jewelry",
96
+ "home", "kitchen", "furniture", "decor",
97
+ "beauty", "personal_care", "health", "wellness",
98
+ "toys", "games", "sports", "outdoors",
99
+ "books", "stationery", "music", "movies"
100
+ ]
101
+ print("Using e-commerce BERT model")
102
  except Exception as e:
103
+ # Fall back to local model if e-commerce BERT fails to load
104
+ print(f"Error loading e-commerce BERT model: {str(e)}")
105
+ print("Falling back to local model")
106
+
107
+ model_path = os.path.dirname(os.path.abspath(__file__))
108
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
109
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
110
+
111
+ # Load the local categories
112
+ try:
113
+ with open(os.path.join(model_path, "categories.json"), "r") as f:
114
+ categories = json.load(f)
115
+ except Exception as e:
116
+ print(f"Error loading categories: {str(e)}")
117
+ categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
118
 
119
  # Global variable to store deals data
120
  deals_cache = None
 
131
  # Get the model prediction
132
  with torch.no_grad():
133
  outputs = model(**inputs)
134
+
135
+ # Handle different model output formats
136
+ if hasattr(outputs, 'logits'):
137
+ # For models that return logits
138
+ if outputs.logits.shape[1] == len(categories):
139
+ # Multi-label classification
140
+ predictions = torch.sigmoid(outputs.logits)
141
+
142
+ # Get the top categories
143
+ top_categories = []
144
+ for i, score in enumerate(predictions[0]):
145
+ if score > 0.3: # Lower threshold for e-commerce model
146
+ top_categories.append((categories[i], score.item()))
147
+ else:
148
+ # Single-label classification
149
+ probabilities = torch.softmax(outputs.logits, dim=1)
150
+ values, indices = torch.topk(probabilities, 3)
151
+
152
+ top_categories = []
153
+ for i, idx in enumerate(indices[0]):
154
+ if idx < len(categories):
155
+ top_categories.append((categories[idx.item()], values[0][i].item()))
156
+ else:
157
+ # Fallback for other model formats
158
+ predictions = torch.sigmoid(outputs[0])
159
+
160
+ # Get the top categories
161
+ top_categories = []
162
+ for i, score in enumerate(predictions[0]):
163
+ if score > 0.5:
164
+ top_categories.append((categories[i], score.item()))
165
 
166
  # Sort by score
167
  top_categories.sort(key=lambda x: x[1], reverse=True)