diff --git a/news/470.bugfix b/news/470.bugfix new file mode 100644 index 000000000..0bda78b3d --- /dev/null +++ b/news/470.bugfix @@ -0,0 +1 @@ +Fix support for merge tags in YAML files diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 1dcf7d9f6..1b9f461b6 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -81,27 +81,21 @@ def yaml_is_bool(b: str) -> bool: def get_yaml_loader() -> Any: - # Custom constructor that checks for duplicate keys - # (from https://gist.github.com/pypt/94d747fe5180851196eb) - def no_duplicates_constructor( - loader: yaml.Loader, node: yaml.Node, deep: bool = False - ) -> Any: - mapping: Dict[str, Any] = {} - for key_node, value_node in node.value: - key = loader.construct_object(key_node, deep=deep) - value = loader.construct_object(value_node, deep=deep) - if key in mapping: - raise yaml.constructor.ConstructorError( - "while constructing a mapping", - node.start_mark, - f"found duplicate key {key}", - key_node.start_mark, - ) - mapping[key] = value - return loader.construct_mapping(node, deep) - class OmegaConfLoader(yaml.SafeLoader): # type: ignore - pass + def construct_mapping(self, node: yaml.Node, deep: bool = False) -> Any: + keys = set() + for key_node, value_node in node.value: + if key_node.tag != yaml.resolver.BaseResolver.DEFAULT_SCALAR_TAG: + continue + if key_node.value in keys: + raise yaml.constructor.ConstructorError( + "while constructing a mapping", + node.start_mark, + f"found duplicate key {key_node.value}", + key_node.start_mark, + ) + keys.add(key_node.value) + return super().construct_mapping(node, deep=deep) loader = OmegaConfLoader loader.add_implicit_resolver( @@ -126,9 +120,6 @@ class OmegaConfLoader(yaml.SafeLoader): # type: ignore ] for key, resolvers in loader.yaml_implicit_resolvers.items() } - loader.add_constructor( - yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, no_duplicates_constructor - ) return loader diff --git a/tests/test_create.py b/tests/test_create.py index 5faa45442..2a0b28432 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,6 +1,7 @@ """Testing for OmegaConf""" import re import sys +from textwrap import dedent from typing import Any, Dict, List, Optional import pytest @@ -292,3 +293,48 @@ def test_create_untyped_dict() -> None: cfg = DictConfig(ref_type=Dict, content={}) assert get_ref_type(cfg) == Optional[Dict] + + +@pytest.mark.parametrize( + "input_", + [ + dedent( + """\ + a: + b: 1 + c: 2 + b: 3 + """ + ), + dedent( + """\ + a: + b: 1 + a: + b: 2 + """ + ), + ], +) +def test_yaml_duplicate_keys(input_: str) -> None: + with pytest.raises(yaml.constructor.ConstructorError): + OmegaConf.create(input_) + + +def test_yaml_merge() -> None: + cfg = OmegaConf.create( + dedent( + """\ + a: &A + x: 1 + b: &B + y: 2 + c: + <<: *A + <<: *B + x: 3 + z: 1 + """ + ) + ) + assert cfg == {"a": {"x": 1}, "b": {"y": 2}, "c": {"x": 3, "y": 2, "z": 1}} diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 74db61fd7..2299b124f 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -5,7 +5,6 @@ import pickle import tempfile from pathlib import Path -from textwrap import dedent from typing import Any, Dict, List, Optional, Type import pytest @@ -142,46 +141,6 @@ def test_pickle(obj: Any) -> None: assert c1._metadata.key_type is Any -def test_load_duplicate_keys_top() -> None: - from yaml.constructor import ConstructorError - - try: - with tempfile.NamedTemporaryFile(delete=False) as fp: - content = dedent( - """\ - a: - b: 1 - a: - b: 2 - """ - ) - fp.write(content.encode("utf-8")) - with pytest.raises(ConstructorError): - OmegaConf.load(fp.name) - finally: - os.unlink(fp.name) - - -def test_load_duplicate_keys_sub() -> None: - from yaml.constructor import ConstructorError - - try: - with tempfile.NamedTemporaryFile(delete=False) as fp: - content = dedent( - """\ - a: - b: 1 - c: 2 - b: 3 - """ - ) - fp.write(content.encode("utf-8")) - with pytest.raises(ConstructorError): - OmegaConf.load(fp.name) - finally: - os.unlink(fp.name) - - def test_load_empty_file(tmpdir: str) -> None: empty = Path(tmpdir) / "test.yaml" empty.touch()