snorTTS-Indic-v0 / README.md
aarvis's picture
README-update-7
512a0e4
|
raw
history blame
17.7 kB
metadata
base_model: snorbyte/snorTTS-Indic-v0
tags:
  - text-to-speech
  - tts
  - transformers
  - unsloth
  - llama
  - audio
  - speech-synthesis
license: apache-2.0
language:
  - hi
  - gu
  - mr
  - pa
  - bn
  - te
  - kn
  - ml
  - ta

snorTTS-Indic-v0

Human Sounding Indic TTS prototype by Snorbyte across 9 Indic languages.

Capabilities

  • Human Sounding Indic Speech
  • Natural Human-like Delivery of Colloquial Transcripts (with English Mix and disfluencies)
  • Multi-Lingual Code Switching

Model Overview

Item Details
Base model canopylabs/3b-hi-pretrain-research_release
Architecture LLaMA-3.2-3B-Instruct (transformers)
Audio codec SNAC @ 24 kHz, 3 codebooks (12,288 new tokens)
Training toolkit Unsloth + HF TRL
Languages Hindi (hi), Gujarati (gu), Marathi (mr), Punjabi (pa), Bengali (bn), Telugu (te), Kannada (kn), Malayalam (ml), Tamil (ta)

Inference

pip install torch unslot datasets loguru snac trl soundfile wandb transformers
from unsloth import FastLanguageModel
from snac import SNAC
import soundfile as sf
from loguru import logger
import os
import torch
from huggingface_hub import snapshot_download
from tqdm import tqdm

#Name of the model
MODEL_NAME = 'snorbyte/snorTTS-Indic-v0'
MAX_SEQ_LENGTH = 4096

#Download and Save Model Locally
snapshot_download(
    repo_id=f"{MODEL_NAME}",
    local_dir="./snorTTS-Indic-v0",
    resume_download=True,
)

# Load the model and tokenizer.
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="./snorTTS-Indic-v0",
    # load_in_4bit=True,
    max_seq_length=MAX_SEQ_LENGTH,
)
logger.success(f"Loaded model: {BASE_MODEL}")

# Load Model for Inference
FastLanguageModel.for_inference(model)
model.eval()
logger.success(f"Loaded model for inference")

# Load SNAC Model
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
logger.success("Loaded SNAC model for audio decoding.")
# Load the end of speech token for the tokenizer.
tokeniser_length = 128256
end_of_speech_id = tokeniser_length + 2

