File size: 19,297 Bytes
cd0b70a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ece0588
 
cd0b70a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
# Created by Fabio Sarracino

import logging
import os
import re
import tempfile
import torch
import numpy as np
from typing import List, Optional

from .base_vibevoice import BaseVibeVoiceNode

# Setup logging
logger = logging.getLogger("VibeVoice")

class VibeVoiceMultipleSpeakersNode(BaseVibeVoiceNode):
    def __init__(self):
        super().__init__()
        # Register this instance for memory management
        try:
            from .free_memory_node import VibeVoiceFreeMemoryNode
            VibeVoiceFreeMemoryNode.register_multi_speaker(self)
        except:
            pass
    
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "text": ("STRING", {
                    "multiline": True,
                    "default": "[1]: Hello, this is the first speaker.\n[2]: Hi there, I'm the second speaker.\n[1]: Nice to meet you!\n[2]: Nice to meet you too!", 
                    "tooltip": "Text with speaker labels. Use '[N]:' format where N is 1-4. Gets disabled when connected to another node.",
                    "forceInput": False,
                    "dynamicPrompts": True
                }),
                "model": (["VibeVoice-1.5B", "VibeVoice-Large", "VibeVoice-Large-Quant-4Bit","VibeVoice-Large-Q8"], {
                    "default": "VibeVoice-Large-Q8",  # Large recommended for multi-speaker
                    "tooltip": "Model to use. Large is recommended for multi-speaker generation, Quant-4Bit uses less VRAM (CUDA only)"
                }),
                "attention_type": (["auto", "eager", "sdpa", "flash_attention_2", "sage"], {
                    "default": "auto",
                    "tooltip": "Attention implementation. Auto selects the best available, eager is standard, sdpa is optimized PyTorch, flash_attention_2 requires compatible GPU, sage uses quantized attention for speedup (CUDA only)"
                }),
                "free_memory_after_generate": ("BOOLEAN", {"default": True, "tooltip": "Free model from memory after generation to save VRAM/RAM. Disable to keep model loaded for faster subsequent generations"}),
                "diffusion_steps": ("INT", {"default": 20, "min": 5, "max": 100, "step": 1, "tooltip": "Number of denoising steps. More steps = better quality but slower. Default: 20"}),
                "seed": ("INT", {"default": 42, "min": 0, "max": 2**32-1, "tooltip": "Random seed for generation. Default 42 is used in official examples"}),
                "cfg_scale": ("FLOAT", {"default": 1.3, "min": 0.5, "max": 3.5, "step": 0.05, "tooltip": "Classifier-free guidance scale (official default: 1.3)"}),
                "use_sampling": ("BOOLEAN", {"default": False, "tooltip": "Enable sampling mode. When False (default), uses deterministic generation like official examples"}),
            },
            "optional": {
                "speaker1_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 1. If not provided, synthetic voice will be used."}),
                "speaker2_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 2. If not provided, synthetic voice will be used."}),
                "speaker3_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 3. If not provided, synthetic voice will be used."}),
                "speaker4_voice": ("AUDIO", {"tooltip": "Optional: Voice sample for Speaker 4. If not provided, synthetic voice will be used."}),
                "temperature": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 2.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
                "top_p": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 1.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
            }
        }

    RETURN_TYPES = ("AUDIO",)
    RETURN_NAMES = ("audio",)
    FUNCTION = "generate_speech"
    CATEGORY = "VibeVoiceWrapper"
    DESCRIPTION = "Generate multi-speaker conversations with up to 4 distinct voices using Microsoft VibeVoice"

    def _prepare_voice_sample(self, voice_audio, speaker_idx: int) -> Optional[np.ndarray]:
        """Prepare a single voice sample from input audio"""
        return self._prepare_audio_from_comfyui(voice_audio)
    
    def generate_speech(self, text: str = "", model: str = "VibeVoice-7B-Preview",
                       attention_type: str = "auto", free_memory_after_generate: bool = True,
                       diffusion_steps: int = 20, seed: int = 42, cfg_scale: float = 1.3,
                       use_sampling: bool = False, speaker1_voice=None, speaker2_voice=None, 
                       speaker3_voice=None, speaker4_voice=None,
                       temperature: float = 0.95, top_p: float = 0.95):
        """Generate multi-speaker speech from text using VibeVoice"""
        
        try:
            # Check text input
            if not text or not text.strip():
                raise Exception("No text provided. Please enter text with speaker labels (e.g., '[1]: Hello' or '[2]: Hi')")
            
            # First detect how many speakers are in the text
            bracket_pattern = r'\[(\d+)\]\s*:'
            speakers_numbers = sorted(list(set([int(m) for m in re.findall(bracket_pattern, text)])))
            
            # Limit to 1-4 speakers
            if not speakers_numbers:
                num_speakers = 1  # Default to 1 if no speaker format found
            else:
                num_speakers = min(max(speakers_numbers), 4)  # Max speaker number, capped at 4
                if max(speakers_numbers) > 4:
                    print(f"[VibeVoice] Warning: Found {max(speakers_numbers)} speakers, limiting to 4")
            
            # Direct conversion from [N]: to Speaker (N-1): for VibeVoice processor
            # This avoids multiple conversion steps
            converted_text = text
            
            # Find all [N]: patterns in the text
            speakers_in_text = sorted(list(set([int(m) for m in re.findall(bracket_pattern, text)])))
            
            if not speakers_in_text:
                # No [N]: format found, try Speaker N: format
                speaker_pattern = r'Speaker\s+(\d+)\s*:'
                speakers_in_text = sorted(list(set([int(m) for m in re.findall(speaker_pattern, text)])))
                
                if speakers_in_text:
                    # Text already in Speaker N format, convert to 0-based
                    for speaker_num in sorted(speakers_in_text, reverse=True):
                        pattern = f'Speaker\\s+{speaker_num}\\s*:'
                        replacement = f'Speaker {speaker_num - 1}:'
                        converted_text = re.sub(pattern, replacement, converted_text)
                else:
                    # No speaker format found
                    speakers_in_text = [1]
                    
                    # Parse pause keywords even for single speaker
                    pause_segments = self._parse_pause_keywords(text)
                    
                    # Store speaker segments for pause processing
                    speaker_segments_with_pauses = []
                    segments = []
                    
                    for seg_type, seg_content in pause_segments:
                        if seg_type == 'pause':
                            speaker_segments_with_pauses.append(('pause', seg_content, None))
                        else:
                            # Clean up newlines
                            text_clean = seg_content.replace('\n', ' ').replace('\r', ' ')
                            text_clean = ' '.join(text_clean.split())
                            
                            if text_clean:
                                speaker_segments_with_pauses.append(('text', text_clean, 1))
                                segments.append(f"Speaker 0: {text_clean}")
                    
                    # Join all segments for fallback
                    converted_text = '\n'.join(segments) if segments else f"Speaker 0: {text}"
            else:
                # Convert [N]: directly to Speaker (N-1): and handle multi-line text
                # Split text to preserve speaker segments while cleaning up newlines within each segment
                segments = []
                
                # Find all speaker markers with their positions
                speaker_matches = list(re.finditer(f'\\[({"|".join(map(str, speakers_in_text))})\\]\\s*:', converted_text))
                
                # Store speaker segments for pause processing
                speaker_segments_with_pauses = []
                
                for i, match in enumerate(speaker_matches):
                    speaker_num = int(match.group(1))
                    start = match.end()
                    
                    # Find where this speaker's text ends (at next speaker or end of text)
                    if i + 1 < len(speaker_matches):
                        end = speaker_matches[i + 1].start()
                    else:
                        end = len(converted_text)
                    
                    # Extract the speaker's text (keep pause keywords for now)
                    speaker_text = converted_text[start:end].strip()
                    
                    # Parse pause keywords within this speaker's text
                    pause_segments = self._parse_pause_keywords(speaker_text)
                    
                    # Process each segment (text or pause) for this speaker
                    for seg_type, seg_content in pause_segments:
                        if seg_type == 'pause':
                            # Add pause segment
                            speaker_segments_with_pauses.append(('pause', seg_content, None))
                        else:
                            # Clean up the text segment
                            text_clean = seg_content.replace('\n', ' ').replace('\r', ' ')
                            text_clean = ' '.join(text_clean.split())
                            
                            if text_clean:  # Only add non-empty text
                                # Add text segment with speaker info
                                speaker_segments_with_pauses.append(('text', text_clean, speaker_num))
                                # Also build the traditional segments for fallback
                                segments.append(f'Speaker {speaker_num - 1}: {text_clean}')
                
                # Join all segments with newlines (required for multi-speaker format) - for fallback
                converted_text = '\n'.join(segments) if segments else ""
            
            # Build speaker names list - these are just for logging, not used by processor
            # The processor uses the speaker labels in the text itself
            speakers = [f"Speaker {i}" for i in range(len(speakers_in_text))]
            
            # Get model mapping and load model with attention type
            model_mapping = self._get_model_mapping()
            model_path = model_mapping.get(model, model)
            self.load_model(model, model_path, attention_type)
            
            voice_inputs = [speaker1_voice, speaker2_voice, speaker3_voice, speaker4_voice]
            
            # Prepare voice samples in order of appearance
            voice_samples = []
            for i, speaker_num in enumerate(speakers_in_text):
                idx = speaker_num - 1  # Convert to 0-based for voice array
                
                # Try to use provided voice sample
                if idx < len(voice_inputs) and voice_inputs[idx] is not None:
                    voice_sample = self._prepare_voice_sample(voice_inputs[idx], idx)
                    if voice_sample is None:
                        # Use the actual speaker index for consistent synthetic voice
                        voice_sample = self._create_synthetic_voice_sample(idx)
                else:
                    # Use the actual speaker index for consistent synthetic voice
                    voice_sample = self._create_synthetic_voice_sample(idx)
                    
                voice_samples.append(voice_sample)
            
            # Ensure voice_samples count matches detected speakers
            if len(voice_samples) != len(speakers_in_text):
                logger.error(f"Mismatch: {len(speakers_in_text)} speakers but {len(voice_samples)} voice samples!")
                raise Exception(f"Voice sample count mismatch: expected {len(speakers_in_text)}, got {len(voice_samples)}")
            
            # Check if we have pause segments to process
            if 'speaker_segments_with_pauses' in locals() and speaker_segments_with_pauses:
                # Process segments with pauses
                all_audio_segments = []
                sample_rate = 24000  # VibeVoice uses 24kHz
                
                # Group consecutive text segments from same speaker for efficiency
                grouped_segments = []
                current_group = []
                current_speaker = None
                
                for seg_type, seg_content, speaker_num in speaker_segments_with_pauses:
                    if seg_type == 'pause':
                        # Save current group if any
                        if current_group:
                            grouped_segments.append(('text_group', current_group, current_speaker))
                            current_group = []
                            current_speaker = None
                        # Add pause
                        grouped_segments.append(('pause', seg_content, None))
                    else:
                        # Text segment
                        if speaker_num == current_speaker:
                            # Same speaker, add to current group
                            current_group.append(seg_content)
                        else:
                            # Different speaker, save current group and start new one
                            if current_group:
                                grouped_segments.append(('text_group', current_group, current_speaker))
                            current_group = [seg_content]
                            current_speaker = speaker_num
                
                # Save last group if any
                if current_group:
                    grouped_segments.append(('text_group', current_group, current_speaker))
                
                # Process grouped segments
                for seg_type, seg_content, speaker_num in grouped_segments:
                    if seg_type == 'pause':
                        # Generate silence
                        duration_ms = seg_content
                        logger.info(f"Adding {duration_ms}ms pause")
                        silence_audio = self._generate_silence(duration_ms, sample_rate)
                        all_audio_segments.append(silence_audio)
                    else:
                        # Process text group for a speaker
                        combined_text = ' '.join(seg_content)
                        formatted_text = f"Speaker {speaker_num - 1}: {combined_text}"
                        
                        # Get voice sample for this speaker
                        speaker_idx = speakers_in_text.index(speaker_num)
                        speaker_voice_samples = [voice_samples[speaker_idx]]
                        
                        logger.info(f"Generating audio for Speaker {speaker_num}: {len(combined_text.split())} words")
                        
                        # Generate audio for this speaker's text
                        segment_audio = self._generate_with_vibevoice(
                            formatted_text, speaker_voice_samples, cfg_scale, seed,
                            diffusion_steps, use_sampling, temperature, top_p
                        )
                        
                        all_audio_segments.append(segment_audio)
                
                # Concatenate all audio segments
                if all_audio_segments:
                    logger.info(f"Concatenating {len(all_audio_segments)} audio segments (including pauses)...")
                    
                    # Extract waveforms
                    waveforms = []
                    for audio_segment in all_audio_segments:
                        if isinstance(audio_segment, dict) and "waveform" in audio_segment:
                            waveforms.append(audio_segment["waveform"])
                    
                    if waveforms:
                        # Filter out None values if any
                        valid_waveforms = [w for w in waveforms if w is not None]
                        
                        if valid_waveforms:
                            # Concatenate along time dimension
                            combined_waveform = torch.cat(valid_waveforms, dim=-1)
                            
                            audio_dict = {
                                "waveform": combined_waveform,
                                "sample_rate": sample_rate
                            }
                            logger.info(f"Successfully generated multi-speaker audio with pauses")
                        else:
                            raise Exception("No valid audio waveforms generated")
                    else:
                        raise Exception("Failed to extract waveforms from audio segments")
                else:
                    raise Exception("No audio segments generated")
            else:
                # Fallback to original method without pause support
                logger.info("Processing without pause support (no pause keywords found)")
                audio_dict = self._generate_with_vibevoice(
                    converted_text, voice_samples, cfg_scale, seed, diffusion_steps,
                    use_sampling, temperature, top_p
                )
            
            # Free memory if requested
            if free_memory_after_generate:
                self.free_memory()
            
            return (audio_dict,)
                    
        except Exception as e:
            # Check if this is an interruption by the user
            import comfy.model_management as mm
            if isinstance(e, mm.InterruptProcessingException):
                # User interrupted - just log it and re-raise to stop the workflow
                logger.info("Generation interrupted by user")
                raise  # Propagate the interruption to stop the workflow
            else:
                # Real error - show it
                logger.error(f"Multi-speaker speech generation failed: {str(e)}")
                raise Exception(f"Error generating multi-speaker speech: {str(e)}")

    @classmethod
    def IS_CHANGED(cls, text="", model="VibeVoice-7B-Preview",
                   speaker1_voice=None, speaker2_voice=None, 
                   speaker3_voice=None, speaker4_voice=None, **kwargs):
        """Cache key for ComfyUI"""
        voices_hash = hash(str([speaker1_voice, speaker2_voice, speaker3_voice, speaker4_voice]))
        return f"{hash(text)}_{model}_{voices_hash}_{kwargs.get('cfg_scale', 1.3)}_{kwargs.get('seed', 0)}"