Spaces:
Sleeping
Sleeping
| from typing import Any, Dict, Union, Tuple | |
| import gradio as gr | |
| from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
| import logging | |
| from .base import BaseModel | |
| logger = logging.getLogger(__name__) | |
| class TextClassificationModel(BaseModel): | |
| """Lightweight text classification model using tiny BERT.""" | |
| def __init__(self): | |
| super().__init__( | |
| name="Lightweight Text Classifier", | |
| description="Fast text classification using a tiny BERT model (4.4MB)" | |
| ) | |
| self.model_name = "prajjwal1/bert-tiny" | |
| self._model = None | |
| def load_model(self) -> None: | |
| """Load the classification model.""" | |
| try: | |
| logger.info(f"Loading model: {self.model_name}") | |
| # Initialize model with binary classification | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| self.model_name, | |
| num_labels=2 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self._model = pipeline( | |
| "text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=-1 # CPU, use device=0 for GPU | |
| ) | |
| # Log model size | |
| model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024) | |
| logger.info(f"Model loaded successfully. Size: {model_size_mb:.2f} MB") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise | |
| async def predict(self, text: str) -> Dict[str, Union[str, float]]: | |
| """Make a prediction using the model.""" | |
| try: | |
| if self._model is None: | |
| self.load_model() | |
| logger.info(f"Processing text: {text[:50]}...") | |
| result = self._model(text)[0] | |
| # Map raw labels to sentiment | |
| label_map = { | |
| "LABEL_0": "NEGATIVE", | |
| "LABEL_1": "POSITIVE" | |
| } | |
| prediction = { | |
| "label": label_map.get(result["label"], result["label"]), | |
| "confidence": float(result["score"]) | |
| } | |
| logger.info(f"Prediction result: {prediction}") | |
| return prediction | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| raise | |
| async def predict_for_interface(self, text: str) -> Tuple[str, float]: | |
| """Make a prediction and return it in a format suitable for the Gradio interface.""" | |
| result = await self.predict(text) | |
| return result["label"], result["confidence"] | |
| def create_interface(self) -> gr.Interface: | |
| """Create a Gradio interface for text classification.""" | |
| if self._model is None: | |
| self.load_model() | |
| examples = [ | |
| ["This movie was fantastic! I really enjoyed it."], | |
| ["The service was terrible and the food was cold."], | |
| ["It was an okay experience, nothing special."], | |
| ["The weather is nice today!"], | |
| ["I'm feeling sick and tired."] | |
| ] | |
| return gr.Interface( | |
| fn=self.predict_for_interface, # Use the interface-specific prediction function | |
| inputs=gr.Textbox( | |
| lines=3, | |
| placeholder="Enter text to classify...", | |
| label="Input Text" | |
| ), | |
| outputs=[ | |
| gr.Label(label="Sentiment"), | |
| gr.Number(label="Confidence", precision=4) | |
| ], | |
| title=self.name, | |
| description=self.description + "\n\nThis model is also available via API!", | |
| examples=examples, | |
| api_name="predict" | |
| ) |