Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions news/negative_interpolation.news
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Allows negative list access.

```python
cfg = OmegaConf.create("list: [1, 2, 3]")
print(cfg.list[-1]) # 3
```
6 changes: 6 additions & 0 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ def __getitem__(self, index: Union[int, slice]) -> Any:
result.reverse()
return result
else:
# Handle negative indices by converting to positive indices
if index < 0:
length = len(self)
if -index > length:
raise IndexError("list index out of range")
index = length + index
return self._resolve_with_default(
key=index, value=self.__dict__["_content"][index]
)
Expand Down
54 changes: 24 additions & 30 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .base import Box, Container, ListMergeMode, Node, SCMode, UnionNode
from .basecontainer import BaseContainer
from .errors import (
InterpolationKeyError,
MissingMandatoryValue,
OmegaConfBaseException,
UnsupportedInterpolationType,
Expand Down Expand Up @@ -170,11 +171,7 @@ def create( # noqa F811
parent: Optional[BaseContainer] = None,
flags: Optional[Dict[str, bool]] = None,
) -> Union[DictConfig, ListConfig]:
return OmegaConf._create_impl(
obj=obj,
parent=parent,
flags=flags,
)
return OmegaConf._create_impl(obj=obj, parent=parent, flags=flags)

@staticmethod
def load(file_: Union[str, pathlib.Path, IO[Any]]) -> Union[DictConfig, ListConfig]:
Expand Down Expand Up @@ -271,10 +268,7 @@ def merge(
assert isinstance(target, (DictConfig, ListConfig))

with flag_override(target, "readonly", False):
target.merge_with(
*configs[1:],
list_merge_mode=list_merge_mode,
)
target.merge_with(*configs[1:], list_merge_mode=list_merge_mode)
turned_readonly = target._get_flag("readonly") is True

if turned_readonly:
Expand Down Expand Up @@ -315,10 +309,7 @@ def unsafe_merge(
with flag_override(
target, ["readonly", "no_deepcopy_set_nodes"], [False, True]
):
target.merge_with(
*configs[1:],
list_merge_mode=list_merge_mode,
)
target.merge_with(*configs[1:], list_merge_mode=list_merge_mode)
turned_readonly = target._get_flag("readonly") is True

if turned_readonly:
Expand Down Expand Up @@ -384,11 +375,7 @@ def resolver_wrapper(

@staticmethod
def register_new_resolver(
name: str,
resolver: Resolver,
*,
replace: bool = False,
use_cache: bool = False,
name: str, resolver: Resolver, *, replace: bool = False, use_cache: bool = False
) -> None:
"""
Register a resolver.
Expand Down Expand Up @@ -946,10 +933,7 @@ def _get_obj_type(c: Any) -> Optional[Type[Any]]:
def _get_resolver(
name: str,
) -> Optional[
Callable[
[Container, Container, Node, Tuple[Any, ...], Tuple[str, ...]],
Any,
]
Callable[[Container, Container, Node, Tuple[Any, ...], Tuple[str, ...]], Any]
]:
# noinspection PyProtectedMember
return (
Expand Down Expand Up @@ -1005,11 +989,7 @@ def open_dict(config: Container) -> Generator[Container, None, None]:


def _node_wrap(
parent: Optional[Box],
is_optional: bool,
value: Any,
key: Any,
ref_type: Any = Any,
parent: Optional[Box], is_optional: bool, value: Any, key: Any, ref_type: Any = Any
) -> Node:
node: Node
if is_dict_annotation(ref_type) or (is_primitive_dict(value) and ref_type is Any):
Expand Down Expand Up @@ -1149,10 +1129,24 @@ def _select_one(
val = None
else:
ret_key = int(ret_key)
if ret_key < 0 or ret_key + 1 > len(c):
val = None
# Handle negative indices in interpolation
if ret_key < 0:
list_len = len(c)
if -ret_key > list_len:
if throw_on_missing:
raise InterpolationKeyError(
f"list index {ret_key} is out of range"
)
else:
val = None
else:
ret_key = list_len + ret_key # Convert negative index to positive
val = c._get_child(ret_key)
else:
val = c._get_child(ret_key)
if ret_key + 1 > len(c):
val = None
else:
val = c._get_child(ret_key)
else:
assert False

Expand Down
6 changes: 5 additions & 1 deletion tests/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@
# Node interpolations.
("dict_access", "${dict.a}", 0),
("list_access", "${list.0}", -1),
("list_access", "${list.10}", 9),
("list_access", "${list.11}", InterpolationKeyError),
("list_access_good_negative_1", "${list.-1}", 9),
("list_access_good_negative_11", "${list.-11}", -1),
("list_access_bad_negative_12", "${list.-12}", InterpolationKeyError),
("dict_access_getitem", "${dict[a]}", 0),
("list_access_getitem", "${list[0]}", -1),
("getitem_first_1", "${[dict].a}", 0),
Expand All @@ -250,7 +255,6 @@
("dict_access_deep_3", "${dict.b[c]}", 1),
("dict_access_deep_4", "${dict[b][c]}", 1),
("list_access_underscore", "${list.1_0}", 9),
("list_access_bad_negative", "${list.-1}", InterpolationKeyError),
("dict_access_list_like_1", "${0}", 0),
("dict_access_list_like_2", "${1.2}", 12),
("bool_like_keys", "${FalsE.TruE}", True),
Expand Down