Skip to content

Thread-safety in xarray #9836

Open
Open
@max-sixty

Description

@max-sixty

What is your issue?

There's a process that can't use xarray because it's a multithreaded context, and xarray objects aren't thread-safe, because pandas indexes aren't thread-safe pandas-dev/pandas#2728

I did some quite exploration on where the issue was. I don't have a clear answer but thought it would be worthwhile to post some results.

Here's a function that tests a few operations:

import traceback
import pandas as pd
import xarray as xr
import numpy as np
import concurrent.futures
import copy

# Dict of test cases - just function and its string representation
tests = {
    'x.reindex(dim_0=x.dim_0)': lambda s, x, idx: x.reindex(dim_0=x.dim_0),
    'x.copy().reindex(dim_0=idx.copy())': lambda s, x, idx: x.copy().reindex(dim_0=idx.copy()),
    'copy.deepcopy(idx)': lambda s, x, idx: copy.deepcopy(idx),
    'x.reindex(dim_0=idx.copy())': lambda s, x, idx: x.reindex(dim_0=idx.copy()),
    'x.reindex(dim_0=idx)': lambda s, x, idx: x.reindex(dim_0=idx),
    'x.reindex(dim_0=x.dim_0); x.reindex(dim_0=idx)': 
        lambda s, x, idx: (x.reindex(dim_0=x.dim_0), x.reindex(dim_0=idx)),
    'x.sel(dim_0=idx); x.reindex(dim_0=idx.copy())':
        lambda s, x, idx: (x.sel(dim_0=idx), x.reindex(dim_0=idx.copy()))
}

def run_test(test_fn, n=1000):
    try:
        a = np.arange(0, 30000)
        def gen_args():
            for i in range(n):
                if i % 1000 == 0:
                    s = pd.Series(data=a, index=a)
                    x = xr.DataArray(a, dims=['dim_0'], coords={'dim_0': a})
                    test_fn(s, x, a)
                yield s, x, np.arange(0, 1000)
                
        with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
            list(executor.map(lambda args: test_fn(*args), gen_args()))
        return True
    except Exception as e:
        return False, str(e)

# Run all tests and collect results
results = {}
for test_str, fn in tests.items():
    print(f"Running: {test_str}")
    results[test_str] = run_test(fn)

# Print summary
print("\nResults Summary:")
print("=" * 50)
for test_str, result in results.items():
    status = "✓ PASS" if result is True else "✗ FAIL"
    print(f"{status} - {test_str}")

The results:


Results Summary:
==================================================
✓ PASS - x.reindex(dim_0=x.dim_0)
✓ PASS - x.copy().reindex(dim_0=idx.copy())
✓ PASS - copy.deepcopy(idx)
✗ FAIL - x.reindex(dim_0=idx.copy())
✗ FAIL - x.reindex(dim_0=idx)
✗ FAIL - x.reindex(dim_0=x.dim_0); x.reindex(dim_0=idx)
✓ PASS - x.sel(dim_0=idx); x.reindex(dim_0=idx.copy())

A couple of things to note:

  • Most .sel operations seemed to pass, most .reindex operations seem to fail
  • Running x.reindex(dim_0=idx.copy()) fails, but running x.sel(dim_0=idx) beforehand makes the initial .reindex pass (?!)
  • Reindexing with x.reindex(dim_0=x.dim_0) works, but not with an index that's passed in, x.reindex(dim_0=idx). (Could we claim that when xarray objects are accessed by different threads, they're safe? I'd be surprised if that were the case, but couldn't immediately see a case where that was falsified...)

I don't think there are any easy answers to this; if we wanted xarray / parts of xarray to be threadsafe. I think it's a bunch of bad options:

  • Remove pandas dependency (could there be a slower dumber indexing option as a plugin?)
  • Protect every indexing with a lock
  • ?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions