Skip to main content

Overview

ParakeetTDT implements the Token-and-Duration Transducer architecture, which jointly predicts both tokens and their durations. This is the recommended model variant for most use cases. Key features:
  • Simultaneous token and duration prediction
  • Supports both greedy and beam search decoding
  • Best accuracy among Parakeet variants
  • Suitable for both real-time and offline transcription

Class Definition

class ParakeetTDT(BaseParakeet):
    def __init__(self, args: ParakeetTDTArgs):
        ...

Inherited Methods

ParakeetTDT inherits all methods from BaseParakeet:
  • transcribe() - Transcribe audio files
  • transcribe_stream() - Real-time streaming transcription
  • generate() - Low-level mel-spectrogram to text
See BaseParakeet documentation for details.

TDT-Specific Methods

decode()

Low-level decoding method that converts encoder features to aligned tokens.
def decode(
    self,
    features: mx.array,
    lengths: Optional[mx.array] = None,
    last_token: Optional[list[Optional[int]]] = None,
    hidden_state: Optional[list[Optional[tuple[mx.array, mx.array]]]] = None,
    *,
    config: DecodingConfig = DecodingConfig(),
) -> tuple[list[list[AlignedToken]], list[Optional[tuple[mx.array, mx.array]]]]

Parameters

features
mx.array
required
Encoder output features with shape [batch, sequence, feature_dim].Typically obtained from:
features, lengths = model.encoder(mel)
lengths
mx.array | None
default:"None"
Valid length of each sequence in the batch. Shape: [batch].If None, assumes all sequences have full length features.shape[1].
last_token
list[Optional[int]] | None
default:"None"
Last predicted token ID for each batch item. Used for stateful decoding in streaming scenarios.
  • Pass None for each batch item to start fresh
  • Pass token IDs from previous decode call for continuity
hidden_state
list[Optional[tuple[mx.array, mx.array]]] | None
default:"None"
Hidden state (LSTM hidden and cell) for each batch item. Used for stateful decoding.
  • Pass None for each batch item to start fresh
  • Pass states from previous decode call for continuity
config
DecodingConfig
default:"DecodingConfig()"
Decoding configuration. TDT supports:
  • Greedy() - Fast greedy decoding
  • Beam() - Beam search with configurable parameters

Returns

tokens
list[list[AlignedToken]]
List of token sequences, one per batch item. Each token includes:
  • id - Token ID in vocabulary
  • text - Decoded text
  • start - Start time in seconds
  • duration - Duration in seconds
  • confidence - Confidence score (0.0 to 1.0)
hidden_states
list[Optional[tuple[mx.array, mx.array]]]
Updated hidden states for each batch item. Pass these to subsequent decode() calls for streaming.

Examples

Basic decoding with greedy search:
import mlx.core as mx
from parakeet_mlx import from_pretrained, DecodingConfig, Greedy
from parakeet_mlx.audio import get_logmel, load_audio
from typing import cast

model = from_pretrained("mlx-community/parakeet-tdt-0.6b-v3")
model_tdt = cast(ParakeetTDT, model)

# Prepare input
audio = load_audio("audio.wav", model.preprocessor_config.sample_rate)
mel = get_logmel(audio, model.preprocessor_config)

# Encode
features, lengths = model.encoder(mel)

# Decode with greedy
tokens, _ = model_tdt.decode(
    features,
    lengths,
    config=DecodingConfig(decoding=Greedy())
)

# Print tokens
for token in tokens[0]:
    print(f"[{token.start:.2f}s] {token.text} (conf: {token.confidence:.2f})")
Beam search decoding:
from parakeet_mlx import DecodingConfig, Beam

config = DecodingConfig(
    decoding=Beam(
        beam_size=5,
        length_penalty=0.013,
        patience=3.5,
        duration_reward=0.67
    )
)

tokens, _ = model_tdt.decode(features, lengths, config=config)
Stateful streaming decoding:
# First chunk
features1, lengths1 = model.encoder(mel1)
tokens1, hidden1 = model_tdt.decode(features1, lengths1)

# Get last token for continuity
last_token_id = tokens1[0][-1].id if tokens1[0] else None

