speech-to-text / app.py
rbcurzon_laptop
refactor: enhance ASR model initialization and update translation configuration
30bd972
# Standard library imports
import os
import time
import tempfile
import logging
from timeit import default_timer as timer
from contextlib import asynccontextmanager
# Third-party imports
import numpy as np
import scipy.io.wavfile
import torch
import torchaudio
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.background import BackgroundTask
from transformers import pipeline, VitsModel, VitsTokenizer, AutoModelForSpeechSeq2Seq, AutoProcessor
# External service imports
from google import genai
from google.genai import types
from silero_vad import (
load_silero_vad,
read_audio,
get_speech_timestamps,
save_audio,
collect_chunks,
)
# Logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load models once at startup and store in app.state
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "rbcurzon/whisper-large-v3-turbo"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
app.state.pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
app.state.vad_model = load_silero_vad()
app.state.client = genai.Client(api_key=os.environ.get("GENAI_API_KEY"))
yield
# Optionally, add cleanup code here
# FastAPI app initialization
app = FastAPI(
title="Real-Time Audio Processor",
description="Process and transcribe audio in real-time using Whisper",
lifespan=lifespan
)
def remove_silence(filename):
"""Remove silence from an audio file using Silero VAD."""
sampling_rate = 16000
try:
wav = read_audio(filename, sampling_rate=sampling_rate)
speech_timestamps = get_speech_timestamps(wav, app.state.vad_model, sampling_rate=sampling_rate)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
save_audio(
temp_file,
collect_chunks(speech_timestamps, wav),
sampling_rate=sampling_rate
)
return temp_file
except Exception as error:
logging.error(f"Error removing silence from {filename}: {error}")
raise HTTPException(status_code=500, detail=str(error))
def translate(text, srcLang, tgtLang):
"""Translate text from srcLang to tgtLang using Gemini API."""
prompt = f"Translate the following text: '{text}'"
response = app.state.client.models.generate_content(
model="gemini-2.5-flash-lite",
contents=prompt,
config=types.GenerateContentConfig(
system_instruction=f"You are an expert translator. Your task is to translate from {srcLang} to {tgtLang}. You must provide ONLY the translated text. Do not include any explanations, additional commentary, or conversational language. Just the translated text.",
thinking_config=types.ThinkingConfig(thinking_budget=0), # Disables thinking
temperature=0.6,
)
)
return response.text
def remove_file(file):
"""Remove a file after a delay (for background cleanup)."""
time.sleep(600) # delay for 10 minutes
os.remove(file)
# API Endpoints
@app.get("/")
def read_root():
return {
"detail": "Philippine Regional Language Translator"
}
@app.post("/translateAudio/")
async def translate_audio(
file: UploadFile = File(...),
srcLang: str = Form("Tagalog"),
tgtLang: str = Form("Cebuano")
):
start = timer()
temp_file = None # initialize temp_file to None
try:
content = await file.read()
with open(file.filename, 'wb') as f:
f.write(content)
print(f"Successfully uploaded {file.filename}")
generate_kwargs = {
"max_new_tokens": 448-4,
"num_beams": 1,
"condition_on_prev_tokens": False,
"compression_ratio_threshold": 1.35,
"temperature": 0.0, # reduce temperature for more deterministic output
"logprob_threshold": -1.0,
"no_speech_threshold": 0.6,
"return_timestamps": True,
"language": "tl"
}
temp_file = remove_silence(file.filename)
result = app.state.pipe(
temp_file,
batch_size=2,
generate_kwargs=generate_kwargs
)
result_dict = {
"transcribed_text": result['text'],
"translated_text": translate(result['text'], srcLang=srcLang, tgtLang=tgtLang),
"srcLang": srcLang,
"tgtLang": tgtLang
}
return result_dict
except Exception as error:
logging.error(f"Error translating audio {file.filename}: {error}")
raise HTTPException(status_code=500, detail=str(error))
finally:
if file.file:
file.file.close()
if os.path.exists(file.filename):
os.remove(file.filename)
if temp_file is not None and os.path.exists(temp_file):
os.remove(temp_file)
end = timer()
logging.info(f"Translation completed for audio {file.filename} in {end - start:.2f} seconds")
@app.post("/translateText/")
async def translate_text(
text: str,
srcLang: str = Form(...),
tgtLang: str = Form(...)
):
start = timer()
result = translate(text, srcLang, tgtLang)
if not result:
logging.error("Translation failed for text: %s", text)
raise HTTPException(status_code=500, detail="Translation failed")
result_dict = {
"text": text,
"translated_text": result,
"srcLang": srcLang,
"tgtLang": tgtLang
}
end = timer()
logging.info(f"Translation completed for text: {text} in {end - start:.2f} seconds")
return result_dict
@app.post("/synthesize/", response_class=FileResponse)
async def synthesize(text: str = Form(...)):
start = timer()
model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-tgl")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
with torch.no_grad():
outputs = model(input_ids)
speech = outputs["waveform"]
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
torchaudio.save(temp_file, speech.cpu(), 16000)
logging.info(f"Synthesizing completed for text: {text}")
end = timer()
logging.info(f"Synthesis completed for text: {text} in {end - start:.2f} seconds")
return FileResponse(
temp_file,
media_type="audio/wav",
filename="speech.wav",
background=BackgroundTask(remove_file, temp_file)
)