Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

adding random uniform initialization #412

Merged
merged 4 commits into from
Dec 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tensornetwork/backends/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,27 @@ def randn(self,
raise NotImplementedError("Backend '{}' has not implemented randn.".format(
self.name))

def random_uniform(self,
shape: Tuple[int, ...],
boundaries: Optional[Tuple[float, float]] = (0.0, 1.0),
dtype: Optional[Type[np.number]] = None,
seed: Optional[int] = None) -> Tensor:
"""Return a random uniform matrix of dimension `dim`.
Depending on specific backends, `dim` has to be either an int
(numpy, torch, tensorflow) or a `ShapeType` object
(for block-sparse backends). Block-sparse
behavior is currently not supported
Args:
shape (int): The dimension of the returned matrix.
boundaries (tuple): The boundaries of the uniform distribution.
dtype: The dtype of the returned matrix.
seed: The seed for the random number generator
Returns:
Tensor : random uniform initialized tensor.
"""
raise NotImplementedError(("Backend '{}' has not implemented "
"random_uniform.").format(self.name))

def conj(self, tensor: Tensor) -> Tensor:
"""
Return the complex conjugate of `tensor`
Expand Down
36 changes: 36 additions & 0 deletions tensornetwork/backends/jax/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,42 @@ def cmplx_randn(complex_dtype, real_dtype):

return self.jax.random.normal(key, shape).astype(dtype)

def random_uniform(self,
shape: Tuple[int, ...],
boundaries: Optional[Tuple[float, float]] = (0.0, 1.0),
dtype: Optional[np.dtype] = None,
seed: Optional[int] = None) -> Tensor:
if not seed:
seed = np.random.randint(0, 2**63)
key = self.jax.random.PRNGKey(seed)

dtype = dtype if dtype is not None else np.dtype(np.float64)

def cmplx_random_uniform(complex_dtype, real_dtype):
real_dtype = np.dtype(real_dtype)
complex_dtype = np.dtype(complex_dtype)

key_2 = self.jax.random.PRNGKey(seed + 1)

real_part = self.jax.random.uniform(key, shape, dtype=real_dtype,
minval=boundaries[0],
maxval=boundaries[1])
complex_part = self.jax.random.uniform(key_2, shape, dtype=real_dtype,
minval=boundaries[0],
maxval=boundaries[1])
unit = (
np.complex64(1j)
if complex_dtype == np.dtype(np.complex64) else np.complex128(1j))
return real_part + unit * complex_part

if np.dtype(dtype) is np.dtype(self.np.complex128):
return cmplx_random_uniform(dtype, self.np.float64)
if np.dtype(dtype) is np.dtype(self.np.complex64):
return cmplx_random_uniform(dtype, self.np.float32)

return self.jax.random.uniform(key, shape, minval=boundaries[0],
maxval=boundaries[1]).astype(dtype)

def eigs(self,
A: Callable,
initial_state: Optional[Tensor] = None,
Expand Down
49 changes: 49 additions & 0 deletions tensornetwork/backends/jax/jax_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,27 @@ def test_randn(dtype):
assert a.shape == (4, 4)


@pytest.mark.parametrize("dtype", np_randn_dtypes)
def test_random_uniform(dtype):
backend = jax_backend.JaxBackend()
a = backend.random_uniform((4, 4), dtype=dtype)
assert a.shape == (4, 4)


@pytest.mark.parametrize("dtype", [np.complex64, np.complex128])
def test_randn_non_zero_imag(dtype):
backend = jax_backend.JaxBackend()
a = backend.randn((4, 4), dtype=dtype)
assert np.linalg.norm(np.imag(a)) != 0.0


@pytest.mark.parametrize("dtype", [np.complex64, np.complex128])
def test_random_uniform_non_zero_imag(dtype):
backend = jax_backend.JaxBackend()
a = backend.random_uniform((4, 4), dtype=dtype)
assert np.linalg.norm(np.imag(a)) != 0.0


@pytest.mark.parametrize("dtype", np_dtypes)
def test_eye_dtype(dtype):
backend = jax_backend.JaxBackend()
Expand Down Expand Up @@ -196,6 +210,13 @@ def test_randn_dtype(dtype):
assert a.dtype == dtype


@pytest.mark.parametrize("dtype", np_randn_dtypes)
def test_random_uniform_dtype(dtype):
backend = jax_backend.JaxBackend()
a = backend.random_uniform((4, 4), dtype=dtype)
assert a.dtype == dtype


@pytest.mark.parametrize("dtype", np_randn_dtypes)
def test_randn_seed(dtype):
backend = jax_backend.JaxBackend()
Expand All @@ -204,6 +225,34 @@ def test_randn_seed(dtype):
np.testing.assert_allclose(a, b)


@pytest.mark.parametrize("dtype", np_randn_dtypes)
def test_random_uniform_seed(dtype):
backend = jax_backend.JaxBackend()
a = backend.random_uniform((4, 4), seed=10, dtype=dtype)
b = backend.random_uniform((4, 4), seed=10, dtype=dtype)
np.testing.assert_allclose(a, b)


@pytest.mark.parametrize("dtype", np_randn_dtypes)
def test_random_uniform_boundaries(dtype):
lb = 1.2
ub = 4.8
backend = jax_backend.JaxBackend()
a = backend.random_uniform((4, 4), seed=10, dtype=dtype)
b = backend.random_uniform((4, 4), (lb, ub), seed=10, dtype=dtype)
assert((a >= 0).all() and (a <= 1).all() and
(b >= lb).all() and (b <= ub).all())


def test_random_uniform_behavior():
seed = 10
key = jax.random.PRNGKey(seed)
backend = jax_backend.JaxBackend()
a = backend.random_uniform((4, 4), seed=seed)
b = jax.random.uniform(key, (4, 4))
np.testing.assert_allclose(a, b)


def test_conj():
backend = jax_backend.JaxBackend()
real = np.random.rand(2, 2, 2)
Expand Down
18 changes: 18 additions & 0 deletions tensornetwork/backends/numpy/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,24 @@ def randn(self,
dtype) + 1j * self.np.random.randn(*shape).astype(dtype)
return self.np.random.randn(*shape).astype(dtype)

def random_uniform(self,
shape: Tuple[int, ...],
boundaries: Optional[Tuple[float, float]] = (0.0, 1.0),
dtype: Optional[numpy.dtype] = None,
seed: Optional[int] = None) -> Tensor:

if seed:
self.np.random.seed(seed)
dtype = dtype if dtype is not None else self.np.float64
if ((self.np.dtype(dtype) is self.np.dtype(self.np.complex128)) or
(self.np.dtype(dtype) is self.np.dtype(self.np.complex64))):
return self.np.random.uniform(boundaries[0], boundaries[1], shape).astype(
dtype) + 1j * self.np.random.uniform(boundaries[0],
boundaries[1],
shape).astype(dtype)
return self.np.random.uniform(boundaries[0],
boundaries[1], shape).astype(dtype)

def conj(self, tensor: Tensor) -> Tensor:
return self.np.conj(tensor)

Expand Down
48 changes: 48 additions & 0 deletions tensornetwork/backends/numpy/numpy_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,27 @@ def test_randn(dtype):
assert a.shape == (4, 4)


@pytest.mark.parametrize("dtype", np_dtypes)
def test_random_uniform(dtype):
backend = numpy_backend.NumPyBackend()
a = backend.random_uniform((4, 4), dtype=dtype, seed=10)
assert a.shape == (4, 4)


@pytest.mark.parametrize("dtype", [np.complex64, np.complex128])
def test_randn_non_zero_imag(dtype):
backend = numpy_backend.NumPyBackend()
a = backend.randn((4, 4), dtype=dtype, seed=10)
assert np.linalg.norm(np.imag(a)) != 0.0


@pytest.mark.parametrize("dtype", [np.complex64, np.complex128])
def test_random_uniform_non_zero_imag(dtype):
backend = numpy_backend.NumPyBackend()
a = backend.random_uniform((4, 4), dtype=dtype, seed=10)
assert np.linalg.norm(np.imag(a)) != 0.0


@pytest.mark.parametrize("dtype", np_dtypes)
def test_eye_dtype(dtype):
backend = numpy_backend.NumPyBackend()
Expand Down Expand Up @@ -194,6 +208,13 @@ def test_randn_dtype(dtype):
assert a.dtype == dtype


@pytest.mark.parametrize("dtype", np_dtypes)
def test_random_uniform_dtype(dtype):
backend = numpy_backend.NumPyBackend()
a = backend.random_uniform((4, 4), dtype=dtype, seed=10)
assert a.dtype == dtype


@pytest.mark.parametrize("dtype", np_randn_dtypes)
def test_randn_seed(dtype):
backend = numpy_backend.NumPyBackend()
Expand All @@ -202,6 +223,33 @@ def test_randn_seed(dtype):
np.testing.assert_allclose(a, b)


@pytest.mark.parametrize("dtype", np_dtypes)
def test_random_uniform_seed(dtype):
backend = numpy_backend.NumPyBackend()
a = backend.random_uniform((4, 4), seed=10, dtype=dtype)
b = backend.random_uniform((4, 4), seed=10, dtype=dtype)
np.testing.assert_allclose(a, b)


@pytest.mark.parametrize("dtype", np_randn_dtypes)
def test_random_uniform_boundaries(dtype):
lb = 1.2
ub = 4.8
backend = numpy_backend.NumPyBackend()
a = backend.random_uniform((4, 4), seed=10, dtype=dtype)
b = backend.random_uniform((4, 4), (lb, ub), seed=10, dtype=dtype)
assert((a >= 0).all() and (a <= 1).all() and
(b >= lb).all() and (b <= ub).all())


def test_random_uniform_behavior():
backend = numpy_backend.NumPyBackend()
a = backend.random_uniform((4, 4), seed=10)
np.random.seed(10)
b = np.random.uniform(size=(4, 4))
np.testing.assert_allclose(a, b)


def test_conj():
backend = numpy_backend.NumPyBackend()
real = np.random.rand(2, 2, 2)
Expand Down
10 changes: 10 additions & 0 deletions tensornetwork/backends/pytorch/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ def randn(self,
dtype = dtype if dtype is not None else self.torch.float64
return self.torch.randn(shape, dtype=dtype)

def random_uniform(self,
shape: Tuple[int, ...],
boundaries: Optional[Tuple[float, float]] = (0.0, 1.0),
dtype: Optional[Any] = None,
seed: Optional[int] = None) -> Tensor:
if seed:
self.torch.manual_seed(seed)
dtype = dtype if dtype is not None else self.torch.float64
return self.torch.empty(shape, dtype=dtype).uniform_(*boundaries)

def conj(self, tensor: Tensor) -> Tensor:
return tensor #pytorch does not support complex dtypes

Expand Down
41 changes: 41 additions & 0 deletions tensornetwork/backends/pytorch/pytorch_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ def test_randn(dtype):
assert a.shape == (4, 4)


@pytest.mark.parametrize("dtype", torch_randn_dtypes)
def test_random_uniform(dtype):
backend = pytorch_backend.PyTorchBackend()
a = backend.random_uniform((4, 4), dtype=dtype)
assert a.shape == (4, 4)


@pytest.mark.parametrize("dtype", torch_eye_dtypes)
def test_eye_dtype(dtype):
backend = pytorch_backend.PyTorchBackend()
Expand Down Expand Up @@ -176,6 +183,13 @@ def test_randn_dtype(dtype):
assert a.dtype == dtype


@pytest.mark.parametrize("dtype", torch_randn_dtypes)
def test_random_uniform_dtype(dtype):
backend = pytorch_backend.PyTorchBackend()
a = backend.random_uniform((4, 4), dtype=dtype)
assert a.dtype == dtype


@pytest.mark.parametrize("dtype", torch_randn_dtypes)
def test_randn_seed(dtype):
backend = pytorch_backend.PyTorchBackend()
Expand All @@ -184,6 +198,33 @@ def test_randn_seed(dtype):
np.testing.assert_allclose(a, b)


@pytest.mark.parametrize("dtype", torch_randn_dtypes)
def test_random_uniform_seed(dtype):
backend = pytorch_backend.PyTorchBackend()
a = backend.random_uniform((4, 4), seed=10, dtype=dtype)
b = backend.random_uniform((4, 4), seed=10, dtype=dtype)
torch.allclose(a, b)


@pytest.mark.parametrize("dtype", torch_randn_dtypes)
def test_random_uniform_boundaries(dtype):
lb = 1.2
ub = 4.8
backend = pytorch_backend.PyTorchBackend()
a = backend.random_uniform((4, 4), seed=10, dtype=dtype)
b = backend.random_uniform((4, 4), (lb, ub), seed=10, dtype=dtype)
assert(torch.ge(a, 0).byte().all() and torch.le(a, 1).byte().all() and
torch.ge(b, lb).byte().all() and torch.le(b, ub).byte().all())


def test_random_uniform_behavior():
backend = pytorch_backend.PyTorchBackend()
a = backend.random_uniform((4, 4), seed=10)
torch.manual_seed(10)
b = torch.empty((4, 4), dtype=torch.float64).uniform_()
torch.allclose(a, b)


def test_conj():
backend = pytorch_backend.PyTorchBackend()
real = np.random.rand(2, 2, 2)
Expand Down
7 changes: 7 additions & 0 deletions tensornetwork/backends/shell/shell_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ def randn(self,
seed: Optional[int] = None) -> Tensor:
return ShellTensor(shape)

def random_uniform(self,
shape: Tuple[int, ...],
boundaries: Optional[Tuple[float, float]] = (0.0, 1.0),
dtype: Optional[Type[np.number]] = None,
seed: Optional[int] = None) -> Tensor:
return ShellTensor(shape)

def conj(self, tensor: Tensor) -> Tensor:
return tensor

Expand Down
5 changes: 5 additions & 0 deletions tensornetwork/backends/shell/shell_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ def test_randn():
assertBackendsAgree("randn", args)


def test_random_uniform():
args = {"shape": (10, 4)}
assertBackendsAgree("random_uniform", args)


def test_eigsh_lanczos_1():
backend = shell_backend.ShellBackend()
D = 16
Expand Down
20 changes: 20 additions & 0 deletions tensornetwork/backends/tensorflow/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,26 @@ def randn(self,
self.tf.random.normal(shape=shape, dtype=dtype.real_dtype))
return self.tf.random.normal(shape=shape, dtype=dtype)

def random_uniform(self,
shape: Tuple[int, ...],
boundaries: Optional[Tuple[float, float]] = (0.0, 1.0),
dtype: Optional[Type[np.number]] = None,
seed: Optional[int] = None) -> Tensor:
if seed:
self.tf.random.set_seed(seed)

dtype = dtype if dtype is not None else self.tf.float64
if (dtype is self.tf.complex128) or (dtype is self.tf.complex64):
return self.tf.complex(
self.tf.random.uniform(shape=shape, minval=boundaries[0],
maxval=boundaries[1], dtype=dtype.real_dtype),
self.tf.random.uniform(shape=shape, minval=boundaries[0],
maxval=boundaries[1], dtype=dtype.real_dtype))
self.tf.random.set_seed(10)
a = self.tf.random.uniform(shape=shape, minval=boundaries[0],
maxval=boundaries[1], dtype=dtype)
return a

def conj(self, tensor: Tensor) -> Tensor:
return self.tf.math.conj(tensor)

Expand Down
Loading