Local attention is a memory-efficient alternative to full attention that restricts the attention mechanism to a fixed-size context window. This optimization is particularly valuable for processing long audio files without chunking.
Why Local Attention?
Full self-attention has quadratic memory complexity O(n²) with respect to sequence length. For long audio files, this can lead to excessive memory usage. Local attention reduces this to O(n·w) where w is the context window size, making it practical to transcribe hours of audio without chunking.
Local attention is especially useful when transcribing long audio files (30+ minutes) without chunking, or when running on devices with limited memory.
How It Works
Instead of attending to all positions in the sequence, local attention restricts each position to attend only to a fixed window of neighboring positions:
- Left context: How many frames before the current position
- Right context: How many frames after the current position
from parakeet_mlx import from_pretrained
model = from_pretrained("mlx-community/parakeet-tdt-0.6b-v3")
# Switch to local attention with 256 frames context on each side
model.encoder.set_attention_model(
"rel_pos_local_attn", # NeMo naming convention
(256, 256), # (left_context, right_context)
)
result = model.transcribe("long_audio.wav")
print(result.text)
Context Size Selection
The context size determines the trade-off between memory usage and model accuracy:
| Context Size | Memory Usage | Accuracy | Use Case |
|---|
| (128, 128) | Low | Good | Memory-constrained devices |
| (256, 256) | Medium | Better | Recommended default |
| (512, 512) | High | Best | When memory allows |
Start with (256, 256) for most use cases. Increase if you notice accuracy degradation on complex audio, decrease if you encounter memory issues.
Memory Savings Example
import mlx.core as mx
from parakeet_mlx import from_pretrained
model = from_pretrained("mlx-community/parakeet-tdt-0.6b-v3")
# For a 60-minute audio file without chunking:
# Full attention: ~8-12 GB peak memory
result_full = model.transcribe("long_audio.wav")
# Local attention: ~2-4 GB peak memory
model.encoder.set_attention_model("rel_pos_local_attn", (256, 256))
result_local = model.transcribe("long_audio.wav")
# Results are nearly identical in most cases
print(f"Full attention: {result_full.text}")
print(f"Local attention: {result_local.text}")
Technical Implementation
Local attention in Parakeet MLX uses custom Metal kernels for efficient computation on Apple Silicon:
# From attention.py:147-165
class RelPositionMultiHeadLocalAttention(RelPositionMultiHeadAttention):
def __init__(
self,
n_head: int,
n_feat: int,
bias: bool = True,
pos_bias_u: mx.array | None = None,
pos_bias_v: mx.array | None = None,
context_size: tuple[int, int] = (256, 256),
):
super().__init__(n_head, n_feat, bias, pos_bias_u, pos_bias_v)
self.context_size = context_size
if min(context_size) <= 0:
raise ValueError(
"Context size for RelPositionMultiHeadLocalAttention must be > 0."
)
Combining with Chunking
Local attention and chunking serve different purposes. Local attention optimizes memory during the forward pass, while chunking processes audio in segments. You can use both together for very long files.
from parakeet_mlx import from_pretrained
model = from_pretrained("mlx-community/parakeet-tdt-0.6b-v3")
# Use local attention for each chunk
model.encoder.set_attention_model("rel_pos_local_attn", (256, 256))
# Process in 2-minute chunks with 15-second overlap
result = model.transcribe(
"very_long_audio.wav",
chunk_duration=120,
overlap_duration=15,
)
print(result.text)
Switching Back to Full Attention
# Restore full attention
model.encoder.set_attention_model("rel_pos")
result = model.transcribe("audio.wav")
CLI Usage
Enable local attention from the command line:
# Use local attention with default context size (256, 256)
parakeet-mlx audio.wav --local-attention
# Specify custom context size
parakeet-mlx audio.wav --local-attention --local-attention-context-size 512
# Environment variable
export PARAKEET_LOCAL_ATTENTION=true
export PARAKEET_LOCAL_ATTENTION_CTX=256
parakeet-mlx audio.wav
Time Complexity
- Full attention: O(n² · d) where n is sequence length, d is feature dimension
- Local attention: O(n · w · d) where w is context window size
Memory Complexity
- Full attention: O(n² · h) where h is number of heads
- Local attention: O(n · w · h)
Accuracy Impact
- For most speech, local context of 256 frames is sufficient
- Minimal degradation compared to full attention
- May affect very long-range dependencies (rare in speech)
Best Practices
- Default Context: Use
(256, 256) for most applications
- Long Audio: Enable local attention for files longer than 30 minutes
- Memory Constraints: Reduce context size to
(128, 128) if needed
- Quality Critical: Increase to
(512, 512) for maximum accuracy
- Benchmarking: Test on your specific audio to find optimal settings
Local attention is automatically enabled in streaming mode (transcribe_stream) to optimize real-time performance.