Skip to content

Commit

Permalink
float_like fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vabor112 committed Aug 15, 2023
1 parent 00573f7 commit f799178
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion geometric_kernels/lab_extras/jax/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def float_like(reference: B.JAXNumeric):
Return the type of the reference if it is a floating point type.
Otherwise return `double` dtype of a backend based on the reference.
"""
reference_dtype = jnp.dtype(reference)
reference_dtype = reference.dtype
if jnp.issubdtype(reference_dtype, jnp.floating):
return reference_dtype
else:
Expand Down
2 changes: 1 addition & 1 deletion geometric_kernels/lab_extras/numpy/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def float_like(reference: B.NPNumeric):
Return the type of the reference if it is a floating point type.
Otherwise return `double` dtype of a backend based on the reference.
"""
reference_dtype = np.dtype(reference)
reference_dtype = reference.dtype
if np.issubdtype(reference_dtype, np.floating):
return reference_dtype
else:
Expand Down
2 changes: 1 addition & 1 deletion geometric_kernels/lab_extras/tensorflow/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def float_like(reference: B.TFNumeric):
Return the type of the reference if it is a floating point type.
Otherwise return `double` dtype of a backend based on the reference.
"""
reference_dtype = tf.dtype(reference)
reference_dtype = reference.dtype
if reference_dtype.is_floating:
return reference_dtype
else:
Expand Down

0 comments on commit f799178

Please sign in to comment.