Skip to content

Commit cba732f

Browse files
committed
[WIP] Added support for converting LayoutLMv3 to CoreML
1 parent 7a54597 commit cba732f

File tree

4 files changed

+186
-0
lines changed

4 files changed

+186
-0
lines changed

src/exporters/coreml/config.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,40 @@ def _input_descriptions(self) -> "OrderedDict[str, InputDescription]":
327327
),
328328
]
329329
)
330+
331+
if self.modality == "multimodal" and self.task in [
332+
"token-classification",
333+
]:
334+
return OrderedDict(
335+
[
336+
(
337+
"input_ids",
338+
InputDescription(
339+
"input_ids",
340+
"Indices of input sequence tokens in the vocabulary",
341+
sequence_length=self.input_ids_sequence_length,
342+
)
343+
),
344+
(
345+
"bbox",
346+
InputDescription(
347+
"bbox",
348+
"Bounding Boxes"
349+
)
350+
),
351+
(
352+
"attention_mask",
353+
InputDescription(
354+
"attention_mask",
355+
"Mask to avoid performing attention on padding token indices (1 = not masked, 0 = masked)",
356+
)
357+
),
358+
(
359+
"pixel_values",
360+
InputDescription("image", "Input image", color_layout="RGB")
361+
),
362+
]
363+
)
330364

