Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/senstella/parakeet-mlx/llms.txt

Use this file to discover all available pages before exploring further.

Overview

Decoding converts the model’s continuous probability distributions into discrete text sequences. Parakeet MLX implements two decoding strategies with different speed-accuracy tradeoffs.

Greedy Decoding

Fast, deterministic single-path search

Beam Search

Slower, explores multiple hypotheses for better accuracy

Decoding Configuration

Basic Usage

from parakeet_mlx import from_pretrained, DecodingConfig, Greedy, Beam

model = from_pretrained("mlx-community/parakeet-tdt-0.6b-v3")

# Greedy decoding (default)
result = model.transcribe(
    "audio.wav",
    decoding_config=DecodingConfig(decoding=Greedy())
)

# Beam search decoding
result = model.transcribe(
    "audio.wav",
    decoding_config=DecodingConfig(
        decoding=Beam(
            beam_size=5,
            length_penalty=1.0,
            patience=1.0,
            duration_reward=0.7  # TDT only
        )
    )
)

Configuration Classes

# From parakeet.py:77-78
@dataclass
class Greedy:
    pass  # No parameters needed
Greedy decoding is parameterless - it always selects the highest probability token at each step.

Greedy Decoding

Algorithm

Greedy decoding follows the highest probability path at each step:
1

Initialize

  • step = 0 (current time frame)
  • hypothesis = [] (output tokens)
  • hidden_state = None (decoder LSTM state)
  • last_token = None (previous emission)
2

Loop until end

While step < length:
  1. Prediction Network: Generate decoder output
    decoder_out, hidden = self.decoder(last_token, hidden_state)
    
  2. Joint Network: Combine with encoder features
    joint_out = self.joint(features[:, step:step+1], decoder_out)
    
  3. Sampling: Select highest probability
    token = argmax(joint_out[..., :vocab_size+1])
    
3

Update state

  • If token != blank: emit token, update hidden_state and last_token
  • Advance time by duration (TDT) or 1 frame (RNNT)
  • Check stuck prevention rules

TDT Greedy Implementation

# From parakeet.py:526-618 (simplified)
def decode_greedy(self, features, lengths, last_token, hidden_state, config):
    results = []
    
    for batch in range(B):
        hypothesis = []
        step = 0
        new_symbols = 0  # For stuck prevention
        
        while step < length:
            # Run decoder
            decoder_out, (hidden, cell) = self.decoder(
                mx.array([[last_token[batch]]]) if last_token[batch] else None,
                hidden_state[batch]
            )
            
            # Joint network
            joint_out = self.joint(features[:, step:step+1], decoder_out)
            
            # Sample token and duration
            token_logits = joint_out[0, 0, :, :len(self.vocabulary)+1]
            pred_token = int(mx.argmax(token_logits))
            
            duration_logits = joint_out[0, 0, :, len(self.vocabulary)+1:]
            decision = int(mx.argmax(duration_logits))
            duration = self.durations[decision]  # e.g., [0,1,2,3,4]
            
            # Compute confidence using entropy
            token_probs = mx.softmax(token_logits, axis=-1)
            entropy = -mx.sum(token_probs * mx.log(token_probs + 1e-10))
            max_entropy = mx.log(mx.array(vocab_size))
            confidence = 1.0 - (entropy / max_entropy)
            
            # Emit non-blank tokens
            if pred_token != len(self.vocabulary):  # Not blank
                hypothesis.append(AlignedToken(
                    id=pred_token,
                    start=step * self.time_ratio,
                    duration=duration * self.time_ratio,
                    confidence=float(confidence),
                    text=decode([pred_token], self.vocabulary)
                ))
                last_token[batch] = pred_token
                hidden_state[batch] = (hidden, cell)
            
            # Advance time
            step += duration
            
            # Stuck prevention: force advance if too many duration=0
            new_symbols += 1
            if duration != 0:
                new_symbols = 0
            elif self.max_symbols and new_symbols >= self.max_symbols:
                step += 1
                new_symbols = 0
        
        results.append(hypothesis)
    
    return results, hidden_state