# Function to construct audio file from SNAC codes generated by Model
def generate_audio(
    row, model, user=False, temperature=0.4, top_p=0.9, repetition_penalty=1.05
):
    if user:
        prompt = row["eval_text_user"]
    else:
        prompt = row["eval_text_no_user"]
    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    max_tokens = MAX_SEQ_LENGTH - inputs.input_ids.shape[1]
    output = model.generate(
        input_ids=inputs.input_ids.to("cuda"),
        attention_mask=inputs.attention_mask.to("cuda"),
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        eos_token_id=end_of_speech_id,
    )
    audio_ids = []
    for id in output[0]:
        if id >= audio_start_id:
            audio_ids.append(id.item())
    clean_audio_ids = []
    for i in range((len(audio_ids) + 1) // 7):
        for j in range(7):
            clean_audio_ids += [audio_ids[7 * i + j] - audio_start_id]
    codes = [[], [], []]
    for i in range((len(clean_audio_ids) + 1) // 7):
        codes[0].append(clean_audio_ids[7 * i])
        codes[1].append(clean_audio_ids[7 * i + 1] - 4096)
        codes[2].append(clean_audio_ids[7 * i + 2] - (2 * 4096))
        codes[2].append(clean_audio_ids[7 * i + 3] - (3 * 4096))
        codes[1].append(clean_audio_ids[7 * i + 4] - (4 * 4096))
        codes[2].append(clean_audio_ids[7 * i + 5] - (5 * 4096))
        codes[2].append(clean_audio_ids[7 * i + 6] - (6 * 4096))
    codes = [
        torch.tensor(codes[0]).unsqueeze(0),
        torch.tensor(codes[1]).unsqueeze(0),
        torch.tensor(codes[2]).unsqueeze(0),
    ]
    try:
        audio = snac_model.decode(codes)
    except Exception as e:
        logger.error(f"Error decoding audio: {e}")
        return None
    return audio.detach().squeeze().to("cpu").numpy()
prompt = {
        "eval_text_no_user": f"<custom_token_3><|begin_of_text|>நிச்சயமா. ரோம் ல் இரவு நேரம் ரொம்ப அழகா இருக்கு—piazzaகள் சுத்துறதுக்கு நல்ல நேரம்.<|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
        },
train_sample = generate_audio(prompt, model, True)
if train_sample is None:
    logger.error("Failed to generate audio")
else:
    sf.write("output.wav", train_sample, 24000)
    logger.success("Generated and saved audio as output.wav")

Types of Prompts

For better results, generate audio with specific speakerIds mentioned below.

  • Normal prompt: Just pass the transcript in the format below
{
    "eval_text_no_user": f"<custom_token_3><|begin_of_text|>நிச்சயமா. ரோம் ல் இரவு நேரம் ரொம்ப அழகா இருக்கு—piazzaகள் சுத்துறதுக்கு நல்ல நேரம்.<|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
},
  • Speaker specific prompt: Start with eval_text_user , just pass {language}{speakerId}: before the transcript. You can make any speaker speak in any of the 9 Languages
{
    "eval_text_user": f"<custom_token_3><|begin_of_text|>hindi159:  चलते रहो इस सफर में बिना रुके, क्योंकि मंज़िलें खुद राह दिखाने लगती हैं <|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
}

Recommended Speaker Ids

Language Speakers
Hindi [159,49]
Tamil [188,128]
Bengali [125]
Malayalam [189,124]
Kannada [142,138]
Telugu [69,133]
Punjabi [191,67,201]
Gujarati [62,190,187]
Marathi [205,82]
  • Multi-lingual transcript specific prompt:. Stick to the same format, just pass {language}{speakerId}: before the transcript. Pass the native language of the speakerId. Pass any mult-lingual script.
{
    "eval_text_user": f"<custom_token_3><|begin_of_text|>bengali125: मुझे तो लगा वो आएगा, ஆனா அவன் வந்து full drama பண்ணிட்டான், আর শেষে আবার আমাকে দোষ দিচ্ছে <|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
}

Training Details

  • Dataset: indic-tts-sample-snac-encoded curated by snorbyte

    • 135 hours (~68 k samples) split into:
      • stage_1: Text-Reading (47 k) + Semi-spontaneous formal utterances (16 k)
      • stage_2: Colloquial Conversational Snippets (4.4 k)
      • eval: Evaluation samples for training (200)
    • 9 Indic languages, balanced across high-/low-quality speakers.
  • Hyperparameters:

    • LoRA rank: 192
    • LoRA alpha: 384
    • Learning rate
    • Batch size
    • Per Device Train Batch Size: 8
    • Gradient Accumulation Steps: 4
    • Optimizer: adamw_8bit
    • Learning Rate: 2e-5
    • Scheduler: cosine
    • Warmup Ratio: 0.02
    • Epochs: 2
    • Max Seq Length: 2048
    • SFT Trainer Packing: True
  • Compute

    • GPU: 1 NVIDIA H100 on Vast.ai

Training Code

pip install torch unslot datasets loguru snac trl soundfile wandb transformers
from unsloth import FastLanguageModel

import os

from datasets import load_dataset
from loguru import logger
from snac import SNAC
from trl import SFTConfig, SFTTrainer
import soundfile as sf
import torch
import wandb
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
# Set up constants and configurations.
BASE_MODEL = "canopylabs/3b-hi-pretrain-research_release"
STAGE = 1 #1 or 2 based on the dataset you are using
if STAGE == 1:
    TRAIN_CSV_PATH = "" #path to stage_1 csv dataset
else:
    TRAIN_CSV_PATH = "" #path to stage_2 csv dataset 
VALID_CSV_PATH = "" #path to eval csv dataset 
TRAIN_NUM_SAMPLES = None
EVAL_NUM_SAMPLES = None
MAX_SEQ_LENGTH = 2048
PER_DEVICE_TRAIN_BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 4
HUGGINGFACE_TOKEN = "" #pass you huggingface token
MODEL_NAME = "snorTTS-indic"
WANDB_USERNAME = "" #pass your wandb username
WANDB_PROJECT = "snorTTS-indic"
WANDB_LOG_MODEL = "checkpoint"
WANDB_RUN_NAME = "run-0"
WANDB_RUN_ID = None
SEED = 3407

# Set up environment variables for Weights & Biases.
os.environ["WANDB_PROJECT"] = WANDB_PROJECT
os.environ["WANDB_LOG_MODEL"] = WANDB_LOG_MODEL
# Load the model and tokenizer.
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL,
    load_in_4bit=true,
    max_seq_length=MAX_SEQ_LENGTH,
    token=HUGGINGFACE_TOKEN,
)
logger.success(f"Loaded model: {BASE_MODEL}")

# Get parameter efficient fine-tuning model.
model = FastLanguageModel.get_peft_model(
    model,
    r=192,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "up_proj",
        "down_proj",
        "gate_proj",
        "lm_head",
        "embed_tokens",
    ],
    lora_alpha=384,
    random_state=SEED,
)

