Skip to main content

Overview

ParakeetRNNT implements the RNN-Transducer architecture, a simpler variant than TDT that predicts only tokens (with fixed duration of 1 frame per token). Key features:
  • Standard RNN-T architecture
  • Simpler than TDT (no duration prediction)
  • Currently supports only greedy decoding
  • Good balance between speed and accuracy
  • Suitable for streaming applications

Class Definition

class ParakeetRNNT(BaseParakeet):
    def __init__(self, args: ParakeetRNNTArgs):
        ...

Inherited Methods

ParakeetRNNT inherits all methods from BaseParakeet:
  • transcribe() - Transcribe audio files
  • transcribe_stream() - Real-time streaming transcription
  • generate() - Low-level mel-spectrogram to text
See BaseParakeet documentation for details.

RNNT-Specific Methods

decode()

Low-level decoding method that converts encoder features to aligned tokens using greedy decoding.
def decode(
    self,
    features: mx.array,
    lengths: Optional[mx.array] = None,
    last_token: Optional[list[Optional[int]]] = None,
    hidden_state: Optional[list[Optional[tuple[mx.array, mx.array]]]] = None,
    *,
    config: DecodingConfig = DecodingConfig(),
) -> tuple[list[list[AlignedToken]], list[Optional[tuple[mx.array, mx.array]]]]
Currently only greedy decoding is supported for RNNT models. Passing Beam() in config will raise an assertion error.

Parameters

features
mx.array
required
Encoder output features with shape [batch, sequence, feature_dim].Typically obtained from:
features, lengths = model.encoder(mel)
lengths
mx.array | None
default:"None"
Valid length of each sequence in the batch. Shape: [batch].If None, assumes all sequences have full length features.shape[1].
last_token
list[Optional[int]] | None
default:"None"
Last predicted token ID for each batch item. Used for stateful decoding in streaming scenarios.
  • Pass None for each batch item to start fresh
  • Pass token IDs from previous decode call for continuity
hidden_state
list[Optional[tuple[mx.array, mx.array]]] | None
default:"None"
Hidden state (LSTM hidden and cell) for each batch item. Used for stateful decoding.
  • Pass None for each batch item to start fresh
  • Pass states from previous decode call for continuity
config
DecodingConfig
default:"DecodingConfig()"
Decoding configuration. Must use Greedy() - beam search is not yet supported.

Returns

tokens
list[list[AlignedToken]]
List of token sequences, one per batch item. Each token includes:
  • id - Token ID in vocabulary
  • text - Decoded text
  • start - Start time in seconds
  • duration - Duration in seconds (always equals time_ratio for RNNT)
  • confidence - Confidence score (0.0 to 1.0)
hidden_states
list[Optional[tuple[mx.array, mx.array]]]
Updated hidden states for each batch item. Pass these to subsequent decode() calls for streaming.

Examples

Basic greedy decoding:
import mlx.core as mx
from parakeet_mlx import from_pretrained, DecodingConfig, Greedy
from parakeet_mlx.audio import get_logmel, load_audio
from typing import cast

model = from_pretrained("mlx-community/parakeet-rnnt-0.6b")
model_rnnt = cast(ParakeetRNNT, model)

# Prepare input
audio = load_audio("audio.wav", model.preprocessor_config.sample_rate)
mel = get_logmel(audio, model.preprocessor_config)

# Encode
features, lengths = model.encoder(mel)

# Decode
tokens, _ = model_rnnt.decode(
    features,
    lengths,
    config=DecodingConfig(decoding=Greedy())
)

# Print tokens
for token in tokens[0]:
    print(f"[{token.start:.2f}s] {token.text} (conf: {token.confidence:.2f})")
Stateful streaming decoding:
# First chunk
features1, lengths1 = model.encoder(mel1)
tokens1, hidden1 = model_rnnt.decode(features1, lengths1)

# Get last token for continuity
last_token_id = tokens1[0][-1].id if tokens1[0] else None

# Second chunk - continues from first
features2, lengths2 = model.encoder(mel2)
tokens2, hidden2 = model_rnnt.decode(
    features2,
    lengths2,
    last_token=[last_token_id],
    hidden_state=hidden1
)

