Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import io | |
| import soundfile as sf | |
| import base64 | |
| import logging | |
| import torch | |
| import librosa | |
| from pathlib import Path | |
| from pydub import AudioSegment | |
| from moviepy.editor import VideoFileClip | |
| import traceback | |
| from logging.handlers import RotatingFileHandler | |
| import os | |
| import boto3 | |
| from botocore.exceptions import NoCredentialsError | |
| import time | |
| import tempfile | |
| # Import functions from other modules | |
| from asr import transcribe, ASR_LANGUAGES | |
| from tts import synthesize, TTS_LANGUAGES | |
| from lid import identify | |
| from asr import ASR_SAMPLING_RATE | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Add a file handler | |
| file_handler = RotatingFileHandler('app.log', maxBytes=10000000, backupCount=5) | |
| file_handler.setLevel(logging.INFO) | |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| file_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages") | |
| # S3 Configuration | |
| S3_BUCKET = os.environ.get("S3_BUCKET") | |
| S3_REGION = os.environ.get("S3_REGION") | |
| S3_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID") | |
| S3_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY") | |
| # Initialize S3 client | |
| s3_client = boto3.client( | |
| 's3', | |
| aws_access_key_id=S3_ACCESS_KEY_ID, | |
| aws_secret_access_key=S3_SECRET_ACCESS_KEY, | |
| region_name=S3_REGION | |
| ) | |
| # Define request models | |
| class AudioRequest(BaseModel): | |
| audio: str # Base64 encoded audio or video data | |
| language: str | |
| class TTSRequest(BaseModel): | |
| text: str | |
| language: str | |
| speed: float | |
| def extract_audio_from_file(input_bytes): | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_file: | |
| temp_file.write(input_bytes) | |
| temp_file_path = temp_file.name | |
| try: | |
| # First, try to read as a standard audio file | |
| audio_array, sample_rate = sf.read(temp_file_path) | |
| return audio_array, sample_rate | |
| except Exception: | |
| try: | |
| # Try to read as a video file | |
| video = VideoFileClip(temp_file_path) | |
| audio = video.audio | |
| if audio is not None: | |
| # Extract audio from video | |
| audio_array = audio.to_soundarray() | |
| sample_rate = audio.fps | |
| # Convert to mono if stereo | |
| if len(audio_array.shape) > 1 and audio_array.shape[1] > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| # Ensure audio is float32 and normalized | |
| audio_array = audio_array.astype(np.float32) | |
| audio_array /= np.max(np.abs(audio_array)) | |
| video.close() | |
| return audio_array, sample_rate | |
| else: | |
| raise ValueError("Video file contains no audio") | |
| except Exception: | |
| # If video reading fails, try as generic audio with pydub | |
| try: | |
| audio = AudioSegment.from_file(temp_file_path) | |
| audio_array = np.array(audio.get_array_of_samples()) | |
| # Convert to float32 and normalize | |
| audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**7) | |
| # Convert stereo to mono if necessary | |
| if audio.channels == 2: | |
| audio_array = audio_array.reshape((-1, 2)).mean(axis=1) | |
| return audio_array, audio.frame_rate | |
| except Exception as e: | |
| raise ValueError(f"Unsupported file format: {str(e)}") | |
| finally: | |
| # Clean up the temporary file | |
| os.unlink(temp_file_path) | |
| async def transcribe_audio(request: AudioRequest): | |
| start_time = time.time() | |
| try: | |
| input_bytes = base64.b64decode(request.audio) | |
| audio_array, sample_rate = extract_audio_from_file(input_bytes) | |
| # Ensure audio_array is float32 | |
| audio_array = audio_array.astype(np.float32) | |
| # Resample if necessary | |
| if sample_rate != ASR_SAMPLING_RATE: | |
| audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE) | |
| result = transcribe(audio_array, request.language) | |
| processing_time = time.time() - start_time | |
| return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time}) | |
| except Exception as e: | |
| logger.error(f"Error in transcribe_audio: {str(e)}", exc_info=True) | |
| error_details = { | |
| "error": str(e), | |
| "traceback": traceback.format_exc() | |
| } | |
| processing_time = time.time() - start_time | |
| return JSONResponse( | |
| status_code=500, | |
| content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time} | |
| ) | |
| async def synthesize_speech(request: TTSRequest): | |
| start_time = time.time() | |
| logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}") | |
| try: | |
| # Extract the ISO code from the full language name | |
| lang_code = request.language.split()[0].strip() | |
| # Input validation | |
| if not request.text: | |
| raise ValueError("Text cannot be empty") | |
| if lang_code not in TTS_LANGUAGES: | |
| raise ValueError(f"Unsupported language: {request.language}") | |
| if not 0.5 <= request.speed <= 2.0: | |
| raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}") | |
| logger.info(f"Calling synthesize function with lang_code: {lang_code}") | |
| result, filtered_text = synthesize(request.text, request.language, request.speed) | |
| logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'") | |
| if result is None: | |
| logger.error("Synthesize function returned None") | |
| raise ValueError("Synthesis failed to produce audio") | |
| sample_rate, audio = result | |
| logger.info(f"Synthesis result: sample_rate={sample_rate}, audio_shape={audio.shape if isinstance(audio, np.ndarray) else 'not numpy array'}, audio_dtype={audio.dtype if isinstance(audio, np.ndarray) else type(audio)}") | |
| logger.info("Converting audio to numpy array") | |
| audio = np.array(audio, dtype=np.float32) | |
| logger.info(f"Converted audio shape: {audio.shape}, dtype: {audio.dtype}") | |
| logger.info("Normalizing audio") | |
| max_value = np.max(np.abs(audio)) | |
| if max_value == 0: | |
| logger.warning("Audio array is all zeros") | |
| raise ValueError("Generated audio is silent (all zeros)") | |
| audio = audio / max_value | |
| logger.info(f"Normalized audio range: [{audio.min()}, {audio.max()}]") | |
| logger.info("Converting to int16") | |
| audio = (audio * 32767).astype(np.int16) | |
| logger.info(f"Int16 audio shape: {audio.shape}, dtype: {audio.dtype}") | |
| logger.info("Writing audio to buffer") | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio, sample_rate, format='wav') | |
| buffer.seek(0) | |
| logger.info(f"Buffer size: {buffer.getbuffer().nbytes} bytes") | |
| # Generate a unique filename | |
| filename = f"synthesized_audio_{int(time.time())}.wav" | |
| # Upload to S3 without ACL | |
| try: | |
| s3_client.upload_fileobj( | |
| buffer, | |
| S3_BUCKET, | |
| filename, | |
| ExtraArgs={'ContentType': 'audio/wav'} | |
| ) | |
| logger.info(f"File uploaded successfully to S3: {filename}") | |
| # Generate the public URL with the correct format | |
| url = f"https://s3.{S3_REGION}.amazonaws.com/{S3_BUCKET}/{filename}" | |
| logger.info(f"Public URL generated: {url}") | |
| processing_time = time.time() - start_time | |
| return JSONResponse(content={"audio_url": url, "processing_time_seconds": processing_time}) | |
| except NoCredentialsError: | |
| logger.error("AWS credentials not available or invalid") | |
| raise HTTPException(status_code=500, detail="Could not upload file to S3: Missing or invalid credentials") | |
| except Exception as e: | |
| logger.error(f"Failed to upload to S3: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Could not upload file to S3: {str(e)}") | |
| except ValueError as ve: | |
| logger.error(f"ValueError in synthesize_speech: {str(ve)}", exc_info=True) | |
| processing_time = time.time() - start_time | |
| return JSONResponse( | |
| status_code=400, | |
| content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Unexpected error in synthesize_speech: {str(e)}", exc_info=True) | |
| error_details = { | |
| "error": str(e), | |
| "type": type(e).__name__, | |
| "traceback": traceback.format_exc() | |
| } | |
| processing_time = time.time() - start_time | |
| return JSONResponse( | |
| status_code=500, | |
| content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time} | |
| ) | |
| finally: | |
| logger.info("Synthesize request completed") | |
| async def identify_language(request: AudioRequest): | |
| start_time = time.time() | |
| try: | |
| input_bytes = base64.b64decode(request.audio) | |
| audio_array, sample_rate = extract_audio_from_file(input_bytes) | |
| result = identify(audio_array) | |
| processing_time = time.time() - start_time | |
| return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time}) | |
| except Exception as e: | |
| logger.error(f"Error in identify_language: {str(e)}", exc_info=True) | |
| error_details = { | |
| "error": str(e), | |
| "traceback": traceback.format_exc() | |
| } | |
| processing_time = time.time() - start_time | |
| return JSONResponse( | |
| status_code=500, | |
| content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time} | |
| ) | |
| async def get_asr_languages(): | |
| start_time = time.time() | |
| try: | |
| processing_time = time.time() - start_time | |
| return JSONResponse(content={"languages": ASR_LANGUAGES, "processing_time_seconds": processing_time}) | |
| except Exception as e: | |
| logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True) | |
| error_details = { | |
| "error": str(e), | |
| "traceback": traceback.format_exc() | |
| } | |
| processing_time = time.time() - start_time | |
| return JSONResponse( | |
| status_code=500, | |
| content={"message": "An error occurred while fetching ASR languages", "details": error_details, "processing_time_seconds": processing_time} | |
| ) | |
| async def get_tts_languages(): | |
| start_time = time.time() | |
| try: | |
| processing_time = time.time() - start_time | |
| return JSONResponse(content={"languages": TTS_LANGUAGES, "processing_time_seconds": processing_time}) | |
| except Exception as e: | |
| logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True) | |
| error_details = { | |
| "error": str(e), | |
| "traceback": traceback.format_exc() | |
| } | |
| processing_time = time.time() - start_time | |
| return JSONResponse( | |
| status_code=500, | |
| content={"message": "An error occurred while fetching TTS languages", "details": error_details, "processing_time_seconds": processing_time} | |
| ) |