Need Correct SNAC Detokenizing code for vLLM inference.

#6
by Hariprasath28 - opened

I'm getting gibberish sound when inferencing vLLM with the current SNAC detokenising.

Snorbyte org

Please share your code.

import requests
import torch
import numpy as np
import wave
import os
import json
import logging
from snac import SNAC
from typing import List, Optional

Configure logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(name)

class SnorTTSDecoder:
"""
Complete snorTTS Audio Decoder

Implements the exact official approach from:
https://huggingface.co/snorbyte/snorTTS-Indic-v0/blob/main/modal/snorTTS_Indic_v0_server.py
"""

def __init__(self, api_url: str):
    """
    Initialize the decoder with API endpoint
    
    Args:
        api_url: The vLLM server endpoint URL
    """
    self.api_url = api_url
    self.headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer EMPTY",
        "ngrok-skip-browser-warning": "true"
    }
    
    # Token configuration (from official implementation)
    self.tokenizer_length = 128256
    self.end_of_speech_id = self.tokenizer_length + 2    # 128258
    self.pad_token_id = self.tokenizer_length + 7        # 128263
    self.audio_start_id = self.tokenizer_length + 10     # 128266
    
    # Load SNAC model
    logger.info("Loading SNAC model...")
    self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
    self.device = "cuda" if torch.cuda.is_available() else "cpu"
    self.snac_model = self.snac_model.to(self.device)
    logger.info(f"SNAC model loaded on {self.device}")
    
    self.sample_rate = 24000
    
def generate_tokens(self, prompt: str, max_tokens: int = 2048) -> List[int]:
    """
    Generate tokens from the vLLM model
    
    Args:
        prompt: Input prompt in snorTTS format
        max_tokens: Maximum tokens to generate
        
    Returns:
        List of generated token IDs
    """
    payload = {
        "model": "snorbyte/snorTTS-Indic-v0",
        "prompt": prompt,
        "max_tokens": max_tokens,
        "temperature": 0.4,
        "top_p": 0.9,
        "repetition_penalty": 1.05,
        "stream": False,
        "add_special_tokens": False,
        "stop_token_ids": [self.end_of_speech_id]
    }
    
    logger.info(f"Sending request to: {self.api_url}")
    logger.info(f"Prompt: {prompt}")
    
    try:
        response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=120)
        response.raise_for_status()
        
        result = response.json()
        generated_text = result['choices'][0]['text']
        
        logger.info(f"Generated text length: {len(generated_text)} characters")
        
        # Extract custom tokens
        tokens = []
        text = generated_text
        while '<custom_token_' in text:
            start = text.find('<custom_token_')
            if start == -1:
                break
            end = text.find('>', start)
            if end == -1:
                break
            
            try:
                token_str = text[start+14:end]
                token_id = int(token_str)
                tokens.append(token_id)
            except ValueError:
                logger.warning(f"Invalid token format: {text[start:end+1]}")
                
            text = text[end+1:]
        
        logger.info(f"Extracted {len(tokens)} tokens")
        if tokens:
            logger.info(f"Token range: {min(tokens)} - {max(tokens)}")
            
        return tokens
        
    except requests.exceptions.RequestException as e:
        logger.error(f"API request failed: {e}")
        return []
    except Exception as e:
        logger.error(f"Token generation failed: {e}")
        return []