331365
if self.task == "image-classification":
332366
return OrderedDict(
@@ -931,6 +965,32 @@ def generate_dummy_inputs(
931965
bool_masked_pos = np.random.randint(low=0, high=2, size=(1, num_patches)).astype(bool)
932966
dummy_inputs["bool_masked_pos"] = (bool_masked_pos, bool_masked_pos.astype(np.int32))
933967

968+
elif self.modality == "multimodal":
969+
input_ids_name = "input_ids"
970+
attention_mask_name = "attention_mask"
971+
972+
input_desc = input_descs[input_ids_name]
973+
974+
# the dummy input will always use the maximum sequence length
975+
sequence_length = self._get_max_sequence_length(input_desc, 64)
976+
977+
shape = (batch_size, sequence_length)
978+
979+
input_ids = np.random.randint(0, preprocessor.tokenizer.vocab_size, shape)
980+
dummy_inputs[input_ids_name] = (input_ids, input_ids.astype(np.int32))
981+
982+
if attention_mask_name in input_descs:
983+
attention_mask = np.ones(shape, dtype=np.int64)
984+
dummy_inputs[attention_mask_name] = (attention_mask, attention_mask.astype(np.int32))
985+
986+
bbox_shape = (batch_size, sequence_length, 4)
987+
bboxes = np.random.randint(low=0, high=1000, size=bbox_shape).astype(np.int64)
988+
dummy_inputs["bbox"] = (bboxes, bboxes.astype(np.int32))
989+
990+
dummy_inputs["pixel_values"] = self._generate_dummy_image(preprocessor.image_processor, framework)
991+
992+
print(dummy_inputs)
993+
934994
elif self.modality == "audio" and isinstance(preprocessor, ProcessorMixin):
935995
if self.seq2seq != "decoder":
936996
if "input_features" in input_descs:

src/exporters/coreml/convert.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,63 @@ def get_input_types(
254254
)
255255
)
256256

257+
elif config.modality == "multimodal":
258+
input_ids_name = "input_ids"
259+
attention_mask_name = "attention_mask"
260+
261+
input_desc = input_descs[input_ids_name]
262+
dummy_input = dummy_inputs[input_ids_name]
263+
shape = get_shape(config, input_desc, dummy_input)
264+
input_types.append(
265+
ct.TensorType(name=input_desc.name, shape=shape, dtype=np.int32)
266+
)
267+
268+
if attention_mask_name in input_descs:
269+
input_desc = input_descs[attention_mask_name]
270+
input_types.append(
271+
ct.TensorType(name=input_desc.name, shape=shape, dtype=np.int32)
272+
)
273+
else:
274+
logger.info(f"Skipping {attention_mask_name} input")
275+
276+
bbox_desc = input_descs["bbox"]
277+
dummy_input = dummy_inputs["bbox"]
278+
shape = get_shape(config, bbox_desc, dummy_input)
279+
input_types.append(
280+
ct.TensorType(name=bbox_desc.name, shape=shape, dtype=np.int32)
281+
)
282+
283+
if hasattr(preprocessor.image_processor, "image_mean"):
284+
bias = [
285+
-preprocessor.image_processor.image_mean[0],
286+
-preprocessor.image_processor.image_mean[1],
287+
-preprocessor.image_processor.image_mean[2],
288+
]
289+
else:
290+
bias = [0.0, 0.0, 0.0]
291+
292+
# If the stddev values are all equal, they can be folded into `bias` and
293+
# `scale`. If not, Wrapper will insert an additional division operation.
294+
if hasattr(preprocessor.image_processor, "image_std") and is_image_std_same(preprocessor.image_processor):
295+
bias[0] /= preprocessor.image_processor.image_std[0]
296+
bias[1] /= preprocessor.image_processor.image_std[1]
297+
bias[2] /= preprocessor.image_processor.image_std[2]
298+
scale = 1.0 / (preprocessor.image_processor.image_std[0] * 255.0)
299+
else:
300+
scale = 1.0 / 255
301+
302+
input_desc = input_descs["pixel_values"]
303+
input_types.append(
304+
ct.ImageType(
305+
name=input_desc.name,
306+
shape=dummy_inputs["pixel_values"][0].shape,
307+
scale=scale,
308+
bias=bias,
309+
color_layout=input_desc.color_layout or "RGB",
310+
channel_first=True,
311+
)
312+
)
313+
257314
elif config.modality == "audio":
258315
if "input_features" in input_descs:
259316
input_desc = input_descs["input_features"]
@@ -510,6 +567,9 @@ def export_pytorch(
510567
# Put the inputs in the order from the config.
511568
example_input = [dummy_inputs[key][0] for key in list(config.inputs.keys())]
512569

570+
print(config)
571+
print(example_input)
572+
513573
wrapper = Wrapper(preprocessor, model, config).eval()
514574

515575
# Running the model once with gradients disabled prevents an error during JIT tracing
@@ -529,6 +589,8 @@ def export_pytorch(
529589
else:
530590
example_output = [example_output.numpy()]
531591

592+
print(example_output)
593+
532594
convert_kwargs = {}
533595
if not config.use_legacy_format:
534596
convert_kwargs["compute_precision"] = ct.precision.FLOAT16 if quantize == "float16" else ct.precision.FLOAT32
@@ -540,6 +602,8 @@ def export_pytorch(
540602

541603
input_tensors = get_input_types(preprocessor, config, dummy_inputs)
542604

605+
print(input_tensors)
606+
543607
patched_ops = config.patch_pytorch_ops()
544608
restore_ops = {}
545609
if patched_ops is not None:

src/exporters/coreml/features.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ class FeaturesManager:
278278
"text-classification",
279279
coreml_config_cls="models.gpt_neox.GPTNeoXCoreMLConfig",
280280
),
281+
"layoutlmv3": supported_features_mapping(
282+
"token-classification",
283+
coreml_config_cls="models.layoutlmv3.LayoutLMv3CoreMLConfig",
284+
),
281285
"levit": supported_features_mapping(
282286
"feature-extraction", "image-classification", coreml_config_cls="models.levit.LevitCoreMLConfig"
283287
),

src/exporters/coreml/models.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
from collections import OrderedDict
17+
from typing import Callable, Mapping
1718

1819
from .config import (
1920
CoreMLConfig,
@@ -348,6 +349,63 @@ def inputs(self) -> OrderedDict[str, InputDescription]:
348349
return input_descs
349350

350351

352+
class LayoutLMv3CoreMLConfig(CoreMLConfig):
353+
modality = "multimodal"
354+
355+
@property
356+
def outputs(self) -> OrderedDict[str, OutputDescription]:
357+
output_descs = super().outputs
358+
self._add_pooler_output(output_descs)
359+
return output_descs
360+
361+
def patch_pytorch_ops(self):
362+
def clip(context, node):
363+
from coremltools.converters.mil import Builder as mb
364+
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
365+
from coremltools.converters.mil.mil.var import Var
366+
import numpy as _np
367+
from coremltools.converters.mil.mil import types
368+
from coremltools.converters.mil.mil.ops.defs._utils import promote_input_dtypes
369+
inputs = _get_inputs(context, node, expected=[1,2,3])
370+
x = inputs[0]
371+
min_val = inputs[1] if (len(inputs) > 1 and inputs[1]) else mb.const(val=_np.finfo(_np.float32).min)
372+
max_val = inputs[2] if (len(inputs) > 2 and inputs[2]) else mb.const(val=_np.finfo(_np.float32).max)
373+
374+
if isinstance(min_val, Var) and isinstance(max_val, Var) and min_val.val >= max_val.val:
375+
# When min >= max, PyTorch sets all values to max.
376+
context.add(mb.fill(shape=mb.shape(x=x), value=max_val.val, name=node.name))
377+
return
378+
379+
is_input_int = types.is_int(x.dtype)
380+
if not types.is_float(x.dtype):
381+
# The `mb.clip` op requires parameters from type domain ['fp16', 'fp32'].
382+
x = mb.cast(x=x, dtype="fp32")
383+
x, min_val, max_val = promote_input_dtypes([x, min_val, max_val])
384+
if is_input_int:
385+
clip_res = mb.clip(x=x, alpha=min_val, beta=max_val)
386+
context.add(mb.cast(x=clip_res, dtype="int32", name=node.name))
387+
else:
388+
context.add(mb.clip(x=x, alpha=min_val, beta=max_val, name=node.name))
389+
390+
def one_hot(context, node):
391+
from coremltools.converters.mil import Builder as mb
392+
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
393+
from coremltools.converters.mil.mil import types
394+
from coremltools.converters.mil.mil.ops.defs._utils import promote_input_dtypes
395+
inputs = _get_inputs(context, node, expected=[2,3])
396+
indices = inputs[0]
397+
num_classes = inputs[1]
398+
399+
if not types.is_int(indices.dtype):
400+
indices = mb.cast(x=indices, dtype="int32")
401+
if not types.is_int(num_classes.dtype):
402+
num_classes = mb.cast(x=num_classes, dtype="int32")
403+
indices, num_classes = promote_input_dtypes([indices, num_classes])
404+
one_hot_res = mb.one_hot(indices=indices, one_hot_vector_size=num_classes)
405+
context.add(one_hot_res, node.name)
406+
407+
return {"clip": clip, "one_hot": one_hot}
408+
351409
class LevitCoreMLConfig(CoreMLConfig):
352410
modality = "vision"
353411

0 commit comments

Comments
 (0)