Skip to main content

Overview

Decoding converts the model’s continuous probability distributions into discrete text sequences. Parakeet MLX implements two decoding strategies with different speed-accuracy tradeoffs.

Greedy Decoding

Fast, deterministic single-path search

Beam Search

Slower, explores multiple hypotheses for better accuracy

Decoding Configuration

Basic Usage

from parakeet_mlx import from_pretrained, DecodingConfig, Greedy, Beam

model = from_pretrained("mlx-community/parakeet-tdt-0.6b-v3")

# Greedy decoding (default)
result = model.transcribe(
    "audio.wav",
    decoding_config=DecodingConfig(decoding=Greedy())
)

# Beam search decoding
result = model.transcribe(
    "audio.wav",
    decoding_config=DecodingConfig(
        decoding=Beam(
            beam_size=5,
            length_penalty=1.0,
            patience=1.0,
            duration_reward=0.7  # TDT only
        )
    )
)

Configuration Classes

# From parakeet.py:77-78
@dataclass
class Greedy:
    pass  # No parameters needed
Greedy decoding is parameterless - it always selects the highest probability token at each step.

Greedy Decoding

Algorithm

Greedy decoding follows the highest probability path at each step:
1

Initialize

  • step = 0 (current time frame)
  • hypothesis = [] (output tokens)
  • hidden_state = None (decoder LSTM state)
  • last_token = None (previous emission)
2

Loop until end

While step < length:
  1. Prediction Network: Generate decoder output
    decoder_out, hidden = self.decoder(last_token, hidden_state)
    
  2. Joint Network: Combine with encoder features
    joint_out = self.joint(features[:, step:step+1], decoder_out)
    
  3. Sampling: Select highest probability
    token = argmax(joint_out[..., :vocab_size+1])
    
3

Update state

  • If token != blank: emit token, update hidden_state and last_token
  • Advance time by duration (TDT) or 1 frame (RNNT)
  • Check stuck prevention rules

TDT Greedy Implementation

# From parakeet.py:526-618 (simplified)
def decode_greedy(self, features, lengths, last_token, hidden_state, config):
    results = []
    
    for batch in range(B):
        hypothesis = []
        step = 0
        new_symbols = 0  # For stuck prevention
        
        while step < length:
            # Run decoder
            decoder_out, (hidden, cell) = self.decoder(
                mx.array([[last_token[batch]]]) if last_token[batch] else None,
                hidden_state[batch]
            )
            
            # Joint network
            joint_out = self.joint(features[:, step:step+1], decoder_out)
            
            # Sample token and duration
            token_logits = joint_out[0, 0, :, :len(self.vocabulary)+1]
            pred_token = int(mx.argmax(token_logits))
            
            duration_logits = joint_out[0, 0, :, len(self.vocabulary)+1:]
            decision = int(mx.argmax(duration_logits))
            duration = self.durations[decision]  # e.g., [0,1,2,3,4]
            
            # Compute confidence using entropy
            token_probs = mx.softmax(token_logits, axis=-1)
            entropy = -mx.sum(token_probs * mx.log(token_probs + 1e-10))
            max_entropy = mx.log(mx.array(vocab_size))
            confidence = 1.0 - (entropy / max_entropy)
            
            # Emit non-blank tokens
            if pred_token != len(self.vocabulary):  # Not blank
                hypothesis.append(AlignedToken(
                    id=pred_token,
                    start=step * self.time_ratio,
                    duration=duration * self.time_ratio,
                    confidence=float(confidence),
                    text=decode([pred_token], self.vocabulary)
                ))
                last_token[batch] = pred_token
                hidden_state[batch] = (hidden, cell)
            
            # Advance time
            step += duration
            
            # Stuck prevention: force advance if too many duration=0
            new_symbols += 1
            if duration != 0:
                new_symbols = 0
            elif self.max_symbols and new_symbols >= self.max_symbols:
                step += 1
                new_symbols = 0
        
        results.append(hypothesis)
    
    return results, hidden_state
