Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Open datatree #15

Merged
merged 2 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datatree/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .datatree import DataTree, map_over_subtree, DataNode
from .io import open_datatree
2 changes: 1 addition & 1 deletion datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def as_dataarray(self) -> DataArray:
@property
def groups(self):
"""Return all netCDF4 groups in the tree, given as a tuple of path-like strings."""
return tuple(node.path for node in self.subtree_nodes)
return tuple(node.pathstr for node in self.subtree_nodes)

def to_netcdf(self, filename: str):
from .io import _datatree_to_netcdf
Expand Down
63 changes: 32 additions & 31 deletions datatree/io.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,47 @@
from typing import Sequence
from typing import Sequence, Dict
import os

from netCDF4 import Dataset as nc_dataset
import netCDF4

from xarray import open_dataset

from .datatree import DataTree, PathType
from .datatree import DataTree, DataNode, PathType


def _get_group_names(file):
rootgrp = nc_dataset("test.nc", "r", format="NETCDF4")
def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs):
for g in ncgroup.groups.values():

def walktree(top):
yield top.groups.values()
for value in top.groups.values():
yield from walktree(value)
# Open and add this node's dataset to the tree
name = os.path.basename(g.path)
ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs)
child_node = DataNode(name, ds)
node.add_child(child_node)

groups = []
for children in walktree(rootgrp):
for child in children:
# TODO include parents in saved path
groups.append(child.name)
_open_group_children_recursively(filename, node[name], g, chunks, **kwargs)

rootgrp.close()
return groups


def open_datatree(filename_or_obj, engine=None, chunks=None, **kwargs) -> DataTree:
"""
Open and decode a dataset from a file or file-like object, creating one DataTree node
for each group in the file.
def open_datatree(filename: str, chunks: Dict = None, **kwargs) -> DataTree:
"""
Open and decode a dataset from a file or file-like object, creating one Tree node for each group in the file.

# TODO find all the netCDF groups in the file
file_groups = _get_group_names(filename_or_obj)
Parameters
----------
filename
chunks

Returns
-------
DataTree
"""

# Populate the DataTree with the groups
groups_and_datasets = {group_path: open_dataset(engine=engine, chunks=chunks, **kwargs)
for group_path in file_groups}
return DataTree(data_objects=groups_and_datasets)
with netCDF4.Dataset(filename, mode='r') as ncfile:
ds = open_dataset(filename, chunks=chunks, **kwargs)
tree_root = DataTree(data_objects={'root': ds})
_open_group_children_recursively(filename, tree_root, ncfile, chunks, **kwargs)
return tree_root


def open_mfdatatree(filepaths, rootnames: Sequence[PathType] = None, engine=None, chunks=None, **kwargs) -> DataTree:
def open_mfdatatree(filepaths, rootnames: Sequence[PathType] = None, chunks=None, **kwargs) -> DataTree:
"""
Open multiple files as a single DataTree.

Expand All @@ -55,11 +56,11 @@ def open_mfdatatree(filepaths, rootnames: Sequence[PathType] = None, engine=None
full_tree = DataTree()

for file, root in zip(filepaths, rootnames):
dt = open_datatree(file, engine=engine, chunks=chunks, **kwargs)
full_tree._set_item(path=root, value=dt, new_nodes_along_path=True, allow_overwrites=False)
dt = open_datatree(file, chunks=chunks, **kwargs)
full_tree.set_node(path=root, node=dt, new_nodes_along_path=True, allow_overwrite=False)

return full_tree


def _datatree_to_netcdf(dt: DataTree, path_or_file: str):
def _datatree_to_netcdf(dt: DataTree, filepath: str):
raise NotImplementedError