-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use pytorch as backend for xarrays #3232
Comments
If pytorch implements overrides of NumPy's API via the I think there has been some discussion about this, but I don't know the current status (CC @rgommers). The biggest challenge for pytorch would be defining the translation layer that implements NumPy's API. Personally, I think the most viable way to achieve seamless integration with deep learning libraries would be to support integration with JAX, which already implements NumPy's API almost exactly. I have an experimental pull request adding |
The PyTorch team is definitely receptive to the idea of adding Also, they want a The tracking issue for all of this is pytorch/pytorch#22402
Agreed. No one is working on |
Less familiar with that, but pytorch does have experimental XLA support, so that's a start. |
Yes, this is a concern for JAX as well. This is a definite downside of reusing NumPy's existing namespace. It turns out even xarray was relying on this behavior with dask in at least one edge case: #3215 |
We didn't discuss an alternative very explicitly I think, but at least we'll have wide adoption fast. Hopefully the pain is limited .... |
I haven't used JAX - but was just browsing through its documentation and it looks super cool. Any ideas on how it compares with Pytorch in terms of: a) Cxecution speed, esp. on GPU |
Within a For data loading and deep learning algorithms, take a look at the examples in the |
While it is pretty straightforward to implement a lot of standard xarray operations with a pytorch / Jax backend (since they just fallback on native functions) - it will be interesting to think about how to implement rolling operations / expanding / exponential window in a way that is both efficient and maintains differentiability. Expanding and exponential window operations would be easy to do leveraging RNN semantics - but doing rolling using convolutions is going to be very inefficient. Do you have any thoughts on this? |
I have not thought too much about these yet. But I agree that they will
probably require backend specific logic to do efficiently.
…On Fri, Aug 23, 2019 at 12:13 PM firdaus janoos ***@***.***> wrote:
While it is pretty straightforward to implement a lot of standard xarray
operations with a pytorch / Jax backend (since they just fallback on native
functions) - it will be interesting to think about how to implement rolling
operations / expanding / exponential window in a way that is both efficient
and maintains differentiability.
Expanding and exponential window operations would be easy to do leveraging
RNN semantics - but doing rolling using convolutions is going to be very
inefficient.
Do you have any thoughts on this?
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#3232?email_source=notifications&email_token=AAJJFVWRVLTFNT3DYOZIJB3QGASFBA5CNFSM4ING6FH2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD5A6IWY#issuecomment-524411995>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAJJFVQ7JBUNO3CAIFGVJ63QGASFBANCNFSM4ING6FHQ>
.
|
This might be a good time to revive this thread and see if there is wider interest (and bandwidth) in having xarray use CuPy (https://cupy.chainer.org/ ) as a backend (along with numpy). It appears to be a plug-and-play replacement for numpy - so it might not have all the issues that were brought up regarding pytorch/jax ? Any thoughts ? |
Just chiming in quickly. I think there's definitely interest in doing this through NEP-18. It looks like CUDA has implemented Have you tried creating Practically, our approach so far has been to add a number of xfailed tests ( |
@jacobtomlinson gave CuPy a go a few months back. I seem to remember that he ran into a few problems but it would be good to get those documented here. |
Yeah Jacob and I played with this a few months back. There were some issues, but my recollection is pretty hazy. If someone gives this another try, it would be interesting to hear how things go. |
If you have any pointers on how to go about this - I can give it a try.
… |
Well here's a blogpost on using Dask + CuPy. Maybe start there and build up to using Xarray. |
I've been test driving xarray objects backed by CuPy arrays, and one issue I keep running into is that operations (such as plotting) that expect numpy arrays fail due to xarray's implicit converstion to Numpy arrays via I am wondering whether there is a plan for dealing with this issue? Here's a small, reproducible example: [23]: ds.tmin.data.device
<CUDA Device 0>
[24]: ds.isel(time=0, lev=0).tmin.plot() # Fails Traceback---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-21-69a72de2b9fd> in <module>
----> 1 ds.isel(time=0, lev=0).tmin.plot()
/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in __call__(self, **kwargs)
444
445 def __call__(self, **kwargs):
--> 446 return plot(self._da, **kwargs)
447
448 @functools.wraps(hist)
/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in plot(darray, row, col, col_wrap, ax, hue, rtol, subplot_kws, **kwargs)
198 kwargs["ax"] = ax
199
--> 200 return plotfunc(darray, **kwargs)
201
202
/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/plot/plot.py in newplotfunc(darray, x, y, figsize, size, aspect, ax, row, col, col_wrap, xincrease, yincrease, add_colorbar, add_labels, vmin, vmax, cmap, center, robust, extend, levels, infer_intervals, colors, subplot_kws, cbar_ax, cbar_kwargs, xscale, yscale, xticks, yticks, xlim, ylim, norm, **kwargs)
684
685 # Pass the data as a masked ndarray too
--> 686 zval = darray.to_masked_array(copy=False)
687
688 # Replace pd.Intervals if contained in xval or yval.
/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/dataarray.py in to_masked_array(self, copy)
2325 Masked where invalid values (nan or inf) occur.
2326 """
-> 2327 values = self.values # only compute lazy arrays once
2328 isnull = pd.isnull(values)
2329 return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)
/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/dataarray.py in values(self)
556 def values(self) -> np.ndarray:
557 """The array's data as a numpy.ndarray"""
--> 558 return self.variable.values
559
560 @values.setter
/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/variable.py in values(self)
444 def values(self):
445 """The variable's data as a numpy.ndarray"""
--> 446 return _as_array_or_item(self._data)
447
448 @values.setter
/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/xarray/core/variable.py in _as_array_or_item(data)
247 TODO: remove this (replace with np.asarray) once these issues are fixed
248 """
--> 249 data = np.asarray(data)
250 if data.ndim == 0:
251 if data.dtype.kind == "M":
/glade/work/abanihi/softwares/miniconda3/envs/rapids/lib/python3.7/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
83
84 """
---> 85 return array(a, dtype, copy=False, order=order)
86
87
ValueError: object __array__ method not producing an array |
@andersy005 I'm about to start working actively on Cupy requests that instead of calling |
Do you have a sense of the overhead / effort of making jax vs cupy as the gpu backend for xarrays ? One advantage of jax would be built in auto-diff functionality that would enable xarray to be plugged directly into deep learning pipelines. Downside is that it is not as numpy compatible as cupy. How much of a non-starter would this be ? |
@fjanoos I'm afraid I don't. In RAPIDS we support cupy as our GPU array implementation. So this request has come from the desire to make xarray compatible with the RAPIDS suite of tools. We commonly see folks using cupy to switch straight over to a tool like pytorch using DLPack. https://docs-cupy.chainer.org/en/stable/reference/interoperability.html#dlpack But I don't really see #4212 as an effort to make cupy the GPU backend for xarray. I see it as adding support for another backend to xarray. The more the merrier! |
I'd like to cast my vote in favor of getting this functionality in. It would be nice to autodiff through xarray operations. From reading this and related threads, I'm trying to determine a gameplan to make this happen. I'm not familiar with xarray code, so any guidance would be much appreciated. This is what I'm thinking :
My first attempts at this haven't been successful. Whatever custom class I make and past to the xarray/xarray/core/dataarray.py Line 408 in bc35548
Any suggestions would be appreciated. I'm hoping to figure out the shortest path to a working prototype. |
@rgommers Do you expect this solution to work with a PyTorch Tensor custom subclass? Or is monkey patching necessary? |
If you use PyTorch 1.7.1 or later, then Tensor subclasses are much better preserved through pytorch functions and operations like slicing. So a custom subclass, adding the attributes and methods Xarray requires for a duck array should be feasible.
Looks like you need to patch that internally just a bit, probably adding pytorch to Note that I do not expect anymore that we'll be adding |
Note that your the main work in adding |
@Duane321 |
I really hope so. I explored named_tensors at first, but the lack an index for each dimension was a non-starter. So, I'll keep an eye out. |
Glad to hear there's progress I can lean on. I'll come back with a minimum version that does the API matching for maybe 1-2 methods, just to get feedback on theoverall structure. If it works, I can brute through a lot of the rest 🤞
Thank you, I hesitate to change xarray code but not anymore.
Does this mean I shouldn't fill out |
I can't reproduce that: In [4]: da.loc["a1"]
Out[4]:
<xarray.DataArray (b: 2)>
tensor([0.4793, 0.7493], dtype=torch.float32)
Coordinates:
a <U2 'a1'
* b (b) <U2 'b1' 'b2' with
maybe this is a environment issue? Edit: the missing feature list includes xr.DataArray(
[0, 1, 2],
coords={"x": XArrayTensor(torch.Tensor([10, 12, 14]))},
dims="x",
).loc[{"x": XArrayTensor(torch.Tensor([10, 14]))}] does not work, but xr.DataArray(
XArrayTensor(torch.Tensor([0, 1, 2])),
coords={"x": [10, 12, 14]},
dims="x",
).loc[{"x": [10, 14]}] should work just fine. |
Thank again @keewis , that was indeed the case. It was due to my older PyTorch version (1.6.0) |
@Duane321: with |
I don't, unfortunately (there's the partial example in #3232 (comment), though). This is nothing usable right now, but the You (or anyone interested) might still want to maintain a "pytorch-xarray" convenience library to allow something like |
Thanks for the prompt response. Would love to contribute but I have to climb the learning curve first. |
changing the We might still be a bit too early with this, though: the PR which adds |
@keewis @shoyer now that numpy is merged in numpy/numpy#18585 |
I started having a look at making xarray work with the array API here: tomwhite@c72a1c4. Some basic operations work (preserving the underlying array): tomwhite@929812a. If there's interest, I'd be happy to turn this into a PR with some tests. |
Absolutely! |
Opened #6804 |
Glad to see progress on this!! 👏 Just curious though, seeing this comment in the PR:
Are we sure this closes the issue? And, how can we try it out? Even lacking docs, a comment explaining how to set it up would be great, and I can do some testing on my end. I understand that it's an experimental feature. |
Hi @hsharrison - thanks for offering to do some testing. Here's a little demo script that you could try, by switching |
Nice that it's so simple. I think it can't be tested with pytorch until they compete pytorch/pytorch#58743, right? Or we should just try passing |
It needs |
Makes sense, then I'll wait for pytorch/pytorch#58743 to try it. |
While it is true to use PyTorch Tensors directly, one would need the Array API implemented in PyTorch. One could use them indirectly by converting them zero-copy to CuPy arrays, which do have Array API support |
I would be interested in using pytorch as a backend for xarrays - because:
a) pytorch is very similar to numpy - so the conceptual overhead is small
b) [most helpful] enable having a GPU as the underlying hardware for compute - which would provide non-trivial speed up
c) it would allow seamless integration with deep-learning algorithms and techniques
Any thoughts on what the interest for such a feature might be ? I would be open to implementing parts of it - so any suggestions on where I could start ?
Thanks
The text was updated successfully, but these errors were encountered: