Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creation functions Fixup #1952

Merged
merged 21 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions ivy/functional/backends/jax/creation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# global
import jax.numpy as jnp
from typing import Union, Optional, List, Sequence
from typing import Union, Optional, Tuple, List, Sequence
from numbers import Number
import jaxlib.xla_extension
from jax.dlpack import from_dlpack as jax_from_dlpack

Expand All @@ -18,13 +19,13 @@


def arange(
start,
stop=None,
step=1,
start: Number,
iamjameskeane marked this conversation as resolved.
Show resolved Hide resolved
stop: Optional[Number] = None,
step: Number = 1,
*,
dtype: jnp.dtype = None,
device: jaxlib.xla_extension.Device,
):
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device
) -> JaxArray:
if dtype:
dtype = as_native_dtype(dtype)
res = _to_device(jnp.arange(start, stop, step=step, dtype=dtype), device=device)
Expand All @@ -37,12 +38,12 @@ def arange(


def asarray(
object_in,
object_in: Union[JaxArray, jnp.ndarray, List[Number], Tuple[Number]],
*,
copy: Optional[bool] = None,
dtype: jnp.dtype = None,
device: jaxlib.xla_extension.Device,
):
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device
) -> JaxArray:
if isinstance(object_in, ivy.NativeArray) and dtype != "bool":
dtype = object_in.dtype
elif (
Expand Down Expand Up @@ -70,7 +71,7 @@ def empty(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device
) -> JaxArray:
return _to_device(
jnp.empty(shape, as_native_dtype(default_dtype(dtype))), device=device
Expand All @@ -94,15 +95,15 @@ def eye(
k: Optional[int] = 0,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device
) -> JaxArray:
dtype = as_native_dtype(default_dtype(dtype))
device = default_device(device)
return _to_device(jnp.eye(n_rows, n_cols, k, dtype), device=device)


# noinspection PyShadowingNames
def from_dlpack(x):
def from_dlpack(x: JaxArray) -> JaxArray:
return jax_from_dlpack(x)


Expand All @@ -111,7 +112,7 @@ def full(
fill_value: Union[int, float],
*,
dtype: jnp.dtype = None,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device
) -> JaxArray:
return _to_device(
jnp.full(shape, fill_value, as_native_dtype(default_dtype(dtype, fill_value))),
Expand All @@ -121,10 +122,10 @@ def full(

def full_like(
x: JaxArray,
fill_value: Union[int, float],
fill_value: float,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device
) -> JaxArray:
if dtype and str:
dtype = jnp.dtype(dtype)
Expand All @@ -140,15 +141,15 @@ def full_like(


def linspace(
start,
stop,
num,
axis=None,
endpoint=True,
start: Union[JaxArray, float],
stop: float,
num: int,
axis: Optional[int] = None,
endpoint: bool = True,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
):
device: jaxlib.xla_extension.Device
) -> JaxArray:
if axis is None:
axis = -1
ans = jnp.linspace(start, stop, num, endpoint, dtype=dtype, axis=axis)
Expand All @@ -165,7 +166,7 @@ def ones(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: Optional[Union[ivy.Dtype, jnp.dtype]] = None,
device: Optional[Union[ivy.Device, jaxlib.xla_extension.Device]] = None,
device: Optional[Union[ivy.Device, jaxlib.xla_extension.Device]] = None
) -> JaxArray:
return _to_device(
jnp.ones(shape, as_native_dtype(default_dtype(dtype))), device=device
Expand Down Expand Up @@ -194,7 +195,7 @@ def zeros(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device,
device: jaxlib.xla_extension.Device
) -> JaxArray:
return _to_device(
jnp.zeros(shape, dtype),
Expand All @@ -218,8 +219,14 @@ def zeros_like(


def logspace(
start, stop, num, base=10.0, axis=None, *, device: jaxlib.xla_extension.Device
):
start: Union[JaxArray, int],
stop: Union[JaxArray, int],
num: int,
base: float = 10.0,
axis: int = None,
*,
device: jaxlib.xla_extension.Device
) -> JaxArray:
if axis is None:
axis = -1
return _to_device(
Expand Down
53 changes: 40 additions & 13 deletions ivy/functional/backends/mxnet/creation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# global
import mxnet as mx
from typing import Union, List, Optional, Iterable, Sequence
from typing import Union, List, Optional, Iterable, Sequence, Tuple
from numbers import Number

# local
Expand Down Expand Up @@ -29,7 +29,13 @@ def _linspace(start, stop, num, cont):
return ret


def arange(stop, start=0, step=1, dtype=None, device=None):
def arange(
stop: Optional[Number] = None,
start: Number = 0,
step: Number = 1,
dtype: Optional[type] = None,
device: mx.context.Context = None,
) -> mx.nd.NDArray:
cont = _mxnet_init_context(default_device(device))
stop = stop if isinstance(stop, Number) else stop.asscalar()
start = start if isinstance(start, Number) else start.asscalar()
Expand All @@ -38,11 +44,11 @@ def arange(stop, start=0, step=1, dtype=None, device=None):


def asarray(
object_in,
dtype: Optional[Union[ivy.Dtype, type]] = None,
device: Optional[Union[ivy.Device, mx.context.Context]] = None,
object_in: Union[mx.nd.NDArray, List[Number], Tuple[Number]],
dtype: Optional[type] = None,
device: mx.context.Context = None,
copy: Optional[bool] = None,
):
) -> mx.nd.NDArray:
# mxnet don't have asarray implementation, haven't properly tested
cont = _mxnet_init_context(default_device(device))
if copy is None:
Expand Down Expand Up @@ -86,13 +92,16 @@ def eye(


# noinspection PyUnresolvedReferences
def from_dlpack(x):
def from_dlpack(x: mx.nd.NDArray) -> mx.nd.NDArray:
return mx.nd.from_dlpack(x)


def full(
shape: Union[ivy.NativeShape, Sequence[int]], fill_value, dtype=None, device=None
):
shape: Union[ivy.NativeShape, Sequence[int]],
fill_value: float,
dtype: Optional[type] = None,
device: Optional[mx.context.Context] = None,
) -> mx.nd.NDArray:
shape = ivy.shape_to_tuple(shape)
cont = _mxnet_init_context(default_device(device))
if len(shape) == 0 or 0 in shape:
Expand All @@ -109,7 +118,13 @@ def full(
)


def linspace(start, stop, num, axis=None, device=None):
def linspace(
start: Union[mx.nd.NDArray, float],
stop: Union[mx.nd.NDArray, float],
num: int,
axis: Optional[int] = None,
device: mx.context.Context = None,
) -> mx.nd.NDArray:
cont = _mxnet_init_context(default_device(device))
num = num.asnumpy()[0] if isinstance(num, mx.nd.NDArray) else num
start_is_array = isinstance(start, mx.nd.NDArray)
Expand Down Expand Up @@ -173,15 +188,19 @@ def zeros(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
dtype: type,
device: mx.context.Context,
device: mx.context.Context
) -> mx.nd.NDArray:
cont = _mxnet_init_context(device)
if len(shape) == 0 or 0 in shape:
return _1_dim_array_to_flat_array(mx.nd.zeros((1,), ctx=cont).astype(dtype))
return mx.nd.zeros(shape, ctx=cont).astype(dtype)


def zeros_like(x, dtype=None, device=None):
def zeros_like(
x: mx.nd.NDArray,
dtype: Optional[type] = None,
device: Optional[mx.context.Context] = None,
) -> mx.nd.NDArray:
if x.shape == ():
return mx.nd.array(0.0, ctx=_mxnet_init_context(default_device(device)))
mx_zeros = mx.nd.zeros_like(x, ctx=_mxnet_init_context(default_device(device)))
Expand All @@ -195,6 +214,14 @@ def zeros_like(x, dtype=None, device=None):
array = asarray


def logspace(start, stop, num, base=10.0, axis=None, device=None):
def logspace(
start: Union[mx.nd.NDArray, int],
stop: Union[mx.nd.NDArray, int],
num: int,
base: float = 10.0,
axis: int = None,
*,
device: mx.context.Context
) -> mx.nd.NDArray:
power_seq = linspace(start, stop, num, axis, default_device(device))
return base**power_seq
45 changes: 37 additions & 8 deletions ivy/functional/backends/numpy/creation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# global
import numpy
import numpy as np
from typing import Union, Optional, List, Sequence
from numbers import Number
from typing import Union, Tuple, Optional, List, Sequence

# local
import ivy
Expand All @@ -14,7 +15,14 @@
# -------------------#


def arange(start, stop=None, step=1, *, dtype: np.dtype = None, device: str):
def arange(
start: Number,
stop: Optional[Number] = None,
step: Number = 1,
*,
dtype: Optional[np.dtype] = None,
device: str
) -> np.ndarray:
if dtype:
dtype = as_native_dtype(dtype)
res = _to_device(np.arange(start, stop, step=step, dtype=dtype), device=device)
Expand All @@ -26,7 +34,13 @@ def arange(start, stop=None, step=1, *, dtype: np.dtype = None, device: str):
return res


def asarray(object_in, *, copy=None, dtype: np.dtype = None, device: str):
def asarray(
object_in: Union[np.ndarray, List[Number], Tuple[Number]],
*,
copy: Optional[bool] = None,
dtype: Optional[np.dtype] = None,
device: str
) -> np.ndarray:
# If copy=none then try using existing memory buffer
if isinstance(object_in, np.ndarray) and dtype is None:
dtype = object_in.dtype
Expand Down Expand Up @@ -81,7 +95,7 @@ def eye(


# noinspection PyShadowingNames
def from_dlpack(x):
def from_dlpack(x: np.ndarray) -> np.ndarray:
return np.from_dlpack(x)


Expand All @@ -99,7 +113,7 @@ def full(


def full_like(
x: np.ndarray, fill_value: Union[int, float], *, dtype: np.dtype, device: str
x: np.ndarray, fill_value: float, *, dtype: np.dtype, device: str
) -> np.ndarray:
if dtype:
dtype = "bool_" if dtype == "bool" else dtype
Expand All @@ -109,8 +123,15 @@ def full_like(


def linspace(
start, stop, num, axis=None, endpoint=True, *, dtype: np.dtype, device: str
):
start: Union[np.ndarray, float],
stop: Union[np.ndarray, float],
num: int,
axis: Optional[int] = None,
endpoint: bool = True,
*,
dtype: np.dtype,
device: str
) -> np.ndarray:
if axis is None:
axis = -1
ans = np.linspace(start, stop, num, endpoint, dtype=dtype, axis=axis)
Expand Down Expand Up @@ -176,7 +197,15 @@ def zeros_like(x: np.ndarray, *, dtype: np.dtype, device: str) -> np.ndarray:
array = asarray


def logspace(start, stop, num, base=10.0, axis=None, *, device: str):
def logspace(
start: Union[np.ndarray, int],
stop: Union[np.ndarray, int],
num: int,
base: float = 10.0,
axis: Optional[int] = None,
*,
device: str
) -> np.ndarray:
if axis is None:
axis = -1
return _to_device(
Expand Down
Loading