Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve order of variables in combine_by_coords #9070

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

kmuehlbauer
Copy link
Contributor

@kmuehlbauer kmuehlbauer commented Jun 5, 2024

@kmuehlbauer kmuehlbauer changed the title FIX: do not sort datasets in combine_by_coords Preserve order of variables in in combine_by_coords Jun 5, 2024
@kmuehlbauer kmuehlbauer changed the title Preserve order of variables in in combine_by_coords Preserve order of variables in combine_by_coords Jun 5, 2024
@TomNicholas
Copy link
Member

I think the reason I originally put the sort call in there was because in the itertools.groupby docs https://docs.python.org/3/library/itertools.html#itertools.groupby it says

Generally, the iterable needs to already be sorted on the same key function.

But if it seems to work without that then I guess it's fine?

@kmuehlbauer
Copy link
Contributor Author

I think this was from times where we had to deal with unsorted dict. At least that was what I understood from a previous comment at that code position.

@kmuehlbauer
Copy link
Contributor Author

Generally, the iterable needs to already be sorted on the same key function.

Now, in light of this... But somehow it works.

@kmuehlbauer
Copy link
Contributor Author

OK, let me add some more testing to be sure this works in datasets of any order.

@keewis
Copy link
Collaborator

keewis commented Jun 5, 2024

the reason itertools.groupby needs sorted iterables is that it combines groups locally, not globally. So list(map(list, itertools.groupby([1, 1, 2, 1, 3], key=lambda x: x))) would result in [(1, [1, 1]), (2, [2]), (1, [1]), (3, [3])], not [(1, [1, 1, 1]), (2, [2]), (3, [3])] (this is not the case for more_itertools.bucket and toolz.itertoolz.groupby).

@kmuehlbauer
Copy link
Contributor Author

kmuehlbauer commented Jun 6, 2024

Thanks @keewis. AFAICT, the sorting before groupby makes sure that all Datasets with same variables are grouped together regardless of the variable order within each Dataset. Update: And the position in the object list.

Removing the sorting will result in more groups but doesn't break the tests. Does that mean we are undertesting?. Or is it just fixed by the subsequent merge?

The issue which should be solved by this PR is that this sorting rearranges the input objects and when using compat="override" might move the first object to another position, resulting in wrong output. Then, I was thinking to special case compat="override" to keep the first object at it's place, but didn't find a way to incorporate that with the groupby.

@kmuehlbauer
Copy link
Contributor Author

After reading on itertools and collections I've found that it's possible to change itertools.groupby with a collections.defaultdict implementation to preserve order and with some intermediate performance gain:

import xarray as xr
import collections
import itertools

def vars_as_keys(ds):
    return tuple(sorted(ds))

def groupby_defaultdict(iter, key=lambda x: x):
    idx = collections.defaultdict(list)
    for i, obj in enumerate(iter):
        idx[key(obj)].append(i)
    for k, ix in idx.items():
        yield k, (iter[i] for i in ix)
        
def groupby_itertools(iter, key=lambda x: x):
    iter = sorted(iter, key=vars_as_keys)
    return itertools.groupby(iter, key=vars_as_keys)

x1 = xr.Dataset({"a": (("y", "x"), [[1]]),
                 "c": (("y", "x"), [[1]]),
                 "b": (("y", "x"), [[1]]),}, 
                coords={"y": [0], "x": [0]})
x2 = xr.Dataset({"d": (("y", "x"), [[1]]),
                 "a": (("y", "x"), [[2]]),},
                coords={"y": [0], "x": [0]})
x3 = xr.Dataset({"a": (("y", "x"), [[3]]),
                 "d": (("y", "x"), [[2]]),},
                coords={"y": [0], "x": [1]})
data_objects = [x2, x1, x3]
%%timeit
grouped_by_vars = groupby_defaultdict(data_objects, key=vars_as_keys)
274 ns ± 0.934 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%%timeit
grouped_by_vars = groupby_itertools(data_objects, key=vars_as_keys)
4.98 µs ± 14 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

@kmuehlbauer
Copy link
Contributor Author

It looks like this can be replaced here too:

# TODO: is the sorted need?
combined_ids = dict(sorted(combined_ids.items(), key=_new_tile_id))
grouped = itertools.groupby(combined_ids.items(), key=_new_tile_id)

@dcherian dcherian requested a review from TomNicholas June 6, 2024 18:15
@kmuehlbauer kmuehlbauer added the run-benchmark Run the ASV benchmark workflow label Jun 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-benchmark Run the ASV benchmark workflow
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dataset combine_by_coords unexpected behavior
3 participants