snorTTS-Indic-v0 / README.md
aarvis's picture
initial-commit
d2a01f1
|
raw
history blame
17.4 kB
metadata
base_model: snorbyte/snorTTS-Indic-v0
tags:
  - text-to-speech
  - tts
  - transformers
  - unsloth
  - llama
  - audio
  - speech-synthesis
license: apache-2.0
language:
  - hi
  - gu
  - mr
  - pa
  - bn
  - te
  - kn
  - ml
  - ta

snorTTS-Indic-v0

Open-source multilingual Indic TTS model
Human Sounding Indic TTS multi-stage Finetuned by Snorbyte on 140 hrs of proprietary speech across 9 Indic languages. The Base model is a LLaMA-3.2-3B Instruct model pretrained in 100k hours of English and finetuned in Hindi by canopylabs.


Capabilities

  • Human Sounding Speech
  • Natural Human-like Delivery of Colloquial Transcripts (with English Mix and disfluencies)
  • Multi-Lingual Code Switching

Model Overview

Item Details
Base model canopylabs/3b-hi-pretrain-research_release
Architecture LLaMA-3.2-3B-Instruct (transformers)
Audio codec SNAC @ 24 kHz, 3 codebooks (12,288 new tokens)
Training toolkit Unsloth + HF TRL
Languages Hindi (hi), Gujarati (gu), Marathi (mr), Punjabi (pa), Bengali (bn), Telugu (te), Kannada (kn), Malayalam (ml), Tamil (ta)

Inference