Problem: Duration prediction can output 0 repeatedly, causing infinite loops.Solution: Track consecutive zero-duration emissions with new_symbols counter:
# From parakeet.py:606-614
new_symbols += 1

if self.durations[decision] != 0:
    new_symbols = 0  # Reset on any forward movement
else:
    if self.max_symbols and new_symbols >= self.max_symbols:
        step += 1  # Force advance
        new_symbols = 0
Default max_symbols=10 prevents getting stuck emitting tokens at the same frame.

RNNT Greedy Implementation

# From parakeet.py:655-743 (simplified)
def decode(self, features, lengths, last_token, hidden_state, config):
    # Similar to TDT but simpler:
    
    while step < length:
        decoder_out, hidden = self.decoder(last_token, hidden_state)
        joint_out = self.joint(features[:, step:step+1], decoder_out)
        
        pred_token = argmax(joint_out[0, 0])  # No duration dimension
        
        if pred_token != blank:
            hypothesis.append(AlignedToken(
                id=pred_token,
                start=step * time_ratio,
                duration=1 * time_ratio,  # Fixed duration
                confidence=confidence
            ))
            last_token = pred_token
            hidden_state = new_hidden
            # Don't advance step - can emit multiple tokens
            
            new_symbols += 1
            if max_symbols and new_symbols >= max_symbols:
                step += 1  # Force advance
                new_symbols = 0
        else:
            step += 1  # Blank advances time
            new_symbols = 0
RNNT differs from TDT: blank tokens advance time, while non-blank tokens stay at the same frame (allowing multiple emissions before moving forward).

CTC Greedy Implementation

# From parakeet.py:774-889 (simplified)
def decode(self, features, lengths, config):
    logits = self.decoder(features)  # [B, S, vocab_size+1]
    
    for batch in range(B):
        predictions = argmax(logits[batch, :length], axis=1)
        probs = exp(logits[batch, :length])
        
        hypothesis = []
        prev_token = -1
        token_boundaries = []
        
        # Collapse repetitions and remove blanks
        for t, token_id in enumerate(predictions):
            if token_id == blank:
                continue
            if token_id == prev_token:  # Skip repetition
                continue
            
            # New token boundary detected
            if prev_token != -1:
                # Finalize previous token
                start_frame = token_boundaries[-1][0]
                end_frame = t
                
                # Compute average confidence over token frames
                token_probs = probs[start_frame:end_frame]
                entropies = -sum(token_probs * log(token_probs + 1e-10))
                avg_entropy = mean(entropies)
                confidence = 1.0 - (avg_entropy / log(vocab_size))
                
                hypothesis.append(AlignedToken(
                    id=prev_token,
                    start=start_frame * time_ratio,
                    duration=(end_frame - start_frame) * time_ratio,
                    confidence=confidence
                ))
            
            token_boundaries.append((t, None))
            prev_token = token_id
        
        # Handle last token
        # ...
CTC decoding is non-autoregressive: all predictions are made independently, then post-processed to collapse repetitions.

Performance Characteristics

MetricTDTRNNTCTC
Speed~100ms/sec~100ms/sec~50ms/sec
MemoryO(1) per stepO(1) per stepO(1) total
Deterministic✅ Yes✅ Yes✅ Yes
Parallelizable❌ Sequential❌ Sequential✅ Fully parallel

Beam Search Decoding

Algorithm

Beam search maintains multiple hypotheses and explores the most promising paths:
1

Initialize

Create initial hypothesis:
Hypothesis(
    score=0.0,
    step=0,
    last_token=None,
    hidden_state=None,
    stuck=0,
    hypothesis=[]
)
2

Expand hypotheses

For each active hypothesis:
  1. Run decoder and joint network
  2. Get top-K tokens (by token logprob)
  3. Get top-K durations (by duration logprob)
  4. Create K×K candidate hypotheses
3

Score and prune

  • Compute combined scores
  • Merge hypotheses with same token sequence
  • Keep top beam_size active hypotheses
  • Move finished hypotheses (step ≥ length) to completed list
4

Termination

