Skip to content

Commit

Permalink
added throw_on_missing flag to _get_node() (#513)
Browse files Browse the repository at this point in the history
* added throw_on_missing flag to _get_node()

* tighter typing

* removed some extraneous assertions

* standardized error message
  • Loading branch information
omry authored Feb 3, 2021
1 parent 877e839 commit 96645ef
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 31 deletions.
7 changes: 6 additions & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,12 @@ def update_node(self, key: str, value: Any = None) -> None:
def select(self, key: str, throw_on_missing: bool = False) -> Any:
...

def _get_node(self, key: Any, validate_access: bool = True) -> Optional[Node]:
def _get_node(
self,
key: Any,
validate_access: bool = True,
throw_on_missing: bool = False,
) -> Union[Optional[Node], List[Optional[Node]]]:
...

@abstractmethod
Expand Down
18 changes: 14 additions & 4 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def convert(val: Node) -> Any:
retdict: Dict[str, Any] = {}
for key in conf.keys():
node = conf._get_node(key)
assert node is not None
assert isinstance(node, Node)
if resolve:
node = node._dereference_node(
throw_on_missing=False, throw_on_resolution_failure=True
Expand All @@ -240,7 +240,7 @@ def convert(val: Node) -> Any:
retlist: List[Any] = []
for index in range(len(conf)):
node = conf._get_node(index)
assert node is not None
assert isinstance(node, Node)
if resolve:
node = node._dereference_node(
throw_on_missing=False, throw_on_resolution_failure=True
Expand Down Expand Up @@ -331,6 +331,8 @@ def expand(node: Container) -> None:
for key, src_value in src.items_ex(resolve=False):
src_node = src._get_node(key, validate_access=False)
dest_node = dest._get_node(key, validate_access=False)
assert src_node is None or isinstance(src_node, Node)
assert dest_node is None or isinstance(dest_node, Node)

if isinstance(dest_node, DictConfig):
dest_node._validate_merge(value=src_node)
Expand Down Expand Up @@ -558,6 +560,7 @@ def wrap(key: Any, val: Any) -> Node:
if is_structured_config(val):
ref_type = self._metadata.element_type
else:
assert isinstance(target, Node)
is_optional = target._is_optional()
ref_type = target._metadata.ref_type
return _maybe_wrap(
Expand Down Expand Up @@ -608,6 +611,9 @@ def _item_eq(
v2 = c2._get_node(k2)
assert v1 is not None and v2 is not None

assert isinstance(v1, Node)
assert isinstance(v2, Node)

if v1._is_none() and v2._is_none():
return True

Expand Down Expand Up @@ -642,11 +648,15 @@ def _item_eq(
elif not v1_inter and not v2_inter:
v1 = _get_value(v1)
v2 = _get_value(v2)
return v1 == v2
ret = v1 == v2
assert isinstance(ret, bool)
return ret
else:
dv1 = _get_value(dv1)
dv2 = _get_value(dv2)
return dv1 == dv2
ret = dv1 == dv2
assert isinstance(ret, bool)
return ret

def _is_none(self) -> bool:
return self.__dict__["_content"] is None
Expand Down
30 changes: 20 additions & 10 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def _validate_set(self, key: Any, value: Any) -> None:
if is_valid_target:
return

target_type = target._metadata.ref_type # type: ignore
assert isinstance(target, Node)

target_type = target._metadata.ref_type
value_type = OmegaConf.get_type(value)

if is_dict(value_type) and is_dict(target_type):
Expand Down Expand Up @@ -239,12 +241,14 @@ def _validate_non_optional(self, key: Any, value: Any) -> None:
if OmegaConf.is_none(value):
if key is not None:
child = self._get_node(key)
if child is not None and not child._is_optional():
self._format_and_raise(
key=key,
value=value,
cause=ValidationError("child '$FULL_KEY' is not Optional"),
)
if child is not None:
assert isinstance(child, Node)
if not child._is_optional():
self._format_and_raise(
key=key,
value=value,
cause=ValidationError("child '$FULL_KEY' is not Optional"),
)
else:
if not self._is_optional():
self._format_and_raise(
Expand Down Expand Up @@ -420,8 +424,11 @@ def _get_impl(self, key: DictKeyType, default_value: Any) -> Any:
)

def _get_node(
self, key: DictKeyType, validate_access: bool = True
) -> Optional[Node]:
self,
key: DictKeyType,
validate_access: bool = True,
throw_on_missing: bool = False,
) -> Union[Optional[Node], List[Optional[Node]]]:
try:
key = self._validate_and_normalize_key(key)
except KeyValidationError:
Expand All @@ -434,6 +441,8 @@ def _get_node(
self._validate_get(key)

value: Node = self.__dict__["_content"].get(key)
if throw_on_missing and value._is_missing():
raise MissingMandatoryValue("Missing mandatory value")

return value

Expand Down Expand Up @@ -487,7 +496,8 @@ def __contains__(self, key: object) -> bool:
return False

try:
node: Optional[Node] = self._get_node(key)
node = self._get_node(key)
assert node is None or isinstance(node, Node)
except (KeyError, AttributeError):
node = None

Expand Down
25 changes: 22 additions & 3 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _validate_set(self, key: Any, value: Any) -> None:
if 0 <= key < self.__len__():
target = self._get_node(key)
if target is not None:
assert isinstance(target, Node)
if value is None and not target._is_optional():
raise ValidationError(
"$FULL_KEY is not optional and cannot be assigned None"
Expand Down Expand Up @@ -274,6 +275,7 @@ def _update_keys(self) -> None:
for i in range(len(self)):
node = self._get_node(i)
if node is not None:
assert isinstance(node, Node)
node._metadata.key = i

def insert(self, index: int, item: Any) -> None:
Expand Down Expand Up @@ -367,8 +369,11 @@ def count(self, x: Any) -> int:
return c

def _get_node(
self, key: Union[int, slice], validate_access: bool = True
) -> Optional[Node]:
self,
key: Union[int, slice],
validate_access: bool = True,
throw_on_missing: bool = False,
) -> Union[Optional[Node], List[Optional[Node]]]:
try:
if self._is_none():
raise TypeError(
Expand All @@ -379,8 +384,22 @@ def _get_node(
assert isinstance(self.__dict__["_content"], list)
if validate_access:
self._validate_get(key)
return self.__dict__["_content"][key] # type: ignore

value = self.__dict__["_content"][key]
if value is not None:
if isinstance(key, slice):
assert isinstance(value, list)
for v in value:
if throw_on_missing and v._is_missing():
raise MissingMandatoryValue("Missing mandatory value")
else:
assert isinstance(value, Node)
if throw_on_missing and value._is_missing():
raise MissingMandatoryValue("Missing mandatory value")
return value
except (IndexError, TypeError, MissingMandatoryValue, KeyValidationError) as e:
if isinstance(e, MissingMandatoryValue) and throw_on_missing:
raise
if validate_access:
self._format_and_raise(key=key, value=None, cause=e)
assert False
Expand Down
5 changes: 4 additions & 1 deletion omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ def is_missing(cfg: Any, key: DictKeyType) -> bool:
node = cfg._get_node(key)
if node is None:
return False
assert isinstance(node, Node)
return node._is_missing()
except (UnsupportedInterpolationType, KeyError, AttributeError):
return False
Expand Down Expand Up @@ -901,8 +902,9 @@ def _select_one(
assert isinstance(c, (DictConfig, ListConfig)), f"Unexpected type : {c}"
if isinstance(c, DictConfig):
assert isinstance(ret_key, str)
val: Optional[Node] = c._get_node(ret_key, validate_access=False)
val = c._get_node(ret_key, validate_access=False)
if val is not None:
assert isinstance(val, Node)
if val._is_missing():
if throw_on_missing:
raise MissingMandatoryValue(
Expand Down Expand Up @@ -930,4 +932,5 @@ def _select_one(
else:
assert False

assert val is None or isinstance(val, Node)
return val, ret_key
3 changes: 2 additions & 1 deletion tests/structured_conf/test_structured_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from omegaconf import (
DictConfig,
IntegerNode,
Node,
OmegaConf,
ValidationError,
_utils,
Expand Down Expand Up @@ -218,7 +219,7 @@ def test_merge_optional_structured_onto_dict(self, class_type: str) -> None:
assert get_ref_type(c2, "user") == Optional[module.User]
assert isinstance(c2, DictConfig)
c2_user = c2._get_node("user")
assert c2_user is not None
assert isinstance(c2_user, Node)
# Compared to the previous assert, here we verify that the `ref_type` found
# in the metadata is *not* optional: instead, the `optional` flag must be set.
assert c2_user._metadata.ref_type == module.User
Expand Down
20 changes: 18 additions & 2 deletions tests/test_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
open_dict,
read_write,
)
from omegaconf.errors import ConfigAttributeError, ConfigKeyError
from omegaconf.errors import ConfigAttributeError, ConfigKeyError, MissingMandatoryValue
from tests import Color, StructuredWithMissing, User, does_not_raise


Expand Down Expand Up @@ -337,7 +337,7 @@ def test_set_flags() -> None:
def test_get_flag_after_dict_assignment(no_deepcopy_set_nodes: bool, node: Any) -> None:
cfg = OmegaConf.create({"c": node})
cfg._set_flag("foo", True)
nc = cfg._get_node("c")
nc: Any = cfg._get_node("c")
assert nc is not None
assert nc._flags_cache is None
assert nc._get_flag("foo") is True
Expand Down Expand Up @@ -739,3 +739,19 @@ def test_assign(parent: Any, key: Union[str, int], value: Any, expected: Any) ->
)
def test_get_node(cfg: Any, key: Any, expected: Any) -> None:
assert cfg._get_node(key) == expected


@pytest.mark.parametrize(
"cfg, key",
[
# dict
pytest.param({"foo": "???"}, "foo", id="dict"),
# list
pytest.param([10, "???", 30], 1, id="list_int"),
pytest.param([10, "???", 30], slice(1, 2), id="list_slice"),
],
)
def test_get_node_throw_on_missing(cfg: Any, key: Any) -> None:
cfg = OmegaConf.create(cfg)
with pytest.raises(MissingMandatoryValue, match="Missing mandatory value"):
cfg._get_node(key, throw_on_missing=True)
10 changes: 5 additions & 5 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,11 +704,11 @@ def test_is_missing() -> None:
"missing_node_inter": "${missing_node}",
}
)
assert cfg._get_node("foo")._is_missing() # type:ignore
assert cfg._get_node("inter")._is_missing() # type:ignore
assert not cfg._get_node("str_inter")._is_missing() # type:ignore
assert cfg._get_node("missing_node")._is_missing() # type:ignore
assert cfg._get_node("missing_node_inter")._is_missing() # type:ignore
assert cfg._get_node("foo")._is_missing() # type: ignore
assert cfg._get_node("inter")._is_missing() # type: ignore
assert not cfg._get_node("str_inter")._is_missing() # type: ignore
assert cfg._get_node("missing_node")._is_missing() # type: ignore
assert cfg._get_node("missing_node_inter")._is_missing() # type: ignore


@pytest.mark.parametrize("ref_type", [None, Any])
Expand Down
8 changes: 4 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ def test_value_kind(value: Any, kind: _utils.ValueKind) -> None:
def test_re_parent() -> None:
def validate(cfg1: DictConfig) -> None:
assert cfg1._get_parent() is None
assert cfg1._get_node("str")._get_parent() == cfg1 # type:ignore
assert cfg1._get_node("list")._get_parent() == cfg1 # type:ignore
assert cfg1._get_node("str")._get_parent() == cfg1 # type: ignore
assert cfg1._get_node("list")._get_parent() == cfg1 # type: ignore
assert cfg1.list._get_node(0)._get_parent() == cfg1.list

cfg = OmegaConf.create({})
Expand All @@ -276,8 +276,8 @@ def validate(cfg1: DictConfig) -> None:

validate(cfg)

cfg._get_node("str")._set_parent(None) # type:ignore
cfg._get_node("list")._set_parent(None) # type:ignore
cfg._get_node("str")._set_parent(None) # type: ignore
cfg._get_node("list")._set_parent(None) # type: ignore
cfg.list._get_node(0)._set_parent(None) # type:ignore
# noinspection PyProtectedMember
cfg._re_parent()
Expand Down

0 comments on commit 96645ef

Please sign in to comment.