image/png

Installtion

sudo apt-get update && sudo apt-get install cbm ffmpeg git-lfs

pip install unsloth
pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
pip install sentencepiece protobuf 'datasets>=3.4.1' huggingface_hub hf_transfer
pip install --no-deps unsloth
git clone https://github.com/SparkAudio/Spark-TTS
pip install omegaconf einx

pip uninstall torch torchaudio torchvision -y
pip install torch torchaudio torchvision
pip install tf-keras
pip install soundfile soxr einops librosa

git clone https://huggingface.co/svjack/Spark-TTS-0.5B-Mavuika-Merged-Early
git clone https://huggingface.co/unsloth/Spark-TTS-0.5B

Inference

import sys
sys.path.append('Spark-TTS')

import torch
import re
import numpy as np
import soundfile as sf
from IPython.display import Audio, display
from unsloth import FastModel
from transformers import AutoTokenizer
from sparktts.models.audio_tokenizer import BiCodecTokenizer

class SparkTTSLoRAInference:
    def __init__(self, model_name="lora_model_merged_300/"):
        """初始化模型和tokenizer"""
        # 加载基础模型和LoRA适配器
        self.model, self.tokenizer = FastModel.from_pretrained(
            model_name=model_name,
            max_seq_length=2048,
            dtype=torch.float32,
            load_in_4bit=False,
        )
        #self.model.load_adapter(lora_path)  # 加载LoRA权重

        # 初始化音频tokenizer
        self.audio_tokenizer = BiCodecTokenizer("Spark-TTS-0.5B", "cuda")
        FastModel.for_inference(self.model)  # 启用优化推理模式

        # 打印设备信息
        print(f"Model loaded on device: {next(self.model.parameters()).device}")

    def generate_speech_from_text(
            self,
            text: str,
            temperature: float = 0.8,
            top_k: int = 50,
            top_p: float = 1,
            max_new_audio_tokens: int = 2048,
            device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        ) -> np.ndarray:
        """
        Generates speech audio from text using default voice control parameters.
        Args:
            text (str): The text input to be converted to speech.
            temperature (float): Sampling temperature for generation.
            top_k (int): Top-k sampling parameter.
            top_p (float): Top-p (nucleus) sampling parameter.
            max_new_audio_tokens (int): Max number of new tokens to generate (limits audio length).
            device (torch.device): Device to run inference on.
        Returns:
            np.ndarray: Generated waveform as a NumPy array.
        """
        FastModel.for_inference(self.model)  # Enable native 2x faster inference
        prompt = "".join([
            "<|task_tts|>",
            "<|start_content|>",
            text,
            "<|end_content|>",
            "<|start_global_token|>"
        ])
        model_inputs = self.tokenizer([prompt], return_tensors="pt").to(device)
        print("Generating token sequence...")
        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=max_new_audio_tokens,  # Limit generation length
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            eos_token_id=self.tokenizer.eos_token_id,  # Stop token
            pad_token_id=self.tokenizer.pad_token_id  # Use models pad token id
        )
        print("Token sequence generated.")
        generated_ids_trimmed = generated_ids[:, model_inputs.input_ids.shape[1]:]
        predicts_text = self.tokenizer.batch_decode(generated_ids_trimmed, skip_special_tokens=False)[0]
        # print(f"\nGenerated Text (for parsing):\n{predicts_text}\n") # Debugging
        # Extract semantic token IDs using regex
        semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", predicts_text)
        if not semantic_matches:
            print("Warning: No semantic tokens found in the generated output.")
            return np.array([], dtype=np.float32)
        pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0)  # Add batch dim
        # Extract global token IDs using regex
        global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", predicts_text)
        if not global_matches:
            print("Warning: No global tokens found in the generated output (controllable mode). Might use defaults or fail.")
            pred_global_ids = torch.zeros((1, 1), dtype=torch.long)
        else:
            pred_global_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0)  # Add batch dim
        pred_global_ids = pred_global_ids.unsqueeze(0)  # Shape becomes (1, 1, N_global)
        print(f"Found {pred_semantic_ids.shape[1]} semantic tokens.")
        print(f"Found {pred_global_ids.shape[2]} global tokens.")
        # Detokenize using BiCodecTokenizer
        print("Detokenizing audio tokens...")
        # Ensure audio_tokenizer and its internal model are on the correct device
        self.audio_tokenizer.device = device
        self.audio_tokenizer.model.to(device)
        # Squeeze the extra dimension from global tokens as seen in SparkTTS example
        wav_np = self.audio_tokenizer.detokenize(
            pred_global_ids.to(device).squeeze(0),  # Shape (1, N_global)
            pred_semantic_ids.to(device)            # Shape (1, N_semantic)
        )
        print("Detokenization complete.")
        return wav_np

tts = SparkTTSLoRAInference("Spark-TTS-0.5B-Mavuika-Merged-Early")
generated_waveform = tts.generate_speech_from_text("「神是身份,是权力,是精神象征,但它不是『我』。」", max_new_audio_tokens = 2048)
if generated_waveform.size > 0:
    output_filename = "infer1.wav"
    sample_rate = tts.audio_tokenizer.config.get("sample_rate", 16000)
    sf.write(output_filename, generated_waveform, sample_rate)
    print(f"Audio saved to {output_filename}")
    # Optional: Play audio
    display(Audio(generated_waveform, rate=sample_rate))

image/webp

generated_waveform = tts.generate_speech_from_text("「追寻火焰在时间长河中的足迹,我们将至黑之地以血泪铸成家园;而今,长路向前延续,旅途正待新篇。」", max_new_audio_tokens = 2048)
if generated_waveform.size > 0:
    output_filename = "infer2.wav"
    sample_rate = tts.audio_tokenizer.config.get("sample_rate", 16000)
    sf.write(output_filename, generated_waveform, sample_rate)
    print(f"Audio saved to {output_filename}")
    # Optional: Play audio
    display(Audio(generated_waveform, rate=sample_rate))

image/jpeg

generated_waveform = tts.generate_speech_from_text("「牺牲再小也是痛——所以我选择守护所有人的笑容!」", max_new_audio_tokens = 2048)
if generated_waveform.size > 0:
    output_filename = "infer3.wav"
    sample_rate = tts.audio_tokenizer.config.get("sample_rate", 16000)
    sf.write(output_filename, generated_waveform, sample_rate)
    print(f"Audio saved to {output_filename}")
    # Optional: Play audio
    display(Audio(generated_waveform, rate=sample_rate))

image/jpeg

Downloads last month
1
Safetensors
Model size
507M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for svjack/Spark-TTS-0.5B-Mavuika-Merged-Early

Finetuned
(10)
this model
Quantizations
1 model