Skip to main content

Overview

The from_pretrained() function is the primary way to load Parakeet models. It downloads models from Hugging Face Hub and automatically detects the model variant (TDT, RNNT, CTC, or TDT-CTC).

Function Signature

from parakeet_mlx import from_pretrained

def from_pretrained(
    hf_id_or_path: str,
    *,
    dtype: mx.Dtype = mx.bfloat16,
    cache_dir: str | Path | None = None,
) -> BaseParakeet

Parameters

hf_id_or_path
str
required
Hugging Face repository ID (e.g., "mlx-community/parakeet-tdt-0.6b-v3") or path to a local model directory containing config.json and model.safetensors.
dtype
mx.Dtype
default:"mx.bfloat16"
Data type for model weights. Common options:
  • mx.bfloat16 (default) - Recommended for Apple Silicon, good balance of speed and accuracy
  • mx.float32 - Higher precision, slower inference
  • mx.float16 - Faster but may have numerical stability issues
cache_dir
str | Path | None
default:"None"
Directory to cache downloaded models. If None, uses Hugging Face’s default cache location (~/.cache/huggingface/hub or the value of HF_HOME/HF_HUB_CACHE environment variables).

Returns

model
BaseParakeet
Returns one of the following model instances based on the config:
  • ParakeetTDT - Token-and-Duration Transducer model
  • ParakeetRNNT - RNN-Transducer model
  • ParakeetCTC - Connectionist Temporal Classification model
  • ParakeetTDTCTC - Hybrid TDT-CTC model
All models inherit from BaseParakeet and share the same core interface.

Examples

Basic Usage

from parakeet_mlx import from_pretrained

# Load model from Hugging Face Hub
model = from_pretrained("mlx-community/parakeet-tdt-0.6b-v3")

result = model.transcribe("audio.wav")
print(result.text)

Loading with Custom Cache Directory

from parakeet_mlx import from_pretrained
from pathlib import Path

# Use custom cache location
model = from_pretrained(
    "mlx-community/parakeet-tdt-0.6b-v3",
    cache_dir="./models_cache"
)

Loading from Local Directory

from parakeet_mlx import from_pretrained

# Load from local directory containing config.json and model.safetensors
model = from_pretrained("./local_models/parakeet-tdt")

Using Different Precision

import mlx.core as mx
from parakeet_mlx import from_pretrained

# Use float32 for higher precision
model = from_pretrained(
    "mlx-community/parakeet-tdt-0.6b-v3",
    dtype=mx.float32
)

Type Casting for Variant-Specific Methods

from typing import cast
from parakeet_mlx import from_pretrained, ParakeetTDT

# Load model and cast to specific type
model = from_pretrained("mlx-community/parakeet-tdt-0.6b-v3")
model_tdt = cast(ParakeetTDT, model)

# Now you can use TDT-specific methods without type checker warnings
features, lengths = model_tdt.encoder(mel)
results, hidden_states = model_tdt.decode(features, lengths)

Available Models

Popular Parakeet models on Hugging Face:
  • mlx-community/parakeet-tdt-0.6b-v3 - Latest TDT model, recommended
  • mlx-community/parakeet-tdt-1.1b - Larger TDT model
  • mlx-community/parakeet-rnnt-0.6b - RNNT variant
  • mlx-community/parakeet-ctc-0.6b - CTC variant
  • mlx-community/parakeet-tdt-ctc-0.6b - Hybrid TDT-CTC model
See the full collection on Hugging Face.

Implementation Details

The function:
  1. Downloads config.json and model.safetensors from Hugging Face or reads from local directory
  2. Detects model type based on config metadata:
    • Checks target field for model architecture
    • Checks model_defaults.tdt_durations to distinguish TDT from RNNT
  3. Instantiates the appropriate model class
  4. Loads weights from safetensors file
  5. Casts weights to specified dtype
  6. Sets model to eval mode

Build docs developers (and LLMs) love