Skip to content

Commit ebbeb45

Browse files
committed
fix pre-commit checks and added proper type safety
1 parent 9827068 commit ebbeb45

File tree

7 files changed

+135
-94
lines changed

7 files changed

+135
-94
lines changed

docling/cli/main.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,17 +579,27 @@ def convert( # noqa: C901
579579
ocr_options.lang = ocr_lang_list
580580

581581
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
582-
582+
583583
# Auto-detect pipeline based on input file formats
584584
if pipeline == ProcessingPipeline.STANDARD:
585585
# Check if any input files are audio files by extension
586-
audio_extensions = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.mp4', '.avi', '.mov'}
586+
audio_extensions = {
587+
".mp3",
588+
".wav",
589+
".m4a",
590+
".aac",
591+
".ogg",
592+
".flac",
593+
".mp4",
594+
".avi",
595+
".mov",
596+
}
587597
for path in input_doc_paths:
588598
if path.suffix.lower() in audio_extensions:
589599
pipeline = ProcessingPipeline.ASR
590600
_log.info(f"Auto-detected ASR pipeline for audio file: {path}")
591601
break
592-
602+
593603
# pipeline_options: PaginatedPipelineOptions
594604
pipeline_options: PipelineOptions
595605

docling/datamodel/asr_model_specs.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,37 @@
1010
# AsrResponseFormat,
1111
# ApiAsrOptions,
1212
InferenceAsrFramework,
13-
InlineAsrNativeWhisperOptions,
1413
InlineAsrMlxWhisperOptions,
14+
InlineAsrNativeWhisperOptions,
1515
TransformersModelType,
1616
)
1717

1818
_log = logging.getLogger(__name__)
1919

