Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def check(self, value):
" positive definite matrices".format(value)
eps = 1e-5
condition = np.all(
np.abs(value - np.swapaxes(value, -1, -2)) < eps, axis=-1)
np.abs(value - value.mT) < eps, axis=-1)
condition = condition & (np.linalg.eigvals(value) > 0)
_value = constraint_check()(condition, err_msg) * value
return _value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def _precision_to_scale_tril(self, P):
L = flip(Cholesky(flip(P))).T
"""
L_flip_inv_T = np.linalg.cholesky(np.flip(P, (-1, -2)))
L = np.linalg.inv(np.swapaxes(
np.flip(L_flip_inv_T, (-1, -2)), -1, -2))
L = np.linalg.inv(np.flip(L_flip_inv_T, (-1, -2)).mT)
return L

@cached_property
Expand All @@ -87,8 +86,7 @@ def scale_tril(self):
def cov(self):
# pylint: disable=method-hidden
if 'scale_tril' in self.__dict__:
scale_triu = np.swapaxes(self.scale_tril, -1, -2)
return np.matmul(self.scale_tril, scale_triu)
return np.matmul(self.scale_tril, self.scale_tril.mT)
return np.linalg.inv(self.precision)

@cached_property
Expand All @@ -97,8 +95,7 @@ def precision(self):
if 'cov' in self.__dict__:
return np.linalg.inv(self.cov)
scale_tril_inv = np.linalg.inv(self.scale_tril)
scale_triu_inv = np.swapaxes(scale_tril_inv, -1, -2)
return np.matmul(scale_triu_inv, scale_tril_inv)
return np.matmul(scale_tril_inv.mT, scale_tril_inv)

@property
def mean(self):
Expand Down
16 changes: 10 additions & 6 deletions python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def matrix_transpose(a):

Notes
-----
`matrix_transpose` is an alias for `transpose`. It is a standard API in
`matrix_transpose` is new in array API spec:
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-matrix-transpose-x
instead of an official NumPy operator.
instead of an official NumPy operator. Unlike transpose, it only transposes the last two axes.

Parameters
----------
Expand All @@ -103,14 +103,18 @@ def matrix_transpose(a):
>>> x
array([[0., 1.],
[2., 3.]])
>>> np.transpose(x)
>>> np.linalg.matrix_transpose(x)
array([[0., 2.],
[1., 3.]])
>>> x = np.ones((1, 2, 3))
>>> np.transpose(x, (1, 0, 2)).shape
(2, 1, 3)
>>> np.linalg.matrix_transpose(x)
array([[[1., 1.],
[1., 1.],
[1., 1.]]])
"""
return _mx_nd_np.transpose(a, axes=None)
if a.ndim < 2:
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
return _mx_nd_np.swapaxes(a, -1, -2)


def trace(a, offset=0):
Expand Down
55 changes: 54 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'atleast_1d', 'atleast_2d', 'atleast_3d', 'fill_diagonal', 'squeeze',
'diagflat', 'repeat', 'prod', 'pad', 'cumsum', 'sum', 'rollaxis', 'diag', 'diagonal',
'positive', 'logaddexp', 'floor_divide']
'positive', 'logaddexp', 'floor_divide', 'permute_dims']

__all__ += fallback.__all__

Expand Down Expand Up @@ -1333,9 +1333,22 @@ def nonzero(self):
# pylint: disable= invalid-name, undefined-variable
def T(self):
"""Same as self.transpose(). This always returns a copy of self."""
if self.ndim != 2:
warnings.warn('x.T requires x to have 2 dimensions. '
'Use x.mT to transpose stacks of matrices and '
'permute_dims() to permute dimensions.')
return self.transpose()
# pylint: enable= invalid-name, undefined-variable

@property
# pylint: disable= invalid-name, undefined-variable
def mT(self):
"""Same as self.transpose(). This always returns a copy of self."""
if self.ndim < 2:
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
return _mx_nd_np.swapaxes(self, -1, -2)
# pylint: enable= invalid-name, undefined-variable

def all(self, axis=None, out=None, keepdims=False):
return _mx_nd_np.all(self, axis=axis, out=out, keepdims=keepdims)

Expand Down Expand Up @@ -6447,6 +6460,46 @@ def transpose(a, axes=None):
return _mx_nd_np.transpose(a, axes)


