Skip to content

Commit

Permalink
Improve handling of interpolations pointing to missing nodes (#545)
Browse files Browse the repository at this point in the history
* Interpolations are never considered to be missing anymore, even if
  they point to a missing node

* When resolving an expression containing an interpolation pointing to a
  missing node, an `InterpolationToMissingValueError` exception is raised

* When resolving an expression containing an interpolation pointing to a
  node that does not exist, an `InterpolationKeyError` exception is raised

* `key in cfg` returns True whenever `key` is an interpolation, even if it cannot be resolved (including interpolations to missing nodes)

* `get()` and `pop()` no longer return the default value in case of interpolation resolution failure (same thing for `OmegaConf.select()`)

* If `throw_on_resolution_failure` is False, then resolving an
  interpolation resulting in a resolution failure always leads to the
  result being `None` (instead of potentially being an expression computed
  from `None`)

Fixes #543
Fixes #561
Fixes #562
Fixes #565
  • Loading branch information
odelalleau authored Mar 3, 2021
1 parent 8b78699 commit cb5e556
Show file tree
Hide file tree
Showing 24 changed files with 619 additions and 414 deletions.
File renamed without changes.
1 change: 0 additions & 1 deletion news/462.api_change.2

This file was deleted.

1 change: 1 addition & 0 deletions news/543.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`OmegaConf.select()`, `DictConfig.{get(),pop()}`, `ListConfig.{get(),pop()}` no longer return the specified default value when the accessed key is an interpolation that cannot be resolved: instead, an exception is raised.
1 change: 1 addition & 0 deletions news/561.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
All exceptions raised during the resolution of an interpolation are either `InterpolationResolutionError` or a subclass of it.
1 change: 1 addition & 0 deletions news/562.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`key in cfg` now returns True when `key` is an interpolation even if the interpolation target is a missing ("???") value.
2 changes: 1 addition & 1 deletion news/563.bugfix
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Calling `OmegaConf.select()` on a missing node from a ListConfig with `throw_on_missing` set to True now raises the intended exception.
`OmegaConf.select()` of a missing ("???") node from a ListConfig with `throw_on_missing` set to True now raises the intended exception.
2 changes: 2 additions & 0 deletions news/565.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
`OmegaConf.select()` as well as container methods `get()` and `pop()` do not return their default value anymore when the accessed key is an interpolation that cannot be resolved: instead, an exception is raised.

183 changes: 84 additions & 99 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
Expand All @@ -10,13 +11,15 @@
from ._utils import (
ValueKind,
_get_value,
_is_missing_literal,
_is_missing_value,
format_and_raise,
get_value_kind,
)
from .errors import (
ConfigKeyError,
InterpolationKeyError,
InterpolationResolutionError,
InterpolationToMissingValueError,
MissingMandatoryValue,
OmegaConfBaseException,
UnsupportedInterpolationType,
Expand Down Expand Up @@ -190,33 +193,26 @@ def _get_full_key(self, key: Optional[Union[DictKeyType, int]]) -> str:

def _dereference_node(
self,
throw_on_missing: bool = False,
throw_on_resolution_failure: bool = True,
) -> Optional["Node"]:
if self._is_interpolation():
parent = self._get_parent()
if parent is None:
raise OmegaConfBaseException(
"Cannot resolve interpolation for a node without a parent"
)
assert parent is not None
key = self._key()
return parent._resolve_interpolation_from_parse_tree(
parent=parent,
key=key,
value=self,
parse_tree=parse(_get_value(self)),
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
else:
# not interpolation, compare directly
if throw_on_missing:
value = self._value()
if _is_missing_literal(value):
raise MissingMandatoryValue("Missing mandatory value")
if not self._is_interpolation():
return self

parent = self._get_parent()
if parent is None:
raise OmegaConfBaseException(
"Cannot resolve interpolation for a node without a parent"
)
assert parent is not None
key = self._key()
return parent._resolve_interpolation_from_parse_tree(
parent=parent,
key=key,
value=self,
parse_tree=parse(_get_value(self)),
throw_on_resolution_failure=throw_on_resolution_failure,
)

def _get_root(self) -> "Container":
root: Optional[Container] = self._get_parent()
if root is None:
Expand All @@ -228,6 +224,9 @@ def _get_root(self) -> "Container":
assert root is not None and isinstance(root, Container)
return root

def _is_missing(self) -> bool:
return _is_missing_value(self)

@abstractmethod
def __eq__(self, other: Any) -> bool:
...
Expand Down Expand Up @@ -256,10 +255,6 @@ def _is_none(self) -> bool:
def _is_optional(self) -> bool:
...

@abstractmethod
def _is_missing(self) -> bool:
...

@abstractmethod
def _is_interpolation(self) -> bool:
...
Expand Down Expand Up @@ -365,7 +360,6 @@ def _select_impl(
)
if isinstance(ret, Node):
ret = ret._dereference_node(
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)

Expand Down Expand Up @@ -394,7 +388,6 @@ def _select_impl(
parent=root,
key=last_key,
value=value,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
return root, last_key, value
Expand All @@ -405,22 +398,23 @@ def _resolve_interpolation_from_parse_tree(
value: "Node",
key: Any,
parse_tree: OmegaConfGrammarParser.ConfigValueContext,
throw_on_missing: bool,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
from .nodes import StringNode

resolved = self.resolve_parse_tree(
parse_tree=parse_tree,
key=key,
parent=parent,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)

if resolved is None:
try:
resolved = self.resolve_parse_tree(
parse_tree=parse_tree,
key=key,
parent=parent,
)
except InterpolationResolutionError:
if throw_on_resolution_failure:
raise
return None
elif isinstance(resolved, str):

assert resolved is not None
if isinstance(resolved, str):
# Result is a string: create a new StringNode for it.
return StringNode(
value=resolved,
Expand All @@ -435,72 +429,67 @@ def _resolve_interpolation_from_parse_tree(
def _resolve_node_interpolation(
self,
inter_key: str,
throw_on_missing: bool,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
) -> "Node":
"""A node interpolation is of the form `${foo.bar}`"""
root_node, inter_key = self._resolve_key_and_root(inter_key)
parent, last_key, value = root_node._select_impl(
inter_key,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)
try:
root_node, inter_key = self._resolve_key_and_root(inter_key)
except ConfigKeyError as exc:
raise InterpolationKeyError(
f"ConfigKeyError while resolving interpolation: {exc}"
).with_traceback(sys.exc_info()[2])

try:
parent, last_key, value = root_node._select_impl(
inter_key,
throw_on_missing=True,
throw_on_resolution_failure=True,
)
except MissingMandatoryValue as exc:
raise InterpolationToMissingValueError(
f"MissingMandatoryValue while resolving interpolation: {exc}"
).with_traceback(sys.exc_info()[2])
except ConfigKeyError as exc:
raise InterpolationKeyError(
f"ConfigKeyError while resolving interpolation: {exc}"
).with_traceback(sys.exc_info()[2])

if parent is None or value is None:
if throw_on_resolution_failure:
raise InterpolationResolutionError(
f"Interpolation key '{inter_key}' not found"
)
else:
return None
assert isinstance(value, Node)
return value
raise InterpolationKeyError(f"Interpolation key '{inter_key}' not found")
else:
return value

def _evaluate_custom_resolver(
self,
key: Any,
inter_type: str,
inter_args: Tuple[Any, ...],
throw_on_missing: bool,
throw_on_resolution_failure: bool,
inter_args_str: Tuple[str, ...],
) -> Optional["Node"]:
) -> Any:
from omegaconf import OmegaConf

from .nodes import ValueNode

resolver = OmegaConf.get_resolver(inter_type)
if resolver is not None:
root_node = self._get_root()
try:
value = resolver(root_node, inter_args, inter_args_str)
return ValueNode(
value=value,
parent=self,
metadata=Metadata(
ref_type=Any, object_type=Any, key=key, optional=True
),
)
except Exception as e:
if throw_on_resolution_failure:
self._format_and_raise(key=None, value=None, cause=e)
assert False
else:
return None
value = resolver(root_node, inter_args, inter_args_str)
return ValueNode(
value=value,
parent=self,
metadata=Metadata(
ref_type=Any, object_type=Any, key=key, optional=True
),
)
else:
if throw_on_resolution_failure:
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
)
else:
return None
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
)

def _maybe_resolve_interpolation(
self,
parent: Optional["Container"],
key: Any,
value: "Node",
throw_on_missing: bool,
throw_on_resolution_failure: bool,
) -> Any:
value_kind = get_value_kind(value)
Expand All @@ -513,7 +502,6 @@ def _maybe_resolve_interpolation(
value=value,
key=key,
parse_tree=parse_tree,
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)

Expand All @@ -522,8 +510,6 @@ def resolve_parse_tree(
parse_tree: ParserRuleContext,
key: Optional[Any] = None,
parent: Optional["Container"] = None,
throw_on_missing: bool = True,
throw_on_resolution_failure: bool = True,
) -> Any:
"""
Resolve a given parse tree into its value.
Expand All @@ -533,26 +519,17 @@ def resolve_parse_tree(
"""
from .nodes import StringNode

# Common arguments to all callbacks.
callback_args: Dict[str, Any] = dict(
throw_on_missing=throw_on_missing,
throw_on_resolution_failure=throw_on_resolution_failure,
)

def node_interpolation_callback(inter_key: str) -> Optional["Node"]:
return self._resolve_node_interpolation(
inter_key=inter_key, **callback_args
)
return self._resolve_node_interpolation(inter_key=inter_key)

def resolver_interpolation_callback(
name: str, args: Tuple[Any, ...], args_str: Tuple[str, ...]
) -> Optional["Node"]:
) -> Any:
return self._evaluate_custom_resolver(
key=key,
inter_type=name,
inter_args=args,
inter_args_str=args_str,
**callback_args,
)

def quoted_string_callback(quoted_str: str) -> str:
Expand All @@ -565,7 +542,7 @@ def quoted_string_callback(quoted_str: str) -> str:
parent=parent,
is_optional=False,
),
**callback_args,
throw_on_resolution_failure=True,
)
return str(quoted_val)

Expand All @@ -574,7 +551,15 @@ def quoted_string_callback(quoted_str: str) -> str:
resolver_interpolation_callback=resolver_interpolation_callback,
quoted_string_callback=quoted_string_callback,
)
return visitor.visit(parse_tree)
try:
return visitor.visit(parse_tree)
except InterpolationResolutionError:
raise
except Exception as exc:
# Other kinds of exceptions are wrapped in an `InterpolationResolutionError`.
raise InterpolationResolutionError(
f"{type(exc).__name__} raised while resolving interpolation: {exc}"
).with_traceback(sys.exc_info()[2])

def _re_parent(self) -> None:
from .dictconfig import DictConfig
Expand Down
Loading

0 comments on commit cb5e556

Please sign in to comment.