Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternate approach to serializing netcdfs for dask.distributed #1095

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

import numpy as np

try:
import distributed.protocol
HAS_DISTRIBUTED = True
except ImportError:
HAS_DISTRIBUTED = False

from .. import Variable
from ..conventions import pop_to, cf_encoder
from ..core import indexing
Expand Down Expand Up @@ -37,6 +43,20 @@ def dtype(self):
dtype = np.dtype('O')
return dtype

def __getstate__(self):
if not HAS_DISTRIBUTED:
raise NotImplementedError
header, frames = distributed.protocol.serialize(self.array)
return (header, frames, self.is_remote)

def __setstate__(self, state):
if not HAS_DISTRIBUTED:
raise NotImplementedError
header, frames, is_remote = state
array = distributed.protocol.deserialize(header, frames)
self.array = array
self.is_remote = is_remote


class NetCDF4ArrayWrapper(BaseNetCDF4Array):
def __init__(self, array, is_remote=False):
Expand Down
55 changes: 55 additions & 0 deletions xarray/test/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
import pandas as pd
import pytest
import xarray as xr

distributed = pytest.importorskip('distributed')
da = pytest.importorskip('dask.array')
from distributed.protocol import serialize, deserialize
from distributed.utils_test import cluster, loop, gen_cluster

from xarray.test.test_backends import create_tmp_file
from xarray.test.test_dataset import create_test_data


def test_dask_distributed_integration_test(loop):
with cluster() as (s, _):
with distributed.Client(('127.0.0.1', s['port']), loop=loop) as client:
original = create_test_data()
# removing the line below results in a test that never times out!
del original['time']
with create_tmp_file() as filename:
original.to_netcdf(filename, engine='netcdf4')
# TODO: should be able to serialize locks?
# TODO: should be able to serialize array types from
# xarray.conventions
restored = xr.open_dataset(filename, chunks=3, lock=False)
assert isinstance(restored.var1.data, da.Array)
restored.load()
assert original.identical(restored)


@gen_cluster(client=True)
def test_dask_distributed_integration_test_fast(c, s, a, b):
values = [10, 20, 30]
values = [0.2, 1.5, 1.8]
values = ['ab', 'cd', 'ef']
# does not work: ValueError: cannot include dtype 'M' in a buffer
# values = pd.date_range('2010-01-01', periods=3).values
original = xr.Dataset({'foo': ('x', values)})
engine = 'netcdf4'
# does not work: we don't know how to pickle h5netcdf objects, which wrap
# h5py datasets/files
# engine = 'h5netcdf'
with create_tmp_file() as filename:
original.to_netcdf(filename, engine=engine)
# TODO: should be able to serialize locks?
# TODO: should be able to serialize array types from
# xarray.conventions
restored = xr.open_dataset(filename, chunks=5, lock=False,
engine=engine)
print(restored.foo.data.dask)
foo = c.compute(restored.foo.data)
foo = yield foo._result()
computed = xr.Dataset({'foo': ('x', foo)})
assert computed.identical(original)