diff --git a/frontend/catalyst/api_extensions/function_maps.py b/frontend/catalyst/api_extensions/function_maps.py index 8b6e21593e..1b3ab49b29 100644 --- a/frontend/catalyst/api_extensions/function_maps.py +++ b/frontend/catalyst/api_extensions/function_maps.py @@ -320,10 +320,11 @@ def _get_batch_size(self, args_flat, axes_flat, axis_size): batch_sizes = [] for i, (arg, d) in enumerate(zip(args_flat, axes_flat)): - shape = np.shape(arg) if d is None: continue - elif len(shape) > d: + + shape = np.shape(arg) + if len(shape) > d: batch_sizes.append(shape[d]) else: raise ValueError( diff --git a/frontend/test/pytest/test_vmap.py b/frontend/test/pytest/test_vmap.py index 3cfc644829..e2555bf851 100644 --- a/frontend/test/pytest/test_vmap.py +++ b/frontend/test/pytest/test_vmap.py @@ -21,6 +21,8 @@ from catalyst import qjit, vmap +# pylint: disable=bug-vmap-shape + class TestVectorizeMap: """Test QJIT compatibility with JAX vectorization."""