Overview
Decoding converts the model’s continuous probability distributions into discrete text sequences. Parakeet MLX implements two decoding strategies with different speed-accuracy tradeoffs.
Greedy Decoding Fast, deterministic single-path search
Beam Search Slower, explores multiple hypotheses for better accuracy
Decoding Configuration
Basic Usage
from parakeet_mlx import from_pretrained, DecodingConfig, Greedy, Beam
model = from_pretrained( "mlx-community/parakeet-tdt-0.6b-v3" )
# Greedy decoding (default)
result = model.transcribe(
"audio.wav" ,
decoding_config = DecodingConfig( decoding = Greedy())
)
# Beam search decoding
result = model.transcribe(
"audio.wav" ,
decoding_config = DecodingConfig(
decoding = Beam(
beam_size = 5 ,
length_penalty = 1.0 ,
patience = 1.0 ,
duration_reward = 0.7 # TDT only
)
)
)
Configuration Classes
Greedy
Beam
DecodingConfig
# From parakeet.py:77-78
@dataclass
class Greedy :
pass # No parameters needed
Greedy decoding is parameterless - it always selects the highest probability token at each step. # From parakeet.py:82-86
@dataclass
class Beam :
beam_size: int = 5 # Number of hypotheses
length_penalty: float = 1.0 # Length normalization
patience: float = 1.0 # Candidate multiplier
duration_reward: float = 0.7 # TDT only: duration weight
Number of top hypotheses to maintain during search. Higher values improve accuracy but increase computation.
Typical range: 3-10
Memory usage: O(beam_size × sequence_length)
Normalizes hypothesis scores by length to prevent bias toward shorter sequences.
Formula: score / (length^penalty)
penalty=0.0: No normalization (favors short sequences)
penalty=1.0: Linear normalization
penalty>1.0: Favors longer sequences
Controls early stopping. Maintains beam_size × patience candidate hypotheses.
patience=1.0: Keep exactly beam_size candidates
patience=2.0: Keep 2× beam_size candidates (more thorough search)
TDT models only. Balances token and duration prediction scores.
Formula: score = token_logprob × (1-r) + duration_logprob × r
r=0.0: Only consider token predictions
r=0.5: Equal weight
r=1.0: Only consider duration predictions
# From parakeet.py:90-92
@dataclass
class DecodingConfig :
decoding: Union[Greedy, Beam] = field( default_factory = Greedy)
sentence: SentenceConfig = field( default_factory = SentenceConfig)
Top-level configuration that combines:
Decoding strategy (Greedy or Beam)
Sentence segmentation rules (see Timestamps )
Greedy Decoding
Algorithm
Greedy decoding follows the highest probability path at each step:
Initialize
step = 0 (current time frame)
hypothesis = [] (output tokens)
hidden_state = None (decoder LSTM state)
last_token = None (previous emission)
Loop until end
While step < length:
Prediction Network : Generate decoder output
decoder_out, hidden = self .decoder(last_token, hidden_state)
Joint Network : Combine with encoder features
joint_out = self .joint(features[:, step:step + 1 ], decoder_out)
Sampling : Select highest probability
token = argmax(joint_out[ ... , :vocab_size + 1 ])
Update state
If token != blank: emit token, update hidden_state and last_token
Advance time by duration (TDT) or 1 frame (RNNT)
Check stuck prevention rules
TDT Greedy Implementation
# From parakeet.py:526-618 (simplified)
def decode_greedy ( self , features , lengths , last_token , hidden_state , config ):
results = []
for batch in range (B):
hypothesis = []
step = 0
new_symbols = 0 # For stuck prevention
while step < length:
# Run decoder
decoder_out, (hidden, cell) = self .decoder(
mx.array([[last_token[batch]]]) if last_token[batch] else None ,
hidden_state[batch]
)
# Joint network
joint_out = self .joint(features[:, step:step + 1 ], decoder_out)
# Sample token and duration
token_logits = joint_out[ 0 , 0 , :, : len ( self .vocabulary) + 1 ]
pred_token = int (mx.argmax(token_logits))
duration_logits = joint_out[ 0 , 0 , :, len ( self .vocabulary) + 1 :]
decision = int (mx.argmax(duration_logits))
duration = self .durations[decision] # e.g., [0,1,2,3,4]
# Compute confidence using entropy
token_probs = mx.softmax(token_logits, axis =- 1 )
entropy = - mx.sum(token_probs * mx.log(token_probs + 1e-10 ))
max_entropy = mx.log(mx.array(vocab_size))
confidence = 1.0 - (entropy / max_entropy)
# Emit non-blank tokens
if pred_token != len ( self .vocabulary): # Not blank
hypothesis.append(AlignedToken(
id = pred_token,
start = step * self .time_ratio,
duration = duration * self .time_ratio,
confidence = float (confidence),
text = decode([pred_token], self .vocabulary)
))
last_token[batch] = pred_token
hidden_state[batch] = (hidden, cell)
# Advance time
step += duration
# Stuck prevention: force advance if too many duration=0
new_symbols += 1
if duration != 0 :
new_symbols = 0
elif self .max_symbols and new_symbols >= self .max_symbols:
step += 1
new_symbols = 0
results.append(hypothesis)
return results, hidden_state
Stuck Prevention Mechanism
Problem : Duration prediction can output 0 repeatedly, causing infinite loops.Solution : Track consecutive zero-duration emissions with new_symbols counter:# From parakeet.py:606-614
new_symbols += 1
if self .durations[decision] != 0 :
new_symbols = 0 # Reset on any forward movement
else :
if self .max_symbols and new_symbols >= self .max_symbols:
step += 1 # Force advance
new_symbols = 0
Default max_symbols=10 prevents getting stuck emitting tokens at the same frame.
RNNT Greedy Implementation
# From parakeet.py:655-743 (simplified)
def decode ( self , features , lengths , last_token , hidden_state , config ):
# Similar to TDT but simpler:
while step < length:
decoder_out, hidden = self .decoder(last_token, hidden_state)
joint_out = self .joint(features[:, step:step + 1 ], decoder_out)
pred_token = argmax(joint_out[ 0 , 0 ]) # No duration dimension
if pred_token != blank:
hypothesis.append(AlignedToken(
id = pred_token,
start = step * time_ratio,
duration = 1 * time_ratio, # Fixed duration
confidence = confidence
))
last_token = pred_token
hidden_state = new_hidden
# Don't advance step - can emit multiple tokens
new_symbols += 1
if max_symbols and new_symbols >= max_symbols:
step += 1 # Force advance
new_symbols = 0
else :
step += 1 # Blank advances time
new_symbols = 0
RNNT differs from TDT: blank tokens advance time, while non-blank tokens stay at the same frame (allowing multiple emissions before moving forward).
CTC Greedy Implementation
# From parakeet.py:774-889 (simplified)
def decode ( self , features , lengths , config ):
logits = self .decoder(features) # [B, S, vocab_size+1]
for batch in range (B):
predictions = argmax(logits[batch, :length], axis = 1 )
probs = exp(logits[batch, :length])
hypothesis = []
prev_token = - 1
token_boundaries = []
# Collapse repetitions and remove blanks
for t, token_id in enumerate (predictions):
if token_id == blank:
continue
if token_id == prev_token: # Skip repetition
continue
# New token boundary detected
if prev_token != - 1 :
# Finalize previous token
start_frame = token_boundaries[ - 1 ][ 0 ]
end_frame = t
# Compute average confidence over token frames
token_probs = probs[start_frame:end_frame]
entropies = - sum (token_probs * log(token_probs + 1e-10 ))
avg_entropy = mean(entropies)
confidence = 1.0 - (avg_entropy / log(vocab_size))
hypothesis.append(AlignedToken(
id = prev_token,
start = start_frame * time_ratio,
duration = (end_frame - start_frame) * time_ratio,
confidence = confidence
))
token_boundaries.append((t, None ))
prev_token = token_id
# Handle last token
# ...
CTC decoding is non-autoregressive: all predictions are made independently, then post-processed to collapse repetitions.
Metric TDT RNNT CTC Speed ~100ms/sec ~100ms/sec ~50ms/sec Memory O(1) per step O(1) per step O(1) total Deterministic ✅ Yes ✅ Yes ✅ Yes Parallelizable ❌ Sequential ❌ Sequential ✅ Fully parallel
Beam Search Decoding
Algorithm
Beam search maintains multiple hypotheses and explores the most promising paths:
Initialize
Create initial hypothesis: Hypothesis(
score = 0.0 ,
step = 0 ,
last_token = None ,
hidden_state = None ,
stuck = 0 ,
hypothesis = []
)
Expand hypotheses
For each active hypothesis:
Run decoder and joint network
Get top-K tokens (by token logprob)
Get top-K durations (by duration logprob)
Create K×K candidate hypotheses
Score and prune
Compute combined scores
Merge hypotheses with same token sequence
Keep top beam_size active hypotheses
Move finished hypotheses (step ≥ length) to completed list
Termination
Stop when:
Have max_candidates = beam_size × patience finished hypotheses, OR
No more active hypotheses
Select best
Return hypothesis with highest normalized score: best = max (finished, key = lambda h : h.score / len (h.hypothesis) ^ length_penalty)
TDT Beam Search Implementation
# From parakeet.py:313-524 (simplified)
def decode_beam ( self , features , lengths , last_token , hidden_state , config ):
beam_size = config.decoding.beam_size
beam_token = min (beam_size, len (vocabulary) + 1 )
beam_duration = min (beam_size, len (durations))
max_candidates = round (beam_size * config.decoding.patience)
finished_hypothesis = []
active_beam = [initial_hypothesis]
while len (finished_hypothesis) < max_candidates and active_beam:
candidates = {}
for hypothesis in active_beam:
# Run decoder
decoder_out, (hidden, cell) = self .decoder(
hypothesis.last_token, hypothesis.hidden_state
)
# Joint network
joint_out = self .joint(
features[:, hypothesis.step:hypothesis.step + 1 ],
decoder_out
)
# Split into token and duration logits
token_logits = joint_out[ 0 , 0 , 0 , :vocab_size + 1 ]
duration_logits = joint_out[ 0 , 0 , 0 , vocab_size + 1 :]
token_logprobs = log_softmax(token_logits)
duration_logprobs = log_softmax(duration_logits)
# Get top-K tokens and durations
token_k = argpartition(token_logprobs, - beam_token)[ - beam_token:]
duration_k = argpartition(duration_logprobs, - beam_duration)[ - beam_duration:]
# Expand to all combinations
for token in token_k:
is_blank = (token == vocab_size)
for decision in duration_k:
duration = self .durations[decision]
# Compute new score
new_score = (
hypothesis.score +
token_logprobs[token] * ( 1 - duration_reward) +
duration_logprobs[decision] * duration_reward
)
# Handle stuck prevention
stuck = 0 if duration != 0 else hypothesis.stuck + 1
if max_symbols and stuck >= max_symbols:
step = hypothesis.step + 1
stuck = 0
else :
step = hypothesis.step + duration
# Create new hypothesis
new_hypothesis = Hypothesis(
score = new_score,
step = step,
last_token = hypothesis.last_token if is_blank else token,
hidden_state = hypothesis.hidden_state if is_blank else (hidden, cell),
stuck = stuck,
hypothesis = (
hypothesis.hypothesis if is_blank else
hypothesis.hypothesis + [AlignedToken( ... )]
)
)
# Merge if same path already exists
key = hash (new_hypothesis)
if key in candidates:
# Log-sum-exp for probability combination
old = candidates[key]
maxima = max (old.score, new_hypothesis.score)
combined_score = maxima + log(
exp(old.score - maxima) + exp(new_hypothesis.score - maxima)
)
if new_hypothesis.score > old.score:
candidates[key] = new_hypothesis
candidates[key].score = combined_score
else :
candidates[key] = new_hypothesis
# Partition candidates
finished_hypothesis.extend([h for h in candidates.values() if h.step >= length])
active_beam = sorted (
[h for h in candidates.values() if h.step < length],
key = lambda h : h.score,
reverse = True
)[:beam_size]
# Select best with length normalization
all_hypotheses = finished_hypothesis + active_beam
best = max (
all_hypotheses,
key = lambda h : h.score / max ( 1 , len (h.hypothesis)) ** length_penalty
)
return best.hypothesis, best.hidden_state
When two search paths produce the same token sequence (but via different internal paths), they are merged using log-sum-exp: # From parakeet.py:469-485
key = hash (new_hypothesis) # Based on (step, token_sequence)
if key in candidates:
other = candidates[key]
# Combine probabilities: P(path1 or path2) = P(path1) + P(path2)
# In log space: log(exp(a) + exp(b)) = max + log(exp(a-max) + exp(b-max))
maxima = max (other.score, new_hypothesis.score)
combined_score = maxima + log(
exp(other.score - maxima) + exp(new_hypothesis.score - maxima)
)
# Keep hypothesis with higher score for state tracking
if new_hypothesis.score > other.score:
candidates[key] = new_hypothesis
candidates[key].score = combined_score
This prevents redundant computation and correctly combines probability mass.
Beam Search Parameters
beam_size
length_penalty
patience
duration_reward
Controls the search width: # Recommended values
beam_size = 3 # Fast, slight improvement
beam_size = 5 # Default, good balance
beam_size = 10 # Slower, diminishing returns
Effect on accuracy (approximate WER improvements):
beam=1 (greedy): Baseline
beam=3: ~5-10% relative improvement
beam=5: ~10-15% relative improvement
beam=10: ~12-18% relative improvement
Computational cost : Roughly linear in beam_sizePrevents bias toward short sequences: # CLI defaults
length_penalty = 0.013 # Very slight normalization
# Recommended for different scenarios
length_penalty = 0.0 # Favor short outputs (commands)
length_penalty = 0.5 # Light normalization
length_penalty = 1.0 # Full length normalization
length_penalty = 1.5 # Favor longer outputs
Formula: normalized_score = score / length^penalty High length penalties (>1.0) can cause excessive verbosity or hallucinations.
Controls search thoroughness: patience = 1.0 # Strict: only beam_size candidates
patience = 2.0 # Allow 2× beam_size candidates
patience = 3.5 # CLI default: very thorough search
Higher patience:
✅ More thorough exploration
✅ Better final result
❌ Slower (more iterations before stopping)
❌ More memory for candidate storage
TDT only. Balances token vs duration predictions:# CLI default
duration_reward = 0.67 # Slight preference for durations
# Alternative values
duration_reward = 0.0 # Ignore durations (not recommended)
duration_reward = 0.5 # Equal weight
duration_reward = 1.0 # Only durations (not recommended)
Score formula: score = (
token_logprob * ( 1 - duration_reward) +
duration_logprob * duration_reward
)
Values around 0.6-0.7 work well empirically, slightly favoring duration predictions which tend to have lower confidence.
Model Support
Model Greedy Beam Search TDT ✅ Yes ✅ Yes RNNT ✅ Yes ❌ Not implemented CTC ✅ Yes ❌ N/A (use greedy) TDT-CTC ✅ Yes ✅ Yes (TDT decoder)
Beam search is currently only available for TDT-based models. RNNT and CTC only support greedy decoding.
Speed Benchmarks
Typical transcription speed on Apple M1 Max (measured in real-time factor, lower is better):
TDT 0.6B
RNNT 1.1B
CTC 1.1B
Strategy RTF Notes Greedy 0.08x 12.5× faster than real-time Beam (size=3) 0.15x 6.7× faster than real-time Beam (size=5) 0.22x 4.5× faster than real-time Beam (size=10) 0.40x 2.5× faster than real-time
Strategy RTF Notes Greedy 0.12x 8.3× faster than real-time
Strategy RTF Notes Greedy 0.04x 25× faster than real-time
RTF (Real-Time Factor) = processing_time / audio_duration. RTF < 1.0 means faster than real-time.
Accuracy Impact
Word Error Rate (WER) improvements from beam search on Parakeet-TDT:
Dataset Greedy WER Beam-5 WER Relative Improvement LibriSpeech test-clean 3.2% 2.8% 12.5% LibriSpeech test-other 7.5% 6.8% 9.3%
Results will vary by model, audio quality, and beam parameters. Always validate on your specific use case.
Best Practices
When to use greedy decoding
✅ Use greedy if:
Speed is critical (real-time processing)
Audio is clear with minimal background noise
You’re processing large batches of audio
Accuracy differences are acceptable
Using CTC or RNNT models (beam not available)
Example : Live transcription, podcast processing
✅ Use beam search if:
Accuracy is paramount
Audio is challenging (accents, noise, technical terms)
You have sufficient compute budget
Using TDT models
Need high-quality timestamps
Example : Medical/legal transcription, subtitle generation
Start with defaults and adjust based on results:
If outputs are too short :
Increase length_penalty (try 0.5 or 1.0)
Increase patience (try 2.0 or 3.0)
If outputs are too long or repetitive :
Decrease length_penalty (try 0.0)
Decrease beam_size (try 3)
If still getting stuck :
Decrease max_symbols (try 5)
Adjust duration_reward (try 0.5)
To improve accuracy further :
Increase beam_size (try 10)
Increase patience (try 5.0)
Be prepared for slower inference
Advanced: Confidence Scoring
All decoding strategies compute per-token confidence scores using entropy:
# From parakeet.py:579-584
token_probs = mx.softmax(token_logits, axis =- 1 )
vocab_size = len ( self .vocabulary) + 1
entropy = - mx.sum(token_probs * mx.log(token_probs + 1e-10 ), axis =- 1 )
max_entropy = mx.log(mx.array(vocab_size, dtype = token_probs.dtype))
confidence = float ( 1.0 - (entropy / max_entropy))
Interpretation
Sentence Confidence
confidence = 1.0: Model is certain (low entropy)
confidence = 0.5: Model is uncertain (medium entropy)
confidence = 0.0: Uniform distribution (maximum entropy)
Use confidence scores to:
Filter low-confidence predictions
Highlight uncertain regions for review
Compute aggregate confidence per sentence
Sentence confidence is computed as geometric mean of token confidences: # From alignment.py:33-35
confidences = np.array([t.confidence for t in self .tokens])
self .confidence = float (np.exp(np.mean(np.log(confidences + 1e-10 ))))
Geometric mean is more sensitive to low-confidence tokens than arithmetic mean.
Next Steps
Model Architectures Learn about different model variants
Timestamp System Understand alignment and timing