Skip to main content

Overview

The from_pretrained() function is the primary way to load Parakeet models. It downloads models from Hugging Face Hub and automatically detects the model variant (TDT, RNNT, CTC, or TDT-CTC).

Function Signature

from parakeet_mlx import from_pretrained

def from_pretrained(
    hf_id_or_path: str,
    *,
    dtype: mx.Dtype = mx.bfloat16,
    cache_dir: str | Path | None = None,
) -> BaseParakeet

Parameters

hf_id_or_path
str
required
Hugging Face repository ID (e.g., "mlx-community/parakeet-tdt-0.6b-v3") or path to a local model directory containing config.json and model.safetensors.
dtype
mx.Dtype
default:"mx.bfloat16"
Data type for model weights. Common options:
  • mx.bfloat16 (default) - Recommended for Apple Silicon, good balance of speed and accuracy
  • mx.float32 - Higher precision, slower inference
  • mx.float16 - Faster but may have numerical stability issues
cache_dir
str | Path | None
default:"None"
Directory to cache downloaded models. If None, uses Hugging Face’s default cache location (~/.cache/huggingface/hub or the value of HF_HOME/HF_HUB_CACHE environment variables).

Returns

model
BaseParakeet
Returns one of the following model instances based on the config:
  • ParakeetTDT - Token-and-Duration Transducer model
  • ParakeetRNNT - RNN-Transducer model
  • ParakeetCTC - Connectionist Temporal Classification model
  • ParakeetTDTCTC - Hybrid TDT-CTC model
All models inherit from BaseParakeet and share the same core interface.

Examples

Basic Usage

from parakeet_mlx import from_pretrained

# Load model from Hugging Face Hub
model = from_pretrained("mlx-community/parakeet-tdt-0.6b-v3")

result = model.transcribe("audio.wav")
print(result.text)

Loading with Custom Cache Directory

from parakeet_mlx import from_pretrained
from pathlib import Path

# Use custom cache location
model = from_pretrained(
    "mlx-community/parakeet-tdt-0.6b-v3",
    cache_dir="./models_cache"
)

Loading from Local Directory

from parakeet_mlx import from_pretrained

# Load from local directory containing config.json and model.safetensors
model = from_pretrained("./local_models/parakeet-tdt")

Using Different Precision

import mlx.core as mx
from parakeet_mlx import from_pretrained

# Use float32 for higher precision
model = from_pretrained(
    "mlx-community/parakeet-tdt-0.6b-v3",
    dtype=mx.float32
)

Type Casting for Variant-Specific Methods

from typing import cast
from parakeet_mlx import from_pretrained, ParakeetTDT

# Load model and cast to specific type
model = from_pretrained("mlx-community/parakeet-tdt-0.6b-v3")
model_tdt = cast(ParakeetTDT, model)

# Now you can use TDT-specific methods without type checker warnings
features, lengths = model_tdt.encoder(mel)
results, hidden_states = model_tdt.decode(features, lengths)

Available Models

Popular Parakeet models on Hugging Face:
  • mlx-community/parakeet-tdt-0.6b-v3 - Latest TDT model, recommended
  • mlx-community/parakeet-tdt-1.1b - Larger TDT model
  • mlx-community/parakeet-rnnt-0.6b - RNNT variant
  • mlx-community/parakeet-ctc-0.6b - CTC variant
  • mlx-community/parakeet-tdt-ctc-0.6b - Hybrid TDT-CTC model
See the full collection on Hugging Face.

Implementation Details

The function:
  1. Downloads config.json and model.safetensors from Hugging Face or reads from local directory
  2. Detects model type based on config metadata:
    • Checks target field for model architecture
    • Checks model_defaults.tdt_durations to distinguish TDT from RNNT
  3. Instantiates the appropriate model class
  4. Loads weights from safetensors file
  5. Casts weights to specified dtype
  6. Sets model to eval mode