Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type check datatree tests #9632

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
more type hints
  • Loading branch information
TomNicholas committed Oct 16, 2024
commit 8d3137afd9bd7a392e16c423653f3b21561aaa8d
78 changes: 39 additions & 39 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@


class TestTreeCreation:
def test_empty(self):
def test_empty(self) -> None:
dt = DataTree(name="root")
assert dt.name == "root"
assert dt.parent is None
assert dt.children == {}
assert_identical(dt.to_dataset(), xr.Dataset())

def test_name(self):
def test_name(self) -> None:
dt = DataTree()
assert dt.name is None

Expand All @@ -50,14 +50,14 @@ def test_name(self):
detached.name = "bar"
assert detached.name == "bar"

def test_bad_names(self):
def test_bad_names(self) -> None:
with pytest.raises(TypeError):
DataTree(name=5) # type: ignore[arg-type]

with pytest.raises(ValueError):
DataTree(name="folder/data")

def test_data_arg(self):
def test_data_arg(self) -> None:
ds = xr.Dataset({"foo": 42})
tree: DataTree = DataTree(dataset=ds)
assert_identical(tree.to_dataset(), ds)
Expand All @@ -67,13 +67,13 @@ def test_data_arg(self):


class TestFamilyTree:
def test_dont_modify_children_inplace(self):
def test_dont_modify_children_inplace(self) -> None:
# GH issue 9196
child = DataTree()
DataTree(children={"child": child})
assert child.parent is None

def test_create_two_children(self):
def test_create_two_children(self) -> None:
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": 0, "b": 1})
root = DataTree.from_dict(
Expand All @@ -82,7 +82,7 @@ def test_create_two_children(self):
assert root["/set1"].name == "set1"
assert root["/set1/set2"].name == "set2"

def test_create_full_tree(self, simple_datatree):
def test_create_full_tree(self, simple_datatree) -> None:
d = simple_datatree.to_dict()
d_keys = list(d.keys())

Expand All @@ -100,12 +100,12 @@ def test_create_full_tree(self, simple_datatree):


class TestNames:
def test_child_gets_named_on_attach(self):
def test_child_gets_named_on_attach(self) -> None:
sue = DataTree()
mary = DataTree(children={"Sue": sue}) # noqa
assert mary.children["Sue"].name == "Sue"

def test_dataset_containing_slashes(self):
def test_dataset_containing_slashes(self) -> None:
xda: xr.DataArray = xr.DataArray(
[[1, 2]],
coords={"label": ["a"], "R30m/y": [30, 60]},
Expand All @@ -124,7 +124,7 @@ def test_dataset_containing_slashes(self):


class TestPaths:
def test_path_property(self):
def test_path_property(self) -> None:
john = DataTree.from_dict(
{
"/Mary/Sue": DataTree(),
Expand All @@ -133,15 +133,15 @@ def test_path_property(self):
assert john["/Mary/Sue"].path == "/Mary/Sue"
assert john.path == "/"

def test_path_roundtrip(self):
def test_path_roundtrip(self) -> None:
john = DataTree.from_dict(
{
"/Mary/Sue": DataTree(),
}
)
assert john["/Mary/Sue"].name == "Sue"

def test_same_tree(self):
def test_same_tree(self) -> None:
john = DataTree.from_dict(
{
"/Mary": DataTree(),
Expand All @@ -150,7 +150,7 @@ def test_same_tree(self):
)
assert john["/Mary"].same_tree(john["/Kate"])

def test_relative_paths(self):
def test_relative_paths(self) -> None:
john = DataTree.from_dict(
{
"/Mary/Sue": DataTree(),
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_relative_paths(self):


class TestStoreDatasets:
def test_create_with_data(self):
def test_create_with_data(self) -> None:
dat = xr.Dataset({"a": 0})
john = DataTree(name="john", dataset=dat)

Expand All @@ -188,7 +188,7 @@ def test_create_with_data(self):
with pytest.raises(TypeError):
DataTree(name="mary", dataset="junk") # type: ignore[arg-type]

def test_set_data(self):
def test_set_data(self) -> None:
john = DataTree(name="john")
dat = xr.Dataset({"a": 0})
john.dataset = dat # type: ignore[assignment]
Expand All @@ -198,14 +198,14 @@ def test_set_data(self):
with pytest.raises(TypeError):
john.dataset = "junk" # type: ignore[assignment]

def test_has_data(self):
def test_has_data(self) -> None:
john = DataTree(name="john", dataset=xr.Dataset({"a": 0}))
assert john.has_data

john_no_data = DataTree(name="john", dataset=None)
assert not john_no_data.has_data

def test_is_hollow(self):
def test_is_hollow(self) -> None:
john = DataTree(dataset=xr.Dataset({"a": 0}))
assert john.is_hollow

Expand All @@ -217,7 +217,7 @@ def test_is_hollow(self):


class TestToDataset:
def test_to_dataset_inherited(self):
def test_to_dataset_inherited(self) -> None:
base = xr.Dataset(coords={"a": [1], "b": 2})
sub = xr.Dataset(coords={"c": [3]})
tree = DataTree.from_dict({"/": base, "/sub": sub})
Expand All @@ -232,16 +232,16 @@ def test_to_dataset_inherited(self):


class TestVariablesChildrenNameCollisions:
def test_parent_already_has_variable_with_childs_name(self):
def test_parent_already_has_variable_with_childs_name(self) -> None:
with pytest.raises(KeyError, match="already contains a variable named a"):
DataTree.from_dict({"/": xr.Dataset({"a": [0], "b": 1}), "/a": None})

def test_parent_already_has_variable_with_childs_name_update(self):
def test_parent_already_has_variable_with_childs_name_update(self) -> None:
dt = DataTree(dataset=xr.Dataset({"a": [0], "b": 1}))
with pytest.raises(ValueError, match="already contains a variable named a"):
dt.update({"a": DataTree()})

def test_assign_when_already_child_with_variables_name(self):
def test_assign_when_already_child_with_variables_name(self) -> None:
dt = DataTree.from_dict(
{
"/a": DataTree(),
Expand All @@ -262,7 +262,7 @@ class TestGet: ...


class TestGetItem:
def test_getitem_node(self):
def test_getitem_node(self) -> None:
folder1 = DataTree.from_dict(
{
"/results/highres": DataTree(),
Expand All @@ -272,16 +272,16 @@ def test_getitem_node(self):
assert folder1["results"].name == "results"
assert folder1["results/highres"].name == "highres"

def test_getitem_self(self):
def test_getitem_self(self) -> None:
dt = DataTree()
assert dt["."] is dt

def test_getitem_single_data_variable(self):
def test_getitem_single_data_variable(self) -> None:
data = xr.Dataset({"temp": [0, 50]})
results = DataTree(name="results", dataset=data)
assert_identical(results["temp"], data["temp"])

def test_getitem_single_data_variable_from_node(self):
def test_getitem_single_data_variable_from_node(self) -> None:
data = xr.Dataset({"temp": [0, 50]})
folder1 = DataTree.from_dict(
{
Expand All @@ -290,67 +290,67 @@ def test_getitem_single_data_variable_from_node(self):
)
assert_identical(folder1["results/highres/temp"], data["temp"])

def test_getitem_nonexistent_node(self):
def test_getitem_nonexistent_node(self) -> None:
folder1 = DataTree.from_dict({"/results": DataTree()}, name="folder1")
with pytest.raises(KeyError):
folder1["results/highres"]

def test_getitem_nonexistent_variable(self):
def test_getitem_nonexistent_variable(self) -> None:
data = xr.Dataset({"temp": [0, 50]})
results = DataTree(name="results", dataset=data)
with pytest.raises(KeyError):
results["pressure"]

@pytest.mark.xfail(reason="Should be deprecated in favour of .subset")
def test_getitem_multiple_data_variables(self):
def test_getitem_multiple_data_variables(self) -> None:
data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]})
results = DataTree(name="results", dataset=data)
assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index]

@pytest.mark.xfail(
reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)"
)
def test_getitem_dict_like_selection_access_to_dataset(self):
def test_getitem_dict_like_selection_access_to_dataset(self) -> None:
data = xr.Dataset({"temp": [0, 50]})
results = DataTree(name="results", dataset=data)
assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index]


class TestUpdate:
def test_update(self):
def test_update(self) -> None:
dt = DataTree()
dt.update({"foo": xr.DataArray(0), "a": DataTree()})
expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None})
assert_equal(dt, expected)
assert dt.groups == ("/", "/a")

def test_update_new_named_dataarray(self):
def test_update_new_named_dataarray(self) -> None:
da = xr.DataArray(name="temp", data=[0, 50])
folder1 = DataTree(name="folder1")
folder1.update({"results": da})
expected = da.rename("results")
assert_equal(folder1["results"], expected)

def test_update_doesnt_alter_child_name(self):
def test_update_doesnt_alter_child_name(self) -> None:
dt = DataTree()
dt.update({"foo": xr.DataArray(0), "a": DataTree(name="b")})
assert "a" in dt.children
child = dt["a"]
assert child.name == "a"

def test_update_overwrite(self):
def test_update_overwrite(self) -> None:
actual = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 1}))})
actual.update({"a": DataTree(xr.Dataset({"x": 2}))})
expected = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 2}))})
assert_equal(actual, expected)

