Spaces:
Running
on
Zero
Running
on
Zero
File size: 19,297 Bytes
cd0b70a ece0588 cd0b70a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 |
# Created by Fabio Sarracino
import logging
import os
import re
import tempfile
import torch
import numpy as np
from typing import List, Optional
from .base_vibevoice import BaseVibeVoiceNode
# Setup logging
logger = logging.getLogger("VibeVoice")
class VibeVoiceMultipleSpeakersNode(BaseVibeVoiceNode):
def __init__(self):
super().__init__()
# Register this instance for memory management
try:
from .free_memory_node import VibeVoiceFreeMemoryNode
VibeVoiceFreeMemoryNode.register_multi_speaker(self)
except:
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {
"multiline": True,
"default": "[1]: Hello, this is the first speaker.\n[2]: Hi there, I'm the second speaker.\n[1]: Nice to meet you!\n[2]: Nice to meet you too!",
"tooltip": "Text with speaker labels. Use '[N]:' format where N is 1-4. Gets disabled when connected to another node.",
"forceInput": False,
"dynamicPrompts": True
}),
"model": (["VibeVoice-1.5B", "VibeVoice-Large", "VibeVoice-Large-Quant-4Bit","VibeVoice-Large-Q8"], {
"default": "VibeVoice-Large-Q8", # Large recommended for multi-speaker
"tooltip": "Model to use. Large is recommended for multi-speaker generation, Quant-4Bit uses less VRAM (CUDA only)"
}),
"attention_type": (["auto", "eager", "sdpa", "flash_attention_2", "sage"], {
"default": "auto",
"tooltip": "Attention implementation. Auto selects the best available, eager is standard, sdpa is optimized PyTorch, flash_attention_2 requires compatible GPU, sage uses quantized attention for speedup (CUDA only)"
}),
"free_memory_after_generate": ("BOOLEAN", {"default": True, "tooltip": "Free model from memory after generation to save VRAM/RAM. Disable to keep model loaded for faster subsequent generations"}),
"diffusion_steps": ("INT", {"default": 20, "min": 5, "max": 100, "step": 1, "tooltip": "Number of denoising steps. More steps = better quality but slower. Default: 20"}),
"seed": ("INT", {"default": 42, "min": 0, "max": 2**32-1, "tooltip": "Random seed for generation. Default 42 is used in official examples"}),
"cfg_scale": ("FLOAT", {"default": 1.3, "min": 0.5, "max": 3.5, "step": 0.05, "tooltip": "Classifier-free guidance scale (official default: 1.3)"}),
"use_sampling": ("BOOLEAN", {"default": False, "tooltip": "Enable sampling mode. When False (default), uses deterministic generation like official examples"}),
},
"optional": {
"speaker1_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 1. If not provided, synthetic voice will be used."}),
"speaker2_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 2. If not provided, synthetic voice will be used."}),
"speaker3_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 3. If not provided, synthetic voice will be used."}),
"speaker4_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 4. If not provided, synthetic voice will be used."}),
"temperature": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 2.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
"top_p": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 1.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "generate_speech"
CATEGORY = "VibeVoiceWrapper"
DESCRIPTION = "Generate multi-speaker conversations with up to 4 distinct voices using Microsoft VibeVoice"
def _prepare_voice_sample(self, voice_audio, speaker_idx: int) -> Optional[np.ndarray]:
"""Prepare a single voice sample from input audio"""
return self._prepare_audio_from_comfyui(voice_audio)
def generate_speech(self, text: str = "", model: str = "VibeVoice-7B-Preview",
attention_type: str = "auto", free_memory_after_generate: bool = True,
diffusion_steps: int = 20, seed: int = 42, cfg_scale: float = 1.3,
use_sampling: bool = False, speaker1_voice=None, speaker2_voice=None,
speaker3_voice=None, speaker4_voice=None,
temperature: float = 0.95, top_p: float = 0.95):
"""Generate multi-speaker speech from text using VibeVoice"""
try:
# Check text input
if not text or not text.strip():
raise Exception("No text provided. Please enter text with speaker labels (e.g., '[1]: Hello' or '[2]: Hi')")
# First detect how many speakers are in the text
bracket_pattern = r'\[(\d+)\]\s*:'
speakers_numbers = sorted(list(set([int(m) for m in re.findall(bracket_pattern, text)])))
# Limit to 1-4 speakers
if not speakers_numbers:
num_speakers = 1 # Default to 1 if no speaker format found
else:
num_speakers = min(max(speakers_numbers), 4) # Max speaker number, capped at 4
if max(speakers_numbers) > 4:
print(f"[VibeVoice] Warning: Found {max(speakers_numbers)} speakers, limiting to 4")
# Direct conversion from [N]: to Speaker (N-1): for VibeVoice processor
# This avoids multiple conversion steps
converted_text = text
# Find all [N]: patterns in the text
speakers_in_text = sorted(list(set([int(m) for m in re.findall(bracket_pattern, text)])))
if not speakers_in_text:
# No [N]: format found, try Speaker N: format
speaker_pattern = r'Speaker\s+(\d+)\s*:'
speakers_in_text = sorted(list(set([int(m) for m in re.findall(speaker_pattern, text)])))
if speakers_in_text:
# Text already in Speaker N format, convert to 0-based
for speaker_num in sorted(speakers_in_text, reverse=True):
pattern = f'Speaker\\s+{speaker_num}\\s*:'
replacement = f'Speaker {speaker_num - 1}:'
converted_text = re.sub(pattern, replacement, converted_text)
else:
# No speaker format found
speakers_in_text = [1]
# Parse pause keywords even for single speaker
pause_segments = self._parse_pause_keywords(text)
# Store speaker segments for pause processing
speaker_segments_with_pauses = []
segments = []
for seg_type, seg_content in pause_segments:
if seg_type == 'pause':
speaker_segments_with_pauses.append(('pause', seg_content, None))
else:
# Clean up newlines
text_clean = seg_content.replace('\n', ' ').replace('\r', ' ')
text_clean = ' '.join(text_clean.split())
if text_clean:
speaker_segments_with_pauses.append(('text', text_clean, 1))
segments.append(f"Speaker 0: {text_clean}")
# Join all segments for fallback
converted_text = '\n'.join(segments) if segments else f"Speaker 0: {text}"
else:
# Convert [N]: directly to Speaker (N-1): and handle multi-line text
# Split text to preserve speaker segments while cleaning up newlines within each segment
segments = []
# Find all speaker markers with their positions
speaker_matches = list(re.finditer(f'\\[({"|".join(map(str, speakers_in_text))})\\]\\s*:', converted_text))
# Store speaker segments for pause processing
speaker_segments_with_pauses = []
for i, match in enumerate(speaker_matches):
speaker_num = int(match.group(1))
start = match.end()
# Find where this speaker's text ends (at next speaker or end of text)
if i + 1 < len(speaker_matches):
end = speaker_matches[i + 1].start()
else:
end = len(converted_text)
# Extract the speaker's text (keep pause keywords for now)
speaker_text = converted_text[start:end].strip()
# Parse pause keywords within this speaker's text
pause_segments = self._parse_pause_keywords(speaker_text)
# Process each segment (text or pause) for this speaker
for seg_type, seg_content in pause_segments:
if seg_type == 'pause':
# Add pause segment
speaker_segments_with_pauses.append(('pause', seg_content, None))
else:
# Clean up the text segment
text_clean = seg_content.replace('\n', ' ').replace('\r', ' ')
text_clean = ' '.join(text_clean.split())
if text_clean: # Only add non-empty text
# Add text segment with speaker info
speaker_segments_with_pauses.append(('text', text_clean, speaker_num))
# Also build the traditional segments for fallback
segments.append(f'Speaker {speaker_num - 1}: {text_clean}')
# Join all segments with newlines (required for multi-speaker format) - for fallback
converted_text = '\n'.join(segments) if segments else ""
# Build speaker names list - these are just for logging, not used by processor
# The processor uses the speaker labels in the text itself
speakers = [f"Speaker {i}" for i in range(len(speakers_in_text))]
# Get model mapping and load model with attention type
model_mapping = self._get_model_mapping()
model_path = model_mapping.get(model, model)
self.load_model(model, model_path, attention_type)
voice_inputs = [speaker1_voice, speaker2_voice, speaker3_voice, speaker4_voice]
# Prepare voice samples in order of appearance
voice_samples = []
for i, speaker_num in enumerate(speakers_in_text):
idx = speaker_num - 1 # Convert to 0-based for voice array
# Try to use provided voice sample
if idx < len(voice_inputs) and voice_inputs[idx] is not None:
voice_sample = self._prepare_voice_sample(voice_inputs[idx], idx)
if voice_sample is None:
# Use the actual speaker index for consistent synthetic voice
voice_sample = self._create_synthetic_voice_sample(idx)
else:
# Use the actual speaker index for consistent synthetic voice
voice_sample = self._create_synthetic_voice_sample(idx)
voice_samples.append(voice_sample)
# Ensure voice_samples count matches detected speakers
if len(voice_samples) != len(speakers_in_text):
logger.error(f"Mismatch: {len(speakers_in_text)} speakers but {len(voice_samples)} voice samples!")
raise Exception(f"Voice sample count mismatch: expected {len(speakers_in_text)}, got {len(voice_samples)}")
# Check if we have pause segments to process
if 'speaker_segments_with_pauses' in locals() and speaker_segments_with_pauses:
# Process segments with pauses
all_audio_segments = []
sample_rate = 24000 # VibeVoice uses 24kHz
# Group consecutive text segments from same speaker for efficiency
grouped_segments = []
current_group = []
current_speaker = None
for seg_type, seg_content, speaker_num in speaker_segments_with_pauses:
if seg_type == 'pause':
# Save current group if any
if current_group:
grouped_segments.append(('text_group', current_group, current_speaker))
current_group = []
current_speaker = None
# Add pause
grouped_segments.append(('pause', seg_content, None))
else:
# Text segment
if speaker_num == current_speaker:
# Same speaker, add to current group
current_group.append(seg_content)
else:
# Different speaker, save current group and start new one
if current_group:
grouped_segments.append(('text_group', current_group, current_speaker))
current_group = [seg_content]
current_speaker = speaker_num
# Save last group if any
if current_group:
grouped_segments.append(('text_group', current_group, current_speaker))
# Process grouped segments
for seg_type, seg_content, speaker_num in grouped_segments:
if seg_type == 'pause':
# Generate silence
duration_ms = seg_content
logger.info(f"Adding {duration_ms}ms pause")
silence_audio = self._generate_silence(duration_ms, sample_rate)
all_audio_segments.append(silence_audio)
else:
# Process text group for a speaker
combined_text = ' '.join(seg_content)
formatted_text = f"Speaker {speaker_num - 1}: {combined_text}"
# Get voice sample for this speaker
speaker_idx = speakers_in_text.index(speaker_num)
speaker_voice_samples = [voice_samples[speaker_idx]]
logger.info(f"Generating audio for Speaker {speaker_num}: {len(combined_text.split())} words")
# Generate audio for this speaker's text
segment_audio = self._generate_with_vibevoice(
formatted_text, speaker_voice_samples, cfg_scale, seed,
diffusion_steps, use_sampling, temperature, top_p
)
all_audio_segments.append(segment_audio)
# Concatenate all audio segments
if all_audio_segments:
logger.info(f"Concatenating {len(all_audio_segments)} audio segments (including pauses)...")
# Extract waveforms
waveforms = []
for audio_segment in all_audio_segments:
if isinstance(audio_segment, dict) and "waveform" in audio_segment:
waveforms.append(audio_segment["waveform"])
if waveforms:
# Filter out None values if any
valid_waveforms = [w for w in waveforms if w is not None]
if valid_waveforms:
# Concatenate along time dimension
combined_waveform = torch.cat(valid_waveforms, dim=-1)
audio_dict = {
"waveform": combined_waveform,
"sample_rate": sample_rate
}
logger.info(f"Successfully generated multi-speaker audio with pauses")
else:
raise Exception("No valid audio waveforms generated")
else:
raise Exception("Failed to extract waveforms from audio segments")
else:
raise Exception("No audio segments generated")
else:
# Fallback to original method without pause support
logger.info("Processing without pause support (no pause keywords found)")
audio_dict = self._generate_with_vibevoice(
converted_text, voice_samples, cfg_scale, seed, diffusion_steps,
use_sampling, temperature, top_p
)
# Free memory if requested
if free_memory_after_generate:
self.free_memory()
return (audio_dict,)
except Exception as e:
# Check if this is an interruption by the user
import comfy.model_management as mm
if isinstance(e, mm.InterruptProcessingException):
# User interrupted - just log it and re-raise to stop the workflow
logger.info("Generation interrupted by user")
raise # Propagate the interruption to stop the workflow
else:
# Real error - show it
logger.error(f"Multi-speaker speech generation failed: {str(e)}")
raise Exception(f"Error generating multi-speaker speech: {str(e)}")
@classmethod
def IS_CHANGED(cls, text="", model="VibeVoice-7B-Preview",
speaker1_voice=None, speaker2_voice=None,
speaker3_voice=None, speaker4_voice=None, **kwargs):
"""Cache key for ComfyUI"""
voices_hash = hash(str([speaker1_voice, speaker2_voice, speaker3_voice, speaker4_voice]))
return f"{hash(text)}_{model}_{voices_hash}_{kwargs.get('cfg_scale', 1.3)}_{kwargs.get('seed', 0)}" |