Skip to content

Commit

Permalink
chore: appease warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Jun 22, 2024
1 parent 3b33775 commit a2444cb
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion optree/integration/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _ravel_leaves(
leaves: list[ArrayLike],
) -> tuple[Array, Callable[[Array], list[ArrayLike]]]:
if not leaves:
return (jnp.array([]), _unravel_empty)
return (jnp.zeros(0), _unravel_empty)

Check warning on line 204 in optree/integration/jax.py

View check run for this annotation

Codecov / codecov/patch

optree/integration/jax.py#L204

Added line #L204 was not covered by tests

from_dtypes = tuple(dtypes.dtype(leaf) for leaf in leaves)
to_dtype = dtypes.result_type(*from_dtypes)
Expand Down
2 changes: 1 addition & 1 deletion optree/integration/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _ravel_leaves(
leaves: list[np.ndarray],
) -> tuple[np.ndarray, Callable[[np.ndarray], list[np.ndarray]]]:
if not leaves:
return (np.array([]), _unravel_empty)
return (np.zeros(0), _unravel_empty)

Check warning on line 133 in optree/integration/numpy.py

View check run for this annotation

Codecov / codecov/patch

optree/integration/numpy.py#L133

Added line #L133 was not covered by tests

from_dtypes = tuple(np.result_type(leaf) for leaf in leaves)
to_dtype = np.result_type(*leaves)
Expand Down
2 changes: 1 addition & 1 deletion optree/integration/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _ravel_leaves(
leaves: list[torch.Tensor],
) -> tuple[torch.Tensor, Callable[[torch.Tensor], list[torch.Tensor]]]:
if not leaves:
return (torch.tensor([]), _unravel_empty)
return (torch.zeros(0), _unravel_empty)

Check warning on line 130 in optree/integration/torch.py

View check run for this annotation

Codecov / codecov/patch

optree/integration/torch.py#L130

Added line #L130 was not covered by tests
if not all(torch.is_tensor(leaf) for leaf in leaves):
raise ValueError('All leaves must be tensors.')

Expand Down
3 changes: 2 additions & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def parametrize(**argvalues):
argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments))))

ids = tuple(
'-'.join(f'{arg}({val!r})' for arg, val in zip(arguments, values)) for values in argvalues
'-'.join(f'{arg}({value!r})' for arg, value in zip(arguments, values))
for values in argvalues
)

return pytest.mark.parametrize(arguments, argvalues, ids=ids)
Expand Down

0 comments on commit a2444cb

Please sign in to comment.