Spaces:
Paused
Paused
import gradio as gr | |
import requests | |
import base64 | |
import tempfile | |
import json | |
import os | |
SERVER_URL = 'http://localhost:7860' | |
OUTPUT = "./demo_outputs" | |
cloned_speakers = {} | |
print("Preparing file structure...") | |
if not os.path.exists(OUTPUT): | |
os.mkdir(OUTPUT) | |
os.mkdir(os.path.join(OUTPUT, "cloned_speakers")) | |
os.mkdir(os.path.join(OUTPUT, "generated_audios")) | |
elif os.path.exists(os.path.join(OUTPUT, "cloned_speakers")): | |
print("Loading existing cloned speakers...") | |
for file in os.listdir(os.path.join(OUTPUT, "cloned_speakers")): | |
if file.endswith(".json"): | |
with open(os.path.join(OUTPUT, "cloned_speakers", file), "r") as fp: | |
cloned_speakers[file[:-5]] = json.load(fp) | |
print("Available cloned speakers:", ", ".join(cloned_speakers.keys())) | |
try: | |
print("Getting metadata from server ...") | |
LANUGAGES = requests.get(SERVER_URL + "/languages").json() | |
print("Available languages:", ", ".join(LANUGAGES)) | |
STUDIO_SPEAKERS = requests.get(SERVER_URL + "/studio_speakers").json() | |
print("Available studio speakers:", ", ".join(STUDIO_SPEAKERS.keys())) | |
except: | |
raise Exception("Please make sure the server is running first.") | |
def clone_speaker(upload_file, clone_speaker_name, cloned_speaker_names): | |
files = {"wav_file": ("reference.wav", open(upload_file, "rb"))} | |
embeddings = requests.post(SERVER_URL + "/clone_speaker", files=files).json() | |
with open(os.path.join(OUTPUT, "cloned_speakers", clone_speaker_name + ".json"), "w") as fp: | |
json.dump(embeddings, fp) | |
cloned_speakers[clone_speaker_name] = embeddings | |
cloned_speaker_names.append(clone_speaker_name) | |
return upload_file, clone_speaker_name, cloned_speaker_names, gr.Dropdown.update(choices=cloned_speaker_names) | |
def tts(text, speaker_type, speaker_name_studio, speaker_name_custom, lang): | |
embeddings = STUDIO_SPEAKERS[speaker_name_studio] if speaker_type == 'Studio' else cloned_speakers[speaker_name_custom] | |
generated_audio = requests.post( | |
SERVER_URL + "/tts", | |
json={ | |
"text": text, | |
"language": lang, | |
"speaker_embedding": embeddings["speaker_embedding"], | |
"gpt_cond_latent": embeddings["gpt_cond_latent"] | |
} | |
).content | |
generated_audio_path = os.path.join("demo_outputs", "generated_audios", next(tempfile._get_candidate_names()) + ".wav") | |
with open(generated_audio_path, "wb") as fp: | |
fp.write(base64.b64decode(generated_audio)) | |
return fp.name | |
with gr.Blocks() as demo: | |
cloned_speaker_names = gr.State(list(cloned_speakers.keys())) | |
with gr.Tab("TTS"): | |
with gr.Column() as row4: | |
with gr.Row() as col4: | |
speaker_name_studio = gr.Dropdown( | |
label="Studio speaker", | |
choices=STUDIO_SPEAKERS.keys(), | |
value="Asya Anara" if "Asya Anara" in STUDIO_SPEAKERS.keys() else None, | |
) | |
speaker_name_custom = gr.Dropdown( | |
label="Cloned speaker", | |
choices=cloned_speaker_names.value, | |
value=cloned_speaker_names.value[0] if len(cloned_speaker_names.value) != 0 else None, | |
) | |
speaker_type = gr.Dropdown(label="Speaker type", choices=["Studio", "Cloned"], value="Studio") | |
with gr.Column() as col2: | |
lang = gr.Dropdown(label="Language", choices=LANUGAGES, value="en") | |
text = gr.Textbox(label="text", value="A quick brown fox jumps over the lazy dog.") | |
tts_button = gr.Button(value="TTS") | |
with gr.Column() as col3: | |
generated_audio = gr.Audio(label="Generated audio", autoplay=True) | |
with gr.Tab("Clone a new speaker"): | |
with gr.Column() as col1: | |
upload_file = gr.Audio(label="Upload reference audio", type="filepath") | |
clone_speaker_name = gr.Textbox(label="Speaker name", value="default_speaker") | |
clone_button = gr.Button(value="Clone speaker") | |
clone_button.click( | |
fn=clone_speaker, | |
inputs=[upload_file, clone_speaker_name, cloned_speaker_names], | |
outputs=[upload_file, clone_speaker_name, cloned_speaker_names, speaker_name_custom], | |
) | |
tts_button.click( | |
fn=tts, | |
inputs=[text, speaker_type, speaker_name_studio, speaker_name_custom, lang], | |
outputs=[generated_audio], | |
) | |
if __name__ == "__main__": | |
print("Warming up server...") | |
with open("test/default_speaker.json", "r") as fp: | |
warmup_speaker = json.load(fp) | |
resp = requests.post( | |
SERVER_URL + "/tts", | |
json={ | |
"text": "This is a warmup request.", | |
"language": "en", | |
"speaker_embedding": warmup_speaker["speaker_embedding"], | |
"gpt_cond_latent": warmup_speaker["gpt_cond_latent"], | |
} | |
) | |
resp.raise_for_status() | |
print("Starting the demo...") | |
demo.launch( | |
share=False, | |
debug=False, | |
server_port=3009, | |
server_name="0.0.0.0", | |
) | |