Skip to main content

Overview

Parakeet MLX implements four distinct ASR model architectures, each optimized for different use cases. All models share a common Conformer encoder but differ in their decoding strategies and output layers.

TDT

Token-and-Duration Transducer - Predicts both tokens and durations

RNNT

RNN Transducer - Classic transducer with frame-by-frame decoding

CTC

Connectionist Temporal Classification - Non-autoregressive decoding

TDT-CTC

Hybrid model with both TDT and auxiliary CTC decoders

Shared Architecture

All Parakeet models inherit from BaseParakeet and share these common components:

Conformer Encoder

The encoder transforms mel-spectrogram input into high-level acoustic features using the Conformer architecture:
# From conformer.py:332-366
class Conformer(nn.Module):
    - Subsampling layer (downsampling by factor 4-8)
    - Positional encoding (relative or local)
    - N Conformer blocks (typically 17-24 layers)
Each ConformerBlock combines:
  • Feed-forward module (with SiLU activation)
  • Multi-head self-attention (with relative positional encoding)
  • Convolution module (depthwise separable convolution)
  • Layer normalization and residual connections

Audio Preprocessing

Before encoding, audio undergoes:
  1. Resampling to 16kHz
  2. Log-mel spectrogram extraction (80 mel bins)
  3. Normalization

Parakeet TDT

Token-and-Duration Transducer extends RNN-T by jointly predicting tokens and their durations, enabling more accurate timestamp alignment.

Architecture

1

Encoder

Conformer encoder produces acoustic features
features, lengths = self.encoder(mel)  # [B, S, d_model]
2

Prediction Network

LSTM-based network maintains language model state
# From rnnt.py:88-116
class PredictNetwork:
    - Embedding layer (vocab_size -> pred_hidden)
    - Multi-layer LSTM (maintains hidden state)
3

Joint Network

Combines encoder and predictor outputs
# From rnnt.py:119-155
joint_out = self.joint(enc_features, pred_features)
# Output: [batch, time, 1, vocab_size + 1 + num_durations]
The output has vocab_size + 1 for tokens (including blank) plus additional logits for duration predictions
4

Duration Modeling

TDT predicts discrete duration values
# From parakeet.py:276-277
self.durations = args.decoding.durations  # e.g., [0, 1, 2, 3, 4]
  • Duration 0: Emit token without advancing time
  • Duration 1-4: Emit token and advance by N frames

Decoding Process

TDT supports both greedy and beam search decoding:
while step < length:
    # 1. Run prediction network
    decoder_out, hidden_state = self.decoder(last_token, hidden_state)
    
    # 2. Compute joint network output
    joint_out = self.joint(features[:, step:step+1], decoder_out)
    
    # 3. Sample token and duration
    token = argmax(joint_out[..., :vocab_size+1])
    duration = self.durations[argmax(joint_out[..., vocab_size+1:])]
    
    # 4. Emit non-blank tokens
    if token != blank:
        hypothesis.append(AlignedToken(
            id=token,
            start=step * time_ratio,
            duration=duration * time_ratio,
            confidence=confidence_score
        ))
        last_token = token
        hidden_state = new_hidden_state
    
    # 5. Advance time
    step += duration
Stuck Prevention: If duration=0 for max_symbols consecutive steps, force step += 1
Beam search explores multiple hypotheses in parallel:
# Key parameters
beam_size: int = 5           # Top-K hypotheses to maintain
length_penalty: float = 1.0  # Normalize by length^penalty
patience: float = 1.0        # Keep patience*beam_size candidates
duration_reward: float = 0.7 # Weight for duration vs token logprobs
Each hypothesis tracks:
  • score: Combined log probability
  • step: Current time position
  • last_token: Previous emitted token
  • hidden_state: Decoder LSTM state
  • hypothesis: List of AlignedToken
Scoring combines token and duration predictions:
score = (
    token_logprob * (1 - duration_reward) +
    duration_logprob * duration_reward
)

Configuration

from parakeet_mlx import ParakeetTDTArgs, DecodingConfig, Beam

args = ParakeetTDTArgs(
    preprocessor=PreprocessArgs(...),
    encoder=ConformerArgs(...),
    decoder=PredictArgs(...),
    joint=JointArgs(...),
    decoding=TDTDecodingArgs(
        model_type="tdt",
        durations=[0, 1, 2, 3, 4],
        greedy={"max_symbols": 10}
    )
)

Parakeet RNNT

RNN Transducer uses the classic transducer formulation where each token has an implicit duration of 1 frame.

Key Differences from TDT

No Duration Prediction

Joint network outputs only vocab_size + 1 logits (no duration head)

Fixed Duration

All tokens have duration = 1 * time_ratio

Greedy Only

Currently only supports greedy decoding (beam search not implemented)

Simpler Decoding

Blank advances time by 1, non-blank emits token with duration=1

Decoding Logic

# From parakeet.py:655-743
while step < length:
    decoder_out, hidden_state = self.decoder(last_token, hidden_state)
    joint_out = self.joint(features[:, step:step+1], decoder_out)
    
    pred_token = argmax(joint_out[0, 0])  # No duration dimension
    
    if pred_token != blank:
        hypothesis.append(AlignedToken(
            id=pred_token,
            start=step * time_ratio,
            duration=1 * time_ratio,  # Fixed duration
            confidence=confidence
        ))
        last_token = pred_token
        hidden_state = new_hidden_state
        # Don't advance step yet (emit multiple tokens at same time)
    else:
        step += 1  # Blank advances time