def decode_audio(self, tokens: List[int]) -> Optional[np.ndarray]:
    """
    Decode tokens to audio using official SNAC approach
    
    Args:
        tokens: List of token IDs from model
        
    Returns:
        Audio numpy array or None if failed
    """
    if not tokens:
        logger.error("No tokens provided")
        return None
        
    logger.info(f"Decoding {len(tokens)} tokens to audio")
    
    # Step 1: Filter audio tokens (official approach)
    audio_ids = [token for token in tokens if token >= self.audio_start_id]
    logger.info(f"Found {len(audio_ids)} audio tokens (>= {self.audio_start_id})")
    
    if len(audio_ids) < 7:
        logger.error(f"Need at least 7 audio tokens, got {len(audio_ids)}")
        logger.error("This indicates the model is not generating proper audio tokens")
        return None
    
    # Step 2: Offset audio tokens (official approach)
    snac_audio_ids = []
    for i in range((len(audio_ids) + 1) // 7):
        for j in range(7):
            if 7 * i + j < len(audio_ids):
                snac_audio_ids.append(audio_ids[7 * i + j] - self.audio_start_id)
    
    logger.info(f"After offset subtraction: {len(snac_audio_ids)} SNAC tokens")
    if snac_audio_ids:
        logger.info(f"SNAC token range: {min(snac_audio_ids)} - {max(snac_audio_ids)}")
    
    # Step 3: Prepare SNAC codes (exact official implementation)
    codes = [[], [], []]
    for i in range((len(snac_audio_ids) + 1) // 7):
        base_idx = 7 * i
        if base_idx + 6 < len(snac_audio_ids):
            # Official SNAC code mapping
            codes[0].append(snac_audio_ids[base_idx])                           # Position 0 -> Codebook 0
            codes[1].append(snac_audio_ids[base_idx + 1] - 4096)               # Position 1 -> Codebook 1
            codes[2].append(snac_audio_ids[base_idx + 2] - (2 * 4096))         # Position 2 -> Codebook 2
            codes[2].append(snac_audio_ids[base_idx + 3] - (3 * 4096))         # Position 3 -> Codebook 2
            codes[1].append(snac_audio_ids[base_idx + 4] - (4 * 4096))         # Position 4 -> Codebook 1
            codes[2].append(snac_audio_ids[base_idx + 5] - (5 * 4096))         # Position 5 -> Codebook 2
            codes[2].append(snac_audio_ids[base_idx + 6] - (6 * 4096))         # Position 6 -> Codebook 2
    
    if not codes[0]:
        logger.error("No SNAC codes generated")
        return None
    
    # Step 4: Validate and clamp codes
    negative_codes = 0
    for codebook_idx, codebook in enumerate(codes):
        for i, code in enumerate(codebook):
            if code < 0:
                negative_codes += 1
            # Clamp to SNAC range [0, 4095]
            codebook[i] = max(0, min(4095, code))
    
    if negative_codes > 0:
        logger.warning(f"Found {negative_codes} negative codes - clamped to 0")
        logger.warning("Negative codes usually indicate model token generation issues")
    
    logger.info(f"SNAC codebook shapes: {len(codes[0])}:{len(codes[1])}:{len(codes[2])}")
    
    # Step 5: Convert to tensors and decode
    try:
        codes_tensors = [
            torch.tensor(codes[0], device=self.device, dtype=torch.int32).unsqueeze(0),
            torch.tensor(codes[1], device=self.device, dtype=torch.int32).unsqueeze(0),
            torch.tensor(codes[2], device=self.device, dtype=torch.int32).unsqueeze(0)
        ]
        
        with torch.inference_mode():
            audio_tensor = self.snac_model.decode(codes_tensors)
            audio_np = audio_tensor.detach().cpu().numpy().squeeze()
            
            # Validate audio
            duration = len(audio_np) / self.sample_rate
            rms = np.sqrt(np.mean(audio_np ** 2))
            peak = np.max(np.abs(audio_np)) if len(audio_np) > 0 else 0
            
            logger.info(f"Audio generated: {duration:.2f}s, RMS={rms:.4f}, Peak={peak:.4f}")
            
            return audio_np
            
    except Exception as e:
        logger.error(f"SNAC decoding failed: {e}")
        return None

def save_audio(self, audio: np.ndarray, filename: str) -> bool:
    """
    Save audio to WAV file
    
    Args:
        audio: Audio numpy array
        filename: Output filename
        
    Returns:
        True if successful, False otherwise
    """
    if audio is None or len(audio) == 0:
        logger.error("No audio data to save")
        return False
        
    try:
        # Create output directory
        os.makedirs(os.path.dirname(filename) if os.path.dirname(filename) else ".", exist_ok=True)
        
        # Normalize audio to prevent clipping
        if np.max(np.abs(audio)) > 0:
            audio = audio / np.max(np.abs(audio)) * 0.9
        
        # Convert to 16-bit PCM
        audio_int16 = (audio * 32767).astype(np.int16)
        
        # Save as WAV
        with wave.open(filename, 'wb') as wav_file:
            wav_file.setnchannels(1)          # Mono
            wav_file.setsampwidth(2)          # 16-bit
            wav_file.setframerate(self.sample_rate)  # 24kHz
            wav_file.writeframes(audio_int16.tobytes())
        
        logger.info(f"Audio saved: {filename}")
        return True
        
    except Exception as e:
        logger.error(f"Failed to save audio: {e}")
        return False

def generate_audio(self, language: str, voice_id: int, text: str, output_file: str = "output.wav") -> bool:
    """
    Complete pipeline: text -> tokens -> audio -> file
    
    Args:
        language: Language name (e.g., "hindi", "tamil")
        voice_id: Voice ID number
        text: Text to synthesize
        output_file: Output audio filename
        
    Returns:
        True if successful, False otherwise
    """
    logger.info(f"=== snorTTS Audio Generation ===")
    logger.info(f"Language: {language}")
    logger.info(f"Voice ID: {voice_id}")
    logger.info(f"Text: {text}")
    logger.info(f"Output: {output_file}")
    
    # Step 1: Create prompt (official format)
    prompt = f"<custom_token_3><|begin_of_text|>{language}{voice_id}: {text}<|eot_id|><custom_token_4><custom_token_5><custom_token_1>"
    logger.info(f"Prompt: {prompt}")
    
    # Step 2: Generate tokens
    tokens = self.generate_tokens(prompt)
    if not tokens:
        logger.error("❌ Token generation failed")
        return False
    
    # Step 3: Decode to audio
    audio = self.decode_audio(tokens)
    if audio is None:
        logger.error("❌ Audio decoding failed")
        return False
    
    # Step 4: Save audio
    success = self.save_audio(audio, output_file)
    if success:
        logger.info(f"✅ Audio generation completed: {output_file}")
    else:
        logger.error("❌ Audio saving failed")
        
    return success

def main():
"""
Demo script showing complete snorTTS usage
"""
# Configuration
API_URL = "https://5729a4e0fea8.ngrok-free.app/v1/completions"

# Test cases
test_cases = [
    {
        "language": "hindi",
        "voice_id": 159,
        "text": "नमस्ते, आज का दिन कैसा है?",
        "output": "output/hindi_test.wav"
    },
    {
        "language": "tamil", 
        "voice_id": 188,
        "text": "வணக்கம், நீங்கள் எப்படி இருக்கிறீர்கள்?",
        "output": "output/tamil_test.wav"
    },
    {
        "language": "hindi",
        "voice_id": 159,
        "text": "चलते रहो इस सफर में बिना रुके, क्योंकि मंज़िलें खुद राह दिखाने लगती हैं",
        "output": "output/hindi_user_input.wav"
    }
]

print("=" * 60)
print("snorTTS Complete Audio Decoder")
print("Official Implementation for Author Review")
print("=" * 60)

# Initialize decoder
decoder = SnorTTSDecoder(API_URL)

# Process each test case
results = []
for i, case in enumerate(test_cases, 1):
    print(f"\n--- Test Case {i}/{len(test_cases)} ---")
    
    success = decoder.generate_audio(
        language=case["language"],
        voice_id=case["voice_id"],
        text=case["text"],
        output_file=case["output"]
    )
    
    results.append({
        "case": i,
        "language": case["language"],
        "success": success,
        "output": case["output"] if success else None
    })

# Summary
print(f"\n" + "=" * 60)
print("RESULTS SUMMARY")
print("=" * 60)

successful = sum(1 for r in results if r["success"])
total = len(results)

print(f"Success Rate: {successful}/{total}")

for result in results:
    status = "✅ SUCCESS" if result["success"] else "❌ FAILED"
    print(f"  Test {result['case']} ({result['language']}): {status}")
    if result["output"]:
        print(f"    Output: {result['output']}")

if successful == 0:
    print(f"\n🔍 DIAGNOSTIC INFORMATION:")
    print(f"  - All tests failed")
    print(f"  - This likely indicates the model endpoint is not generating audio tokens")
    print(f"  - Expected: tokens >= 128266 (audio tokens)")
    print(f"  - Check model deployment and configuration")
elif successful < total:
    print(f"  - Some tests failed")
    print(f"  - Check logs above for specific failure reasons")
else:
    print(f"  - snorTTS implementation working correctly")
    print(f"  - Audio files generated in output/ directory")

print("=" * 60)

if name == "main":
main()

Snorbyte org

Sign up or log in to comment