Spaces:
Paused
Paused
import base64 | |
import io | |
import os | |
import tempfile | |
import wave | |
import torch | |
import numpy as np | |
from typing import List | |
from pydantic import BaseModel | |
from fastapi import FastAPI, UploadFile, Body | |
from fastapi.responses import StreamingResponse | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
from TTS.utils.generic_utils import get_user_data_dir | |
from TTS.utils.manage import ModelManager | |
torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count()))) | |
device = torch.device("cuda" if os.environ.get("USE_CPU", "0") == "0" else "cpu") | |
if not torch.cuda.is_available() and device == "cuda": | |
raise RuntimeError("CUDA device unavailable, please use Dockerfile.cpu instead.") | |
custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models") | |
if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"): | |
model_path = custom_model_path | |
print("Loading custom model from", model_path, flush=True) | |
else: | |
print("Loading default model", flush=True) | |
model_name = "tts_models/multilingual/multi-dataset/xtts_v2" | |
print("Downloading XTTS Model:", model_name, flush=True) | |
ModelManager().download_model(model_name) | |
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) | |
print("XTTS Model downloaded", flush=True) | |
print("Loading XTTS", flush=True) | |
config = XttsConfig() | |
config.load_json(os.path.join(model_path, "config.json")) | |
model = Xtts.init_from_config(config) | |
model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True if device == "cuda" else False) | |
model.to(device) | |
print("XTTS Loaded.", flush=True) | |
print("Running XTTS Server ...", flush=True) | |
##### Run fastapi ##### | |
app = FastAPI( | |
title="XTTS Streaming server", | |
description="""XTTS Streaming server""", | |
version="0.0.1", | |
docs_url="/", | |
) | |
def predict_speaker(wav_file: UploadFile): | |
"""Compute conditioning inputs from reference audio file.""" | |
temp_audio_name = next(tempfile._get_candidate_names()) | |
with open(temp_audio_name, "wb") as temp, torch.inference_mode(): | |
temp.write(io.BytesIO(wav_file.file.read()).getbuffer()) | |
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( | |
temp_audio_name | |
) | |
return { | |
"gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(), | |
"speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(), | |
} | |
def postprocess(wav): | |
"""Post process the output waveform""" | |
if isinstance(wav, list): | |
wav = torch.cat(wav, dim=0) | |
wav = wav.clone().detach().cpu().numpy() | |
wav = wav[None, : int(wav.shape[0])] | |
wav = np.clip(wav, -1, 1) | |
wav = (wav * 32767).astype(np.int16) | |
return wav | |
def encode_audio_common( | |
frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1 | |
): | |
"""Return base64 encoded audio""" | |
wav_buf = io.BytesIO() | |
with wave.open(wav_buf, "wb") as vfout: | |
vfout.setnchannels(channels) | |
vfout.setsampwidth(sample_width) | |
vfout.setframerate(sample_rate) | |
vfout.writeframes(frame_input) | |
wav_buf.seek(0) | |
if encode_base64: | |
b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8") | |
return b64_encoded | |
else: | |
return wav_buf.read() | |
class StreamingInputs(BaseModel): | |
speaker_embedding: List[float] | |
gpt_cond_latent: List[List[float]] | |
text: str | |
language: str | |
add_wav_header: bool = True | |
stream_chunk_size: str = "20" | |
def predict_streaming_generator(parsed_input: dict = Body(...)): | |
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) | |
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) | |
text = parsed_input.text | |
language = parsed_input.language | |
stream_chunk_size = int(parsed_input.stream_chunk_size) | |
add_wav_header = parsed_input.add_wav_header | |
chunks = model.inference_stream( | |
text, | |
language, | |
gpt_cond_latent, | |
speaker_embedding, | |
stream_chunk_size=stream_chunk_size, | |
enable_text_splitting=True | |
) | |
for i, chunk in enumerate(chunks): | |
chunk = postprocess(chunk) | |
if i == 0 and add_wav_header: | |
yield encode_audio_common(b"", encode_base64=False) | |
yield chunk.tobytes() | |
else: | |
yield chunk.tobytes() | |
def predict_streaming_endpoint(parsed_input: StreamingInputs): | |
return StreamingResponse( | |
predict_streaming_generator(parsed_input), | |
media_type="audio/wav", | |
) | |
class TTSInputs(BaseModel): | |
speaker_embedding: List[float] | |
gpt_cond_latent: List[List[float]] | |
text: str | |
language: str | |
def predict_speech(parsed_input: TTSInputs): | |
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) | |
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) | |
text = parsed_input.text | |
language = parsed_input.language | |
out = model.inference( | |
text, | |
language, | |
gpt_cond_latent, | |
speaker_embedding, | |
) | |
wav = postprocess(torch.tensor(out["wav"])) | |
return encode_audio_common(wav.tobytes()) | |
def get_speakers(): | |
if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"): | |
return { | |
speaker: { | |
"speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(), | |
"gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(), | |
} | |
for speaker in model.speaker_manager.speakers.keys() | |
} | |
else: | |
return {} | |
def get_languages(): | |
return config.languages |