RNNT can get stuck emitting tokens without advancing. The max_symbols parameter forces step += 1 after too many consecutive non-blank tokens.

Configuration

args = ParakeetRNNTArgs(
    preprocessor=PreprocessArgs(...),
    encoder=ConformerArgs(...),
    decoder=PredictArgs(...),
    joint=JointArgs(...),
    decoding=RNNTDecodingArgs(
        greedy={"max_symbols": 10}
    )
)

Parakeet CTC

Connectionist Temporal Classification uses a simpler, non-autoregressive approach without a prediction network.

Architecture

1

Encoder

Same Conformer encoder as other models
2

CTC Decoder

Simple 1D convolution projects encoder features to vocabulary
# From ctc.py:19-34
class ConvASRDecoder:
    def __init__(self, args):
        self.decoder_layers = [
            nn.Conv1d(feat_in, num_classes, kernel_size=1, bias=True)
        ]
    
    def __call__(self, x):
        return nn.log_softmax(
            self.decoder_layers[0](x) / self.temperature
        )
3

CTC Decoding

Collapse repeated tokens and remove blanks

Decoding Process

# From parakeet.py:774-889
logits = self.decoder(features)  # [B, S, vocab_size+1]
predictions = argmax(logits, axis=1)  # Greedy per-frame prediction

# Collapse repetitions and remove blanks
hypothesis = []
prev_token = -1
for t, token_id in enumerate(predictions):
    if token_id == blank:  # Skip blank
        continue
    if token_id == prev_token:  # Skip repetition
        continue
    
    # Token boundary detected
    hypothesis.append(AlignedToken(
        id=token_id,
        start=token_start * time_ratio,
        duration=(t - token_start) * time_ratio,
        confidence=avg_confidence_over_frames
    ))
    prev_token = token_id

Advantages & Limitations

  • Fast: No autoregressive decoding loop
  • Simple: No hidden state management
  • Parallelizable: All predictions are independent
  • Memory efficient: Smaller model (no prediction network)

Configuration

args = ParakeetCTCArgs(
    preprocessor=PreprocessArgs(...),
    encoder=ConformerArgs(...),
    decoder=ConvASRDecoderArgs(
        feat_in=512,
        num_classes=1024,  # Set to -1 to use len(vocabulary)
        vocabulary=vocabulary
    ),
    decoding=CTCDecodingArgs(
        greedy={}  # CTC only supports greedy
    )
)

Parakeet TDT-CTC

Hybrid architecture that combines TDT with an auxiliary CTC decoder for multi-task learning.

Architecture

# From parakeet.py:909-917
class ParakeetTDTCTC(ParakeetTDT):
    """Has ConvASRDecoder in .ctc_decoder but .generate uses TDT decoder"""
    
    def __init__(self, args: ParakeetTDTCTCArgs):
        super().__init__(args)  # Initialize TDT components
        self.ctc_decoder = ConvASRDecoder(args.aux_ctc.decoder)
During training, both decoders are used (multi-task learning). During inference, only the TDT decoder is used by default via .generate().

Configuration

args = ParakeetTDTCTCArgs(
    # All TDT parameters
    preprocessor=PreprocessArgs(...),
    encoder=ConformerArgs(...),
    decoder=PredictArgs(...),
    joint=JointArgs(...),
    decoding=TDTDecodingArgs(...),
    # Additional CTC decoder
    aux_ctc=AuxCTCArgs(
        decoder=ConvASRDecoderArgs(...)
    )
)

Model Comparison

FeatureTDTRNNTCTCTDT-CTC
DecodingAutoregressiveAutoregressiveNon-autoregressiveAutoregressive (inference)
Beam Search✅ Yes❌ No❌ No✅ Yes
Duration Modeling✅ Explicit❌ Fixed (1 frame)⚠️ Inferred✅ Explicit
Timestamp Accuracy⭐⭐⭐ High⭐⭐ Medium⭐⭐ Medium⭐⭐⭐ High
Speed⭐⭐ Medium⭐⭐ Medium⭐⭐⭐ Fast⭐⭐ Medium
WER⭐⭐⭐ Best⭐⭐ Good⭐ Fair⭐⭐⭐ Best
Language Model✅ Internal✅ Internal❌ None✅ Internal

Choosing a Model

Use TDT when...

  • You need accurate word-level timestamps
  • Quality is more important than speed
  • You have sufficient compute for beam search
  • Recommended: mlx-community/parakeet-tdt-0.6b-v3

Use RNNT when...

  • You need balanced accuracy and speed
  • Greedy decoding is sufficient
  • Model size is a concern
  • Recommended: mlx-community/parakeet-rnnt-1.1b

Use CTC when...

  • Speed is the top priority
  • You’re processing long audio files
  • Approximate timestamps are acceptable
  • Recommended: mlx-community/parakeet-ctc-1.1b

Use TDT-CTC when...

  • You want the best of both worlds
  • Training with multi-task learning
  • Maximum accuracy is required
  • Note: Uses TDT decoder for inference

Next Steps

Decoding Strategies

Learn about greedy vs beam search decoding

Timestamp System

Understand how timestamps are computed

Build docs developers (and LLMs) love