# Second chunk - continues from first
features2, lengths2 = model.encoder(mel2)
tokens2, hidden2 = model_tdt.decode(
    features2,
    lengths2,
    last_token=[last_token_id],
    hidden_state=hidden1
)

# Combine results
all_tokens = tokens1[0] + tokens2[0]
Batch decoding:
# Process multiple mel-spectrograms
batch_mel = mx.stack([mel1, mel2, mel3])
features, lengths = model.encoder(batch_mel)

# Decode all at once
tokens_batch, _ = model_tdt.decode(features, lengths)

for i, tokens in enumerate(tokens_batch):
    text = "".join(t.text for t in tokens)
    print(f"Input {i}: {text}")

Decoding Algorithms

Greedy Decoding

Fast, single-pass decoding that selects the most likely token at each step. Characteristics:
  • Fastest inference
  • Good accuracy for clear audio
  • Deterministic output
  • Low memory usage
Configuration:
from parakeet_mlx import DecodingConfig, Greedy

config = DecodingConfig(decoding=Greedy())

Beam Search Decoding

Explores multiple hypotheses simultaneously for better accuracy. Characteristics:
  • Higher accuracy, especially for challenging audio
  • Slower than greedy
  • Non-deterministic (can vary slightly)
  • Higher memory usage
Configuration:
from parakeet_mlx import DecodingConfig, Beam

config = DecodingConfig(
    decoding=Beam(
        beam_size=5,           # Number of hypotheses to track
        length_penalty=0.013,  # Penalty for longer sequences (0.0 to disable)
        patience=3.5,          # Search until patience × beam_size candidates found
        duration_reward=0.67   # Weight between token and duration logprobs
    )
)
Beam parameters:
beam_size
int
default:"5"
Number of top hypotheses to maintain. Higher values improve accuracy but increase computation.Typical values: 3-10
length_penalty
float
default:"1.0"
Penalty applied based on sequence length. Helps prevent overly short or long predictions.
  • 0.0 - No penalty
  • < 1.0 - Favors shorter sequences
  • > 1.0 - Favors longer sequences
Formula: score / (sequence_length ** length_penalty)
patience
float
default:"1.0"
Controls when to stop searching. Search continues until patience × beam_size complete hypotheses are found.
  • 1.0 - Stop as soon as beam_size hypotheses complete
  • > 1.0 - Continue searching for potentially better hypotheses
duration_reward
float
default:"0.7"
Weight between token and duration predictions (TDT-specific).
  • 0.0 - Only use token logprobs
  • 1.0 - Only use duration logprobs
  • 0.5 - Equal weight
  • < 0.5 - Favor token predictions
  • > 0.5 - Favor duration predictions
Formula: token_logprob × (1 - duration_reward) + duration_logprob × duration_reward

Model Properties

model.vocabulary       # list[str] - Token vocabulary
model.durations        # list[int] - Allowed duration values
model.max_symbols      # int | None - Max symbols before forcing advance
model.decoder          # PredictNetwork - Prediction network
model.joint            # JointNetwork - Joint network
model.encoder          # Conformer - Encoder network

Architecture Details

TDT decoding process:
  1. Encoder: Converts mel-spectrogram to features
    features, lengths = model.encoder(mel)
    
  2. Decoder: Predicts next token based on history
    decoder_out, (hidden, cell) = model.decoder(last_token, hidden_state)
    
  3. Joint: Combines encoder and decoder outputs
    joint_out = model.joint(features, decoder_out)
    
  4. Decision: Extract token and duration predictions
    token_logits = joint_out[..., :vocab_size+1]      # +1 for blank
    duration_logits = joint_out[..., vocab_size+1:]   # Duration options
    
  5. Advance: Move forward by predicted duration
    • Non-blank token: Update history, advance by duration
    • Blank token: Advance by duration, keep same history

Performance Tips

  1. Use greedy for real-time: Greedy decoding is 3-5x faster than beam search
  2. Use beam for accuracy: Beam search improves WER by 5-15% on challenging audio
  3. Tune duration_reward: Adjust based on your audio characteristics
    • Speech with clear pauses: higher values (0.7-0.8)
    • Fast speech or music: lower values (0.5-0.6)
  4. Batch when possible: Process multiple files together for better GPU utilization

Build docs developers (and LLMs) love