Skip to content

Conversation

Maalvi14
Copy link
Contributor

Proposed changes

Related to Issue #2665

CPU Backend (sort.cpp)

  1. New nan_aware_less function: Centralized NaN handling logic
  2. Updated all sort operations: sort, argsort, partition, argpartition
  3. Consistent NaN placement: NaNs always go to the end regardless of operation type
  4. Type safety: Only applies NaN logic to floating-point types

Metal Backend (sort.h)

  1. New NanLastLess comparator template: GPU-optimized NaN handling
  2. Specialized implementations: Separate logic for float and half types
  3. Bitwise operations: Efficient GPU implementation using isnan() checks
  4. Updated kernel defaults: All sorting kernels now use NanLastLess by default

Behavioral Changes

  • Before: NaN ordering was undefined/inconsistent
  • After: NaNs are consistently placed at the end of sorted arrays
  • Compatibility: Matches NumPy's np.sort() behavior exactly
  • Performance: Minimal overhead with efficient NaN detection

Implementation Strategy

  • CPU: Uses std::isnan() for precise NaN detection
  • Metal: Uses Metal's isnan() function with bitwise logic for GPU efficiency
  • Consistency: Both backends implement identical NaN placement logic
  • Type coverage: Handles float32, float64, float16, and bfloat16 types

Example Unit Tests

import mlx.core as mx
import numpy as np
import unittest


