Skip to content

Commit a979a68

Browse files
committed
add mlx-whisper support
1 parent f0b630e commit a979a68

File tree

6 files changed

+338
-36
lines changed

6 files changed

+338
-36
lines changed

docling/datamodel/asr_model_specs.py

Lines changed: 149 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ApiAsrOptions,
1212
InferenceAsrFramework,
1313
InlineAsrNativeWhisperOptions,
14+
InlineAsrMlxWhisperOptions,
1415
TransformersModelType,
1516
)
1617

@@ -27,16 +28,54 @@
2728
max_time_chunk=30.0,
2829
)
2930

30-
WHISPER_SMALL = InlineAsrNativeWhisperOptions(
31-
repo_id="small",
32-
inference_framework=InferenceAsrFramework.WHISPER,
33-
verbose=True,
34-
timestamps=True,
35-
word_timestamps=True,
36-
temperature=0.0,
37-
max_new_tokens=256,
38-
max_time_chunk=30.0,
39-
)
31+
def _get_whisper_small_model():
32+
"""
33+
Get the best Whisper Small model for the current hardware.
34+
35+
Automatically selects MLX Whisper Small for Apple Silicon (MPS) if available,
36+
otherwise falls back to native Whisper Small.
37+
"""
38+
# Check if MPS is available (Apple Silicon)
39+
try:
40+
import torch
41+
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
42+
except ImportError:
43+
has_mps = False
44+
45+
# Check if mlx-whisper is available
46+
try:
47+
import mlx_whisper # type: ignore
48+
has_mlx_whisper = True
49+
except ImportError:
50+
has_mlx_whisper = False
51+
52+
# Use MLX Whisper if both MPS and mlx-whisper are available
53+
if has_mps and has_mlx_whisper:
54+
return InlineAsrMlxWhisperOptions(
55+
repo_id="mlx-community/whisper-small-mlx",
56+
inference_framework=InferenceAsrFramework.MLX,
57+
language="en",
58+
task="transcribe",
59+
word_timestamps=True,
60+
no_speech_threshold=0.6,
61+
logprob_threshold=-1.0,
62+
compression_ratio_threshold=2.4,
63+
)
64+
else:
65+
return InlineAsrNativeWhisperOptions(
66+
repo_id="small",
67+
inference_framework=InferenceAsrFramework.WHISPER,
68+
verbose=True,
69+
timestamps=True,
70+
word_timestamps=True,
71+
temperature=0.0,
72+
max_new_tokens=256,
73+
max_time_chunk=30.0,
74+
)
75+
76+
77+
# Create the model instance
78+
WHISPER_SMALL = _get_whisper_small_model()
4079

4180
WHISPER_MEDIUM = InlineAsrNativeWhisperOptions(
4281
repo_id="medium",
@@ -49,16 +88,54 @@
4988
max_time_chunk=30.0,
5089
)
5190

52-
WHISPER_BASE = InlineAsrNativeWhisperOptions(
53-
repo_id="base",
54-
inference_framework=InferenceAsrFramework.WHISPER,
55-
verbose=True,
56-
timestamps=True,
57-
word_timestamps=True,
58-
temperature=0.0,
59-
max_new_tokens=256,
60-
max_time_chunk=30.0,
61-
)
91+
def _get_whisper_base_model():
92+
"""
93+
Get the best Whisper Base model for the current hardware.
94+
95+
Automatically selects MLX Whisper Base for Apple Silicon (MPS) if available,
96+
otherwise falls back to native Whisper Base.
97+
"""
98+
# Check if MPS is available (Apple Silicon)
99+
try:
100+
import torch
101+
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
102+
except ImportError:
103+
has_mps = False
104+
105+
# Check if mlx-whisper is available
106+
try:
107+
import mlx_whisper # type: ignore
108+
has_mlx_whisper = True
109+
except ImportError:
110+
has_mlx_whisper = False
111+
112+
# Use MLX Whisper if both MPS and mlx-whisper are available
113+
if has_mps and has_mlx_whisper:
114+
return InlineAsrMlxWhisperOptions(
115+
repo_id="mlx-community/whisper-base-mlx",
116+
inference_framework=InferenceAsrFramework.MLX,
117+
language="en",
118+
task="transcribe",
119+
word_timestamps=True,
120+
no_speech_threshold=0.6,
121+
logprob_threshold=-1.0,
122+
compression_ratio_threshold=2.4,
123+
)
124+
else:
125+
return InlineAsrNativeWhisperOptions(
126+
repo_id="base",
127+
inference_framework=InferenceAsrFramework.WHISPER,
128+
verbose=True,
129+
timestamps=True,
130+
word_timestamps=True,
131+
temperature=0.0,
132+
max_new_tokens=256,
133+
max_time_chunk=30.0,
134+
)
135+
136+
137+
# Create the model instance
138+
WHISPER_BASE = _get_whisper_base_model()
62139

63140
WHISPER_LARGE = InlineAsrNativeWhisperOptions(
64141
repo_id="large",
@@ -71,16 +148,58 @@
71148
max_time_chunk=30.0,
72149
)
73150

74-
WHISPER_TURBO = InlineAsrNativeWhisperOptions(
75-
repo_id="turbo",
76-
inference_framework=InferenceAsrFramework.WHISPER,
77-
verbose=True,
78-
timestamps=True,
79-
word_timestamps=True,
80-
temperature=0.0,
81-
max_new_tokens=256,
82-
max_time_chunk=30.0,
83-
)
151+
def _get_whisper_turbo_model():
152+
"""
153+
Get the best Whisper Turbo model for the current hardware.
154+
155+
Automatically selects MLX Whisper Turbo for Apple Silicon (MPS) if available,
156+
otherwise falls back to native Whisper Turbo.
157+
"""
158+
# Check if MPS is available (Apple Silicon)
159+
try:
160+
import torch
161+
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
162+
except ImportError:
163+
has_mps = False
164+
165+
# Check if mlx-whisper is available
166+
try:
167+
import mlx_whisper # type: ignore
168+
has_mlx_whisper = True
169+
except ImportError:
170+
has_mlx_whisper = False
171+
172+
# Use MLX Whisper if both MPS and mlx-whisper are available
173+
if has_mps and has_mlx_whisper:
174+
return InlineAsrMlxWhisperOptions(
175+
repo_id="mlx-community/whisper-turbo",
176+
inference_framework=InferenceAsrFramework.MLX,
177+
language="en",
178+
task="transcribe",
179+
word_timestamps=True,
180+
no_speech_threshold=0.6,
181+
logprob_threshold=-1.0,
182+
compression_ratio_threshold=2.4,
183+
)
184+
else:
185+
return InlineAsrNativeWhisperOptions(
186+
repo_id="turbo",
187+
inference_framework=InferenceAsrFramework.WHISPER,
188+
verbose=True,
189+
timestamps=True,
190+
word_timestamps=True,
191+
temperature=0.0,
192+
max_new_tokens=256,
193+
max_time_chunk=30.0,
194+
)
195+
196+
197+
# Create the model instance
198+
WHISPER_TURBO = _get_whisper_turbo_model()
199+
200+
# Note: MLX Whisper models are now automatically selected when using
201+
# WHISPER_TURBO, WHISPER_BASE, WHISPER_SMALL, etc. on Apple Silicon systems
202+
# with mlx-whisper installed. No need for separate MLX-specific model specs.
84203

85204

86205
class AsrModelType(str, Enum):

docling/datamodel/pipeline_options_asr_model.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class BaseAsrOptions(BaseModel):
1717

1818

1919
class InferenceAsrFramework(str, Enum):
20-
# MLX = "mlx" # disabled for now
20+
MLX = "mlx"
2121
# TRANSFORMERS = "transformers" # disabled for now
2222
WHISPER = "whisper"
2323

@@ -55,3 +55,22 @@ class InlineAsrNativeWhisperOptions(InlineAsrOptions):
5555
AcceleratorDevice.CUDA,
5656
]
5757
word_timestamps: bool = True
58+
59+
60+
class InlineAsrMlxWhisperOptions(InlineAsrOptions):
61+
"""
62+
MLX Whisper options for Apple Silicon optimization.
63+
64+
Uses mlx-whisper library for efficient inference on Apple Silicon devices.
65+
"""
66+
inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX
67+
68+
language: str = "en"
69+
task: str = "transcribe" # "transcribe" or "translate"
70+
supported_devices: List[AcceleratorDevice] = [
71+
AcceleratorDevice.MPS, # MLX is optimized for Apple Silicon
72+
]
73+
word_timestamps: bool = True
74+
no_speech_threshold: float = 0.6 # Threshold for detecting speech
75+
logprob_threshold: float = -1.0 # Log probability threshold
76+
compression_ratio_threshold: float = 2.4 # Compression ratio threshold

docling/pipeline/asr_pipeline.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from docling.datamodel.pipeline_options_asr_model import (
3434
InlineAsrNativeWhisperOptions,
35+
InlineAsrMlxWhisperOptions,
3536
# AsrResponseFormat,
3637
InlineAsrOptions,
3738
)
@@ -201,6 +202,130 @@ def transcribe(self, fpath: Path) -> list[_ConversationItem]:
201202
return convo
202203

203204

205+
class _MlxWhisperModel:
206+
def __init__(
207+
self,
208+
enabled: bool,
209+
artifacts_path: Optional[Path],
210+
accelerator_options: AcceleratorOptions,
211+
asr_options: InlineAsrMlxWhisperOptions,
212+
):
213+
"""
214+
Transcriber using MLX Whisper for Apple Silicon optimization.
215+
"""
216+
self.enabled = enabled
217+
218+
_log.info(f"artifacts-path: {artifacts_path}")
219+
_log.info(f"accelerator_options: {accelerator_options}")
220+
221+
if self.enabled:
222+
try:
223+
import mlx_whisper # type: ignore
224+
except ImportError:
225+
raise ImportError(
226+
"mlx-whisper is not installed. Please install it via `pip install mlx-whisper` or do `uv sync --extra asr`."
227+
)
228+
self.asr_options = asr_options
229+
self.mlx_whisper = mlx_whisper
230+
231+
self.device = decide_device(
232+
accelerator_options.device,
233+
supported_devices=asr_options.supported_devices,
234+
)
235+
_log.info(f"Available device for MLX Whisper: {self.device}")
236+
237+
self.model_name = asr_options.repo_id
238+
_log.info(f"loading _MlxWhisperModel({self.model_name})")
239+
240+
# MLX Whisper models are loaded differently - they use HuggingFace repos
241+
self.model_path = self.model_name
242+
243+
# Store MLX-specific options
244+
self.language = asr_options.language
245+
self.task = asr_options.task
246+
self.word_timestamps = asr_options.word_timestamps
247+
self.no_speech_threshold = asr_options.no_speech_threshold
248+
self.logprob_threshold = asr_options.logprob_threshold
249+
self.compression_ratio_threshold = asr_options.compression_ratio_threshold
250+
251+
def run(self, conv_res: ConversionResult) -> ConversionResult:
252+
audio_path: Path = Path(conv_res.input.file).resolve()
253+
254+
try:
255+
conversation = self.transcribe(audio_path)
256+
257+
# Ensure we have a proper DoclingDocument
258+
origin = DocumentOrigin(
259+
filename=conv_res.input.file.name or "audio.wav",
260+
mimetype="audio/x-wav",
261+
binary_hash=conv_res.input.document_hash,
262+
)
263+
conv_res.document = DoclingDocument(
264+
name=conv_res.input.file.stem or "audio.wav", origin=origin
265+
)
266+
267+
for citem in conversation:
268+
conv_res.document.add_text(
269+
label=DocItemLabel.TEXT, text=citem.to_string()
270+
)
271+
272+
conv_res.status = ConversionStatus.SUCCESS
273+
return conv_res
274+
275+
except Exception as exc:
276+
_log.error(f"MLX Audio transcription has an error: {exc}")
277+
278+
conv_res.status = ConversionStatus.FAILURE
279+
return conv_res
280+
281+
def transcribe(self, fpath: Path) -> list[_ConversationItem]:
282+
"""
283+
Transcribe audio using MLX Whisper.
284+
285+
Args:
286+
fpath: Path to audio file
287+
288+
Returns:
289+
List of conversation items with timestamps
290+
"""
291+
result = self.mlx_whisper.transcribe(
292+
str(fpath),
293+
path_or_hf_repo=self.model_path,
294+
language=self.language,
295+
task=self.task,
296+
word_timestamps=self.word_timestamps,
297+
no_speech_threshold=self.no_speech_threshold,
298+
logprob_threshold=self.logprob_threshold,
299+
compression_ratio_threshold=self.compression_ratio_threshold,
300+
)
301+
302+
convo: list[_ConversationItem] = []
303+
304+
# MLX Whisper returns segments similar to native Whisper
305+
for segment in result.get("segments", []):
306+
item = _ConversationItem(
307+
start_time=segment.get("start"),
308+
end_time=segment.get("end"),
309+
text=segment.get("text", "").strip(),
310+
words=[]
311+
)
312+
313+
# Add word-level timestamps if available
314+
if self.word_timestamps and "words" in segment:
315+
item.words = []
316+
for word_data in segment["words"]:
317+
item.words.append(
318+
_ConversationWord(
319+
start_time=word_data.get("start"),
320+
end_time=word_data.get("end"),
321+
text=word_data.get("word", ""),
322+
)
323+
)
324+
convo.append(item)
325+
326+
return convo
327+
328+
204329
class AsrPipeline(BasePipeline):
205330
def __init__(self, pipeline_options: AsrPipelineOptions):
206331
super().__init__(pipeline_options)
@@ -218,6 +343,16 @@ def __init__(self, pipeline_options: AsrPipelineOptions):
218343
accelerator_options=pipeline_options.accelerator_options,
219344
asr_options=asr_options,
220345
)
346+
elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions):
347+
asr_options: InlineAsrMlxWhisperOptions = (
348+
self.pipeline_options.asr_options
349+
)
350+
self._model = _MlxWhisperModel(
351+
enabled=True, # must be always enabled for this pipeline to make sense.
352+
artifacts_path=self.artifacts_path,
353+
accelerator_options=pipeline_options.accelerator_options,
354+
asr_options=asr_options,
355+
)
221356
else:
222357
_log.error(f"No model support for {self.pipeline_options.asr_options}")
223358

0 commit comments

Comments
 (0)