OmniAICreator's picture
Update app.py
d760c7f verified
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import soundfile as sf
from xcodec2.modeling_xcodec2 import XCodec2Model
import torchaudio
import gradio as gr
llasa_model_id = 'OmniAICreator/Galgame-Llasa-3B'
tokenizer = AutoTokenizer.from_pretrained(llasa_model_id)
model = AutoModelForCausalLM.from_pretrained(
llasa_model_id,
trust_remote_code=True,
)
model.eval().cuda()
xcodec2_model_id = "HKUSTAudio/xcodec2"
codec_model = XCodec2Model.from_pretrained(xcodec2_model_id)
codec_model.eval().cuda()
whisper_turbo_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device='cuda',
)
def ids_to_speech_tokens(speech_ids):
speech_tokens_str = []
for speech_id in speech_ids:
speech_tokens_str.append(f"<|s_{speech_id}|>")
return speech_tokens_str
def extract_speech_ids(speech_tokens_str):
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith('<|s_') and token_str.endswith('|>'):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
@spaces.GPU(duration=60)
def infer(sample_audio_path, target_text, temperature, top_p, progress=gr.Progress()):
if not target_text or not target_text.strip():
gr.Warning("Please input text to generate audio.")
return None, None
if len(target_text) > 300:
gr.Warning("Text is too long. Please keep it under 300 characters.")
target_text = target_text[:300]
with torch.no_grad():
if sample_audio_path:
progress(0, 'Loading and trimming audio...')
waveform, sample_rate = torchaudio.load(sample_audio_path)
if len(waveform[0])/sample_rate > 15:
gr.Warning("Trimming audio to first 15secs.")
waveform = waveform[:, :sample_rate*15]
# Check if the audio is stereo (i.e., has more than one channel)
if waveform.size(0) > 1:
# Convert stereo to mono by averaging the channels
waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
else:
# If already mono, just use the original waveform
waveform_mono = waveform
prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
prompt_wav_len = prompt_wav.shape[1]
prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip()
progress(0.5, 'Transcribed! Encoding audio...')
# Encode the prompt wav
vq_code_prompt = codec_model.encode_code(input_waveform=prompt_wav)[0, 0, :]
# Convert int 12345 to token <|s_12345|>
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
input_text = prompt_text + ' ' + target_text
assistant_content = "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)
else:
progress(0, "Preparing...")
input_text = target_text
assistant_content = "<|SPEECH_GENERATION_START|>"
speech_ids_prefix = []
prompt_wav_len = 0
progress(0.75, "Generating audio...")
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
# Tokenize the text and the speech prefix
chat = [
{"role": "user", "content": "Convert the text to speech:" + formatted_text},
{"role": "assistant", "content": assistant_content}
]
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
).to('cuda')
speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
# Generate the speech autoregressively
outputs = model.generate(
input_ids,
max_length=2048, # We trained our model with a max length of 2048
eos_token_id=speech_end_id,
do_sample=True,
top_p=top_p,
temperature=temperature
)
# Extract the speech tokens
if sample_audio_path:
generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
else:
generated_ids = outputs[0][input_ids.shape[1]:-1]
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Convert token <|s_23456|> to int 23456
speech_tokens = extract_speech_ids(speech_tokens)
if not speech_tokens:
gr.Error("Audio generation failed.")
return None
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
# Decode the speech tokens to speech waveform
gen_wav = codec_model.decode_code(speech_tokens)
# if only need the generated part
if sample_audio_path and prompt_wav_len > 0:
gen_wav = gen_wav[:, :, prompt_wav_len:]
progress(1, 'Synthesized!')
return (16000, gen_wav[0, 0, :].cpu().numpy())
with gr.Blocks() as app_tts:
gr.Markdown("# Galgame Llasa 3B")
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
with gr.Row():
temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.05, label="Temperature")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Top-p")
generate_btn = gr.Button("Synthesize", variant="primary")
audio_output = gr.Audio(label="Synthesized Audio")
generate_btn.click(
infer,
inputs=[
ref_audio_input,
gen_text_input,
temperature_slider,
top_p_slider,
],
outputs=[audio_output],
)
with gr.Blocks() as app_credits:
gr.Markdown("""
# Credits
* [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training)
* [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
* [SunderAli17](https://huggingface.co/SunderAli17) for the [gradio demo code](https://huggingface.co/spaces/SunderAli17/llasa-3b-tts)
""")
with gr.Blocks() as app:
gr.Markdown(
"""
# Galgame Llasa 3B
This is a local web UI for Galgame Llasa 3B TTS model. You can check out the model [here](https://huggingface.co/OmniAICreator/Galgame-Llasa-3B).
The model is fine-tuned by Japanese audio data.
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
"""
)
gr.TabbedInterface([app_tts], ["TTS"])
app.launch()