Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/ml/inference/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class RunInference(beam.PTransform):

TODO(BEAM-14046): Add and link to help documentation
"""
def __init__(self, model_loader: base.ModelLoader):
def __init__(self, model_loader: base.ModelHandler):
self._model_loader = model_loader

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
Expand Down
49 changes: 20 additions & 29 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

"""An extensible run inference transform.

Users of this module can extend the ModelLoader class for any MLframework. Then
pass their extended ModelLoader object into RunInference to create a
Users of this module can extend the ModelHandler class for any MLframework. Then
pass their extended ModelHandler object into RunInference to create a
RunInference Beam transform for that framework.

The transform will handle standard inference functionality like metric
Expand Down Expand Up @@ -64,8 +64,12 @@ def _to_microseconds(time_ns: int) -> int:
return int(time_ns / _NANOSECOND_TO_MICROSECOND)


class InferenceRunner(Generic[ExampleT, PredictionT, ModelT]):
"""Implements running inferences for a framework."""
class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
"""Has the ability to load and apply an ML model."""
def load_model(self) -> ModelT:
"""Loads and initializes a model for processing."""
raise NotImplementedError(type(self))

def run_inference(self, batch: List[ExampleT], model: ModelT,
**kwargs) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples and
Expand All @@ -80,18 +84,6 @@ def get_metrics_namespace(self) -> str:
"""Returns a namespace for metrics collected by RunInference transform."""
return 'RunInference'


class ModelLoader(Generic[ExampleT, PredictionT, ModelT]):
"""Has the ability to load an ML model."""
def load_model(self) -> ModelT:
"""Loads and initializes a model for processing."""
raise NotImplementedError(type(self))

def get_inference_runner(
self) -> InferenceRunner[ExampleT, PredictionT, ModelT]:
"""Returns an implementation of InferenceRunner for this model."""
raise NotImplementedError(type(self))

def get_resource_hints(self) -> dict:
"""Returns resource hints for the transform."""
return {}
Expand All @@ -105,30 +97,30 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
beam.PCollection[PredictionT]]):
"""An extensible transform for running inferences.
Args:
model_loader: An implementation of ModelLoader.
model_handler: An implementation of ModelHandler.
clock: A clock implementing get_current_time_in_microseconds.
"""
def __init__(
self,
model_loader: ModelLoader[ExampleT, PredictionT, Any],
model_handler: ModelHandler[ExampleT, PredictionT, Any],
clock=time,
**kwargs):
self._model_loader = model_loader
self._model_handler = model_handler
self._kwargs = kwargs
self._clock = clock

# TODO(BEAM-14208): Add batch_size back off in the case there
# are functional reasons large batch sizes cannot be handled.
def expand(
self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]:
resource_hints = self._model_loader.get_resource_hints()
resource_hints = self._model_handler.get_resource_hints()
return (
pcoll
# TODO(BEAM-14044): Hook into the batching DoFn APIs.
| beam.BatchElements(**self._model_loader.batch_elements_kwargs())
| beam.BatchElements(**self._model_handler.batch_elements_kwargs())
| (
beam.ParDo(
_RunInferenceDoFn(self._model_loader, self._clock),
_RunInferenceDoFn(self._model_handler, self._clock),
**self._kwargs).with_resource_hints(**resource_hints)))


Expand Down Expand Up @@ -184,8 +176,8 @@ def update(
class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
"""A DoFn implementation generic to frameworks."""
def __init__(
self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock):
self._model_loader = model_loader
self, model_handler: ModelHandler[ExampleT, PredictionT, Any], clock):
self._model_handler = model_handler
self._shared_model_handle = shared.Shared()
self._clock = clock
self._model = None
Expand All @@ -195,7 +187,7 @@ def load():
"""Function for constructing shared LoadedModel."""
memory_before = _get_current_process_memory_in_bytes()
start_time = _to_milliseconds(self._clock.time_ns())
model = self._model_loader.load_model()
model = self._model_handler.load_model()
end_time = _to_milliseconds(self._clock.time_ns())
memory_after = _get_current_process_memory_in_bytes()
load_model_latency_ms = end_time - start_time
Expand All @@ -208,9 +200,8 @@ def load():
return self._shared_model_handle.acquire(load)