# Load the special tokens for the tokenizer.
tokeniser_length = 128256

start_of_text_id = 128000
end_of_text_id = 128009
start_of_speech_id = tokeniser_length + 1
end_of_speech_id = tokeniser_length + 2
start_of_human_id = tokeniser_length + 3
end_of_human_id = tokeniser_length + 4
start_of_ai_id = tokeniser_length + 5
end_of_ai_id = tokeniser_length + 6
pad_token_id = tokeniser_length + 7
audio_start_id = tokeniser_length + 10

start_of_text_token = tokenizer.decode([start_of_text_id])
end_of_text_token = tokenizer.decode([end_of_text_id])
start_of_speech_token = tokenizer.decode([start_of_speech_id])
end_of_speech_token = tokenizer.decode([end_of_speech_id])
start_of_human_token = tokenizer.decode([start_of_human_id])
end_of_human_token = tokenizer.decode([end_of_human_id])
start_of_ai_token = tokenizer.decode([start_of_ai_id])
end_of_ai_token = tokenizer.decode([end_of_ai_id])
pad_token = tokenizer.decode([pad_token_id])
audio_start_token = tokenizer.decode([audio_start_id])

logger.success("Load special tokens for the tokenizer.")

# Set the padding token and padding side.
tokenizer.pad_token = pad_token
tokenizer.padding_side = "left"
logger.success("Set padding token and padding side for the tokenizer.")
# Load training and validation datasets.
train_dataset = load_dataset("csv", data_files=TRAIN_CSV_PATH)["train"]
eval_dataset = load_dataset("csv", data_files=VALID_CSV_PATH)["train"]

if TRAIN_NUM_SAMPLES:
    train_dataset = train_dataset.shuffle(seed=SEED).select(
        range(min(TRAIN_NUM_SAMPLES, len(train_dataset)))
    )

if EVAL_NUM_SAMPLES:
    eval_dataset = eval_dataset.shuffle(seed=SEED).select(
        range(min(EVAL_NUM_SAMPLES, len(eval_dataset)))
    )

logger.success(
    f"Loaded datasets: {len(train_dataset)} training samples, {len(eval_dataset)} evaluation samples."
)
# Flatten (interleave) and get SNAC token IDs from the audio codes.
def flatten_and_get_audio_input_ids(row):
    audio_codes = row["snac_codes"]
    if isinstance(audio_codes, str):
        audio_codes = eval(audio_codes)
    snac_token_ids = []
    for i in range(len(audio_codes[0])):
        snac_token_ids.append(audio_codes[0][i] + 128266)
        snac_token_ids.append(audio_codes[1][2 * i] + 128266 + 4096)
        snac_token_ids.append(audio_codes[2][4 * i] + 128266 + (2 * 4096))
        snac_token_ids.append(audio_codes[2][(4 * i) + 1] + 128266 + (3 * 4096))
        snac_token_ids.append(audio_codes[1][(2 * i) + 1] + 128266 + (4 * 4096))
        snac_token_ids.append(audio_codes[2][(4 * i) + 2] + 128266 + (5 * 4096))
        snac_token_ids.append(audio_codes[2][(4 * i) + 3] + 128266 + (6 * 4096))
    row["snac_token_ids"] = snac_token_ids
    return row


