Skip to content

Commit

Permalink
Support YAML merge tags (#507)
Browse files Browse the repository at this point in the history
* Support YAML merge tags

This adds support for YAML merge tags (<< *ref) while retaining the
sanity check for duplicate keys. Note that the spec for merge keys
(https://yaml.org/type/merge.html) explicitly states that keys in the
current mapping override the ones in the merged mapping. Hence, the
check for duplicates is applied to scalar keys of the current mapping
only.

Fixes #470.
  • Loading branch information
omry authored Feb 3, 2021
2 parents 6b3e0d4 + 905cddb commit 877e839
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 64 deletions.
1 change: 1 addition & 0 deletions news/470.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix support for merge tags in YAML files
37 changes: 14 additions & 23 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,21 @@ def yaml_is_bool(b: str) -> bool:


def get_yaml_loader() -> Any:
# Custom constructor that checks for duplicate keys
# (from https://gist.github.com/pypt/94d747fe5180851196eb)
def no_duplicates_constructor(
loader: yaml.Loader, node: yaml.Node, deep: bool = False
) -> Any:
mapping: Dict[str, Any] = {}
for key_node, value_node in node.value:
key = loader.construct_object(key_node, deep=deep)
value = loader.construct_object(value_node, deep=deep)
if key in mapping:
raise yaml.constructor.ConstructorError(
"while constructing a mapping",
node.start_mark,
f"found duplicate key {key}",
key_node.start_mark,
)
mapping[key] = value
return loader.construct_mapping(node, deep)

class OmegaConfLoader(yaml.SafeLoader): # type: ignore
pass
def construct_mapping(self, node: yaml.Node, deep: bool = False) -> Any:
keys = set()
for key_node, value_node in node.value:
if key_node.tag != yaml.resolver.BaseResolver.DEFAULT_SCALAR_TAG:
continue
if key_node.value in keys:
raise yaml.constructor.ConstructorError(
"while constructing a mapping",
node.start_mark,
f"found duplicate key {key_node.value}",
key_node.start_mark,
)
keys.add(key_node.value)
return super().construct_mapping(node, deep=deep)

loader = OmegaConfLoader
loader.add_implicit_resolver(
Expand All @@ -126,9 +120,6 @@ class OmegaConfLoader(yaml.SafeLoader): # type: ignore
]
for key, resolvers in loader.yaml_implicit_resolvers.items()
}
loader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, no_duplicates_constructor
)
return loader


Expand Down
46 changes: 46 additions & 0 deletions tests/test_create.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Testing for OmegaConf"""
import re
import sys
from textwrap import dedent
from typing import Any, Dict, List, Optional

import pytest
Expand Down Expand Up @@ -292,3 +293,48 @@ def test_create_untyped_dict() -> None:

cfg = DictConfig(ref_type=Dict, content={})
assert get_ref_type(cfg) == Optional[Dict]


@pytest.mark.parametrize(
"input_",
[
dedent(
"""\
a:
b: 1
c: 2
b: 3
"""
),
dedent(
"""\
a:
b: 1
a:
b: 2
"""
),
],
)
def test_yaml_duplicate_keys(input_: str) -> None:
with pytest.raises(yaml.constructor.ConstructorError):
OmegaConf.create(input_)


def test_yaml_merge() -> None:
cfg = OmegaConf.create(
dedent(
"""\
a: &A
x: 1
b: &B
y: 2
c:
<<: *A
<<: *B
x: 3
z: 1
"""
)
)
assert cfg == {"a": {"x": 1}, "b": {"y": 2}, "c": {"x": 3, "y": 2, "z": 1}}
41 changes: 0 additions & 41 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pickle
import tempfile
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, List, Optional, Type

import pytest
Expand Down Expand Up @@ -142,46 +141,6 @@ def test_pickle(obj: Any) -> None:
assert c1._metadata.key_type is Any


def test_load_duplicate_keys_top() -> None:
from yaml.constructor import ConstructorError

try:
with tempfile.NamedTemporaryFile(delete=False) as fp:
content = dedent(
"""\
a:
b: 1
a:
b: 2
"""
)
fp.write(content.encode("utf-8"))
with pytest.raises(ConstructorError):
OmegaConf.load(fp.name)
finally:
os.unlink(fp.name)


def test_load_duplicate_keys_sub() -> None:
from yaml.constructor import ConstructorError

try:
with tempfile.NamedTemporaryFile(delete=False) as fp:
content = dedent(
"""\
a:
b: 1
c: 2
b: 3
"""
)
fp.write(content.encode("utf-8"))
with pytest.raises(ConstructorError):
OmegaConf.load(fp.name)
finally:
os.unlink(fp.name)


def test_load_empty_file(tmpdir: str) -> None:
empty = Path(tmpdir) / "test.yaml"
empty.touch()
Expand Down

0 comments on commit 877e839

Please sign in to comment.