Skip to content

Commit

Permalink
Resolver output is now validated and converted for typed nodes
Browse files Browse the repository at this point in the history
This is related to omry#488
  • Loading branch information
odelalleau committed Feb 10, 2021
1 parent 6d4cd86 commit 0535e8e
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 2 deletions.
File renamed without changes.
1 change: 1 addition & 0 deletions news/445.api_change.2
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
If the value of a typed node is computed by a resolver (including `env`), it is now validated (and possibly converted) based on the type
9 changes: 8 additions & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,13 +463,20 @@ def _evaluate_custom_resolver(
) -> Optional["Node"]:
from omegaconf import OmegaConf

from .nodes import ValueNode
from .nodes import AnyNode, ValueNode

resolver = OmegaConf.get_resolver(inter_type)
if resolver is not None:
root_node = self._get_root()
try:
value = resolver(root_node, inter_args, inter_args_str)

# Ensure the resolver output is compatible with the node's type.
if key is not None:
node = self._get_node(key)
if isinstance(node, ValueNode) and not isinstance(node, AnyNode):
value = node.validate_and_convert(value)

return ValueNode(
value=value,
parent=self,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from omegaconf import (
II,
SI,
Container,
DictConfig,
IntegerNode,
Expand All @@ -30,6 +31,8 @@
UnsupportedInterpolationType,
)

from . import User

# file deepcode ignore CopyPasteError: there are several tests of the form `c.k == c.k`
# (this is intended to trigger multiple accesses to the same config key)

Expand Down Expand Up @@ -1082,3 +1085,24 @@ def _check_is_same_type(value: Any, expected: Any) -> None:
assert expected is None
else:
raise NotImplementedError(type(value))


def test_custom_resolver_return_validated(restore_resolvers: Any) -> Any:
def cast(t: Any, v: Any) -> Any:
if t == "str":
return str(v)
if t == "int":
return int(v)
assert False

OmegaConf.new_register_resolver("cast", cast)
cfg = OmegaConf.structured(User(name="Bond", age=SI("${cast:int,'7'}")))
assert cfg.age == 7

# converted to int per the dataclass age field
cfg = OmegaConf.structured(User(name="Bond", age=SI("${cast:str,'7'}")))
assert cfg.age == 7

cfg = OmegaConf.structured(User(name="Bond", age=SI("${cast:str,seven}")))
with pytest.raises(ValidationError):
cfg.age
2 changes: 1 addition & 1 deletion tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_none_construction(self, node_type: Any, values: Any) -> None:
def test_interpolation(
self, node_type: Any, values: Any, restore_resolvers: Any
) -> None:
resolver_output = 9999
resolver_output = "9999"
OmegaConf.new_register_resolver("func", lambda: resolver_output)
values = copy.deepcopy(values)
for value in values:
Expand Down

0 comments on commit 0535e8e

Please sign in to comment.