Spaces:
Sleeping
Sleeping
rbcurzon_laptop
refactor: enhance ASR model initialization and update translation configuration
30bd972
# Standard library imports | |
import os | |
import time | |
import tempfile | |
import logging | |
from timeit import default_timer as timer | |
from contextlib import asynccontextmanager | |
# Third-party imports | |
import numpy as np | |
import scipy.io.wavfile | |
import torch | |
import torchaudio | |
from fastapi import FastAPI, UploadFile, File, HTTPException, Form | |
from fastapi.responses import FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from starlette.background import BackgroundTask | |
from transformers import pipeline, VitsModel, VitsTokenizer, AutoModelForSpeechSeq2Seq, AutoProcessor | |
# External service imports | |
from google import genai | |
from google.genai import types | |
from silero_vad import ( | |
load_silero_vad, | |
read_audio, | |
get_speech_timestamps, | |
save_audio, | |
collect_chunks, | |
) | |
# Logging configuration | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
async def lifespan(app: FastAPI): | |
# Load models once at startup and store in app.state | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model_id = "rbcurzon/whisper-large-v3-turbo" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
app.state.pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
app.state.vad_model = load_silero_vad() | |
app.state.client = genai.Client(api_key=os.environ.get("GENAI_API_KEY")) | |
yield | |
# Optionally, add cleanup code here | |
# FastAPI app initialization | |
app = FastAPI( | |
title="Real-Time Audio Processor", | |
description="Process and transcribe audio in real-time using Whisper", | |
lifespan=lifespan | |
) | |
def remove_silence(filename): | |
"""Remove silence from an audio file using Silero VAD.""" | |
sampling_rate = 16000 | |
try: | |
wav = read_audio(filename, sampling_rate=sampling_rate) | |
speech_timestamps = get_speech_timestamps(wav, app.state.vad_model, sampling_rate=sampling_rate) | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name | |
save_audio( | |
temp_file, | |
collect_chunks(speech_timestamps, wav), | |
sampling_rate=sampling_rate | |
) | |
return temp_file | |
except Exception as error: | |
logging.error(f"Error removing silence from {filename}: {error}") | |
raise HTTPException(status_code=500, detail=str(error)) | |
def translate(text, srcLang, tgtLang): | |
"""Translate text from srcLang to tgtLang using Gemini API.""" | |
prompt = f"Translate the following text: '{text}'" | |
response = app.state.client.models.generate_content( | |
model="gemini-2.5-flash-lite", | |
contents=prompt, | |
config=types.GenerateContentConfig( | |
system_instruction=f"You are an expert translator. Your task is to translate from {srcLang} to {tgtLang}. You must provide ONLY the translated text. Do not include any explanations, additional commentary, or conversational language. Just the translated text.", | |
thinking_config=types.ThinkingConfig(thinking_budget=0), # Disables thinking | |
temperature=0.6, | |
) | |
) | |
return response.text | |
def remove_file(file): | |
"""Remove a file after a delay (for background cleanup).""" | |
time.sleep(600) # delay for 10 minutes | |
os.remove(file) | |
# API Endpoints | |
def read_root(): | |
return { | |
"detail": "Philippine Regional Language Translator" | |
} | |
async def translate_audio( | |
file: UploadFile = File(...), | |
srcLang: str = Form("Tagalog"), | |
tgtLang: str = Form("Cebuano") | |
): | |
start = timer() | |
temp_file = None # initialize temp_file to None | |
try: | |
content = await file.read() | |
with open(file.filename, 'wb') as f: | |
f.write(content) | |
print(f"Successfully uploaded {file.filename}") | |
generate_kwargs = { | |
"max_new_tokens": 448-4, | |
"num_beams": 1, | |
"condition_on_prev_tokens": False, | |
"compression_ratio_threshold": 1.35, | |
"temperature": 0.0, # reduce temperature for more deterministic output | |
"logprob_threshold": -1.0, | |
"no_speech_threshold": 0.6, | |
"return_timestamps": True, | |
"language": "tl" | |
} | |
temp_file = remove_silence(file.filename) | |
result = app.state.pipe( | |
temp_file, | |
batch_size=2, | |
generate_kwargs=generate_kwargs | |
) | |
result_dict = { | |
"transcribed_text": result['text'], | |
"translated_text": translate(result['text'], srcLang=srcLang, tgtLang=tgtLang), | |
"srcLang": srcLang, | |
"tgtLang": tgtLang | |
} | |
return result_dict | |
except Exception as error: | |
logging.error(f"Error translating audio {file.filename}: {error}") | |
raise HTTPException(status_code=500, detail=str(error)) | |
finally: | |
if file.file: | |
file.file.close() | |
if os.path.exists(file.filename): | |
os.remove(file.filename) | |
if temp_file is not None and os.path.exists(temp_file): | |
os.remove(temp_file) | |
end = timer() | |
logging.info(f"Translation completed for audio {file.filename} in {end - start:.2f} seconds") | |
async def translate_text( | |
text: str, | |
srcLang: str = Form(...), | |
tgtLang: str = Form(...) | |
): | |
start = timer() | |
result = translate(text, srcLang, tgtLang) | |
if not result: | |
logging.error("Translation failed for text: %s", text) | |
raise HTTPException(status_code=500, detail="Translation failed") | |
result_dict = { | |
"text": text, | |
"translated_text": result, | |
"srcLang": srcLang, | |
"tgtLang": tgtLang | |
} | |
end = timer() | |
logging.info(f"Translation completed for text: {text} in {end - start:.2f} seconds") | |
return result_dict | |
async def synthesize(text: str = Form(...)): | |
start = timer() | |
model = VitsModel.from_pretrained("facebook/mms-tts-tgl") | |
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-tgl") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
inputs = tokenizer(text, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids) | |
speech = outputs["waveform"] | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name | |
torchaudio.save(temp_file, speech.cpu(), 16000) | |
logging.info(f"Synthesizing completed for text: {text}") | |
end = timer() | |
logging.info(f"Synthesis completed for text: {text} in {end - start:.2f} seconds") | |
return FileResponse( | |
temp_file, | |
media_type="audio/wav", | |
filename="speech.wav", | |
background=BackgroundTask(remove_file, temp_file) | |
) | |