@set_module('mxnet.numpy')
def permute_dims(a, axes=None):
"""
Permute the dimensions of an array.

Parameters
----------
a : ndarray
Input array.
axes : list of ints, optional
By default, reverse the dimensions,
otherwise permute the axes according to the values given.

Returns
-------
p : ndarray
a with its axes permuted.

Note
--------
`permute_dims` is a alias for `transpose`. It is a standard API in
https://data-apis.org/array-api/latest/API_specification/manipulation_functions.html#permute-dims-x-axes
instead of an official NumPy operator.

Examples
--------
>>> x = np.arange(4).reshape((2,2))
>>> x
array([[0., 1.],
[2., 3.]])
>>> np.permute_dims(x)
array([[0., 2.],
[1., 3.]])
>>> x = np.ones((1, 2, 3))
>>> np.permute_dims(x, (1, 0, 2)).shape
(2, 1, 3)
"""
return _mx_nd_np.transpose(a, axes)


@set_module('mxnet.numpy')
def repeat(a, repeats, axis=None):
"""
Expand Down
14 changes: 5 additions & 9 deletions tests/python/unittest/test_gluon_probability_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,8 +1632,7 @@ def _stable_inv(cov):
Force the precision matrix to be symmetric.
"""
precision = np.linalg.inv(cov)
precision_t = np.swapaxes(precision, -1, -2)
return (precision + precision_t) / 2
return (precision + precision.mT) / 2

event_shapes = [3, 5]
loc_shapes = [(), (2,), (4, 2)]
Expand All @@ -1653,8 +1652,7 @@ def _stable_inv(cov):
loc.attach_grad()
_s.attach_grad()
# Full covariance matrix
sigma = np.matmul(_s, np.swapaxes(
_s, -1, -2)) + np.eye(event_shape)
sigma = np.matmul(_s, _s.mT) + np.eye(event_shape)
cov_param = cov_func[cov_type](sigma)
net = TestMVN('sample', cov_type)
if hybridize:
Expand All @@ -1678,8 +1676,7 @@ def _stable_inv(cov):
loc.attach_grad()
_s.attach_grad()
# Full covariance matrix
sigma = np.matmul(_s, np.swapaxes(
_s, -1, -2)) + np.eye(event_shape)
sigma = np.matmul(_s, _s.mT) + np.eye(event_shape)
cov_param = cov_func[cov_type](sigma)
net = TestMVN('log_prob', cov_type)
if hybridize:
Expand Down Expand Up @@ -1709,8 +1706,7 @@ def _stable_inv(cov):
loc.attach_grad()
_s.attach_grad()
# Full covariance matrix
sigma = np.matmul(_s, np.swapaxes(
_s, -1, -2)) + np.eye(event_shape)
sigma = np.matmul(_s, _s.mT) + np.eye(event_shape)
cov_param = cov_func[cov_type](sigma)
net = TestMVN('entropy', cov_type)
if hybridize:
Expand Down Expand Up @@ -2093,7 +2089,7 @@ def param2(): return np.random.uniform(0.5, 1.5, shape)
for loc_shape, cov_shape, event_shape in itertools.product(loc_shapes, cov_shapes, event_shapes):
loc = np.random.randn(*(loc_shape + (event_shape,)))
_s = np.random.randn(*(cov_shape + (event_shape, event_shape)))
sigma = np.matmul(_s, np.swapaxes(_s, -1, -2)) + np.eye(event_shape)
sigma = np.matmul(_s, _s.mT) + np.eye(event_shape)
dist = mgp.MultivariateNormal(loc, cov=sigma)
desired_shape = (loc + sigma[..., 0]).shape[:-1]
_test_zero_kl(dist, desired_shape)
Expand Down
99 changes: 99 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2580,6 +2580,67 @@ def test_np_transpose_error():
pytest.raises(MXNetError, lambda: dat.transpose((0, 1, 3)))


