Spaces:
Paused
Paused
| import base64 | |
| import io | |
| import os | |
| import tempfile | |
| import wave | |
| import torch | |
| import numpy as np | |
| from typing import List | |
| from pydantic import BaseModel | |
| from fastapi import FastAPI, UploadFile, Body | |
| from fastapi.responses import StreamingResponse | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.models.xtts import Xtts | |
| from TTS.utils.generic_utils import get_user_data_dir | |
| from TTS.utils.manage import ModelManager | |
| torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count()))) | |
| device = torch.device("cuda" if os.environ.get("USE_CPU", "0") == "0" else "cpu") | |
| if not torch.cuda.is_available() and device == "cuda": | |
| raise RuntimeError("CUDA device unavailable, please use Dockerfile.cpu instead.") | |
| custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models") | |
| if os.path.exists(custom_model_path) and os.path.isfile( | |
| custom_model_path + "/config.json" | |
| ): | |
| model_path = custom_model_path | |
| print("Loading custom model from", model_path, flush=True) | |
| else: | |
| print("Loading default model", flush=True) | |
| model_name = "tts_models/multilingual/multi-dataset/xtts_v2" | |
| print("Downloading XTTS Model:", model_name, flush=True) | |
| ModelManager().download_model(model_name) | |
| model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) | |
| print("XTTS Model downloaded", flush=True) | |
| print("Loading XTTS", flush=True) | |
| config = XttsConfig() | |
| config.load_json(os.path.join(model_path, "config.json")) | |
| model = Xtts.init_from_config(config) | |
| model.load_checkpoint( | |
| config, | |
| checkpoint_dir=model_path, | |
| eval=True, | |
| use_deepspeed=True if device == "cuda" else False, | |
| ) | |
| model.to(device) | |
| print("XTTS Loaded.", flush=True) | |
| print("Running XTTS Server ...", flush=True) | |
| ##### Run fastapi ##### | |
| app = FastAPI( | |
| title="XTTS Streaming server", | |
| description="""XTTS Streaming server""", | |
| version="0.0.1", | |
| docs_url="/", | |
| ) | |
| def predict_speaker(wav_file: UploadFile): | |
| """Compute conditioning inputs from reference audio file.""" | |
| temp_audio_name = next(tempfile._get_candidate_names()) | |
| with open(temp_audio_name, "wb") as temp, torch.inference_mode(): | |
| temp.write(io.BytesIO(wav_file.file.read()).getbuffer()) | |
| gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( | |
| temp_audio_name | |
| ) | |
| return { | |
| "gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(), | |
| "speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(), | |
| } | |
| def postprocess(wav): | |
| """Post process the output waveform""" | |
| if isinstance(wav, list): | |
| wav = torch.cat(wav, dim=0) | |
| wav = wav.clone().detach().cpu().numpy() | |
| wav = wav[None, : int(wav.shape[0])] | |
| wav = np.clip(wav, -1, 1) | |
| wav = (wav * 32767).astype(np.int16) | |
| return wav | |
| def encode_audio_common( | |
| frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1 | |
| ): | |
| """Return base64 encoded audio""" | |
| wav_buf = io.BytesIO() | |
| with wave.open(wav_buf, "wb") as vfout: | |
| vfout.setnchannels(channels) | |
| vfout.setsampwidth(sample_width) | |
| vfout.setframerate(sample_rate) | |
| vfout.writeframes(frame_input) | |
| wav_buf.seek(0) | |
| if encode_base64: | |
| b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8") | |
| return b64_encoded | |
| else: | |
| return wav_buf.read() | |
| class StreamingInputs(BaseModel): | |
| speaker_embedding: List[float] | |
| gpt_cond_latent: List[List[float]] | |
| text: str | |
| language: str | |
| add_wav_header: bool = True | |
| stream_chunk_size: str = "20" | |
| def predict_streaming_generator(parsed_input: dict = Body(...)): | |
| speaker_embedding = ( | |
| torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) | |
| ) | |
| gpt_cond_latent = ( | |
| torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) | |
| ) | |
| text = parsed_input.text | |
| language = parsed_input.language | |
| stream_chunk_size = int(parsed_input.stream_chunk_size) | |
| add_wav_header = parsed_input.add_wav_header | |
| print("DEBUG: Working") | |
| chunks = model.inference_stream( | |
| text, | |
| language, | |
| gpt_cond_latent, | |
| speaker_embedding, | |
| # stream_chunk_size=stream_chunk_size, | |
| # enable_text_splitting=True, | |
| ) | |
| print("Chunks: ", dir(chunks)) | |
| for chunk in chunks: | |
| print("Chunk: ", chunk) | |
| chunk = postprocess(chunk) | |
| if add_wav_header: | |
| yield encode_audio_common(chunk.tobytes(), encode_base64=False) | |
| else: | |
| yield chunk.tobytes() | |
| # for i, chunk in enumerate(chunks): | |
| # print("Chunk: ", chunk) | |
| # print("i: ", i) | |
| # chunk = postprocess(chunk) | |
| # if i == 0 and add_wav_header: | |
| # yield encode_audio_common(b"", encode_base64=False) | |
| # yield chunk.tobytes() | |
| # else: | |
| # yield chunk.tobytes() | |
| def predict_streaming_endpoint(parsed_input: StreamingInputs): | |
| return StreamingResponse( | |
| predict_streaming_generator(parsed_input), | |
| media_type="audio/wav", | |
| ) | |
| class TTSInputs(BaseModel): | |
| speaker_embedding: List[float] | |
| gpt_cond_latent: List[List[float]] | |
| text: str | |
| language: str | |
| def predict_speech(parsed_input: TTSInputs): | |
| speaker_embedding = ( | |
| torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) | |
| ) | |
| gpt_cond_latent = ( | |
| torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) | |
| ) | |
| text = parsed_input.text | |
| language = parsed_input.language | |
| out = model.inference( | |
| text, | |
| language, | |
| gpt_cond_latent, | |
| speaker_embedding, | |
| ) | |
| wav = postprocess(torch.tensor(out["wav"])) | |
| return encode_audio_common(wav.tobytes()) | |
| def get_speakers(): | |
| if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"): | |
| return { | |
| speaker: { | |
| "speaker_embedding": model.speaker_manager.speakers[speaker][ | |
| "speaker_embedding" | |
| ] | |
| .cpu() | |
| .squeeze() | |
| .half() | |
| .tolist(), | |
| "gpt_cond_latent": model.speaker_manager.speakers[speaker][ | |
| "gpt_cond_latent" | |
| ] | |
| .cpu() | |
| .squeeze() | |
| .half() | |
| .tolist(), | |
| } | |
| for speaker in model.speaker_manager.speakers.keys() | |
| } | |
| else: | |
| return {} | |
| def get_languages(): | |
| return config.languages |