class TestNaNAwareSorting(unittest.TestCase):
    """Test that MLX sort places NaNs at the end like NumPy."""
    
    def setUp(self):
        """Store original default device."""
        self.original_device = mx.default_device()
    
    def tearDown(self):
        """Restore original default device."""
        mx.set_default_device(self.original_device)
    
    def _test_sort_matches_numpy(self, data, device):
        """Helper to test that MLX sort matches NumPy behavior."""
        mx.set_default_device(device)
        
        # NumPy reference
        np_data = np.array(data, dtype=np.float32)
        np_sorted = np.sort(np_data)
        
        # MLX result
        mx_data = mx.array(data)
        mx_sorted = mx.sort(mx_data)
        
        # Compare - both NaNs and non-NaNs should be in same positions
        np_result = np.array(mx_sorted)
        
        # Check non-NaN values match
        np_mask = ~np.isnan(np_sorted)
        mx_mask = ~np.isnan(np_result)
        
        self.assertTrue(np.array_equal(np_mask, mx_mask), 
                       f"NaN positions don't match. NumPy: {np_sorted}, MLX: {np_result}")
        self.assertTrue(np.allclose(np_sorted[np_mask], np_result[mx_mask]), 
                       f"Non-NaN values don't match. NumPy: {np_sorted[np_mask]}, MLX: {np_result[mx_mask]}")
    
    def test_single_nan_middle(self):
        """Test array with single NaN in the middle."""
        data = [1.0, np.nan, 2.0, 0.0]
        self._test_sort_matches_numpy(data, mx.cpu)
        if mx.metal.is_available():
            self._test_sort_matches_numpy(data, mx.gpu)
    
    def test_single_nan_start(self):
        """Test array with NaN at the start."""
        data = [np.nan, 3.0, 1.0, 2.0]
        self._test_sort_matches_numpy(data, mx.cpu)
        if mx.metal.is_available():
            self._test_sort_matches_numpy(data, mx.gpu)
    
    def test_single_nan_end(self):
        """Test array with NaN at the end."""
        data = [3.0, 1.0, 2.0, np.nan]
        self._test_sort_matches_numpy(data, mx.cpu)
        if mx.metal.is_available():
            self._test_sort_matches_numpy(data, mx.gpu)
    
    def test_multiple_nans(self):
        """Test array with multiple NaNs."""
        data = [1.0, np.nan, 2.0, np.nan, 0.0, np.nan, 3.0]
        self._test_sort_matches_numpy(data, mx.cpu)
        if mx.metal.is_available():
            self._test_sort_matches_numpy(data, mx.gpu)
    
    def test_all_nans(self):
        """Test array with all NaNs."""
        data = [np.nan, np.nan, np.nan, np.nan]
        self._test_sort_matches_numpy(data, mx.cpu)
        if mx.metal.is_available():
            self._test_sort_matches_numpy(data, mx.gpu)
    
    def test_no_nans(self):
        """Test array with no NaNs."""
        data = [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]
        self._test_sort_matches_numpy(data, mx.cpu)
        if mx.metal.is_available():
            self._test_sort_matches_numpy(data, mx.gpu)
    
    def test_negative_values_with_nans(self):
        """Test array with negative values and NaNs."""
        data = [-1.0, np.nan, -3.0, 2.0, np.nan, 0.0, -2.0]
        self._test_sort_matches_numpy(data, mx.cpu)
        if mx.metal.is_available():
            self._test_sort_matches_numpy(data, mx.gpu)
    
    def test_zeros_with_nans(self):
        """Test array with zeros and NaNs."""
        data = [0.0, np.nan, -0.0, 1.0, np.nan]
        self._test_sort_matches_numpy(data, mx.cpu)
        if mx.metal.is_available():
            self._test_sort_matches_numpy(data, mx.gpu)
    
    def test_large_array_with_nans(self):
        """Test larger array with NaNs scattered throughout."""
        np.random.seed(42)
        data = np.random.randn(100).tolist()
        # Insert NaNs at various positions
        for i in [5, 15, 30, 50, 75, 99]:
            data[i] = np.nan
        self._test_sort_matches_numpy(data, mx.cpu)
        if mx.metal.is_available():
            self._test_sort_matches_numpy(data, mx.gpu)
    
    def test_argsort_with_nans(self):
        """Test argsort places NaN indices at the end."""
        mx.set_default_device(mx.cpu)
        data = [1.0, np.nan, 2.0, 0.0, np.nan]
        
        # NumPy reference
        np_data = np.array(data, dtype=np.float32)
        np_argsort = np.argsort(np_data)
        
        # MLX result
        mx_data = mx.array(data)
        mx_argsort = mx.argsort(mx_data)
        
        # Check that the sorted values using indices match
        np_sorted_via_argsort = np_data[np_argsort]
        mx_sorted_via_argsort = np.array(mx_data[mx_argsort])
        
        # Check non-NaN values match
        np_mask = ~np.isnan(np_sorted_via_argsort)
        mx_mask = ~np.isnan(mx_sorted_via_argsort)
        
        self.assertTrue(np.array_equal(np_mask, mx_mask), 
                       f"NaN positions don't match in argsort")
        self.assertTrue(np.allclose(np_sorted_via_argsort[np_mask], 
                                   mx_sorted_via_argsort[mx_mask]), 
                       f"Non-NaN values don't match in argsort")
        
        if mx.metal.is_available():
            mx.set_default_device(mx.gpu)
            mx_data_gpu = mx.array(data)
            mx_argsort_gpu = mx.argsort(mx_data_gpu)
            mx_sorted_via_argsort_gpu = np.array(mx_data_gpu[mx_argsort_gpu])
            mx_mask_gpu = ~np.isnan(mx_sorted_via_argsort_gpu)
            self.assertTrue(np.array_equal(np_mask, mx_mask_gpu))
            self.assertTrue(np.allclose(np_sorted_via_argsort[np_mask], 
                                       mx_sorted_via_argsort_gpu[mx_mask_gpu]))
    
    def test_2d_array_with_nans(self):
        """Test sorting 2D array with NaNs along different axes."""
        data = [
            [3.0, np.nan, 1.0, 2.0],
            [np.nan, 0.0, np.nan, 4.0],
            [5.0, 2.0, 3.0, 1.0]
        ]
        
        for axis in [0, 1]:
            mx.set_default_device(mx.cpu)
            
            # NumPy reference
            np_data = np.array(data, dtype=np.float32)
            np_sorted = np.sort(np_data, axis=axis)
            
            # MLX result
            mx_data = mx.array(data)
            mx_sorted = mx.sort(mx_data, axis=axis)
            
            np_result = np.array(mx_sorted)
            
            # Check non-NaN values match
            np_mask = ~np.isnan(np_sorted)
            mx_mask = ~np.isnan(np_result)
            
            self.assertTrue(np.array_equal(np_mask, mx_mask), 
                           f"NaN positions don't match for axis={axis}")
            self.assertTrue(np.allclose(np_sorted[np_mask], np_result[mx_mask]), 
                           f"Non-NaN values don't match for axis={axis}")
            
            if mx.metal.is_available():
                mx.set_default_device(mx.gpu)
                mx_data_gpu = mx.array(data)
                mx_sorted_gpu = mx.sort(mx_data_gpu, axis=axis)
                np_result_gpu = np.array(mx_sorted_gpu)
                mx_mask_gpu = ~np.isnan(np_result_gpu)
                self.assertTrue(np.array_equal(np_mask, mx_mask_gpu))
                self.assertTrue(np.allclose(np_sorted[np_mask], np_result_gpu[mx_mask_gpu]))
    
    def test_cpu_gpu_consistency(self):
        """Test that CPU and GPU produce identical results."""
        if not mx.metal.is_available():
            self.skipTest("Metal GPU not available")
        
        data = [3.0, np.nan, 1.0, np.nan, 2.0, 0.0, np.nan, 4.0]
        
        # CPU result
        mx.set_default_device(mx.cpu)
        cpu_result = mx.sort(mx.array(data))
        
        # GPU result
        mx.set_default_device(mx.gpu)
        gpu_result = mx.sort(mx.array(data))
        
        # Convert to numpy for comparison
        cpu_np = np.array(cpu_result)
        gpu_np = np.array(gpu_result)
        
        # Check NaN positions match
        cpu_nan_mask = np.isnan(cpu_np)
        gpu_nan_mask = np.isnan(gpu_np)
        self.assertTrue(np.array_equal(cpu_nan_mask, gpu_nan_mask),
                       f"NaN positions differ between CPU and GPU. CPU: {cpu_np}, GPU: {gpu_np}")
        
        # Check non-NaN values match
        self.assertTrue(np.allclose(cpu_np[~cpu_nan_mask], gpu_np[~gpu_nan_mask]),
                       f"Non-NaN values differ between CPU and GPU. CPU: {cpu_np}, GPU: {gpu_np}")


if __name__ == '__main__':
    unittest.main()
Screenshot 2025-10-12 at 8 10 55 PM

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@awni
Copy link
Member

awni commented Oct 13, 2025

Thanks. I think CUDA already does the right thing here.. so we should be good to go if the tests clear.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@awni awni merged commit 9cbb1b0 into ml-explore:main Oct 13, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants