diff --git a/determine_tests.py b/determine_tests.py index 4222589801c83..1f72b9f0cb514 100644 --- a/determine_tests.py +++ b/determine_tests.py @@ -61,7 +61,9 @@ def get_all_tests(): + ["ivy_tests"] ) directories_filtered = [ - x for x in directories if not (x.endswith("__pycache__") or "hypothesis" in x) + x + for x in directories + if not (x.endswith("__pycache__") or "hypothesis" in x) ] directories = set(directories_filtered) for test_backend in new_tests[old_num_tests:num_tests]: diff --git a/ivy/functional/frontends/torch/miscellaneous_ops.py b/ivy/functional/frontends/torch/miscellaneous_ops.py index c50de5e6e1d05..f76e8cd805f9d 100644 --- a/ivy/functional/frontends/torch/miscellaneous_ops.py +++ b/ivy/functional/frontends/torch/miscellaneous_ops.py @@ -42,7 +42,9 @@ def trace(input): return ivy.astype(ivy.trace(input), target_type) -@with_unsupported_dtypes({"1.11.0 and below": ("int8", "float16", "bfloat16", "bool")}, "torch") +@with_unsupported_dtypes( + {"1.11.0 and below": ("int8", "float16", "bfloat16", "bool")}, "torch" +) @to_ivy_arrays_and_back def tril_indices(row, col, offset=0, *, dtype="int64", device="cpu", layout=None): sample_matrix = ivy.tril(ivy.ones((row, col), device=device), k=offset) @@ -153,7 +155,15 @@ def logcumsumexp(input, dim, *, out=None): return ret -@with_supported_dtypes({"1.11.0 and below": ("int32", "int64", )}, "torch") +@with_supported_dtypes( + { + "1.11.0 and below": ( + "int32", + "int64", + ) + }, + "torch", +) @to_ivy_arrays_and_back def repeat_interleave(input, repeats, dim=None, *, output_size=None): return ivy.repeat(input, repeats, axis=dim) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_linalg.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_linalg.py index fd1969b82a98d..c729d043e3914 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_linalg.py @@ -542,16 +542,20 @@ def test_jax_numpy_solve( @st.composite def norm_helper(draw): - dtype, x = draw(helpers.dtype_and_values( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - available_dtypes=helpers.get_dtypes("valid"), - min_num_dims=1, - safety_factor_scale="log", - large_abs_safety_factor=2, - )) - axis = draw(helpers.get_axis( - shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), - )) + dtype, x = draw( + helpers.dtype_and_values( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=1, + safety_factor_scale="log", + large_abs_safety_factor=2, + ) + ) + axis = draw( + helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ) + ) if type(axis) in [tuple, list]: if len(axis) == 2: ord_param = draw( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py index d4ff7a646d497..8c281232afd2d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py @@ -9,6 +9,7 @@ # Helper functions + @st.composite def _fill_value(draw): dtype = draw(st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"))[0] @@ -24,13 +25,17 @@ def _start_stop_step(draw): start = draw(helpers.ints(min_value=0, max_value=50)) stop = draw(helpers.ints(min_value=0, max_value=50)) if start < stop: - step = draw(helpers.ints(min_value=0, max_value=50).filter( - lambda x: True if x != 0 else False - )) + step = draw( + helpers.ints(min_value=0, max_value=50).filter( + lambda x: True if x != 0 else False + ) + ) else: - step = draw(helpers.ints(min_value=-50, max_value=0).filter( - lambda x: True if x != 0 else False - )) + step = draw( + helpers.ints(min_value=-50, max_value=0).filter( + lambda x: True if x != 0 else False + ) + ) return start, stop, step @@ -443,7 +448,9 @@ def test_torch_empty_like( @handle_frontend_test( fn_tree="torch.full_like", dtype_and_x=helpers.dtype_and_values( - available_dtypes=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype") + available_dtypes=st.shared( + helpers.get_dtypes("numeric", full=False), key="dtype" + ) ), fill_value=_fill_value(), dtype=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"),