ViT5BaseNode / app.py
VietCat's picture
reduce processing time
fb4a646
raw
history blame
2.26 kB
import time
import logging
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large-vietnews-summarization")
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summarization")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Warm-up model to reduce first-request latency
dummy_input = tokenizer("Tin nhanh: Đây là văn bản mẫu để warmup mô hình.", return_tensors="pt").to(device)
with torch.no_grad():
_ = model.generate(**dummy_input, max_length=32)
class SummarizeRequest(BaseModel):
text: str
@app.get("/")
async def root():
return {"message": "Model is ready."}
@app.post("/summarize")
async def summarize(req: Request, body: SummarizeRequest):
start_time = time.time()
client_ip = req.client.host
logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] 🔵 Received request from {client_ip}")
text = body.text.strip()
# Tiền xử lý: nếu không giống tin tức thì thêm "Tin nhanh:"
if not text.lower().startswith(("theo", "trong khi", "bộ", "ngày", "việt nam", "công an")):
text = "Tin nhanh: " + text
else:
text = "Vietnews: " + text
input_text = text + " </s>"
encoding = tokenizer(input_text, return_tensors="pt")
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
# Sinh tóm tắt với cấu hình ổn định (loại bỏ early_stopping và dùng greedy decoding)
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=256,
num_beams=1, # greedy decoding
no_repeat_ngram_size=2
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
end_time = time.time()
logger.info(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ✅ Response sent — total time: {end_time - start_time:.2f}s")
return {"summary": summary}