Problem: Duration prediction can output 0 repeatedly, causing infinite loops.Solution: Track consecutive zero-duration emissions with new_symbols counter:
# From parakeet.py:606-614
new_symbols += 1

if self.durations[decision] != 0:
    new_symbols = 0  # Reset on any forward movement
else:
    if self.max_symbols and new_symbols >= self.max_symbols:
        step += 1  # Force advance
        new_symbols = 0
Default max_symbols=10 prevents getting stuck emitting tokens at the same frame.

RNNT Greedy Implementation

# From parakeet.py:655-743 (simplified)
def decode(self, features, lengths, last_token, hidden_state, config):
    # Similar to TDT but simpler:
    
    while step < length:
        decoder_out, hidden = self.decoder(last_token, hidden_state)
        joint_out = self.joint(features[:, step:step+1], decoder_out)
        
        pred_token = argmax(joint_out[0, 0])  # No duration dimension
        
        if pred_token != blank:
            hypothesis.append(AlignedToken(
                id=pred_token,
                start=step * time_ratio,
                duration=1 * time_ratio,  # Fixed duration
                confidence=confidence
            ))
            last_token = pred_token
            hidden_state = new_hidden
            # Don't advance step - can emit multiple tokens
            
            new_symbols += 1
            if max_symbols and new_symbols >= max_symbols:
                step += 1  # Force advance
                new_symbols = 0
        else:
            step += 1  # Blank advances time
            new_symbols = 0
RNNT differs from TDT: blank tokens advance time, while non-blank tokens stay at the same frame (allowing multiple emissions before moving forward).

CTC Greedy Implementation

# From parakeet.py:774-889 (simplified)
def decode(self, features, lengths, config):
    logits = self.decoder(features)  # [B, S, vocab_size+1]
    
    for batch in range(B):
        predictions = argmax(logits[batch, :length], axis=1)
        probs = exp(logits[batch, :length])
        
        hypothesis = []
        prev_token = -1
        token_boundaries = []
        
        # Collapse repetitions and remove blanks
        for t, token_id in enumerate(predictions):
            if token_id == blank:
                continue
            if token_id == prev_token:  # Skip repetition
                continue
            
            # New token boundary detected
            if prev_token != -1:
                # Finalize previous token
                start_frame = token_boundaries[-1][0]
                end_frame = t
                
                # Compute average confidence over token frames
                token_probs = probs[start_frame:end_frame]
                entropies = -sum(token_probs * log(token_probs + 1e-10))
                avg_entropy = mean(entropies)
                confidence = 1.0 - (avg_entropy / log(vocab_size))
                
                hypothesis.append(AlignedToken(
                    id=prev_token,
                    start=start_frame * time_ratio,
                    duration=(end_frame - start_frame) * time_ratio,
                    confidence=confidence
                ))
            
            token_boundaries.append((t, None))
            prev_token = token_id
        
        # Handle last token
        # ...
CTC decoding is non-autoregressive: all predictions are made independently, then post-processed to collapse repetitions.

Performance Characteristics

MetricTDTRNNTCTC
Speed~100ms/sec~100ms/sec~50ms/sec
MemoryO(1) per stepO(1) per stepO(1) total
Deterministic✅ Yes✅ Yes✅ Yes
Parallelizable❌ Sequential❌ Sequential✅ Fully parallel

Beam Search Decoding

Algorithm

Beam search maintains multiple hypotheses and explores the most promising paths:
1

Initialize

Create initial hypothesis:
Hypothesis(
    score=0.0,
    step=0,
    last_token=None,
    hidden_state=None,
    stuck=0,
    hypothesis=[]
)
2

Expand hypotheses

For each active hypothesis:
  1. Run decoder and joint network
  2. Get top-K tokens (by token logprob)
  3. Get top-K durations (by duration logprob)
  4. Create K×K candidate hypotheses
3

Score and prune

  • Compute combined scores
  • Merge hypotheses with same token sequence
  • Keep top beam_size active hypotheses
  • Move finished hypotheses (step ≥ length) to completed list
4

Termination

