Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
0.3 (unreleased)
================

- Closing the dataset returned by `open_ncml` will close the underlying opened files. By @huard


0.2 (2023-02-23)
================

Expand Down
2,037 changes: 1,946 additions & 91 deletions docs/source/tutorial.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ xarray
cftime
netCDF4
dask
psutil
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ select = B,C,E,F,W,T4,B9

[isort]
known_first_party=xncml
known_third_party=numpy,pkg_resources,pytest,setuptools,xarray,xmltodict,xsdata
known_third_party=numpy,pkg_resources,psutil,pytest,setuptools,xarray,xmltodict,xsdata
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
55 changes: 40 additions & 15 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

import numpy as np
import psutil
import pytest

import xncml
Expand All @@ -15,22 +16,44 @@
data = Path(__file__).parent / 'data'


class CheckClose(object):
"""Check that files are closed after the test. Note that `close` has to be explicitly called within the
context manager for this to work."""

def __init__(self):
self.proc = psutil.Process()
self.before = None

def __enter__(self):
self.before = len(self.proc.open_files())

def __exit__(self, *args):
"""Raise error if files are left open at the end of the test."""
after = len(self.proc.open_files())
if after != self.before:
raise AssertionError(f'Files left open after test: {after - self.before}')


def test_aggexisting():
ds = xncml.open_ncml(data / 'aggExisting.xml')
check_dimension(ds)
check_coord_var(ds)
check_agg_coord_var(ds)
check_read_data(ds)
assert ds['time'].attrs['ncmlAdded'] == 'timeAtt'
with CheckClose():
ds = xncml.open_ncml(data / 'aggExisting.xml')
check_dimension(ds)
check_coord_var(ds)
check_agg_coord_var(ds)
check_read_data(ds)
assert ds['time'].attrs['ncmlAdded'] == 'timeAtt'
ds.close()


def test_aggexisting_w_coords():
ds = xncml.open_ncml(data / 'aggExistingWcoords.xml')
check_dimension(ds)
check_coord_var(ds)
check_agg_coord_var(ds)
check_read_data(ds)
assert ds['time'].attrs['ncmlAdded'] == 'timeAtt'
with CheckClose():
ds = xncml.open_ncml(data / 'aggExistingWcoords.xml')
check_dimension(ds)
check_coord_var(ds)
check_agg_coord_var(ds)
check_read_data(ds)
assert ds['time'].attrs['ncmlAdded'] == 'timeAtt'
ds.close()


def test_aggexisting_coords_var():
Expand Down Expand Up @@ -155,9 +178,11 @@ def test_agg_synthetic_3():


def test_agg_syn_scan():
ds = xncml.open_ncml(data / 'aggSynScan.xml')
assert len(ds.time) == 3
assert all(ds.time == [0, 10, 20])
with CheckClose():
ds = xncml.open_ncml(data / 'aggSynScan.xml')
assert len(ds.time) == 3
assert all(ds.time == [0, 10, 20])
ds.close()


def test_agg_syn_rename():
Expand Down
40 changes: 29 additions & 11 deletions xncml/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from __future__ import annotations

import datetime as dt
from functools import partial
from itertools import chain
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -165,11 +167,14 @@ def read_aggregation(target: xr.Dataset, obj: Aggregation, ncml: Path) -> xr.Dat
for attr in obj.promote_global_attribute:
raise NotImplementedError

# Create list of items to aggregate.
items = []
# Create list of datasets to aggregate.
datasets = []
closers = []

for item in obj.netcdf:
# Open dataset defined in <netcdf>'s `location` attribute
tar = read_netcdf(xr.Dataset(), ref=xr.Dataset(), obj=item, ncml=ncml)
closers.append(getattr(tar, '_close'))

# Select variables
if names:
Expand All @@ -180,31 +185,35 @@ def read_aggregation(target: xr.Dataset, obj: Aggregation, ncml: Path) -> xr.Dat
dtypes = [i[obj.dim_name].dtype.type for i in [tar, target] if obj.dim_name in i]
coords = read_coord_value(item, obj, dtypes=dtypes)
tar = tar.assign_coords({obj.dim_name: coords})
items.append(tar)
datasets.append(tar)

# Handle <scan> element
for item in obj.scan:
items.extend(read_scan(item, ncml))
dss = read_scan(item, ncml)
datasets.extend([ds.chunk() for ds in dss])
closers.extend([getattr(ds, '_close') for ds in dss])

# Need to decode time variable
if obj.time_units_change:
for i, ds in enumerate(items):
for i, ds in enumerate(datasets):
t = xr.as_variable(ds[obj.dim_name], obj.dim_name) # Maybe not the same name...
encoded = CFDatetimeCoder(use_cftime=True).decode(t, name=t.name)
items[i] = ds.assign_coords({obj.dim_name: encoded})
datasets[i] = ds.assign_coords({obj.dim_name: encoded})

# Translate different types of aggregation into xarray instructions.
if obj.type == AggregationType.JOIN_EXISTING:
agg = xr.concat(items, obj.dim_name)
agg = xr.concat(datasets, obj.dim_name)
elif obj.type == AggregationType.JOIN_NEW:
agg = xr.concat(items, obj.dim_name)
agg = xr.concat(datasets, obj.dim_name)
elif obj.type == AggregationType.UNION:
agg = xr.merge(items)
agg = xr.merge(datasets)
else:
raise NotImplementedError

agg = read_group(agg, None, obj)
return target.merge(agg, combine_attrs='no_conflicts')
out = target.merge(agg, combine_attrs='no_conflicts')
out.set_close(partial(_multi_file_closer, closers))
return out


def read_ds(obj: Netcdf, ncml: Path) -> xr.Dataset:
Expand Down Expand Up @@ -319,7 +328,7 @@ def read_scan(obj: Aggregation.Scan, ncml: Path) -> [xr.Dataset]:

files.sort()

return [xr.open_dataset(f, decode_times=False).chunk() for f in files]
return [xr.open_dataset(f, decode_times=False) for f in files]


def read_coord_value(nc: Netcdf, agg: Aggregation, dtypes: list = ()):
Expand Down Expand Up @@ -575,3 +584,12 @@ def filter_by_class(iterable, klass):
for item in iterable:
if isinstance(item, klass):
yield item


def _multi_file_closer(closers):
"""Close multiple files."""
# Note that if a closer is None, it probably means an alteration was made to the original dataset. Make sure
# that the `_close` attribute is obtained directly from the object returned by `open_dataset`.
for closer in closers:
if closer is not None:
closer()