Stop when:
  • Have max_candidates = beam_size × patience finished hypotheses, OR
  • No more active hypotheses
5

Select best

Return hypothesis with highest normalized score:
best = max(finished, key=lambda h: h.score / len(h.hypothesis)^length_penalty)

TDT Beam Search Implementation

# From parakeet.py:313-524 (simplified)
def decode_beam(self, features, lengths, last_token, hidden_state, config):
    beam_size = config.decoding.beam_size
    beam_token = min(beam_size, len(vocabulary) + 1)
    beam_duration = min(beam_size, len(durations))
    max_candidates = round(beam_size * config.decoding.patience)
    
    finished_hypothesis = []
    active_beam = [initial_hypothesis]
    
    while len(finished_hypothesis) < max_candidates and active_beam:
        candidates = {}
        
        for hypothesis in active_beam:
            # Run decoder
            decoder_out, (hidden, cell) = self.decoder(
                hypothesis.last_token, hypothesis.hidden_state
            )
            
            # Joint network
            joint_out = self.joint(
                features[:, hypothesis.step:hypothesis.step+1],
                decoder_out
            )
            
            # Split into token and duration logits
            token_logits = joint_out[0, 0, 0, :vocab_size+1]
            duration_logits = joint_out[0, 0, 0, vocab_size+1:]
            
            token_logprobs = log_softmax(token_logits)
            duration_logprobs = log_softmax(duration_logits)
            
            # Get top-K tokens and durations
            token_k = argpartition(token_logprobs, -beam_token)[-beam_token:]
            duration_k = argpartition(duration_logprobs, -beam_duration)[-beam_duration:]
            
            # Expand to all combinations
            for token in token_k:
                is_blank = (token == vocab_size)
                
                for decision in duration_k:
                    duration = self.durations[decision]
                    
                    # Compute new score
                    new_score = (
                        hypothesis.score +
                        token_logprobs[token] * (1 - duration_reward) +
                        duration_logprobs[decision] * duration_reward
                    )
                    
                    # Handle stuck prevention
                    stuck = 0 if duration != 0 else hypothesis.stuck + 1
                    if max_symbols and stuck >= max_symbols:
                        step = hypothesis.step + 1
                        stuck = 0
                    else:
                        step = hypothesis.step + duration
                    
                    # Create new hypothesis
                    new_hypothesis = Hypothesis(
                        score=new_score,
                        step=step,
                        last_token=hypothesis.last_token if is_blank else token,
                        hidden_state=hypothesis.hidden_state if is_blank else (hidden, cell),
                        stuck=stuck,
                        hypothesis=(
                            hypothesis.hypothesis if is_blank else
                            hypothesis.hypothesis + [AlignedToken(...)]
                        )
                    )
                    
                    # Merge if same path already exists
                    key = hash(new_hypothesis)
                    if key in candidates:
                        # Log-sum-exp for probability combination
                        old = candidates[key]
                        maxima = max(old.score, new_hypothesis.score)
                        combined_score = maxima + log(
                            exp(old.score - maxima) + exp(new_hypothesis.score - maxima)
                        )
                        if new_hypothesis.score > old.score:
                            candidates[key] = new_hypothesis
                        candidates[key].score = combined_score
                    else:
                        candidates[key] = new_hypothesis
        
        # Partition candidates
        finished_hypothesis.extend([h for h in candidates.values() if h.step >= length])
        active_beam = sorted(
            [h for h in candidates.values() if h.step < length],
            key=lambda h: h.score,
            reverse=True
        )[:beam_size]
    
    # Select best with length normalization
    all_hypotheses = finished_hypothesis + active_beam
    best = max(
        all_hypotheses,
        key=lambda h: h.score / max(1, len(h.hypothesis))**length_penalty
    )
    
    return best.hypothesis, best.hidden_state
When two search paths produce the same token sequence (but via different internal paths), they are merged using log-sum-exp:
# From parakeet.py:469-485
key = hash(new_hypothesis)  # Based on (step, token_sequence)