Stop when:
  • Have max_candidates = beam_size × patience finished hypotheses, OR
  • No more active hypotheses
5

Select best

Return hypothesis with highest normalized score:
best = max(finished, key=lambda h: h.score / len(h.hypothesis)^length_penalty)

TDT Beam Search Implementation

# From parakeet.py:313-524 (simplified)
def decode_beam(self, features, lengths, last_token, hidden_state, config):
    beam_size = config.decoding.beam_size
    beam_token = min(beam_size, len(vocabulary) + 1)
    beam_duration = min(beam_size, len(durations))
    max_candidates = round(beam_size * config.decoding.patience)
    
    finished_hypothesis = []
    active_beam = [initial_hypothesis]
    
    while len(finished_hypothesis) < max_candidates and active_beam:
        candidates = {}
        
        for hypothesis in active_beam:
            # Run decoder
            decoder_out, (hidden, cell) = self.decoder(
                hypothesis.last_token, hypothesis.hidden_state
            )
            
            # Joint network
            joint_out = self.joint(
                features[:, hypothesis.step:hypothesis.step+1],
                decoder_out
            )
            
            # Split into token and duration logits
            token_logits = joint_out[0, 0, 0, :vocab_size+1]
            duration_logits = joint_out[0, 0, 0, vocab_size+1:]
            
            token_logprobs = log_softmax(token_logits)
            duration_logprobs = log_softmax(duration_logits)
            
            # Get top-K tokens and durations
            token_k = argpartition(token_logprobs, -beam_token)[-beam_token:]
            duration_k = argpartition(duration_logprobs, -beam_duration)[-beam_duration:]
            
            # Expand to all combinations
            for token in token_k:
                is_blank = (token == vocab_size)
                
                for decision in duration_k:
                    duration = self.durations[decision]
                    
                    # Compute new score
                    new_score = (
                        hypothesis.score +
                        token_logprobs[token] * (1 - duration_reward) +
                        duration_logprobs[decision] * duration_reward
                    )
                    
                    # Handle stuck prevention
                    stuck = 0 if duration != 0 else hypothesis.stuck + 1
                    if max_symbols and stuck >= max_symbols:
                        step = hypothesis.step + 1
                        stuck = 0
                    else:
                        step = hypothesis.step + duration
                    
                    # Create new hypothesis
                    new_hypothesis = Hypothesis(
                        score=new_score,
                        step=step,
                        last_token=hypothesis.last_token if is_blank else token,
                        hidden_state=hypothesis.hidden_state if is_blank else (hidden, cell),
                        stuck=stuck,
                        hypothesis=(
                            hypothesis.hypothesis if is_blank else
                            hypothesis.hypothesis + [AlignedToken(...)]
                        )
                    )
                    
                    # Merge if same path already exists
                    key = hash(new_hypothesis)
                    if key in candidates:
                        # Log-sum-exp for probability combination
                        old = candidates[key]
                        maxima = max(old.score, new_hypothesis.score)
                        combined_score = maxima + log(
                            exp(old.score - maxima) + exp(new_hypothesis.score - maxima)
                        )
                        if new_hypothesis.score > old.score:
                            candidates[key] = new_hypothesis
                        candidates[key].score = combined_score
                    else:
                        candidates[key] = new_hypothesis
        
        # Partition candidates
        finished_hypothesis.extend([h for h in candidates.values() if h.step >= length])
        active_beam = sorted(
            [h for h in candidates.values() if h.step < length],
            key=lambda h: h.score,
            reverse=True
        )[:beam_size]
    
    # Select best with length normalization
    all_hypotheses = finished_hypothesis + active_beam
    best = max(
        all_hypotheses,
        key=lambda h: h.score / max(1, len(h.hypothesis))**length_penalty
    )
    
    return best.hypothesis, best.hidden_state
When two search paths produce the same token sequence (but via different internal paths), they are merged using log-sum-exp:
# From parakeet.py:469-485
key = hash(new_hypothesis)  # Based on (step, token_sequence)

