|
|
|
""" |
|
Gradio demo for the Shopping Assistant model |
|
""" |
|
import gradio as gr |
|
import requests |
|
import numpy as np |
|
import argparse |
|
|
|
def query_model(text, api_token=None, model_id="selvaonline/shopping-assistant"): |
|
""" |
|
Query the model using the Hugging Face Inference API |
|
""" |
|
api_url = f"https://api-inference.huggingface.co/models/{model_id}" |
|
|
|
headers = {} |
|
if api_token: |
|
headers["Authorization"] = f"Bearer {api_token}" |
|
|
|
payload = { |
|
"inputs": text, |
|
"options": { |
|
"wait_for_model": True |
|
} |
|
} |
|
|
|
response = requests.post(api_url, headers=headers, json=payload) |
|
|
|
if response.status_code == 200: |
|
return response.json() |
|
else: |
|
print(f"Error: {response.status_code}") |
|
print(response.text) |
|
return None |
|
|
|
def process_results(results, text): |
|
""" |
|
Process the results from the Inference API |
|
""" |
|
if not results or not isinstance(results, list) or len(results) == 0: |
|
return f"No results found for '{text}'" |
|
|
|
|
|
|
|
|
|
probabilities = 1 / (1 + np.exp(-np.array(results[0]))) |
|
|
|
|
|
categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"] |
|
|
|
|
|
top_categories = [] |
|
for i, score in enumerate(probabilities): |
|
if score > 0.5: |
|
top_categories.append((categories[i], float(score))) |
|
|
|
|
|
top_categories.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
if top_categories: |
|
result = f"Top categories for '{text}':\n\n" |
|
for category, score in top_categories: |
|
result += f"- {category}: {score:.4f}\n" |
|
|
|
result += f"\nBased on your query, I would recommend looking for deals in the **{top_categories[0][0]}** category." |
|
else: |
|
result = f"No categories found for '{text}'. Please try a different query." |
|
|
|
return result |
|
|
|
def classify_query(query, api_token=None, model_id="selvaonline/shopping-assistant"): |
|
""" |
|
Classify a shopping query using the model |
|
""" |
|
results = query_model(query, api_token, model_id) |
|
return process_results(results, query) |
|
|
|
def create_gradio_interface(api_token=None, model_id="selvaonline/shopping-assistant"): |
|
""" |
|
Create a Gradio interface for the Shopping Assistant model |
|
""" |
|
|
|
demo = gr.Interface( |
|
fn=lambda query: classify_query(query, api_token, model_id), |
|
inputs=gr.Textbox( |
|
lines=2, |
|
placeholder="Enter your shopping query here...", |
|
label="Shopping Query" |
|
), |
|
outputs=gr.Markdown(label="Results"), |
|
title="Shopping Assistant", |
|
description=""" |
|
This demo shows how to use the Shopping Assistant model to classify shopping queries into categories. |
|
Enter a shopping query below to see which categories it belongs to. |
|
|
|
Examples: |
|
- "I'm looking for headphones" |
|
- "Do you have any kitchen appliance deals?" |
|
- "Show me the best laptop deals" |
|
- "I need a new smart TV" |
|
""", |
|
examples=[ |
|
["I'm looking for headphones"], |
|
["Do you have any kitchen appliance deals?"], |
|
["Show me the best laptop deals"], |
|
["I need a new smart TV"] |
|
], |
|
theme=gr.themes.Soft() |
|
) |
|
|
|
return demo |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Gradio demo for the Shopping Assistant model") |
|
parser.add_argument("--token", type=str, help="Hugging Face API token") |
|
parser.add_argument("--model-id", type=str, default="selvaonline/shopping-assistant", help="Hugging Face model ID") |
|
parser.add_argument("--share", action="store_true", help="Create a public link") |
|
args = parser.parse_args() |
|
|
|
print(f"Starting Gradio demo for model {args.model_id}") |
|
demo = create_gradio_interface(args.token, args.model_id) |
|
demo.launch(share=args.share) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|