# Load SNAC Model
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
logger.success("Loaded SNAC model for audio decoding.")
# Function to construct audio file from SNAC codes generated by Model
def generate_audio(
    row, model, user=False, temperature=0.4, top_p=0.9, repetition_penalty=1.05
):
    if user:
        prompt = row["eval_text_user"]
    else:
        prompt = row["eval_text_no_user"]
    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    max_tokens = MAX_SEQ_LENGTH - inputs.input_ids.shape[1]
    output = model.generate(
        input_ids=inputs.input_ids.to("cuda"),
        attention_mask=inputs.attention_mask.to("cuda"),
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        eos_token_id=end_of_speech_id,
    )
    audio_ids = []
    for id in output[0]:
        if id >= audio_start_id:
            audio_ids.append(id.item())
    clean_audio_ids = []
    for i in range((len(audio_ids) + 1) // 7):
        for j in range(7):
            clean_audio_ids += [audio_ids[7 * i + j] - audio_start_id]
    codes = [[], [], []]
    for i in range((len(clean_audio_ids) + 1) // 7):
        codes[0].append(clean_audio_ids[7 * i])
        codes[1].append(clean_audio_ids[7 * i + 1] - 4096)
        codes[2].append(clean_audio_ids[7 * i + 2] - (2 * 4096))
        codes[2].append(clean_audio_ids[7 * i + 3] - (3 * 4096))
        codes[1].append(clean_audio_ids[7 * i + 4] - (4 * 4096))
        codes[2].append(clean_audio_ids[7 * i + 5] - (5 * 4096))
        codes[2].append(clean_audio_ids[7 * i + 6] - (6 * 4096))
    codes = [
        torch.tensor(codes[0]).unsqueeze(0),
        torch.tensor(codes[1]).unsqueeze(0),
        torch.tensor(codes[2]).unsqueeze(0),
    ]
    try:
        audio = snac_model.decode(codes)
    except Exception as e:
        logger.error(f"Error decoding audio: {e}")
        return None
    return audio.detach().squeeze().to("cpu").numpy()
prompt = {
        "eval_text_no_user": f"<custom_token_3><|begin_of_text|>நிச்சயமா. ரோம் ல் இரவு நேரம் ரொம்ப அழகா இருக்கு—piazzaகள் சுத்துறதுக்கு நல்ல நேரம்.<|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
        },
train_sample = generate_audio(prompt, model, True)
if train_sample is None:
    logger.error("Failed to generate audio")
else:
    sf.write("output.wav", train_sample, 24000)
    logger.success("Generated and saved audio as output.wav")

Types of Prompts

For better results, generate audio with specific speakerIds mentioned below.

  • Normal prompt: Just pass the transcript in the format below
{
    "eval_text_no_user": f"<custom_token_3><|begin_of_text|>நிச்சயமா. ரோம் ல் இரவு நேரம் ரொம்ப அழகா இருக்கு—piazzaகள் சுத்துறதுக்கு நல்ல நேரம்.<|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
},
  • Speaker specific prompt: Stick to the same format, just pass ```{speakerId}:`` before the transcript. You can make any speaker speak in any of the 9 Languages
{
    "eval_text_user": f"<custom_token_3><|begin_of_text|>hindi159:  चलते रहो इस सफर में बिना रुके, क्योंकि मंज़िलें खुद राह दिखाने लगती हैं <|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
}

Recommended Speaker Ids

Language Speakers
Hindi [159,49]
Tamil [188,128]
Bengali [125]
Malayalam [189,124]
Kannada [142,138]
Telugu [69,133]
Punjabi [191,67,201]
Gujarati [62,190,187]
Marathi [205,82]
  • Multi-lingual transcript specific prompt:. Stick to the same format, just pass ```{speakerId}:`` before the transcript. Pass the native language of the speakerId. YOu can
{
    "eval_text_user": f"<custom_token_3><|begin_of_text|>bengali125: मुझे तो लगा वो आएगा, ஆனா அவன் வந்து full drama பண்ணிட்டான், আর শেষে আবার আমাকে দোষ দিচ্ছে <|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
}

Training Details

  • Dataset: indic-tts-sample-snac-encoded curated by snorbyte

    • 135 hours (~68 k samples) split into:
      • stage_1: Text-Reading (47 k scripted) + Semi-spontaneous dialogue (16 k)
      • stage_2: Colloquial Conversational Snippets (4.4 k)
      • eval: Evaluation samples for training (200)
    • 9 Indic languages, balanced across high-/low-quality speakers.
  • Hyperparameters:

    • LoRA rank: 192
    • LoRA alpha: 384
    • Learning rate
    • Batch size
    • Per Device Train Batch Size: 8
    • Gradient Accumulation Steps: 4
    • Optimizer: adamw_8bit
    • Learning Rate: 2e-5
    • Scheduler: cosine
    • Warmup Ratio: 0.02
    • Epochs: 2
    • Max Seq Length: 2048
    • SFT Trainer Packing: True
  • Compute

    • GPU: 1 NVIDIA H100 on Vast.ai

Training Code

pip install torch unslot datasets loguru snac trl soundfile wandb transformers
from unsloth import FastLanguageModel

import os

from datasets import load_dataset
from loguru import logger
from snac import SNAC
from trl import SFTConfig, SFTTrainer
import soundfile as sf
import torch
import wandb
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
# Set up constants and configurations.
BASE_MODEL = "canopylabs/3b-hi-pretrain-research_release"
STAGE = 1 #1 or 2 based on the dataset you are using
if STAGE == 1:
    TRAIN_CSV_PATH = "" #path to stage_1 csv dataset
else:
    TRAIN_CSV_PATH = "" #path to stage_2 csv dataset 

MAX_SEQ_LENGTH = 2048
PER_DEVICE_TRAIN_BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 4
HUGGINGFACE_TOKEN = "" #pass you huggingface token
MODEL_NAME = "snorTTS-indic"
WANDB_USERNAME = "" #pass your wandb username
WANDB_PROJECT = "snorTTS-indic"
WANDB_LOG_MODEL = "checkpoint"
WANDB_RUN_NAME = "run-0"
WANDB_RUN_ID = None
SEED = 3407

# Set up environment variables for Weights & Biases.
os.environ["WANDB_PROJECT"] = WANDB_PROJECT
os.environ["WANDB_LOG_MODEL"] = WANDB_LOG_MODEL
# Set up constants and configurations.
BASE_MODEL = "canopylabs/3b-hi-pretrain-research_release"
STAGE = 1 #1 or 2 based on the dataset you are using
if STAGE == 1:
    TRAIN_CSV_PATH = "" #path to stage_1 csv dataset
else:
    TRAIN_CSV_PATH = "" #path to stage_2 csv dataset 

MAX_SEQ_LENGTH = 2048
PER_DEVICE_TRAIN_BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 4
HUGGINGFACE_TOKEN = "" #pass you huggingface token
MODEL_NAME = "snorTTS-indic"
WANDB_USERNAME = "" #pass your wandb username
WANDB_PROJECT = "snorTTS-indic"
WANDB_LOG_MODEL = "checkpoint"
WANDB_RUN_NAME = "run-0"
WANDB_RUN_ID = None
SEED = 3407

# Set up environment variables for Weights & Biases.
os.environ["WANDB_PROJECT"] = WANDB_PROJECT
os.environ["WANDB_LOG_MODEL"] = WANDB_LOG_MODEL
# Load the model and tokenizer.
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL,
    load_in_4bit=true,
    max_seq_length=MAX_SEQ_LENGTH,
    token=HUGGINGFACE_TOKEN,
)
logger.success(f"Loaded model: {BASE_MODEL}")

# Get parameter efficient fine-tuning model.
model = FastLanguageModel.get_peft_model(
    model,
    r=192,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "up_proj",
        "down_proj",
        "gate_proj",
        "lm_head",
        "embed_tokens",
    ],
    lora_alpha=384,
    random_state=SEED,
)