def test_update_coordinates(self):
def test_update_coordinates(self) -> None:
expected = DataTree.from_dict({"/": xr.Dataset(coords={"a": 1})})
actual = DataTree.from_dict({"/": xr.Dataset()})
actual.update(xr.Dataset(coords={"a": 1}))
assert_equal(actual, expected)

def test_update_inherited_coords(self):
def test_update_inherited_coords(self) -> None:
expected = DataTree.from_dict(
{
"/": xr.Dataset(coords={"a": 1}),
Expand All @@ -375,7 +375,7 @@ def test_update_inherited_coords(self):


class TestCopy:
def test_copy(self, create_test_datatree):
def test_copy(self, create_test_datatree) -> None:
dt = create_test_datatree()

for node in dt.root.subtree:
Expand All @@ -402,7 +402,7 @@ def test_copy(self, create_test_datatree):
assert "foo" not in node.attrs
assert node.attrs["Test"] is copied_node.attrs["Test"]

def test_copy_subtree(self):
def test_copy_subtree(self) -> None:
dt = DataTree.from_dict({"/level1/level2/level3": xr.Dataset()})

actual = dt["/level1/level2"].copy()
Expand All @@ -426,7 +426,7 @@ def test_copy_coord_inheritance(self) -> None:
expected = DataTree(name="c")
assert_identical(expected, actual)

def test_deepcopy(self, create_test_datatree):
def test_deepcopy(self, create_test_datatree) -> None:
dt = create_test_datatree()

for node in dt.root.subtree:
Expand Down Expand Up @@ -454,7 +454,7 @@ def test_deepcopy(self, create_test_datatree):
assert node.attrs["Test"] is not copied_node.attrs["Test"]

@pytest.mark.xfail(reason="data argument not yet implemented")
def test_copy_with_data(self, create_test_datatree):
def test_copy_with_data(self, create_test_datatree) -> None:
orig = create_test_datatree()
# TODO use .data_vars once that property is available
data_vars = {
Expand Down
Loading