From a40eb4693de0c9e43dfea0337832cbfc1aba3768 Mon Sep 17 00:00:00 2001 From: David Ittah Date: Mon, 23 Sep 2024 18:22:35 -0400 Subject: [PATCH] linting --- frontend/catalyst/api_extensions/function_maps.py | 5 +++-- frontend/test/pytest/test_vmap.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/api_extensions/function_maps.py b/frontend/catalyst/api_extensions/function_maps.py index 8b6e21593e..1b3ab49b29 100644 --- a/frontend/catalyst/api_extensions/function_maps.py +++ b/frontend/catalyst/api_extensions/function_maps.py @@ -320,10 +320,11 @@ def _get_batch_size(self, args_flat, axes_flat, axis_size): batch_sizes = [] for i, (arg, d) in enumerate(zip(args_flat, axes_flat)): - shape = np.shape(arg) if d is None: continue - elif len(shape) > d: + + shape = np.shape(arg) + if len(shape) > d: batch_sizes.append(shape[d]) else: raise ValueError( diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index 3cfc644829..e2555bf851 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -21,6 +21,8 @@ from catalyst import qjit, vmap +# pylint: disable=bug-vmap-shape + class TestVectorizeMap: """Test QJIT compatibility with JAX vectorization."""