Skip to content

Commit

Permalink
feat(types): add apply and apply_batch to parallel mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Dec 2, 2021
1 parent db53ecc commit fe57705
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 13 deletions.
22 changes: 12 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ Leveraging these three components, let's build an app that **find similar images
```python
from jina import DocumentArray, Document

docs = DocumentArray.from_files('img/*.jpg') # load all image filenames into a DocumentArray
for d in docs: # preprocess them
(d.load_uri_to_image_blob() # load
.set_image_blob_normalization() # normalize color
.set_image_blob_channel_axis(-1, 0)) # switch color axis
def preproc(d: Document):
return (d.load_uri_to_image_blob() # load
.set_image_blob_normalization() # normalize color
.set_image_blob_channel_axis(-1, 0)) # switch color axis
docs = DocumentArray.from_files('img/*.jpg').apply(preproc)

import torchvision
model = torchvision.models.resnet50(pretrained=True) # load ResNet50
Expand Down Expand Up @@ -117,17 +117,19 @@ With an extremely trivial refactoring and 10 extra lines of code, you can make t

1. Import what we need.
```python
from jina import DocumentArray, Executor, Flow, requests
from jina import Document, DocumentArray, Executor, Flow, requests
```
2. Copy-paste the preprocessing step and wrap it via `Executor`:
```python
class PreprocImg(Executor):
@requests
def foo(self, docs: DocumentArray, **kwargs):
for d in docs:
(d.load_uri_to_image_blob()
.set_image_blob_normalization()
.set_image_blob_channel_axis(-1, 0))
return docs.apply(preproc)

def preproc(d: Document):
return (d.load_uri_to_image_blob() # load
.set_image_blob_normalization() # normalize color
.set_image_blob_channel_axis(-1, 0)) # switch color axis
```
3. Copy-paste the embedding step and wrap it via `Executor`:

Expand Down
9 changes: 9 additions & 0 deletions docs/fundamentals/document/documentarray-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,13 @@ For processing batches in parallel, please refer to {meth}`~jina.types.arrays.mi

## Parallel processing

```{seealso}
- {meth}`~jina.types.arrays.mixins.parallel.ParallelMixin.map`: to parallel process element by element, return an interator of elements;
- {meth}`~jina.types.arrays.mixins.parallel.ParallelMixin.map_batch`: to parallel process batch by batch, return an iterator of batches;
- {meth}`~jina.types.arrays.mixins.parallel.ParallelMixin.apply`: like `.map()`, but return a `DocumentArray`;
- {meth}`~jina.types.arrays.mixins.parallel.ParallelMixin.apply_batch`: like `.map_batch()`, but return a `DocumentArray`;
```

Working with large `DocumentArray` element-wise can be time-consuming. The naive way is to run a for-loop and enumerate all `Document` one by one. Jina provides {meth}`~jina.types.arrays.mixins.parallel.ParallelMixin.map` to speed up things quite a lot. It is like Python
built-in `map()` function but mapping the function to every element of the `DocumentArray` in parallel. There is also {meth}`~jina.types.arrays.mixins.parallel.ParallelMixin.map_batch` that works on the minibatch level.

Expand Down Expand Up @@ -873,6 +880,8 @@ list(da.map(func))
This follows the same convention as you using Python built-in `map()`.
You can also use `.apply()` which always returns a `DocumentArray`.
````


Expand Down
85 changes: 82 additions & 3 deletions jina/types/arrays/mixins/parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, TYPE_CHECKING, Generator, Optional
from typing import Callable, TYPE_CHECKING, Generator, Optional, overload

if TYPE_CHECKING:
from ....helper import T
Expand All @@ -9,6 +9,42 @@
class ParallelMixin:
"""Helper functions that provide parallel map to :class:`DocumentArray` or :class:`DocumentArrayMemmap`."""

@overload
def apply(
self: 'T',
func: Callable[['Document'], 'Document'],
backend: str = 'process',
num_worker: Optional[int] = None,
) -> 'T':
"""Return a new :class:`DocumentArray` or :class:`DocumentArrayMemmap` where each element is applied with ``func``.
:param func: a function that takes :class:`Document` as input and outputs :class:`Document`.
:param backend: if to use multi-`process` or multi-`thread` as the parallelization backend. In general, if your
``func`` is IO-bound then perhaps `thread` is good enough. If your ``func`` is CPU-bound then you may use `process`.
In practice, you should try yourselves to figure out the best value. However, if you wish to modify the elements
in-place, regardless of IO/CPU-bound, you should always use `thread` backend.
.. warning::
When using `process` backend, you should not expect ``func`` modify elements in-place. This is because
the multiprocessing backing pass the variable via pickle and work in another process. The passed object
and the original object do **not** share the same memory.
:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
"""
...