def setup(self):
self._inference_runner = self._model_loader.get_inference_runner()
self._metrics_collector = _MetricsCollector(
self._inference_runner.get_metrics_namespace())
self._model_handler.get_metrics_namespace())
self._model = self._load_model()

def process(self, batch, **kwargs):
Expand All @@ -225,13 +216,13 @@ def process(self, batch, **kwargs):
keys = None

start_time = _to_microseconds(self._clock.time_ns())
result_generator = self._inference_runner.run_inference(
result_generator = self._model_handler.run_inference(
examples, self._model, **kwargs)
predictions = list(result_generator)

end_time = _to_microseconds(self._clock.time_ns())
inference_latency = end_time - start_time
num_bytes = self._inference_runner.get_num_bytes(examples)
num_bytes = self._model_handler.get_num_bytes(examples)
num_elements = len(batch)
self._metrics_collector.update(num_elements, num_bytes, inference_latency)

Expand Down
46 changes: 14 additions & 32 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,15 @@ def predict(self, example: int) -> int:
return example + 1


class FakeInferenceRunner(base.InferenceRunner[int, int, FakeModel]):
class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def __init__(self, clock=None):
self._fake_clock = clock

def load_model(self):
if self._fake_clock:
self._fake_clock.current_time_ns += 500_000_000 # 500ms
return FakeModel()

def run_inference(self, batch: List[int], model: FakeModel,
**kwargs) -> Iterable[int]:
if self._fake_clock:
Expand All @@ -47,19 +52,6 @@ def run_inference(self, batch: List[int], model: FakeModel,
yield model.predict(example)


class FakeModelLoader(base.ModelLoader[int, int, FakeModel]):
def __init__(self, clock=None):
self._fake_clock = clock

def load_model(self):
if self._fake_clock:
self._fake_clock.current_time_ns += 500_000_000 # 500ms
return FakeModel()

def get_inference_runner(self):
return FakeInferenceRunner(self._fake_clock)


class FakeClock:
def __init__(self):
# Start at 10 seconds.
Expand All @@ -74,40 +66,30 @@ def process(self, prediction_result):
yield prediction_result.inference


class FakeInferenceRunnerNeedsBigBatch(FakeInferenceRunner):
class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
def run_inference(self, batch, unused_model):
if len(batch) < 100:
raise ValueError('Unexpectedly small batch')
return batch


class FakeLoaderWithBatchArgForwarding(FakeModelLoader):
def get_inference_runner(self):
return FakeInferenceRunnerNeedsBigBatch()

def batch_elements_kwargs(self):
return {'min_batch_size': 9999}


class FakeInferenceRunnerKwargs(FakeInferenceRunner):
class FakeModelHandlerWithKwargs(FakeModelHandler):
def run_inference(self, batch, unused_model, **kwargs):
if not kwargs.get('key'):
raise ValueError('key should be True')
return batch


class FakeLoaderWithKwargs(FakeModelLoader):
def get_inference_runner(self):
return FakeInferenceRunnerKwargs()


class RunInferenceBaseTest(unittest.TestCase):
def test_run_inference_impl_simple_examples(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
expected = [example + 1 for example in examples]
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(FakeModelLoader())
actual = pcoll | base.RunInference(FakeModelHandler())
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_run_inference_impl_with_keyed_examples(self):
Expand All @@ -116,22 +98,22 @@ def test_run_inference_impl_with_keyed_examples(self):
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [(i, example + 1) for i, example in enumerate(examples)]
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
actual = pcoll | base.RunInference(FakeModelLoader())
actual = pcoll | base.RunInference(FakeModelHandler())
assert_that(actual, equal_to(expected), label='assert:inferences')

def test_run_inference_impl_kwargs(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
pcoll = pipeline | 'start' >> beam.Create(examples)
kwargs = {'key': True}
actual = pcoll | base.RunInference(FakeLoaderWithKwargs(), **kwargs)
actual = pcoll | base.RunInference(FakeModelHandlerWithKwargs(), **kwargs)
assert_that(actual, equal_to(examples), label='assert:inferences')

def test_counted_metrics(self):
pipeline = TestPipeline()
examples = [1, 5, 3, 10]
pcoll = pipeline | 'start' >> beam.Create(examples)
_ = pcoll | base.RunInference(FakeModelLoader())
_ = pcoll | base.RunInference(FakeModelHandler())
run_result = pipeline.run()
run_result.wait_until_finish()

Expand Down Expand Up @@ -161,7 +143,7 @@ def test_timing_metrics(self):
pcoll = pipeline | 'start' >> beam.Create(examples)
fake_clock = FakeClock()
_ = pcoll | base.RunInference(
FakeModelLoader(clock=fake_clock), clock=fake_clock)
FakeModelHandler(clock=fake_clock), clock=fake_clock)
res = pipeline.run()
res.wait_until_finish()

Expand All @@ -183,7 +165,7 @@ def test_forwards_batch_args(self):
examples = list(range(100))
with TestPipeline() as pipeline:
pcoll = pipeline | 'start' >> beam.Create(examples)
actual = pcoll | base.RunInference(FakeLoaderWithBatchArgForwarding())
actual = pcoll | base.RunInference(FakeModelHandlerNeedsBigBatch())
assert_that(actual, equal_to(examples), label='assert:inferences')


Expand Down
101 changes: 42 additions & 59 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,51 @@
import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.api import PredictionResult
from apache_beam.ml.inference.base import InferenceRunner
from apache_beam.ml.inference.base import ModelLoader
from apache_beam.ml.inference.base import ModelHandler


class PytorchInferenceRunner(InferenceRunner[torch.Tensor,
PredictionResult,
torch.nn.Module]):
"""
This class runs Pytorch inferences with the run_inference method. It also has
other methods to get the bytes of a batch of Tensors as well as the namespace
for Pytorch models.
class PytorchModelHandler(ModelHandler[torch.Tensor,
PredictionResult,
torch.nn.Module]):
""" Implementation of the ModelHandler interface for PyTorch.

NOTE: This API and its implementation are under development and
do not provide backward compatibility guarantees.
"""
def __init__(self, device: torch.device):
self._device = device
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
device: str = 'CPU'):
"""
Initializes a PytorchModelHandler
:param state_dict_path: path to the saved dictionary of the model state.
:param model_class: class of the Pytorch model that defines the model
structure.
:param device: the device on which you wish to run the model. If
``device = GPU`` then a GPU device will be used if it is available.
Otherwise, it will be CPU.

See https://pytorch.org/tutorials/beginner/saving_loading_models.html
for details
"""
self._state_dict_path = state_dict_path
if device == 'GPU' and torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self._model_class = model_class
self._model_params = model_params

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model = self._model_class(**self._model_params)
model.to(self._device)
file = FileSystems.open(self._state_dict_path, 'rb')
model.load_state_dict(torch.load(file))
model.eval()
return model

def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -103,51 +134,3 @@ def get_metrics_namespace(self) -> str:
Returns a namespace for metrics collected by the RunInference transform.
"""
return 'RunInferencePytorch'


class PytorchModelLoader(ModelLoader[torch.Tensor,
PredictionResult,
torch.nn.Module]):
""" Implementation of the ModelLoader interface for PyTorch.

NOTE: This API and its implementation are under development and
do not provide backward compatibility guarantees.
"""
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
device: str = 'CPU'):
"""
Initializes a PytorchModelLoader
:param state_dict_path: path to the saved dictionary of the model state.
:param model_class: class of the Pytorch model that defines the model
structure.
:param device: the device on which you wish to run the model. If
``device = GPU`` then a GPU device will be used if it is available.
Otherwise, it will be CPU.

See https://pytorch.org/tutorials/beginner/saving_loading_models.html
for details
"""
self._state_dict_path = state_dict_path
if device == 'GPU' and torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self._model_class = model_class
self._model_params = model_params

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model = self._model_class(**self._model_params)
model.to(self._device)
file = FileSystems.open(self._state_dict_path, 'rb')
model.load_state_dict(torch.load(file))
model.eval()
return model

def get_inference_runner(self) -> PytorchInferenceRunner:
"""Returns a Pytorch implementation of InferenceRunner."""
return PytorchInferenceRunner(device=self._device)
Loading