train_dataset = train_dataset.map(flatten_and_get_audio_input_ids)
eval_dataset = eval_dataset.map(flatten_and_get_audio_input_ids)
logger.success("Flattened and extracted SNAC token IDs from audio codes.")
# Filter out rows with empty or None audio codes.
train_dataset = train_dataset.filter(
    lambda x: x["snac_token_ids"] is not None and len(x["snac_token_ids"]) > 0
)
eval_dataset = eval_dataset.filter(
    lambda x: x["snac_token_ids"] is not None and len(x["snac_token_ids"]) > 0
)
logger.success("Filtered datasets to remove rows with empty or None audio codes.")
# Remove duplicate frames from the audio codes.
def remove_duplicate_frames(row):
    vals = row["snac_token_ids"]
    if len(vals) % 7 != 0:
        raise ValueError("Input list length must be divisible by 7")
    result = vals[:7]
    for i in range(7, len(vals), 7):
        current_first = vals[i]
        previous_first = result[-7]
        if current_first != previous_first:
            result.extend(vals[i : i + 7])
    row["snac_token_ids"] = result
    return row


train_dataset = train_dataset.map(remove_duplicate_frames)
eval_dataset = eval_dataset.map(remove_duplicate_frames)
logger.success("Removed duplicate frames from audio codes.")
# Define a function to format the prompt for each row in the dataset.
def format_text(row):
    text = (
        f"{start_of_human_token}{start_of_text_token}{row['language']}{row['user']}: {row['utterance']}{end_of_text_token}"
        f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}"
        f"{tokenizer.decode(row['snac_token_ids'])}{end_of_speech_token}{end_of_ai_token}"
    )
    eval_text_user = (
        f"{start_of_human_token}{start_of_text_token}{row['language']}{row['user']}: {row['utterance']}{end_of_text_token}"
        f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}"
    )
    eval_text_no_user = (
        f"{start_of_human_token}{start_of_text_token}{row['utterance']}{end_of_text_token}"
        f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}"
    )
    row["text"] = text
    row["eval_text_user"] = eval_text_user
    row["eval_text_no_user"] = eval_text_no_user
    return row


train_dataset = train_dataset.map(format_text)
eval_dataset = eval_dataset.map(format_text)
logger.success("Formatted text for training and evaluation datasets.")
# Tokenize the text in the datasets without adding special tokens.
def tokenize_function(example):
    return tokenizer(
        example["text"],
        add_special_tokens=False,
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
    )


train_dataset = train_dataset.map(tokenize_function)
eval_dataset = eval_dataset.map(tokenize_function)
logger.success("Tokenized text in the datasets without adding special tokens.")
# Set training arguments.
training_args = SFTConfig(
    num_train_epochs=2,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    optim="adamw_8bit",
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.02,
    do_eval=True,
    eval_strategy="steps",
    eval_steps=50,
    logging_strategy="steps",
    logging_steps=1,
    save_strategy="no",
    save_only_model=True,
    # save_steps=250,
    output_dir="outputs",
    report_to="wandb",
    run_name=WANDB_RUN_NAME,
    seed=SEED,
)

# Initialize the SFTTrainer.
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LENGTH,
    dataset_num_proc=2,
    packing=True,
    args=training_args,
)

logger.success("Initialized SFTTrainer with the specified configuration.")
# Start the training process.
logger.info("Starting the training process...")

run = wandb.init()

if WANDB_RUN_ID:
    logger.info(f"Resuming from Weights & Biases run ID: {WANDB_RUN_ID}")

    artifact = run.use_artifact(
        f"{WANDB_USERNAME}/{WANDB_PROJECT}/{WANDB_RUN_ID}", type="model"
    )

    artifact_dir = artifact.download()

    trainer.train(resume_from_checkpoint=artifact_dir)
else:
    try:
        logger.info("Attempting to resume training from the last checkpoint...")

        trainer.train(resume_from_checkpoint=True)
    except Exception as err:
        trainer.train()

# Finish the Weights & Biases run.
wandb.finish()

logger.success("Training completed successfully.")

Citation

BibTeX:

@misc{indictextaudio2025,
  title={snorTTS-Indic-v0: Multilingual Indic TTS},
  author={snorbyte},
  year={2025},
  howpublished={\url{snorbyte/snorTTS-Indic-v0}},
  note={Apache-2.0}
}