if key in candidates:
    other = candidates[key]
    
    # Combine probabilities: P(path1 or path2) = P(path1) + P(path2)
    # In log space: log(exp(a) + exp(b)) = max + log(exp(a-max) + exp(b-max))
    maxima = max(other.score, new_hypothesis.score)
    combined_score = maxima + log(
        exp(other.score - maxima) + exp(new_hypothesis.score - maxima)
    )
    
    # Keep hypothesis with higher score for state tracking
    if new_hypothesis.score > other.score:
        candidates[key] = new_hypothesis
    candidates[key].score = combined_score
This prevents redundant computation and correctly combines probability mass.

Beam Search Parameters

Controls the search width:
# Recommended values
beam_size = 3   # Fast, slight improvement
beam_size = 5   # Default, good balance
beam_size = 10  # Slower, diminishing returns
Effect on accuracy (approximate WER improvements):
  • beam=1 (greedy): Baseline
  • beam=3: ~5-10% relative improvement
  • beam=5: ~10-15% relative improvement
  • beam=10: ~12-18% relative improvement
Computational cost: Roughly linear in beam_size

Model Support

ModelGreedyBeam Search
TDT✅ Yes✅ Yes
RNNT✅ Yes❌ Not implemented
CTC✅ Yes❌ N/A (use greedy)
TDT-CTC✅ Yes✅ Yes (TDT decoder)
Beam search is currently only available for TDT-based models. RNNT and CTC only support greedy decoding.

Performance Comparison

Speed Benchmarks

Typical transcription speed on Apple M1 Max (measured in real-time factor, lower is better):
StrategyRTFNotes
Greedy0.08x12.5× faster than real-time
Beam (size=3)0.15x6.7× faster than real-time
Beam (size=5)0.22x4.5× faster than real-time
Beam (size=10)0.40x2.5× faster than real-time
RTF (Real-Time Factor) = processing_time / audio_duration. RTF < 1.0 means faster than real-time.

Accuracy Impact

Word Error Rate (WER) improvements from beam search on Parakeet-TDT:
DatasetGreedy WERBeam-5 WERRelative Improvement
LibriSpeech test-clean3.2%2.8%12.5%
LibriSpeech test-other7.5%6.8%9.3%
Results will vary by model, audio quality, and beam parameters. Always validate on your specific use case.

Best Practices

Use greedy if:
  • Speed is critical (real-time processing)
  • Audio is clear with minimal background noise
  • You’re processing large batches of audio
  • Accuracy differences are acceptable
  • Using CTC or RNNT models (beam not available)
Example: Live transcription, podcast processing
Start with defaults and adjust based on results:
  1. If outputs are too short:
    • Increase length_penalty (try 0.5 or 1.0)
    • Increase patience (try 2.0 or 3.0)
  2. If outputs are too long or repetitive:
    • Decrease length_penalty (try 0.0)
    • Decrease beam_size (try 3)
  3. If still getting stuck:
    • Decrease max_symbols (try 5)
    • Adjust duration_reward (try 0.5)
  4. To improve accuracy further:
    • Increase beam_size (try 10)
    • Increase patience (try 5.0)
    • Be prepared for slower inference

Advanced: Confidence Scoring

All decoding strategies compute per-token confidence scores using entropy:
# From parakeet.py:579-584
token_probs = mx.softmax(token_logits, axis=-1)
vocab_size = len(self.vocabulary) + 1

entropy = -mx.sum(token_probs * mx.log(token_probs + 1e-10), axis=-1)
max_entropy = mx.log(mx.array(vocab_size, dtype=token_probs.dtype))
confidence = float(1.0 - (entropy / max_entropy))
  • confidence = 1.0: Model is certain (low entropy)
  • confidence = 0.5: Model is uncertain (medium entropy)
  • confidence = 0.0: Uniform distribution (maximum entropy)
Use confidence scores to:
  • Filter low-confidence predictions
  • Highlight uncertain regions for review
  • Compute aggregate confidence per sentence

Next Steps

Model Architectures

Learn about different model variants

Timestamp System

Understand alignment and timing

Build docs developers (and LLMs) love