Skip to content
3 changes: 3 additions & 0 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ def update_refs_pattern(cls, value: str, refs: dict) -> str:
"""
# regular expression pattern to match "@XXX" or "@XXX#YYY"
result = cls.id_matcher.findall(value)
# reversely sort the matched references by length
# and handle the longer first in case a reference item is substring of another longer item
result.sort(key=len, reverse=True)
value_is_expr = ConfigExpression.is_expression(value)
for item in result:
# only update reference when string starts with "$" or the whole content is "@XXX"
Expand Down
7 changes: 7 additions & 0 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def __call__(self, a, b):

TEST_CASE_4 = [{"A": 1, "B": "@A", "C": "@D", "E": "$'test' + '@F'"}]

TEST_CASE_5 = [{"training": {"A": 1, "A_B": 2}, "total": "$@training#A + @training#A_B + 1"}, 4]


class TestConfigParser(unittest.TestCase):
def test_config_content(self):
Expand Down Expand Up @@ -296,6 +298,11 @@ def test_builtin(self):
config = {"import statements": "$import math", "calc": {"_target_": "math.isclose", "a": 0.001, "b": 0.001}}
self.assertEqual(ConfigParser(config).calc, True)

@parameterized.expand([TEST_CASE_5])
def test_substring_reference(self, config, expected):
parser = ConfigParser(config=config)
self.assertEqual(parser.get_parsed_content("total"), expected)


if __name__ == "__main__":
unittest.main()