Skip to content

Commit

Permalink
docs: add doctest entries
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 5, 2023
1 parent f593ad8 commit 86f4f7c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ OpTree: Optimized PyTree Utilities
api.rst

.. toctree::
:maxdepth: 1
:maxdepth: 2

integration.rst

Expand Down
20 changes: 20 additions & 0 deletions optree/integration/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,26 @@ def tree_ravel(
) -> tuple[Array, Callable[[Array], ArrayTree]]:
r"""Ravel (flatten) a pytree of arrays down to a 1D array.
>>> tree = {
... 'layer1': {
... 'weight': jnp.arange(6, dtype=jnp.float32).reshape((2, 3)),
... 'bias': jnp.arange(2, dtype=jnp.float32).reshape((2,)),
... },
... 'layer2': {
... 'weight': jnp.arange(2, dtype=jnp.float32).reshape((1, 2)),
... 'bias': jnp.arange(1, dtype=jnp.float32).reshape((1,))
... },
... }
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
Array([0., 1., 0., 1., 2., 3., 4., 5., 0., 0., 1.], dtype=float32)
>>> unravel_func(flat)
{'layer1': {'weight': Array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': Array([0., 1.], dtype=float32)},
'layer2': {'weight': Array([[0., 1.]], dtype=float32),
'bias': Array([0.], dtype=float32)}}
Args:
tree (pytree): a pytree of arrays and scalars to ravel.
is_leaf (callable, optional): An optionally specified function that will be called at each
Expand Down
20 changes: 20 additions & 0 deletions optree/integration/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,26 @@ def tree_ravel(
) -> tuple[np.ndarray, Callable[[np.ndarray], ArrayTree]]:
r"""Ravel (flatten) a pytree of arrays down to a 1D array.
>>> tree = {
... 'layer1': {
... 'weight': np.arange(6, dtype=np.float32).reshape((2, 3)),
... 'bias': np.arange(2, dtype=np.float32).reshape((2,)),
... },
... 'layer2': {
... 'weight': np.arange(2, dtype=np.float32).reshape((1, 2)),
... 'bias': np.arange(1, dtype=np.float32).reshape((1,))
... },
... }
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
array([0., 1., 0., 1., 2., 3., 4., 5., 0., 0., 1.], dtype=float32)
>>> unravel_func(flat)
{'layer1': {'weight': array([[0., 1., 2.],
[3., 4., 5.]], dtype=float32),
'bias': array([0., 1.], dtype=float32)},
'layer2': {'weight': array([[0., 1.]], dtype=float32),
'bias': array([0.], dtype=float32)}}
Args:
tree (pytree): a pytree of arrays and scalars to ravel.
is_leaf (callable, optional): An optionally specified function that will be called at each
Expand Down
20 changes: 20 additions & 0 deletions optree/integration/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,26 @@ def tree_ravel(
) -> tuple[torch.Tensor, Callable[[torch.Tensor], TensorTree]]:
r"""Ravel (flatten) a pytree of tensors down to a 1D tensor.
>>> tree = {
... 'layer1': {
... 'weight': torch.arange(6, dtype=torch.float64).reshape((2, 3)),
... 'bias': torch.arange(2, dtype=torch.float64).reshape((2,)),
... },
... 'layer2': {
... 'weight': torch.arange(2, dtype=torch.float64).reshape((1, 2)),
... 'bias': torch.arange(1, dtype=torch.float64).reshape((1,))
... },
... }
>>> flat, unravel_func = tree_ravel(tree)
>>> flat
tensor([0., 1., 0., 1., 2., 3., 4., 5., 0., 0., 1.], dtype=torch.float64)
>>> unravel_func(flat)
{'layer1': {'weight': tensor([[0., 1., 2.],
[3., 4., 5.]], dtype=torch.float64),
'bias': tensor([0., 1.], dtype=torch.float64)},
'layer2': {'weight': tensor([[0., 1.]], dtype=torch.float64),
'bias': tensor([0.], dtype=torch.float64)}}
Args:
tree (pytree): a pytree of tensors to ravel.
is_leaf (callable, optional): An optionally specified function that will be called at each
Expand Down

0 comments on commit 86f4f7c

Please sign in to comment.