import os import torch import torch.nn as nn import streamlit as st from pydantic import BaseModel from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModel from peft import PeftModel # Get the token from environment variable (optional) hf_token = os.environ.get("HF_TOKEN") # Define model IDs adapter_model_id = "seniormgt/arabicmgt-test" base_model_id = "Alibaba-NLP/gte-multilingual-base" # Define your model class GTEClassifier(nn.Module): def __init__(self, model_name=base_model_id): super(GTEClassifier, self).__init__() self.base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True) self.config = self.base_model.config self.pooler = nn.Linear(self.config.hidden_size, self.config.hidden_size) self.pooler_activation = nn.Tanh() self.dropout = nn.Dropout(0.0) self.classifier = nn.Linear(self.config.hidden_size, 1) self.loss_fn = nn.BCEWithLogitsLoss() def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, **kwargs): if inputs_embeds is not None: outputs = self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask) else: outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.last_hidden_state[:, 0, :] pooled_output = self.pooler(pooled_output) pooled_output = self.pooler_activation(pooled_output) logits = self.classifier(self.dropout(pooled_output)).squeeze(-1) loss = None if labels is not None: labels = labels.float() loss = self.loss_fn(logits, labels) return {"loss": loss, "logits": logits} # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(adapter_model_id, token=hf_token, trust_remote_code=True) base_model = GTEClassifier() peft_model = PeftModel.from_pretrained(base_model, adapter_model_id, token=hf_token) # peft_model.eval() # Define prediction def classify_text(text): inputs = tokenizer(text, max_length=512, padding=True, return_attention_mask=True, return_tensors="pt", truncation=True) input_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] with torch.no_grad(): outputs = peft_model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs["logits"] probs = torch.sigmoid(logits).cpu().numpy().squeeze() pred_label = int(probs >= 0.5) return {"label": str(pred_label), "confidence": float(probs)} # 🔹 Streamlit UI st.title("Text Classification (MGT Detection)") text = st.text_area("Enter text", height=150) if st.button("Classify") and text.strip(): result = classify_text(text) st.json(result) # 🔹 FastAPI endpoint app = FastAPI() class Input(BaseModel): data: list @app.post("/predict") async def predict(request: Request): payload = await request.json() text = payload["data"][0]["text"] result = classify_text(text) return {"data": [result]}