@use_np
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float16, onp.int32])
@pytest.mark.parametrize('data_shape,axes_workload', [
[(), [(), None]],
[(2,), [(0,), None]],
[(0, 2), [(0, 1), (1, 0)]],
[(5, 10), [(0, 1), (1, 0), None]],
[(8, 2, 3), [(2, 0, 1), (0, 2, 1), (0, 1, 2), (2, 1, 0), (-1, 1, 0), None]],
[(8, 2, 16), [(0, 2, 1), (2, 0, 1), (0, 1, 2), (2, 1, 0), (-1, -2, -3)]],
[(8, 3, 4, 8), [(0, 2, 3, 1), (1, 2, 3, 0), (0, 3, 2, 1)]],
[(8, 3, 2, 3, 8), [(0, 1, 3, 2, 4), (0, 1, 2, 3, 4), (4, 0, 1, 2, 3)]],
[(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]],
[(3, 4, 3, 4, 3, 2, 2), [(0, 1, 3, 2, 4, 5, 6),
(2, 3, 4, 1, 0, 5, 6), None]],
[(3, 4, 3, 4, 3, 2, 3, 2), [(0, 1, 3, 2, 4, 5, 7, 6),
(2, 3, 4, 1, 0, 5, 7, 6), None]],
])
@pytest.mark.parametrize('grad_req', ['write', 'add'])
def test_np_permute_dims(data_shape, axes_workload, hybridize, dtype, grad_req):
def np_permute_dims_grad(out_shape, dtype, axes=None):
ograd = onp.ones(out_shape, dtype=dtype)
if axes is None or axes == ():
return onp.transpose(ograd, axes)
np_axes = onp.array(list(axes))
permute_dims_axes = onp.zeros_like(np_axes)
permute_dims_axes[np_axes] = onp.arange(len(np_axes))
return onp.transpose(ograd, tuple(list(permute_dims_axes)))

class TestPermuteDims(HybridBlock):
def __init__(self, axes=None):
super(TestPermuteDims, self).__init__()
self.axes = axes

def forward(self, a):
return np.permute_dims(a, self.axes)

for axes in axes_workload:
test_trans = TestPermuteDims(axes)
if hybridize:
test_trans.hybridize()
x = np.random.normal(0, 1, data_shape).astype(dtype)
x = x.astype(dtype)
x.attach_grad(grad_req=grad_req)
if grad_req == 'add':
x.grad[()] = np.random.normal(0, 1, x.grad.shape).astype(x.grad.dtype)
x_grad_np = x.grad.asnumpy()
np_out = onp.transpose(x.asnumpy(), axes)
with mx.autograd.record():
mx_out = test_trans(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)
mx_out.backward()
np_backward = np_permute_dims_grad(np_out.shape, dtype, axes)
if grad_req == 'add':
assert_almost_equal(x.grad.asnumpy(), np_backward + x_grad_np,
rtol=1e-3, atol=1e-5, use_broadcast=False)
else:
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False)


@use_np
def test_np_meshgrid():
nx, ny = (4, 5)
Expand Down Expand Up @@ -6820,6 +6881,44 @@ def check_matrix_rank(rank, a_np, tol, hermitian):
check_matrix_rank(rank, a.asnumpy(), tol.asnumpy(), hermitian=False)


@use_np
@pytest.mark.parametrize('shape', [
(),
(1,),
(0, 1, 2),
(0, 1, 2),
(0, 1, 2),
(4, 5, 6, 7),
(4, 5, 6, 7),
(4, 5, 6, 7),
])
def test_np_linalg_matrix_transpose(shape):
class TestMatTranspose(HybridBlock):
def __init__(self):
super(TestMatTranspose, self).__init__()

def forward(self, x):
return np.linalg.matrix_transpose(x)

data_np = onp.random.uniform(size=shape)
data_mx = np.array(data_np, dtype=data_np.dtype)
if data_mx.ndim < 2:
assertRaises(ValueError, np.linalg.matrix_transpose, data_mx)
return
ret_np = onp.swapaxes(data_np, -1, -2)
ret_mx = np.linalg.matrix_transpose(data_mx)
assert same(ret_mx.asnumpy(), ret_np)

net = TestMatTranspose()
for hybrid in [False, True]:
if hybrid:
net.hybridize()
ret_mx = net(data_mx)
assert same(ret_mx.asnumpy(), ret_np)

assert same(data_mx.mT.asnumpy(), ret_np)


@use_np
def test_np_linalg_pinv():
class TestPinv(HybridBlock):
Expand Down