VibeVoiceStreamingForConditionalGenerationInference
The main inference model for VibeVoice streaming text-to-speech generation. This model enables real-time streaming of speech output with interleaved text processing and audio generation.Class Signature
Initialization
Configuration object containing model architecture settings
Key Properties
The noise scheduler used for diffusion-based speech generation
The diffusion head that predicts noise during speech token sampling
Scaling factor applied to speech latents before decoding
Bias factor applied to speech latents before decoding
The acoustic tokenizer that decodes speech latents to audio waveforms
Methods
from_pretrained
Load a pretrained model from HuggingFace Hub or local directory.Path to pretrained model or model identifier from huggingface.co/models
Data type for model weights. Use
torch.bfloat16 for CUDA, torch.float32 for MPS/CPUDevice placement strategy. Options:
"cuda", "cpu", "mps", or "auto"Attention implementation. Options:
"flash_attention_2" (recommended for CUDA), "sdpa"generate
Generate speech from text inputs with streaming support.Prompt input IDs (typically from processor output)
Configuration for generation. Set
do_sample=False for deterministic outputOptional streamer to receive audio chunks during generation
Full text tokens to stream in windows during generation
Classifier-free guidance scale for speech diffusion. Higher values (1.5-3.0) increase adherence to conditioning
Whether to concatenate and return speech audio tensors
Optional callback function that returns True to halt generation early
Tokenizer instance (from processor.tokenizer)
Cached prompt outputs containing KV caches for lm, tts_lm, neg_lm, and neg_tts_lm
Maximum number of new tokens to generate. If None, uses max_position_embeddings
Whether to print generation progress information
Generation output containing:
sequences(torch.LongTensor): Generated token IDsspeech_outputs(List[torch.FloatTensor]): List of audio waveforms for each samplereach_max_step_sample(torch.BoolTensor): Flags indicating samples that reached max length
set_ddpm_inference_steps
Set the number of diffusion denoising steps for speech generation.Number of inference steps for diffusion sampling. Default is from config. Lower values (5) are faster but may reduce quality
set_speech_tokenizers
Set the acoustic tokenizer used for encoding and decoding speech.Custom acoustic tokenizer instance
forward_lm
Single forward pass through the base language model (text encoding).Input token IDs of shape
(batch_size, sequence_length)Attention mask of shape
(batch_size, sequence_length)Cached key-value states from previous forward passes
Whether to return key-value cache for next iteration
Positions for cached tokens
Output containing:
last_hidden_state(torch.FloatTensor): Hidden states from final layerpast_key_values(Tuple): Cached attention statesattentions(Tuple, optional): Attention weights
forward_tts_lm
Single forward pass through the TTS language model (text + speech encoding).Input token IDs of shape
(batch_size, sequence_length)Hidden states from base LM to splice into input embeddings, shape
(batch_size, K, hidden_size)Mask indicating text (1) vs speech (0) positions, shape
(batch_size, 1)Attention mask of shape
(batch_size, sequence_length)Cached key-value states from previous forward passes
Output containing:
logits(torch.FloatTensor): EOS prediction logits from binary classifierlast_hidden_state(torch.FloatTensor): Hidden states from final layerpast_key_values(Tuple): Cached attention states
sample_speech_tokens
Sample speech latent tokens using diffusion with classifier-free guidance.Positive conditioning from TTS LM hidden states
Negative (unconditional) conditioning from TTS LM
Classifier-free guidance scale
Sampled speech latent vectors of shape
(batch_size, acoustic_vae_dim)Usage Example
Notes
- The model currently only supports batch size of 1
- Text is processed in windows of 5 tokens (TTS_TEXT_WINDOW_SIZE)
- Speech is generated in windows of 6 tokens (TTS_SPEECH_WINDOW_SIZE)
- The
forward()method is intentionally disabled - useforward_lm(),forward_tts_lm(), orgenerate()instead - For CUDA, use
flash_attention_2andtorch.bfloat16for best performance - For MPS (Apple Silicon), use
sdpaattention andtorch.float32 - For CPU, use
sdpaattention andtorch.float32
VibeVoiceGenerationOutput
Output dataclass returned by thegenerate() method.
Fields
Generated token sequences of shape
(batch_size, sequence_length) containing both input and generated tokensList of generated speech waveforms. Each tensor is of shape
(1, num_samples) containing the audio at 24kHz sample rate. Returns None if return_speech=FalseBoolean flags of shape
(batch_size,) indicating which samples stopped due to reaching maximum generation length