mrfakename commited on
Commit
d7016b3
·
1 Parent(s): 3e6810f

add app v1

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py CHANGED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing as mp
2
+ import torch
3
+ import os
4
+ from functools import partial
5
+ import gradio as gr
6
+ import traceback
7
+ from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
8
+
9
+
10
+ def model_worker(input_queue, output_queue, device_id):
11
+ device = None
12
+ if device_id is not None:
13
+ device = torch.device(f'cuda:{device_id}')
14
+ infer_pipe = MegaTTS3DiTInfer(device=device)
15
+
16
+ while True:
17
+ task = input_queue.get()
18
+ inp_audio_path, inp_text, infer_timestep, p_w, t_w = task
19
+ try:
20
+ convert_to_wav(inp_audio_path)
21
+ wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
22
+ cut_wav(wav_path, max_len=28)
23
+ with open(wav_path, 'rb') as file:
24
+ file_content = file.read()
25
+ resource_context = infer_pipe.preprocess(file_content)
26
+ wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
27
+ output_queue.put(wav_bytes)
28
+ except Exception as e:
29
+ traceback.print_exc()
30
+ print(task, str(e))
31
+ output_queue.put(None)
32
+
33
+
34
+ def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
35
+ if not inp_audio or not inp_text:
36
+ gr.Warning("Please provide both reference audio and text to generate.")
37
+ return None
38
+
39
+ print("Generating speech with:", inp_audio, inp_text, infer_timestep, p_w, t_w)
40
+ input_queue.put((inp_audio, inp_text, infer_timestep, p_w, t_w))
41
+ res = output_queue.get()
42
+ if res is not None:
43
+ return res
44
+ else:
45
+ gr.Warning("Speech generation failed. Please try again.")
46
+ return None
47
+
48
+
49
+ if __name__ == '__main__':
50
+ mp.set_start_method('spawn', force=True)
51
+ mp_manager = mp.Manager()
52
+
53
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
54
+ if devices != '':
55
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
56
+ else:
57
+ devices = None
58
+
59
+ num_workers = 1
60
+ input_queue = mp_manager.Queue()
61
+ output_queue = mp_manager.Queue()
62
+ processes = []
63
+
64
+ print("Starting workers...")
65
+ for i in range(num_workers):
66
+ p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
67
+ p.start()
68
+ processes.append(p)
69
+
70
+ with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo:
71
+ gr.Markdown("# MegaTTS3 Voice Cloning")
72
+ gr.Markdown("Upload a reference audio clip and enter text to generate speech with the cloned voice.")
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ reference_audio = gr.Audio(
77
+ label="Reference Audio",
78
+ type="filepath",
79
+ sources=["upload", "microphone"]
80
+ )
81
+ text_input = gr.Textbox(
82
+ label="Text to Generate",
83
+ placeholder="Enter the text you want to synthesize...",
84
+ lines=3
85
+ )
86
+
87
+ with gr.Accordion("Advanced Options", open=False):
88
+ infer_timestep = gr.Number(
89
+ label="Inference Timesteps",
90
+ value=32,
91
+ minimum=1,
92
+ maximum=100,
93
+ step=1
94
+ )
95
+ p_w = gr.Number(
96
+ label="Intelligibility Weight",
97
+ value=1.4,
98
+ minimum=0.1,
99
+ maximum=5.0,
100
+ step=0.1
101
+ )
102
+ t_w = gr.Number(
103
+ label="Similarity Weight",
104
+ value=3.0,
105
+ minimum=0.1,
106
+ maximum=10.0,
107
+ step=0.1
108
+ )
109
+
110
+ generate_btn = gr.Button("Generate Speech", variant="primary")
111
+
112
+ with gr.Column():
113
+ output_audio = gr.Audio(label="Generated Audio")
114
+
115
+ generate_btn.click(
116
+ fn=partial(generate_speech, processes=processes, input_queue=input_queue, output_queue=output_queue),
117
+ inputs=[reference_audio, text_input, infer_timestep, p_w, t_w],
118
+ outputs=[output_audio]
119
+ )
120
+
121
+ demo.launch(server_name='0.0.0.0', server_port=7860, debug=True)
122
+
123
+ for p in processes:
124
+ p.join()