Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions sdks/python/apache_beam/ml/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,3 +797,42 @@ def get_metrics_namespace(self) -> str:
return (
self._underlying.get_metrics_namespace() or
'BeamML_ImageEmbeddingHandler')


class _MultiModalEmbeddingHandler(_EmbeddingHandler):
"""
A ModelHandler intended to be work on
list[dict[str, TypedDict(Image, Video, str)]] inputs.

The inputs to the model handler are expected to be a list of dicts.

For example, if the original mode is used with RunInference to take a
PCollection[E] to a PCollection[P], this ModelHandler would take a
PCollection[dict[str, E]] to a PCollection[dict[str, P]].

_MultiModalEmbeddingHandler will accept an EmbeddingsManager instance, which
contains the details of the model to be loaded and the inference_fn to be
used. The purpose of _MultiMOdalEmbeddingHandler is to generate embeddings
for image, video, and text inputs using the EmbeddingsManager instance.

If the input is not an Image representation column, a RuntimeError will be
raised.

This is an internal class and offers no backwards compatibility guarantees.

Args:
embeddings_manager: An EmbeddingsManager instance.
"""
def _validate_column_data(self, batch):
# Don't want to require framework-specific imports
# here, so just catch columns of primatives for now.
if isinstance(batch[0], (int, str, float, bool)):
raise TypeError(
'Embeddings can only be generated on '
' dict[str, dataclass] types. '
f'Got dict[str, {type(batch[0])}] instead.')

def get_metrics_namespace(self) -> str:
return (
self._underlying.get_metrics_namespace() or
'BeamML_MultiModalEmbeddingHandler')
117 changes: 117 additions & 0 deletions sdks/python/apache_beam/ml/transforms/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import time
import unittest
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from typing import Optional

Expand Down Expand Up @@ -629,6 +630,122 @@ def test_handler_with_dict_inputs(self):
)


@dataclass
class FakeMultiModalInput:
image: Optional[PIL_Image] = None
video: Optional[Any] = None
text: Optional[str] = None


class FakeMultiModalModel:
def __call__(self,
example: list[FakeMultiModalInput]) -> list[FakeMultiModalInput]:
for i in range(len(example)):
if not isinstance(example[i], FakeMultiModalInput):
raise TypeError('Input must be a MultiModalInput')
return example


class FakeMultiModalModelHandler(ModelHandler):
def run_inference(
self,
batch: Sequence[FakeMultiModalInput],
model: Any,
inference_args: Optional[dict[str, Any]] = None):
return model(batch)

def load_model(self):
return FakeMultiModalModel()


class FakeMultiModalEmbeddingsManager(base.EmbeddingsManager):
def __init__(self, columns, **kwargs):
super().__init__(columns=columns, **kwargs)

def get_model_handler(self) -> ModelHandler:
FakeModelHandler.__repr__ = lambda x: 'FakeMultiModalEmbeddingsManager' # type: ignore[method-assign]
return FakeMultiModalModelHandler()

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
return (RunInference(model_handler=base._MultiModalEmbeddingHandler(self)))

def __repr__(self):
return 'FakeMultiModalEmbeddingsManager'


class TestMultiModalEmbeddingHandler(unittest.TestCase):
def setUp(self) -> None:
self.embedding_config = FakeMultiModalEmbeddingsManager(columns=['x'])
self.artifact_location = tempfile.mkdtemp()

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

@unittest.skipIf(PIL is None, 'PIL module is not installed.')
def test_handler_with_non_dict_datatype(self):
image_handler = base._MultiModalEmbeddingHandler(
embeddings_manager=self.embedding_config)
data = [
('x', 'hi there'),
('x', 'not an image'),
('x', 'image_path.jpg'),
]
with self.assertRaises(TypeError):
image_handler.run_inference(data, None, None)

@unittest.skipIf(PIL is None, 'PIL module is not installed.')
def test_handler_with_incorrect_datatype(self):
image_handler = base._MultiModalEmbeddingHandler(
embeddings_manager=self.embedding_config)
data = [
{
'x': 'hi there'
},
{
'x': 'not an image'
},
{
'x': 'image_path.jpg'
},
]
with self.assertRaises(TypeError):
image_handler.run_inference(data, None, None)

@unittest.skipIf(PIL is None, 'PIL module is not installed.')
def test_handler_with_dict_inputs(self):
input_one = FakeMultiModalInput(
image=PIL.Image.new(mode='RGB', size=(1, 1)), text="test image one")
input_two = FakeMultiModalInput(
image=PIL.Image.new(mode='RGB', size=(1, 1)), text="test image two")
input_three = FakeMultiModalInput(
image=PIL.Image.new(mode='RGB', size=(1, 1)),
video=bytes.fromhex('2Ef0 F1f2 '),
text="test image three with video")
data = [
{
'x': input_one
},
{
'x': input_two
},
{
'x': input_three
},
]
expected_data = [{key: value for key, value in d.items()} for d in data]
with beam.Pipeline() as p:
result = (
p
| beam.Create(data)
| base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
self.embedding_config))
assert_that(
result,
equal_to(expected_data),
)


class TestUtilFunctions(unittest.TestCase):
def test_dict_input_fn_normal(self):
input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]
Expand Down
Loading
Loading