Skip to content

Commit

Permalink
lintfixbot: Auto-commit fixed lint errors in codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
ivy-branch committed Nov 23, 2022
1 parent 6794d27 commit 9d3c3d2
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 20 deletions.
4 changes: 3 additions & 1 deletion determine_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def get_all_tests():
+ ["ivy_tests"]
)
directories_filtered = [
x for x in directories if not (x.endswith("__pycache__") or "hypothesis" in x)
x
for x in directories
if not (x.endswith("__pycache__") or "hypothesis" in x)
]
directories = set(directories_filtered)
for test_backend in new_tests[old_num_tests:num_tests]:
Expand Down
14 changes: 12 additions & 2 deletions ivy/functional/frontends/torch/miscellaneous_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def trace(input):
return ivy.astype(ivy.trace(input), target_type)


@with_unsupported_dtypes({"1.11.0 and below": ("int8", "float16", "bfloat16", "bool")}, "torch")
@with_unsupported_dtypes(
{"1.11.0 and below": ("int8", "float16", "bfloat16", "bool")}, "torch"
)
@to_ivy_arrays_and_back
def tril_indices(row, col, offset=0, *, dtype="int64", device="cpu", layout=None):
sample_matrix = ivy.tril(ivy.ones((row, col), device=device), k=offset)
Expand Down Expand Up @@ -153,7 +155,15 @@ def logcumsumexp(input, dim, *, out=None):
return ret


@with_supported_dtypes({"1.11.0 and below": ("int32", "int64", )}, "torch")
@with_supported_dtypes(
{
"1.11.0 and below": (
"int32",
"int64",
)
},
"torch",
)
@to_ivy_arrays_and_back
def repeat_interleave(input, repeats, dim=None, *, output_size=None):
return ivy.repeat(input, repeats, axis=dim)
Expand Down
24 changes: 14 additions & 10 deletions ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,16 +542,20 @@ def test_jax_numpy_solve(

@st.composite
def norm_helper(draw):
dtype, x = draw(helpers.dtype_and_values(
shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"),
available_dtypes=helpers.get_dtypes("valid"),
min_num_dims=1,
safety_factor_scale="log",
large_abs_safety_factor=2,
))
axis = draw(helpers.get_axis(
shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"),
))
dtype, x = draw(
helpers.dtype_and_values(
shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"),
available_dtypes=helpers.get_dtypes("valid"),
min_num_dims=1,
safety_factor_scale="log",
large_abs_safety_factor=2,
)
)
axis = draw(
helpers.get_axis(
shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"),
)
)
if type(axis) in [tuple, list]:
if len(axis) == 2:
ord_param = draw(
Expand Down
21 changes: 14 additions & 7 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# Helper functions


@st.composite
def _fill_value(draw):
dtype = draw(st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"))[0]
Expand All @@ -24,13 +25,17 @@ def _start_stop_step(draw):
start = draw(helpers.ints(min_value=0, max_value=50))
stop = draw(helpers.ints(min_value=0, max_value=50))
if start < stop:
step = draw(helpers.ints(min_value=0, max_value=50).filter(
lambda x: True if x != 0 else False
))
step = draw(
helpers.ints(min_value=0, max_value=50).filter(
lambda x: True if x != 0 else False
)
)
else:
step = draw(helpers.ints(min_value=-50, max_value=0).filter(
lambda x: True if x != 0 else False
))
step = draw(
helpers.ints(min_value=-50, max_value=0).filter(
lambda x: True if x != 0 else False
)
)
return start, stop, step


Expand Down Expand Up @@ -443,7 +448,9 @@ def test_torch_empty_like(
@handle_frontend_test(
fn_tree="torch.full_like",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype")
available_dtypes=st.shared(
helpers.get_dtypes("numeric", full=False), key="dtype"
)
),
fill_value=_fill_value(),
dtype=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"),
Expand Down

0 comments on commit 9d3c3d2

Please sign in to comment.