Need Correct SNAC Detokenizing code for vLLM inference.
I'm getting gibberish sound when inferencing vLLM with the current SNAC detokenising.
Please share your code.
import requests
import torch
import numpy as np
import wave
import os
import json
import logging
from snac import SNAC
from typing import List, Optional
Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(name)
class SnorTTSDecoder:
"""
Complete snorTTS Audio Decoder
Implements the exact official approach from:
https://huggingface.co/snorbyte/snorTTS-Indic-v0/blob/main/modal/snorTTS_Indic_v0_server.py
"""
def __init__(self, api_url: str):
"""
Initialize the decoder with API endpoint
Args:
api_url: The vLLM server endpoint URL
"""
self.api_url = api_url
self.headers = {
"Content-Type": "application/json",
"Authorization": "Bearer EMPTY",
"ngrok-skip-browser-warning": "true"
}
# Token configuration (from official implementation)
self.tokenizer_length = 128256
self.end_of_speech_id = self.tokenizer_length + 2 # 128258
self.pad_token_id = self.tokenizer_length + 7 # 128263
self.audio_start_id = self.tokenizer_length + 10 # 128266
# Load SNAC model
logger.info("Loading SNAC model...")
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.snac_model = self.snac_model.to(self.device)
logger.info(f"SNAC model loaded on {self.device}")
self.sample_rate = 24000
def generate_tokens(self, prompt: str, max_tokens: int = 2048) -> List[int]:
"""
Generate tokens from the vLLM model
Args:
prompt: Input prompt in snorTTS format
max_tokens: Maximum tokens to generate
Returns:
List of generated token IDs
"""
payload = {
"model": "snorbyte/snorTTS-Indic-v0",
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": 0.4,
"top_p": 0.9,
"repetition_penalty": 1.05,
"stream": False,
"add_special_tokens": False,
"stop_token_ids": [self.end_of_speech_id]
}
logger.info(f"Sending request to: {self.api_url}")
logger.info(f"Prompt: {prompt}")
try:
response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
generated_text = result['choices'][0]['text']
logger.info(f"Generated text length: {len(generated_text)} characters")
# Extract custom tokens
tokens = []
text = generated_text
while '<custom_token_' in text:
start = text.find('<custom_token_')
if start == -1:
break
end = text.find('>', start)
if end == -1:
break
try:
token_str = text[start+14:end]
token_id = int(token_str)
tokens.append(token_id)
except ValueError:
logger.warning(f"Invalid token format: {text[start:end+1]}")
text = text[end+1:]
logger.info(f"Extracted {len(tokens)} tokens")
if tokens:
logger.info(f"Token range: {min(tokens)} - {max(tokens)}")
return tokens
except requests.exceptions.RequestException as e:
logger.error(f"API request failed: {e}")
return []
except Exception as e:
logger.error(f"Token generation failed: {e}")
return []
def decode_audio(self, tokens: List[int]) -> Optional[np.ndarray]:
"""
Decode tokens to audio using official SNAC approach
Args:
tokens: List of token IDs from model
Returns:
Audio numpy array or None if failed
"""
if not tokens:
logger.error("No tokens provided")
return None
logger.info(f"Decoding {len(tokens)} tokens to audio")
# Step 1: Filter audio tokens (official approach)
audio_ids = [token for token in tokens if token >= self.audio_start_id]
logger.info(f"Found {len(audio_ids)} audio tokens (>= {self.audio_start_id})")
if len(audio_ids) < 7:
logger.error(f"Need at least 7 audio tokens, got {len(audio_ids)}")
logger.error("This indicates the model is not generating proper audio tokens")
return None
# Step 2: Offset audio tokens (official approach)
snac_audio_ids = []
for i in range((len(audio_ids) + 1) // 7):
for j in range(7):
if 7 * i + j < len(audio_ids):
snac_audio_ids.append(audio_ids[7 * i + j] - self.audio_start_id)
logger.info(f"After offset subtraction: {len(snac_audio_ids)} SNAC tokens")
if snac_audio_ids:
logger.info(f"SNAC token range: {min(snac_audio_ids)} - {max(snac_audio_ids)}")
# Step 3: Prepare SNAC codes (exact official implementation)
codes = [[], [], []]
for i in range((len(snac_audio_ids) + 1) // 7):
base_idx = 7 * i
if base_idx + 6 < len(snac_audio_ids):
# Official SNAC code mapping
codes[0].append(snac_audio_ids[base_idx]) # Position 0 -> Codebook 0
codes[1].append(snac_audio_ids[base_idx + 1] - 4096) # Position 1 -> Codebook 1
codes[2].append(snac_audio_ids[base_idx + 2] - (2 * 4096)) # Position 2 -> Codebook 2
codes[2].append(snac_audio_ids[base_idx + 3] - (3 * 4096)) # Position 3 -> Codebook 2
codes[1].append(snac_audio_ids[base_idx + 4] - (4 * 4096)) # Position 4 -> Codebook 1
codes[2].append(snac_audio_ids[base_idx + 5] - (5 * 4096)) # Position 5 -> Codebook 2
codes[2].append(snac_audio_ids[base_idx + 6] - (6 * 4096)) # Position 6 -> Codebook 2
if not codes[0]:
logger.error("No SNAC codes generated")
return None
# Step 4: Validate and clamp codes
negative_codes = 0
for codebook_idx, codebook in enumerate(codes):
for i, code in enumerate(codebook):
if code < 0:
negative_codes += 1
# Clamp to SNAC range [0, 4095]
codebook[i] = max(0, min(4095, code))
if negative_codes > 0:
logger.warning(f"Found {negative_codes} negative codes - clamped to 0")
logger.warning("Negative codes usually indicate model token generation issues")
logger.info(f"SNAC codebook shapes: {len(codes[0])}:{len(codes[1])}:{len(codes[2])}")
# Step 5: Convert to tensors and decode
try:
codes_tensors = [
torch.tensor(codes[0], device=self.device, dtype=torch.int32).unsqueeze(0),
torch.tensor(codes[1], device=self.device, dtype=torch.int32).unsqueeze(0),
torch.tensor(codes[2], device=self.device, dtype=torch.int32).unsqueeze(0)
]
with torch.inference_mode():
audio_tensor = self.snac_model.decode(codes_tensors)
audio_np = audio_tensor.detach().cpu().numpy().squeeze()
# Validate audio
duration = len(audio_np) / self.sample_rate
rms = np.sqrt(np.mean(audio_np ** 2))
peak = np.max(np.abs(audio_np)) if len(audio_np) > 0 else 0
logger.info(f"Audio generated: {duration:.2f}s, RMS={rms:.4f}, Peak={peak:.4f}")
return audio_np
except Exception as e:
logger.error(f"SNAC decoding failed: {e}")
return None
def save_audio(self, audio: np.ndarray, filename: str) -> bool:
"""
Save audio to WAV file
Args:
audio: Audio numpy array
filename: Output filename
Returns:
True if successful, False otherwise
"""
if audio is None or len(audio) == 0:
logger.error("No audio data to save")
return False
try:
# Create output directory
os.makedirs(os.path.dirname(filename) if os.path.dirname(filename) else ".", exist_ok=True)
# Normalize audio to prevent clipping
if np.max(np.abs(audio)) > 0:
audio = audio / np.max(np.abs(audio)) * 0.9
# Convert to 16-bit PCM
audio_int16 = (audio * 32767).astype(np.int16)
# Save as WAV
with wave.open(filename, 'wb') as wav_file:
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(self.sample_rate) # 24kHz
wav_file.writeframes(audio_int16.tobytes())
logger.info(f"Audio saved: {filename}")
return True
except Exception as e:
logger.error(f"Failed to save audio: {e}")
return False
def generate_audio(self, language: str, voice_id: int, text: str, output_file: str = "output.wav") -> bool:
"""
Complete pipeline: text -> tokens -> audio -> file
Args:
language: Language name (e.g., "hindi", "tamil")
voice_id: Voice ID number
text: Text to synthesize
output_file: Output audio filename
Returns:
True if successful, False otherwise
"""
logger.info(f"=== snorTTS Audio Generation ===")
logger.info(f"Language: {language}")
logger.info(f"Voice ID: {voice_id}")
logger.info(f"Text: {text}")
logger.info(f"Output: {output_file}")
# Step 1: Create prompt (official format)
prompt = f"<custom_token_3><|begin_of_text|>{language}{voice_id}: {text}<|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
logger.info(f"Prompt: {prompt}")
# Step 2: Generate tokens
tokens = self.generate_tokens(prompt)
if not tokens:
logger.error("❌ Token generation failed")
return False
# Step 3: Decode to audio
audio = self.decode_audio(tokens)
if audio is None:
logger.error("❌ Audio decoding failed")
return False
# Step 4: Save audio
success = self.save_audio(audio, output_file)
if success:
logger.info(f"✅ Audio generation completed: {output_file}")
else:
logger.error("❌ Audio saving failed")
return success
def main():
"""
Demo script showing complete snorTTS usage
"""
# Configuration
API_URL = "https://5729a4e0fea8.ngrok-free.app/v1/completions"
# Test cases
test_cases = [
{
"language": "hindi",
"voice_id": 159,
"text": "नमस्ते, आज का दिन कैसा है?",
"output": "output/hindi_test.wav"
},
{
"language": "tamil",
"voice_id": 188,
"text": "வணக்கம், நீங்கள் எப்படி இருக்கிறீர்கள்?",
"output": "output/tamil_test.wav"
},
{
"language": "hindi",
"voice_id": 159,
"text": "चलते रहो इस सफर में बिना रुके, क्योंकि मंज़िलें खुद राह दिखाने लगती हैं",
"output": "output/hindi_user_input.wav"
}
]
print("=" * 60)
print("snorTTS Complete Audio Decoder")
print("Official Implementation for Author Review")
print("=" * 60)
# Initialize decoder
decoder = SnorTTSDecoder(API_URL)
# Process each test case
results = []
for i, case in enumerate(test_cases, 1):
print(f"\n--- Test Case {i}/{len(test_cases)} ---")
success = decoder.generate_audio(
language=case["language"],
voice_id=case["voice_id"],
text=case["text"],
output_file=case["output"]
)
results.append({
"case": i,
"language": case["language"],
"success": success,
"output": case["output"] if success else None
})
# Summary
print(f"\n" + "=" * 60)
print("RESULTS SUMMARY")
print("=" * 60)
successful = sum(1 for r in results if r["success"])
total = len(results)
print(f"Success Rate: {successful}/{total}")
for result in results:
status = "✅ SUCCESS" if result["success"] else "❌ FAILED"
print(f" Test {result['case']} ({result['language']}): {status}")
if result["output"]:
print(f" Output: {result['output']}")
if successful == 0:
print(f"\n🔍 DIAGNOSTIC INFORMATION:")
print(f" - All tests failed")
print(f" - This likely indicates the model endpoint is not generating audio tokens")
print(f" - Expected: tokens >= 128266 (audio tokens)")
print(f" - Check model deployment and configuration")
elif successful < total:
print(f" - Some tests failed")
print(f" - Check logs above for specific failure reasons")
else:
print(f" - snorTTS implementation working correctly")
print(f" - Audio files generated in output/ directory")
print("=" * 60)
if name == "main":
main()
I feel there is a trivial bug somewhere.
If you want server classes you can check this out,
vLLM: https://github.com/SaudxInu/indic-tts/blob/main/scripts/modal/snorTTS_Indic_v0_vllm.py
Generate: https://github.com/SaudxInu/indic-tts/blob/main/scripts/modal/snorTTS_Indic_v0_server.py