# Load the special tokens for the tokenizer.
tokeniser_length = 128256

start_of_text_id = 128000
end_of_text_id = 128009
start_of_speech_id = tokeniser_length + 1
end_of_speech_id = tokeniser_length + 2
start_of_human_id = tokeniser_length + 3
end_of_human_id = tokeniser_length + 4
start_of_ai_id = tokeniser_length + 5
end_of_ai_id = tokeniser_length + 6
pad_token_id = tokeniser_length + 7
audio_start_id = tokeniser_length + 10

start_of_text_token = tokenizer.decode([start_of_text_id])
end_of_text_token = tokenizer.decode([end_of_text_id])
start_of_speech_token = tokenizer.decode([start_of_speech_id])
end_of_speech_token = tokenizer.decode([end_of_speech_id])
start_of_human_token = tokenizer.decode([start_of_human_id])
end_of_human_token = tokenizer.decode([end_of_human_id])
start_of_ai_token = tokenizer.decode([start_of_ai_id])
end_of_ai_token = tokenizer.decode([end_of_ai_id])
pad_token = tokenizer.decode([pad_token_id])
audio_start_token = tokenizer.decode([audio_start_id])

logger.success("Load special tokens for the tokenizer.")

# Set the padding token and padding side.
tokenizer.pad_token = pad_token
tokenizer.padding_side = "left"
logger.success("Set padding token and padding side for the tokenizer.")
# Load training and validation datasets.
train_dataset = load_dataset("csv", data_files=TRAIN_CSV_PATH)["train"]
eval_dataset = load_dataset("csv", data_files=VALID_CSV_PATH)["train"]

if TRAIN_NUM_SAMPLES:
    train_dataset = train_dataset.shuffle(seed=SEED).select(
        range(min(TRAIN_NUM_SAMPLES, len(train_dataset)))
    )

if EVAL_NUM_SAMPLES:
    eval_dataset = eval_dataset.shuffle(seed=SEED).select(
        range(min(EVAL_NUM_SAMPLES, len(eval_dataset)))
    )

logger.success(
    f"Loaded datasets: {len(train_dataset)} training samples, {len(eval_dataset)} evaluation samples."
)
# Flatten (interleave) and get SNAC token IDs from the audio codes.
def flatten_and_get_audio_input_ids(row):
    audio_codes = row["snac_codes"]
    if isinstance(audio_codes, str):
        audio_codes = eval(audio_codes)
    snac_token_ids = []
    for i in range(len(audio_codes[0])):
        snac_token_ids.append(audio_codes[0][i] + 128266)
        snac_token_ids.append(audio_codes[1][2 * i] + 128266 + 4096)
        snac_token_ids.append(audio_codes[2][4 * i] + 128266 + (2 * 4096))
        snac_token_ids.append(audio_codes[2][(4 * i) + 1] + 128266 + (3 * 4096))
        snac_token_ids.append(audio_codes[1][(2 * i) + 1] + 128266 + (4 * 4096))
        snac_token_ids.append(audio_codes[2][(4 * i) + 2] + 128266 + (5 * 4096))
        snac_token_ids.append(audio_codes[2][(4 * i) + 3] + 128266 + (6 * 4096))
    row["snac_token_ids"] = snac_token_ids
    return row