# Combine results
all_tokens = tokens1[0] + tokens2[0]
text = "".join(t.text for t in all_tokens)
print(text)
Batch decoding:
# Process multiple mel-spectrograms
batch_mel = mx.stack([mel1, mel2, mel3])
features, lengths = model.encoder(batch_mel)

# Decode all at once
tokens_batch, _ = model_rnnt.decode(features, lengths)

for i, tokens in enumerate(tokens_batch):
    text = "".join(t.text for t in tokens)
    print(f"Input {i}: {text}")
With progress tracking in streaming:
last_token_id = None
hidden_state = None
all_tokens = []

# Process in chunks
for mel_chunk in mel_chunks:
    features, lengths = model.encoder(mel_chunk)
    
    tokens, hidden_state = model_rnnt.decode(
        features,
        lengths,
        last_token=[last_token_id],
        hidden_state=[hidden_state] if hidden_state else None
    )
    
    # Update state
    if tokens[0]:
        last_token_id = tokens[0][-1].id
        all_tokens.extend(tokens[0])
        
        # Print incremental result
        text = "".join(t.text for t in all_tokens)
        print(f"\rCurrent: {text}", end="")

print()  # New line

Decoding Algorithm

Greedy Decoding

RNNT currently supports only greedy decoding: Characteristics:
  • Fast, single-pass decoding
  • Deterministic output
  • Low memory usage
  • Good accuracy for clear audio
RNNT decoding process:
  1. For each encoder frame:
    • Get decoder prediction for current history
    • Compute joint output
    • Select most likely token (argmax)
    • If non-blank: Emit token, update history, continue
    • If blank: Advance to next frame
  2. Each non-blank token gets duration of 1 frame
  3. Stuck prevention: If too many non-blank emissions without advancing, force advance
Differences from TDT:
  • No duration prediction (always 1 frame per token)
  • Simpler joint output (only token logits)
  • No duration_reward parameter
  • Beam search not yet implemented

Model Properties

model.vocabulary       # list[str] - Token vocabulary
model.max_symbols      # int | None - Max symbols before forcing advance
model.decoder          # PredictNetwork - Prediction network
model.joint            # JointNetwork - Joint network
model.encoder          # Conformer - Encoder network

Architecture Details

RNNT decoding process:
  1. Encoder: Converts mel-spectrogram to features
    features, lengths = model.encoder(mel)
    
  2. Decoder: Predicts next token based on history
    decoder_out, (hidden, cell) = model.decoder(last_token, hidden_state)
    
  3. Joint: Combines encoder and decoder outputs
    joint_out = model.joint(features, decoder_out)
    
  4. Decision: Extract token prediction
    token_logits = joint_out[0, 0]  # Shape: [vocab_size + 1]
    pred_token = mx.argmax(token_logits)  # Greedy selection
    
  5. Advance:
    • Non-blank token: Emit token (duration=1), update history, stay on frame
    • Blank token: Advance to next frame, keep same history
Comparison with TDT:
FeatureRNNTTDT
Token prediction
Duration prediction
Beam search
Greedy decoding
Streaming support
SpeedFastModerate
AccuracyGoodBetter

Max Symbols Prevention

To prevent the model from getting stuck emitting non-blank tokens without advancing:
model.max_symbols  # Default from model config
When max_symbols consecutive non-blank tokens are emitted:
  • Force advance to next frame
  • Reset emission counter
  • Continue decoding
This ensures decoding always terminates and doesn’t loop infinitely.

Performance Tips

  1. Use for streaming: RNNT’s simpler architecture makes it well-suited for real-time streaming
  2. Batch processing: Process multiple files together for better throughput
  3. State management: Carefully manage last_token and hidden_state for streaming
  4. Memory efficiency: RNNT uses less memory than TDT (no duration prediction)

When to Use RNNT

Choose RNNT when:
  • You need streaming transcription
  • You want simpler architecture
  • Speed is more important than maximum accuracy
  • You don’t need beam search
  • Memory is constrained
Choose TDT when:
  • You need maximum accuracy
  • You want beam search capability
  • You can accept slightly slower inference
  • Duration prediction is valuable for your use case

Build docs developers (and LLMs) love