diff --git a/geometric_kernels/lab_extras/jax/extras.py b/geometric_kernels/lab_extras/jax/extras.py index addf54e8..99d84bc3 100644 --- a/geometric_kernels/lab_extras/jax/extras.py +++ b/geometric_kernels/lab_extras/jax/extras.py @@ -87,7 +87,7 @@ def float_like(reference: B.JAXNumeric): Return the type of the reference if it is a floating point type. Otherwise return `double` dtype of a backend based on the reference. """ - reference_dtype = jnp.dtype(reference) + reference_dtype = reference.dtype if jnp.issubdtype(reference_dtype, jnp.floating): return reference_dtype else: diff --git a/geometric_kernels/lab_extras/numpy/extras.py b/geometric_kernels/lab_extras/numpy/extras.py index 97a89f81..b4be60df 100644 --- a/geometric_kernels/lab_extras/numpy/extras.py +++ b/geometric_kernels/lab_extras/numpy/extras.py @@ -74,7 +74,7 @@ def float_like(reference: B.NPNumeric): Return the type of the reference if it is a floating point type. Otherwise return `double` dtype of a backend based on the reference. """ - reference_dtype = np.dtype(reference) + reference_dtype = reference.dtype if np.issubdtype(reference_dtype, np.floating): return reference_dtype else: diff --git a/geometric_kernels/lab_extras/tensorflow/extras.py b/geometric_kernels/lab_extras/tensorflow/extras.py index 24f56fa0..09d6a4f0 100644 --- a/geometric_kernels/lab_extras/tensorflow/extras.py +++ b/geometric_kernels/lab_extras/tensorflow/extras.py @@ -95,7 +95,7 @@ def float_like(reference: B.TFNumeric): Return the type of the reference if it is a floating point type. Otherwise return `double` dtype of a backend based on the reference. """ - reference_dtype = tf.dtype(reference) + reference_dtype = reference.dtype if reference_dtype.is_floating: return reference_dtype else: