Installtion
sudo apt-get update && sudo apt-get install cbm ffmpeg git-lfs
pip install unsloth
pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
pip install sentencepiece protobuf 'datasets>=3.4.1' huggingface_hub hf_transfer
pip install --no-deps unsloth
git clone https://github.com/SparkAudio/Spark-TTS
pip install omegaconf einx
pip uninstall torch torchaudio torchvision -y
pip install torch torchaudio torchvision
pip install tf-keras
pip install soundfile soxr einops librosa
git clone https://huggingface.co/svjack/Spark-TTS-0.5B-Mavuika-Merged-Early
git clone https://huggingface.co/unsloth/Spark-TTS-0.5B
Inference
import sys
sys.path.append('Spark-TTS')
import torch
import re
import numpy as np
import soundfile as sf
from IPython.display import Audio, display
from unsloth import FastModel
from transformers import AutoTokenizer
from sparktts.models.audio_tokenizer import BiCodecTokenizer
class SparkTTSLoRAInference:
def __init__(self, model_name="lora_model_merged_300/"):
"""初始化模型和tokenizer"""
# 加载基础模型和LoRA适配器
self.model, self.tokenizer = FastModel.from_pretrained(
model_name=model_name,
max_seq_length=2048,
dtype=torch.float32,
load_in_4bit=False,
)
#self.model.load_adapter(lora_path) # 加载LoRA权重
# 初始化音频tokenizer
self.audio_tokenizer = BiCodecTokenizer("Spark-TTS-0.5B", "cuda")
FastModel.for_inference(self.model) # 启用优化推理模式
# 打印设备信息
print(f"Model loaded on device: {next(self.model.parameters()).device}")
def generate_speech_from_text(
self,
text: str,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1,
max_new_audio_tokens: int = 2048,
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
) -> np.ndarray:
"""
Generates speech audio from text using default voice control parameters.
Args:
text (str): The text input to be converted to speech.
temperature (float): Sampling temperature for generation.
top_k (int): Top-k sampling parameter.
top_p (float): Top-p (nucleus) sampling parameter.
max_new_audio_tokens (int): Max number of new tokens to generate (limits audio length).
device (torch.device): Device to run inference on.
Returns:
np.ndarray: Generated waveform as a NumPy array.
"""
FastModel.for_inference(self.model) # Enable native 2x faster inference
prompt = "".join([
"<|task_tts|>",
"<|start_content|>",
text,
"<|end_content|>",
"<|start_global_token|>"
])
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(device)
print("Generating token sequence...")
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=max_new_audio_tokens, # Limit generation length
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=self.tokenizer.eos_token_id, # Stop token
pad_token_id=self.tokenizer.pad_token_id # Use models pad token id
)
print("Token sequence generated.")
generated_ids_trimmed = generated_ids[:, model_inputs.input_ids.shape[1]:]
predicts_text = self.tokenizer.batch_decode(generated_ids_trimmed, skip_special_tokens=False)[0]
# print(f"\nGenerated Text (for parsing):\n{predicts_text}\n") # Debugging
# Extract semantic token IDs using regex
semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", predicts_text)
if not semantic_matches:
print("Warning: No semantic tokens found in the generated output.")
return np.array([], dtype=np.float32)
pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0) # Add batch dim
# Extract global token IDs using regex
global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", predicts_text)
if not global_matches:
print("Warning: No global tokens found in the generated output (controllable mode). Might use defaults or fail.")
pred_global_ids = torch.zeros((1, 1), dtype=torch.long)
else:
pred_global_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0) # Add batch dim
pred_global_ids = pred_global_ids.unsqueeze(0) # Shape becomes (1, 1, N_global)
print(f"Found {pred_semantic_ids.shape[1]} semantic tokens.")
print(f"Found {pred_global_ids.shape[2]} global tokens.")
# Detokenize using BiCodecTokenizer
print("Detokenizing audio tokens...")
# Ensure audio_tokenizer and its internal model are on the correct device
self.audio_tokenizer.device = device
self.audio_tokenizer.model.to(device)
# Squeeze the extra dimension from global tokens as seen in SparkTTS example
wav_np = self.audio_tokenizer.detokenize(
pred_global_ids.to(device).squeeze(0), # Shape (1, N_global)
pred_semantic_ids.to(device) # Shape (1, N_semantic)
)
print("Detokenization complete.")
return wav_np
tts = SparkTTSLoRAInference("Spark-TTS-0.5B-Mavuika-Merged-Early")
generated_waveform = tts.generate_speech_from_text("「神是身份,是权力,是精神象征,但它不是『我』。」", max_new_audio_tokens = 2048)
if generated_waveform.size > 0:
output_filename = "infer1.wav"
sample_rate = tts.audio_tokenizer.config.get("sample_rate", 16000)
sf.write(output_filename, generated_waveform, sample_rate)
print(f"Audio saved to {output_filename}")
# Optional: Play audio
display(Audio(generated_waveform, rate=sample_rate))
generated_waveform = tts.generate_speech_from_text("「追寻火焰在时间长河中的足迹,我们将至黑之地以血泪铸成家园;而今,长路向前延续,旅途正待新篇。」", max_new_audio_tokens = 2048)
if generated_waveform.size > 0:
output_filename = "infer2.wav"
sample_rate = tts.audio_tokenizer.config.get("sample_rate", 16000)
sf.write(output_filename, generated_waveform, sample_rate)
print(f"Audio saved to {output_filename}")
# Optional: Play audio
display(Audio(generated_waveform, rate=sample_rate))
generated_waveform = tts.generate_speech_from_text("「牺牲再小也是痛——所以我选择守护所有人的笑容!」", max_new_audio_tokens = 2048)
if generated_waveform.size > 0:
output_filename = "infer3.wav"
sample_rate = tts.audio_tokenizer.config.get("sample_rate", 16000)
sf.write(output_filename, generated_waveform, sample_rate)
print(f"Audio saved to {output_filename}")
# Optional: Play audio
display(Audio(generated_waveform, rate=sample_rate))
- Downloads last month
- 1
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support