adhsdksdjsbdk commited on
Commit
15c7196
·
verified ·
1 Parent(s): 6ad7bf6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import streamlit as st
5
+ from pydantic import BaseModel
6
+ from fastapi import FastAPI, Request
7
+ from transformers import AutoTokenizer, AutoModel
8
+ from peft import PeftModel
9
+
10
+ # Get the token from environment variable (optional)
11
+ hf_token = os.environ.get("HF_TOKEN")
12
+
13
+ # Define model IDs
14
+ adapter_model_id = "seniormgt/arabicmgt-test"
15
+ base_model_id = "Alibaba-NLP/gte-multilingual-base"
16
+
17
+ # Define your model
18
+ class GTEClassifier(nn.Module):
19
+ def __init__(self, model_name=base_model_id):
20
+ super(GTEClassifier, self).__init__()
21
+ self.base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
22
+ self.config = self.base_model.config
23
+ self.pooler = nn.Linear(self.config.hidden_size, self.config.hidden_size)
24
+ self.pooler_activation = nn.Tanh()
25
+ self.dropout = nn.Dropout(0.0)
26
+ self.classifier = nn.Linear(self.config.hidden_size, 1)
27
+ self.loss_fn = nn.BCEWithLogitsLoss()
28
+
29
+ def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, **kwargs):
30
+ outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
31
+ pooled_output = outputs.last_hidden_state[:, 0, :]
32
+ pooled_output = self.pooler(pooled_output)
33
+ pooled_output = self.pooler_activation(pooled_output)
34
+ logits = self.classifier(self.dropout(pooled_output)).squeeze(-1)
35
+ loss = self.loss_fn(logits, labels.float()) if labels is not None else None
36
+ return {"loss": loss, "logits": logits}
37
+
38
+ # Load tokenizer and model
39
+ tokenizer = AutoTokenizer.from_pretrained(adapter_model_id, token=hf_token, trust_remote_code=True)
40
+ base_model = GTEClassifier()
41
+ peft_model = PeftModel.from_pretrained(base_model, adapter_model_id, token=hf_token)
42
+ peft_model.eval()
43
+
44
+ # Define prediction
45
+ def classify_text(text):
46
+ inputs = tokenizer(text, max_length=512, padding=True, return_attention_mask=True, return_tensors="pt", truncation=True)
47
+ input_ids = inputs['input_ids']
48
+ attention_mask = inputs['attention_mask']
49
+
50
+ with torch.no_grad():
51
+ outputs = peft_model(input_ids=input_ids, attention_mask=attention_mask)
52
+ logits = outputs["logits"]
53
+
54
+ probs = torch.sigmoid(logits).cpu().numpy().squeeze()
55
+ pred_label = int(probs >= 0.5)
56
+ return {"label": str(pred_label), "confidence": float(probs)}
57
+
58
+ # 🔹 Streamlit UI
59
+ st.title("Text Classification (MGT Detection)")
60
+ text = st.text_area("Enter text", height=150)
61
+
62
+ if st.button("Classify") and text.strip():
63
+ result = classify_text(text)
64
+ st.json(result)
65
+
66
+ # 🔹 FastAPI endpoint
67
+ app = FastAPI()
68
+
69
+ class Input(BaseModel):
70
+ data: list
71
+
72
+ @app.post("/predict")
73
+ async def predict(request: Request):
74
+ payload = await request.json()
75
+ text = payload["data"][0]["text"]
76
+ result = classify_text(text)
77
+ return {"data": [result]}