def apply(self: 'T', *args, **kwargs) -> 'T':
"""
# noqa: DAR102
# noqa: DAR101
# noqa: DAR201
:return: a new :class:`DocumentArray` or :class:`DocumentArrayMemmap`
"""
new_da = type(self)()
new_da.extend(self.map(*args, **kwargs))
return new_da

def map(
self,
func: Callable[['Document'], 'T'],
Expand All @@ -18,7 +54,8 @@ def map(
"""Return an iterator that applies function to every **element** of iterable in parallel, yielding the results.
.. seealso::
To process on a batch of elements, please use :meth:`.map_batch`
- To process on a batch of elements, please use :meth:`.map_batch`;
- To return a :class:`DocumentArray`/:class:`DocumentArrayMemmap`, please use :meth:`.apply`.
:param func: a function that takes :class:`Document` as input and outputs anything. You can either modify elements
in-place (only with `thread` backend) or work later on return elements.
Expand All @@ -39,6 +76,46 @@ def map(
for x in p.imap(func, self):
yield x

@overload
def apply_batch(
self: 'T',
func: Callable[['Document'], 'Document'],
batch_size: int,
backend: str = 'process',
num_worker: Optional[int] = None,
shuffle: bool = False,
) -> 'T':
"""Return a new :class:`DocumentArray` or :class:`DocumentArrayMemmap` where each element is applied with ``func``.
:param func: a function that takes :class:`Document` as input and outputs :class:`Document`.
:param backend: if to use multi-`process` or multi-`thread` as the parallelization backend. In general, if your
``func`` is IO-bound then perhaps `thread` is good enough. If your ``func`` is CPU-bound then you may use `process`.
In practice, you should try yourselves to figure out the best value. However, if you wish to modify the elements
in-place, regardless of IO/CPU-bound, you should always use `thread` backend.
.. warning::
When using `process` backend, you should not expect ``func`` modify elements in-place. This is because
the multiprocessing backing pass the variable via pickle and work in another process. The passed object
and the original object do **not** share the same memory.
:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:param batch_size: Size of each generated batch (except the last one, which might be smaller, default: 32)
:param shuffle: If set, shuffle the Documents before dividing into minibatches.
"""
...

def apply_batch(self: 'T', *args, **kwargs) -> 'T':
"""
# noqa: DAR102
# noqa: DAR101
# noqa: DAR201
:return: a new :class:`DocumentArray` or :class:`DocumentArrayMemmap`
"""
new_da = type(self)()
for _b in self.map_batch(*args, **kwargs):
new_da.extend(_b)
return new_da

def map_batch(
self,
func: Callable[['DocumentArray'], 'T'],
Expand All @@ -48,9 +125,11 @@ def map_batch(
shuffle: bool = False,
):
"""Return an iterator that applies function to every **minibatch** of iterable in parallel, yielding the results.
Each element in the returned iterator is :class:`DocumentArray`.
.. seealso::
To process single element, please use :meth:`.map`
- To process single element, please use :meth:`.map`;
- To return :class:`DocumentArray` or :class:`DocumentArrayMemmap`, please use :meth:`.apply_batch`.
:param batch_size: Size of each generated batch (except the last one, which might be smaller, default: 32)
:param shuffle: If set, shuffle the Documents before dividing into minibatches.
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/types/arrays/mixins/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def test_parallel_map(pytestconfig, da_cls, backend, num_worker):
else:
assert da.blobs is None

da = da_cls.from_files(f'{pytestconfig.rootdir}/docs/**/*.png')[:10]
da_new = da.apply(foo)
assert da_new.blobs.shape == (len(da_new), 3, 222, 222)


@pytest.mark.parametrize('da_cls', [DocumentArray, DocumentArrayMemmap])
@pytest.mark.parametrize('backend', ['process', 'thread'])
Expand Down Expand Up @@ -64,3 +68,6 @@ def test_parallel_map(pytestconfig, da_cls, backend, num_worker, b_size):
assert da.blobs.shape == (len(da), 3, 222, 222)
else:
assert da.blobs is None

da_new = da.apply_batch(foo_batch, batch_size=b_size)
assert da_new.blobs.shape == (len(da_new), 3, 222, 222)

0 comments on commit fe57705

Please sign in to comment.