20+
2021
def _get_whisper_tiny_model():
2122
"""
2223
Get the best Whisper Tiny model for the current hardware.
23-
24+
2425
Automatically selects MLX Whisper Tiny for Apple Silicon (MPS) if available,
2526
otherwise falls back to native Whisper Tiny.
2627
"""
2728
# Check if MPS is available (Apple Silicon)
2829
try:
2930
import torch
31+
3032
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
3133
except ImportError:
3234
has_mps = False
33-
35+
3436
# Check if mlx-whisper is available
3537
try:
3638
import mlx_whisper # type: ignore
39+
3740
has_mlx_whisper = True
3841
except ImportError:
3942
has_mlx_whisper = False
40-
43+
4144
# Use MLX Whisper if both MPS and mlx-whisper are available
4245
if has_mps and has_mlx_whisper:
4346
return InlineAsrMlxWhisperOptions(
@@ -66,27 +69,30 @@ def _get_whisper_tiny_model():
6669
# Create the model instance
6770
WHISPER_TINY = _get_whisper_tiny_model()
6871

72+
6973
def _get_whisper_small_model():
7074
"""
7175
Get the best Whisper Small model for the current hardware.
72-
76+
7377
Automatically selects MLX Whisper Small for Apple Silicon (MPS) if available,
7478
otherwise falls back to native Whisper Small.
7579
"""
7680
# Check if MPS is available (Apple Silicon)
7781
try:
7882
import torch
83+
7984
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
8085
except ImportError:
8186
has_mps = False
82-
87+
8388
# Check if mlx-whisper is available
8489
try:
8590
import mlx_whisper # type: ignore
91+
8692
has_mlx_whisper = True
8793
except ImportError:
8894
has_mlx_whisper = False
89-
95+
9096
# Use MLX Whisper if both MPS and mlx-whisper are available
9197
if has_mps and has_mlx_whisper:
9298
return InlineAsrMlxWhisperOptions(
@@ -115,27 +121,30 @@ def _get_whisper_small_model():
115121
# Create the model instance
116122
WHISPER_SMALL = _get_whisper_small_model()
117123

124+
118125
def _get_whisper_medium_model():
119126
"""
120127
Get the best Whisper Medium model for the current hardware.
121-
128+
122129
Automatically selects MLX Whisper Medium for Apple Silicon (MPS) if available,
123130
otherwise falls back to native Whisper Medium.
124131
"""
125132
# Check if MPS is available (Apple Silicon)
126133
try:
127134
import torch
135+
128136
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
129137
except ImportError:
130138
has_mps = False
131-
139+
132140
# Check if mlx-whisper is available
133141
try:
134142
import mlx_whisper # type: ignore
143+
135144
has_mlx_whisper = True
136145
except ImportError:
137146
has_mlx_whisper = False
138-
147+
139148
# Use MLX Whisper if both MPS and mlx-whisper are available
140149
if has_mps and has_mlx_whisper:
141150
return InlineAsrMlxWhisperOptions(
@@ -164,27 +173,30 @@ def _get_whisper_medium_model():
164173
# Create the model instance
165174
WHISPER_MEDIUM = _get_whisper_medium_model()
166175

176+
167177
def _get_whisper_base_model():
168178
"""
169179
Get the best Whisper Base model for the current hardware.
170-
180+
171181
Automatically selects MLX Whisper Base for Apple Silicon (MPS) if available,
172182
otherwise falls back to native Whisper Base.
173183
"""
174184
# Check if MPS is available (Apple Silicon)
175185
try:
176186
import torch
187+
177188
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
178189
except ImportError:
179190
has_mps = False
180-
191+
181192
# Check if mlx-whisper is available
182193
try:
183194
import mlx_whisper # type: ignore
195+
184196
has_mlx_whisper = True
185197
except ImportError:
186198
has_mlx_whisper = False
187-
199+
188200
# Use MLX Whisper if both MPS and mlx-whisper are available
189201
if has_mps and has_mlx_whisper:
190202
return InlineAsrMlxWhisperOptions(
@@ -213,27 +225,30 @@ def _get_whisper_base_model():
213225
# Create the model instance
214226
WHISPER_BASE = _get_whisper_base_model()
215227

228+
216229
def _get_whisper_large_model():
217230
"""
218231
Get the best Whisper Large model for the current hardware.
219-
232+
220233
Automatically selects MLX Whisper Large for Apple Silicon (MPS) if available,
221234
otherwise falls back to native Whisper Large.
222235
"""
223236
# Check if MPS is available (Apple Silicon)
224237
try:
225238
import torch
239+
226240
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
227241
except ImportError:
228242
has_mps = False
229-
243+
230244
# Check if mlx-whisper is available
231245
try:
232246
import mlx_whisper # type: ignore
247+
233248
has_mlx_whisper = True
234249
except ImportError:
235250
has_mlx_whisper = False
236-
251+
237252
# Use MLX Whisper if both MPS and mlx-whisper are available
238253
if has_mps and has_mlx_whisper:
239254
return InlineAsrMlxWhisperOptions(
@@ -262,27 +277,30 @@ def _get_whisper_large_model():
262277
# Create the model instance
263278
WHISPER_LARGE = _get_whisper_large_model()
264279

280+
265281
def _get_whisper_turbo_model():
266282
"""
267283
Get the best Whisper Turbo model for the current hardware.
268-
284+
269285
Automatically selects MLX Whisper Turbo for Apple Silicon (MPS) if available,
270286
otherwise falls back to native Whisper Turbo.
271287
"""
272288
# Check if MPS is available (Apple Silicon)
273289
try:
274290
import torch
291+
275292
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
276293
except ImportError:
277294
has_mps = False
278-
295+
279296
# Check if mlx-whisper is available
280297
try:
281298
import mlx_whisper # type: ignore
299+
282300
has_mlx_whisper = True
283301
except ImportError:
284302
has_mlx_whisper = False
285-
303+
286304
# Use MLX Whisper if both MPS and mlx-whisper are available
287305
if has_mps and has_mlx_whisper:
288306
return InlineAsrMlxWhisperOptions(

docling/datamodel/pipeline_options_asr_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,10 @@ class InlineAsrNativeWhisperOptions(InlineAsrOptions):
6060
class InlineAsrMlxWhisperOptions(InlineAsrOptions):
6161
"""
6262
MLX Whisper options for Apple Silicon optimization.
63-
63+
6464
Uses mlx-whisper library for efficient inference on Apple Silicon devices.
6565
"""
66+
6667
inference_framework: InferenceAsrFramework = InferenceAsrFramework.MLX
6768

6869
language: str = "en"

docling/pipeline/asr_pipeline.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
from io import BytesIO
55
from pathlib import Path
6-
from typing import List, Optional, Union, cast
6+
from typing import TYPE_CHECKING, List, Optional, Union, cast
77

88
from docling_core.types.doc import DoclingDocument, DocumentOrigin
99

@@ -31,8 +31,8 @@
3131
AsrPipelineOptions,
3232
)
3333
from docling.datamodel.pipeline_options_asr_model import (
34-
InlineAsrNativeWhisperOptions,
3534
InlineAsrMlxWhisperOptions,
35+
InlineAsrNativeWhisperOptions,
3636
# AsrResponseFormat,
3737
InlineAsrOptions,
3838
)
@@ -236,7 +236,7 @@ def __init__(
236236

237237
self.model_name = asr_options.repo_id
238238
_log.info(f"loading _MlxWhisperModel({self.model_name})")
239-
239+
240240
# MLX Whisper models are loaded differently - they use HuggingFace repos
241241
self.model_path = self.model_name
242242

@@ -281,10 +281,10 @@ def run(self, conv_res: ConversionResult) -> ConversionResult:
281281
def transcribe(self, fpath: Path) -> list[_ConversationItem]:
282282
"""
283283
Transcribe audio using MLX Whisper.
284-
284+
285285
Args:
286286
fpath: Path to audio file
287-
287+
288288
Returns:
289289
List of conversation items with timestamps
290290
"""
@@ -300,16 +300,16 @@ def transcribe(self, fpath: Path) -> list[_ConversationItem]:
300300
)
301301

302302
convo: list[_ConversationItem] = []
303-
303+
304304
# MLX Whisper returns segments similar to native Whisper
305305
for segment in result.get("segments", []):
306306
item = _ConversationItem(
307307
start_time=segment.get("start"),
308308
end_time=segment.get("end"),
309309
text=segment.get("text", "").strip(),
310-
words=[]
310+
words=[],
311311
)
312-
312+
313313
# Add word-level timestamps if available
314314
if self.word_timestamps and "words" in segment:
315315
item.words = []
@@ -332,26 +332,27 @@ def __init__(self, pipeline_options: AsrPipelineOptions):
332332
self.keep_backend = True
333333

334334
self.pipeline_options: AsrPipelineOptions = pipeline_options
335+
self._model: Union[_NativeWhisperModel, _MlxWhisperModel]
335336

336337
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
337-
asr_options: InlineAsrNativeWhisperOptions = (
338+
native_asr_options: InlineAsrNativeWhisperOptions = (
338339
self.pipeline_options.asr_options
339340
)
340341
self._model = _NativeWhisperModel(
341342
enabled=True, # must be always enabled for this pipeline to make sense.
342343
artifacts_path=self.artifacts_path,
343344
accelerator_options=pipeline_options.accelerator_options,
344-
asr_options=asr_options,
345+
asr_options=native_asr_options,
345346
)
346347
elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions):
347-
asr_options: InlineAsrMlxWhisperOptions = (
348+
mlx_asr_options: InlineAsrMlxWhisperOptions = (
348349
self.pipeline_options.asr_options
349350
)
350351
self._model = _MlxWhisperModel(
351352
enabled=True, # must be always enabled for this pipeline to make sense.
352353
artifacts_path=self.artifacts_path,
353354
accelerator_options=pipeline_options.accelerator_options,
354-
asr_options=asr_options,
355+
asr_options=mlx_asr_options,
355356
)
356357
else:
357358
_log.error(f"No model support for {self.pipeline_options.asr_options}")

docs/examples/minimal_asr_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_asr_converter():
4343
implementation for your hardware:
4444
- MLX Whisper Turbo for Apple Silicon (M1/M2/M3) with mlx-whisper installed
4545
- Native Whisper Turbo as fallback
46-
46+
4747
You can swap in another model spec from `docling.datamodel.asr_model_specs`
4848
to experiment with different model sizes.
4949
"""

0 commit comments

Comments
 (0)