Skip to content

Bug Report: Incorrect dtype comparison using is instead of == in _integrate.py #678

@ArianAmani

Description

@ArianAmani

Summary

There are dtype comparison bugs in diffrax/_integrate.py where is is used instead of == for dtype equality checks. This can cause assertion errors in scenarios involving serialization/deserialization, distributed computing, or different JAX environments, even when the dtypes are logically equivalent.

Affected Lines

  1. Line 388: assert jnp.result_type(keep_step) is jnp.dtype(bool)
  2. Line 265: if tnext.dtype is jnp.dtype("float64"):

Problem Description

The is operator checks for object identity (same object in memory), while == checks for value equality. For dtypes, different instances representing the same type may not be the exact same object, especially in scenarios involving:

  • Serialization/deserialization (model saving/loading)
  • Distributed computing
  • Cross-process communication
  • Different JAX versions or environments
  • Pickling/unpickling operations

Reproduction

The bug can be reproduced with this simple script:

import pickle
import jax.numpy as jnp

# Create a boolean value
keep_step = jnp.array(True)

# Create different instances of the same dtype
original_dtype = jnp.result_type(keep_step)
pickled_dtype = pickle.loads(pickle.dumps(jnp.dtype(bool)))

print(f"Equal: {original_dtype == pickled_dtype}")      # True
print(f"Identical: {original_dtype is pickled_dtype}")  # False

# This would fail (the buggy assertion):
try:
    assert jnp.result_type(keep_step) is pickled_dtype
    print("✓ Assertion with 'is' passed")
except AssertionError:
    print("✗ Assertion with 'is' FAILED")

# This works correctly (the fixed assertion):
try:
    assert jnp.result_type(keep_step) == pickled_dtype
    print("✓ Assertion with '==' passed")
except AssertionError:
    print("✗ Assertion with '==' failed")

Output:

Equal: True
Identical: False
✗ Assertion with 'is' FAILED
✓ Assertion with '==' passed

Real-world Impact

This bug can cause diffrax to fail unexpectedly in production environments where:

  • Models are saved and loaded (serialization)
  • Code runs in distributed systems
  • Different JAX configurations are used
  • Any scenario where dtype objects are reconstructed

Proposed Fix

Replace is with == for all dtype comparisons:

Line 388:

# Current (buggy)
assert jnp.result_type(keep_step) is jnp.dtype(bool)

# Fixed
assert jnp.result_type(keep_step) == jnp.dtype(bool)

Line 265:

# Current (buggy)
if tnext.dtype is jnp.dtype("float64"):

# Fixed
if tnext.dtype == jnp.dtype("float64"):

Testing

I've tested the fixes and confirmed they work correctly while maintaining the intended behavior. The fixes are backward compatible and don't change the logic - they just make the comparisons more robust.

Environment

  • JAX version: [varies]
  • Python version: [varies]
  • Platform: [varies - affects all platforms]

This is a compatibility and robustness issue that affects the reliability of diffrax across different deployment scenarios.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions