Skip to content

Commit

Permalink
Type validation of interpolations (#578)
Browse files Browse the repository at this point in the history
Validate and convert interpolation results to their intended type

This PR fixes several issues:

* For nested resolvers (ex: `${f:${g:x}}`), intermediate resolver
  outputs (of `g` in this example) were wrapped in a ValueNode just to
  be unwrapped immediately, which was wasteful. This commit pushes the
  node wrapping to the very last step of the interpolation resolution.

* There was no type checking to make sure that the result of an
  interpolation had a type consistent with the node's type (when
  specified). Now a check is made and the interpolation result may be
  converted into the desired type (see #488).

* If a resolver interpolation returns a dict / list, it is now wrapped
  into a DictConfig / ListConfig, instead of a ValueNode. This makes it
  possible to generate configs from resolvers (see #540). These configs
  are read-only since they are being re-generated on each access.


Fixes #488
Fixes #540
  • Loading branch information
odelalleau authored Mar 13, 2021
1 parent 8740730 commit 3bf2c59
Show file tree
Hide file tree
Showing 14 changed files with 446 additions and 56 deletions.
30 changes: 25 additions & 5 deletions docs/source/structured_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ Optional fields
Interpolations
^^^^^^^^^^^^^^

:ref:`interpolation` works normally with Structured configs but static type checkers may object to you assigning a string to an other types.
:ref:`interpolation` works normally with Structured configs but static type checkers may object to you assigning a string to another type.
To work around it, use SI and II described below.

.. doctest::
Expand All @@ -333,18 +333,38 @@ To work around it, use SI and II described below.
>>> assert conf.c == 100


Type validation is performed on assignment, but not on values returned by interpolation, e.g:
Interpolated values are validated, and converted when possible, to the annotated type when the interpolation is accessed, e.g:

.. doctest::

>>> from omegaconf import SI
>>> from omegaconf import II
>>> @dataclass
... class Interpolation:
... int_key: int = II("str_key")
... str_key: str = "string"
... int_key: int = II("str_key")

>>> cfg = OmegaConf.structured(Interpolation)
>>> assert cfg.int_key == "string"
>>> cfg.int_key # fails due to type mismatch
Traceback (most recent call last):
...
omegaconf.errors.InterpolationValidationError: Value 'string' could not be converted to Integer
full_key: int_key
object_type=Interpolation
>>> cfg.str_key = "1234" # string value
>>> assert cfg.int_key == 1234 # automatically convert str to int

Note however that this validation step is currently skipped for container node interpolations:

.. doctest::

>>> @dataclass
... class NotValidated:
... some_int: int = 0
... some_dict: Dict[str, str] = II("some_int")

>>> cfg = OmegaConf.structured(NotValidated)
>>> assert cfg.some_dict == 0 # type mismatch, but no error


Frozen
^^^^^^
Expand Down
13 changes: 13 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,19 @@ simply use quotes to bypass character limitations in strings.
'Hello, World'


Custom resolvers can return lists or dictionaries, that are automatically converted into DictConfig and ListConfig:

.. doctest::

>>> OmegaConf.register_new_resolver(
... "min_max", lambda *a: {"min": min(a), "max": max(a)}
... )
>>> c = OmegaConf.create({'stats': '${min_max: -1, 3, 2, 5, -10}'})
>>> assert isinstance(c.stats, DictConfig)
>>> c.stats.min, c.stats.max
(-10, 5)


You can take advantage of nested interpolations to perform custom operations over variables:

.. doctest::
Expand Down
1 change: 1 addition & 0 deletions news/488.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
When resolving an interpolation of a typed config value, the interpolated value is validated and possibly converted based on the node's type.
1 change: 1 addition & 0 deletions news/540.api_change
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
A custom resolver interpolation whose output is a list or dictionary is now automatically converted into a ListConfig or DictConfig.
1 change: 1 addition & 0 deletions news/540.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Custom resolvers can now generate transient config nodes dynamically.
141 changes: 119 additions & 22 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
InterpolationKeyError,
InterpolationResolutionError,
InterpolationToMissingValueError,
InterpolationValidationError,
KeyValidationError,
MissingMandatoryValue,
OmegaConfBaseException,
UnsupportedInterpolationType,
ValidationError,
)
from .grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
from .grammar_parser import parse
Expand Down Expand Up @@ -358,8 +361,6 @@ def _select_impl(
) -> Tuple[Optional["Container"], Optional[str], Optional[Node]]:
"""
Select a value using dot separated key sequence
:param key:
:return:
"""
from .omegaconf import _select_one

Expand Down Expand Up @@ -421,8 +422,34 @@ def _resolve_interpolation_from_parse_tree(
parse_tree: OmegaConfGrammarParser.ConfigValueContext,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
from .nodes import StringNode

"""
Resolve an interpolation.
This happens in two steps:
1. The parse tree is visited, which outputs either a `Node` (e.g.,
for node interpolations "${foo}"), a string (e.g., for string
interpolations "hello ${name}", or any other arbitrary value
(e.g., or custom interpolations "${foo:bar}").
2. This output is potentially validated and converted when the node
being resolved (`value`) is typed.
If an error occurs in one of the above steps, an `InterpolationResolutionError`
(or a subclass of it) is raised, *unless* `throw_on_resolution_failure` is set
to `False` (in which case the return value is `None`).
:param parent: Parent of the node being resolved.
:param value: Node being resolved.
:param key: The associated key in the parent.
:param parse_tree: The parse tree as obtained from `grammar_parser.parse()`.
:param throw_on_resolution_failure: If `False`, then exceptions raised during
the resolution of the interpolation are silenced, and instead `None` is
returned.
:return: A `Node` that contains the interpolation result. This may be an existing
node in the config (in the case of a node interpolation "${foo}"), or a new
node that is created to wrap the interpolated value. It is `None` if and only if
`throw_on_resolution_failure` is `False` and an error occurs during resolution.
"""
try:
resolved = self.resolve_parse_tree(
parse_tree=parse_tree,
Expand All @@ -434,19 +461,98 @@ def _resolve_interpolation_from_parse_tree(
raise
return None

assert resolved is not None
if isinstance(resolved, str):
# Result is a string: create a new StringNode for it.
return StringNode(
value=resolved,
key=key,
return self._validate_and_convert_interpolation_result(
parent=parent,
value=value,
key=key,
resolved=resolved,
throw_on_resolution_failure=throw_on_resolution_failure,
)

def _validate_and_convert_interpolation_result(
self,
parent: Optional["Container"],
value: "Node",
key: Any,
resolved: Any,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
from .nodes import AnyNode, ValueNode

# If the output is not a Node already (e.g., because it is the output of a
# custom resolver), then we will need to wrap it within a Node.
must_wrap = not isinstance(resolved, Node)

# If the node is typed, validate (and possibly convert) the result.
if isinstance(value, ValueNode) and not isinstance(value, AnyNode):
res_value = _get_value(resolved)
try:
conv_value = value.validate_and_convert(res_value)
except ValidationError as e:
if throw_on_resolution_failure:
self._format_and_raise(
key=key,
value=res_value,
cause=e,
type_override=InterpolationValidationError,
)
return None

# If the converted value is of the same type, it means that no conversion
# was actually needed. As a result, we can keep the original `resolved`
# (and otherwise, the converted value must be wrapped into a new node).
if type(conv_value) != type(res_value):
must_wrap = True
resolved = conv_value

if must_wrap:
return self._wrap_interpolation_result(
parent=parent,
is_optional=value._metadata.optional,
value=value,
key=key,
resolved=resolved,
throw_on_resolution_failure=throw_on_resolution_failure,
)
else:
assert isinstance(resolved, Node)
return resolved

def _wrap_interpolation_result(
self,
parent: Optional["Container"],
value: "Node",
key: Any,
resolved: Any,
throw_on_resolution_failure: bool,
) -> Optional["Node"]:
from .basecontainer import BaseContainer
from .omegaconf import _node_wrap

assert parent is None or isinstance(parent, BaseContainer)
try:
wrapped = _node_wrap(
type_=value._metadata.ref_type,
parent=parent,
is_optional=value._metadata.optional,
value=resolved,
key=key,
ref_type=value._metadata.ref_type,
)
except (KeyValidationError, ValidationError) as e:
if throw_on_resolution_failure:
self._format_and_raise(
key=key,
value=resolved,
cause=e,
type_override=InterpolationValidationError,
)
return None
# Since we created a new node on the fly, future changes to this node are
# likely to be lost. We thus set the "readonly" flag to `True` to reduce
# the risk of accidental modifications.
wrapped._set_flag("readonly", True)
return wrapped

def _resolve_node_interpolation(
self,
inter_key: str,
Expand Down Expand Up @@ -488,19 +594,10 @@ def _evaluate_custom_resolver(
) -> Any:
from omegaconf import OmegaConf

from .nodes import ValueNode

resolver = OmegaConf.get_resolver(inter_type)
if resolver is not None:
root_node = self._get_root()
value = resolver(root_node, inter_args, inter_args_str)
return ValueNode(
value=value,
parent=self,
metadata=Metadata(
ref_type=Any, object_type=Any, key=key, optional=True
),
)
return resolver(root_node, inter_args, inter_args_str)
else:
raise UnsupportedInterpolationType(
f"Unsupported interpolation type {inter_type}"
Expand Down Expand Up @@ -561,7 +658,7 @@ def quoted_string_callback(quoted_str: str) -> str:
value=quoted_str,
key=key,
parent=parent,
is_optional=False,
is_optional=True,
),
throw_on_resolution_failure=True,
)
Expand Down
4 changes: 1 addition & 3 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,7 @@ def _s_validate_and_normalize_key(self, key_type: Any, key: Any) -> DictKeyType:
return key # type: ignore
elif issubclass(key_type, Enum):
try:
ret = EnumNode.validate_and_convert_to_enum(key_type, key)
assert ret is not None
return ret
return EnumNode.validate_and_convert_to_enum(key_type, key)
except ValidationError:
valid = ", ".join([x for x in key_type.__members__.keys()])
raise KeyValidationError(
Expand Down
6 changes: 6 additions & 0 deletions omegaconf/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ class InterpolationToMissingValueError(InterpolationResolutionError):
"""


class InterpolationValidationError(InterpolationResolutionError, ValidationError):
"""
Thrown when the result of an interpolation fails the validation step.
"""


class ConfigKeyError(OmegaConfBaseException, KeyError):
"""
Thrown from DictConfig when a regular dict access would have caused a KeyError.
Expand Down
24 changes: 7 additions & 17 deletions omegaconf/grammar_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GrammarVisitor(OmegaConfGrammarParserVisitor):
def __init__(
self,
node_interpolation_callback: Callable[[str], Optional["Node"]],
resolver_interpolation_callback: Callable[..., Optional["Node"]],
resolver_interpolation_callback: Callable[..., Any],
quoted_string_callback: Callable[[str], str],
**kw: Dict[Any, Any],
):
Expand Down Expand Up @@ -96,22 +96,16 @@ def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str:
)
return child.symbol.text

def visitConfigValue(
self, ctx: OmegaConfGrammarParser.ConfigValueContext
) -> Union[str, Optional["Node"]]:
def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any:
# (toplevelStr | (toplevelStr? (interpolation toplevelStr?)+)) EOF
# Visit all children (except last one which is EOF)
vals = [self.visit(c) for c in list(ctx.getChildren())[:-1]]
assert vals
if len(vals) == 1 and isinstance(
ctx.getChild(0), OmegaConfGrammarParser.InterpolationContext
):
from .base import Node # noqa F811

# Single interpolation: return the resulting node "as is".
ret = vals[0]
assert ret is None or isinstance(ret, Node), ret
return ret
# Single interpolation: return the result "as is".
return vals[0]
# Concatenation of multiple components.
return "".join(map(str, vals))

Expand All @@ -135,13 +129,9 @@ def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any:

def visitInterpolation(
self, ctx: OmegaConfGrammarParser.InterpolationContext
) -> Optional["Node"]:
from .base import Node # noqa F811

) -> Any:
assert ctx.getChildCount() == 1 # interpolationNode | interpolationResolver
ret = self.visit(ctx.getChild(0))
assert ret is None or isinstance(ret, Node)
return ret
return self.visit(ctx.getChild(0))

def visitInterpolationNode(
self, ctx: OmegaConfGrammarParser.InterpolationNodeContext
Expand All @@ -168,7 +158,7 @@ def visitInterpolationNode(

def visitInterpolationResolver(
self, ctx: OmegaConfGrammarParser.InterpolationResolverContext
) -> Optional["Node"]:
) -> Any:

# INTER_OPEN resolverName COLON sequence? BRACE_CLOSE
assert 4 <= ctx.getChildCount() <= 5
Expand Down
4 changes: 3 additions & 1 deletion omegaconf/nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import math
import sys
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Type, Union

Expand Down Expand Up @@ -55,8 +56,9 @@ def validate_and_convert(self, value: Any) -> Any:
# Subclasses can assume that `value` is not None in `_validate_and_convert_impl()`.
return self._validate_and_convert_impl(value)

@abstractmethod
def _validate_and_convert_impl(self, value: Any) -> Any:
return value
...

def __str__(self) -> str:
return str(self._val)
Expand Down
Loading

0 comments on commit 3bf2c59

Please sign in to comment.