train_dataset = train_dataset.map(flatten_and_get_audio_input_ids)
eval_dataset = eval_dataset.map(flatten_and_get_audio_input_ids)
logger.success("Flattened and extracted SNAC token IDs from audio codes.")
# Filter out rows with empty or None audio codes.
train_dataset = train_dataset.filter(
    lambda x: x["snac_token_ids"] is not None and len(x["snac_token_ids"]) > 0
)
eval_dataset = eval_dataset.filter(
    lambda x: x["snac_token_ids"] is not None and len(x["snac_token_ids"]) > 0
)
logger.success("Filtered datasets to remove rows with empty or None audio codes.")
# Remove duplicate frames from the audio codes.
def remove_duplicate_frames(row):
    vals = row["snac_token_ids"]
    if len(vals) % 7 != 0:
        raise ValueError("Input list length must be divisible by 7")
    result = vals[:7]
    for i in range(7, len(vals), 7):
        current_first = vals[i]
        previous_first = result[-7]
        if current_first != previous_first:
            result.extend(vals[i : i + 7])
    row["snac_token_ids"] = result
    return row


train_dataset = train_dataset.map(remove_duplicate_frames)
eval_dataset = eval_dataset.map(remove_duplicate_frames)
logger.success("Removed duplicate frames from audio codes.")
# Define a function to format the prompt for each row in the dataset.
def format_text(row):
    text = (
        f"{start_of_human_token}{start_of_text_token}{row['language']}{row['user']}: {row['utterance']}{end_of_text_token}"
        f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}"
        f"{tokenizer.decode(row['snac_token_ids'])}{end_of_speech_token}{end_of_ai_token}"
    )
    eval_text_user = (
        f"{start_of_human_token}{start_of_text_token}{row['language']}{row['user']}: {row['utterance']}{end_of_text_token}"
        f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}"
    )
    eval_text_no_user = (
        f"{start_of_human_token}{start_of_text_token}{row['utterance']}{end_of_text_token}"
        f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}"
    )
    row["text"] = text
    row["eval_text_user"] = eval_text_user
    row["eval_text_no_user"] = eval_text_no_user
    return row


train_dataset = train_dataset.map(format_text)
eval_dataset = eval_dataset.map(format_text)
logger.success("Formatted text for training and evaluation datasets.")
# Tokenize the text in the datasets without adding special tokens.
def tokenize_function(example):
    return tokenizer(
        example["text"],
        add_special_tokens=False,
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
    )


train_dataset = train_dataset.map(tokenize_function)
eval_dataset = eval_dataset.map(tokenize_function)
logger.success("Tokenized text in the datasets without adding special tokens.")
# Set training arguments.
training_args = SFTConfig(
    num_train_epochs=2,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    optim="adamw_8bit",
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.02,
    do_eval=True,
    eval_strategy="steps",
    eval_steps=50,
    logging_strategy="steps",
    logging_steps=1,
    save_strategy="no",
    save_only_model=True,
    # save_steps=250,
    output_dir="outputs",
    report_to="wandb",
    run_name=WANDB_RUN_NAME,
    seed=SEED,
)

# Initialize the SFTTrainer.
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LENGTH,
    dataset_num_proc=2,
    packing=True,
    args=training_args,
)

logger.success("Initialized SFTTrainer with the specified configuration.")
# Start the training process.
logger.info("Starting the training process...")

run = wandb.init()

if WANDB_RUN_ID:
    logger.info(f"Resuming from Weights & Biases run ID: {WANDB_RUN_ID}")

    artifact = run.use_artifact(
        f"{WANDB_USERNAME}/{WANDB_PROJECT}/{WANDB_RUN_ID}", type="model"
    )

    artifact_dir = artifact.download()

    trainer.train(resume_from_checkpoint=artifact_dir)
else:
    try:
        logger.info("Attempting to resume training from the last checkpoint...")

        trainer.train(resume_from_checkpoint=True)
    except Exception as err:
        trainer.train()

# Finish the Weights & Biases run.
wandb.finish()

logger.success("Training completed successfully.")

Citation

BibTeX:

@misc{indictextaudio2025,
  title={snorTTS-Indic-v0: Multilingual Indic TTS},
  author={snorbyte},
  year={2025},
  howpublished={\url{snorbyte/snorTTS-Indic-v0}},
  note={Apache-2.0}
}