if key in candidates:
    other = candidates[key]
    
    # Combine probabilities: P(path1 or path2) = P(path1) + P(path2)
    # In log space: log(exp(a) + exp(b)) = max + log(exp(a-max) + exp(b-max))
    maxima = max(other.score, new_hypothesis.score)
    combined_score = maxima + log(
        exp(other.score - maxima) + exp(new_hypothesis.score - maxima)
    )
    
    # Keep hypothesis with higher score for state tracking
    if new_hypothesis.score > other.score:
        candidates[key] = new_hypothesis
    candidates[key].score = combined_score
This prevents redundant computation and correctly combines probability mass.

Beam Search Parameters

Controls the search width:
# Recommended values
beam_size = 3   # Fast, slight improvement
beam_size = 5   # Default, good balance
beam_size = 10  # Slower, diminishing returns
Effect on accuracy (approximate WER improvements):
  • beam=1 (greedy): Baseline
  • beam=3: ~5-10% relative improvement
  • beam=5: ~10-15% relative improvement
  • beam=10: ~12-18% relative improvement
Computational cost: Roughly linear in beam_size

Model Support

ModelGreedyBeam Search
TDT✅ Yes✅ Yes
RNNT✅ Yes❌ Not implemented
CTC✅ Yes❌ N/A (use greedy)
TDT-CTC✅ Yes✅ Yes (TDT decoder)
Beam search is currently only available for TDT-based models. RNNT and CTC only support greedy decoding.

Performance Comparison

Speed Benchmarks

Typical transcription speed on Apple M1 Max (measured in real-time factor, lower is better):
StrategyRTFNotes
Greedy0.08x12.5× faster than real-time
Beam (size=3)0.15x6.7× faster than real-time
Beam (size=5)0.22x4.5× faster than real-time
Beam (size=10)0.40x2.5× faster than real-time
RTF (Real-Time Factor) = processing_time / audio_duration. RTF < 1.0 means faster than real-time.

Accuracy Impact

Word Error Rate (WER) improvements from beam search on Parakeet-TDT:
DatasetGreedy WERBeam-5 WERRelative Improvement
LibriSpeech test-clean3.2%2.8%12.5%
LibriSpeech test-other7.5%6.8%9.3%
Results will vary by model, audio quality, and beam parameters. Always validate on your specific use case.

Best Practices

Use greedy if:
  • Speed is critical (real-time processing)
  • Audio is clear with minimal background noise
  • You’re processing large batches of audio
  • Accuracy differences are acceptable
  • Using CTC or RNNT models (beam not available)
Example: Live transcription, podcast processing
Start with defaults and adjust based on results:
  1. If outputs are too short:
    • Increase length_penalty (try 0.5 or 1.0)
    • Increase patience (try 2.0 or 3.0)
  2. If outputs are too long or repetitive:
    • Decrease length_penalty (try 0.0)
    • Decrease beam_size (try 3)
  3. If still getting stuck:
    • Decrease max_symbols (try 5)
    • Adjust duration_reward (try 0.5)
  4. To improve accuracy further:
    • Increase beam_size (try 10)
    • Increase patience (try 5.0)
    • Be prepared for slower inference

Advanced: Confidence Scoring

All decoding strategies compute per-token confidence scores using entropy:
# From parakeet.py:579-584
token_probs = mx.softmax(token_logits, axis=-1)
vocab_size = len(self.vocabulary) + 1

entropy = -mx.sum(token_probs * mx.log(token_probs + 1e-10), axis=-1)
max_entropy = mx.log(mx.array(vocab_size, dtype=token_probs.dtype))
confidence = float(1.0 - (entropy / max_entropy))
  • confidence = 1.0: Model is certain (low entropy)
  • confidence = 0.5: Model is uncertain (medium entropy)
  • confidence = 0.0: Uniform distribution (maximum entropy)
Use confidence scores to:
  • Filter low-confidence predictions
  • Highlight uncertain regions for review
  • Compute aggregate confidence per sentence

Next Steps

Model Architectures

Learn about different model variants

Timestamp System

Understand alignment and timing

Build docs developers (and LLMs) love