Skip to content

Commit eee8ea9

Browse files
authored
Experimental support for sub-classing ArgumentParser to customize add_argument (#661)
1 parent 2285c1a commit eee8ea9

File tree

7 files changed

+109
-11
lines changed

7 files changed

+109
-11
lines changed

CHANGELOG.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@ The semantic versioning only considers the public API as described in
1212
paths are considered internals and can change in minor and patch releases.
1313

1414

15+
v4.37.0 (2025-01-??)
16+
--------------------
17+
18+
Added
19+
^^^^^
20+
- Experimental support for sub-classing ``ArgumentParser`` to customize
21+
``add_argument`` (`#661
22+
<https://github.com/omni-us/jsonargparse/pull/661>`__).
23+
24+
1525
v4.36.0 (2025-01-17)
1626
--------------------
1727

DOCUMENTATION.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,6 +1562,33 @@ that don't have attribute docstrings. To enable, do as follows:
15621562
set_docstring_parse_options(style=DocstringStyle.GOOGLE)
15631563
set_docstring_parse_options(attribute_docstrings=False)
15641564

1565+
Customization of arguments
1566+
--------------------------
1567+
1568+
Since the arguments are added automatically based on the function signatures,
1569+
the developer has limited control over their behavior. To customize some of the
1570+
arguments, you can create a subclass and override the
1571+
:py:meth:`.ArgumentParser.add_argument` method. For example, by default,
1572+
``bool`` arguments require a ``true|false`` value from the command line. To
1573+
change this behavior and use :class:`.ActionYesNo` instead, through a CLI based
1574+
on :func:`.auto_cli`, you can:
1575+
1576+
.. testcode::
1577+
1578+
from jsonargparse import ArgumentParser, auto_cli
1579+
1580+
class CustomArgumentParser(ArgumentParser):
1581+
def add_argument(self, *args, **kwargs):
1582+
if "type" in kwargs and kwargs["type"] == bool:
1583+
kwargs.pop("type")
1584+
kwargs["action"] = ActionYesNo
1585+
return super().add_argument(*args, **kwargs)
1586+
1587+
def main_function(flag: bool = False):
1588+
...
1589+
1590+
if __name__ == "__main__":
1591+
auto_cli(main_function, parser_class=CustomArgumentParser)
15651592

15661593
Classes from functions
15671594
----------------------

jsonargparse/_core.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
Path,
8989
argument_error,
9090
change_to_path_dir,
91+
get_argument_group_class,
9192
get_private_kwargs,
9293
identity,
9394
return_parser_if_captured,
@@ -102,7 +103,7 @@
102103
class ActionsContainer(SignatureArguments, argparse._ActionsContainer):
103104
"""Extension of argparse._ActionsContainer to support additional functionalities."""
104105

105-
_action_groups: Sequence["_ArgumentGroup"] # type: ignore[assignment]
106+
_action_groups: Sequence["ArgumentGroup"] # type: ignore[assignment]
106107

107108
def __init__(self, *args, **kwargs) -> None:
108109
super().__init__(*args, **kwargs)
@@ -154,7 +155,7 @@ def add_argument(self, *args, enable_path: bool = False, **kwargs):
154155
action.required = False
155156
return action
156157

157-
def add_argument_group(self, *args, name: Optional[str] = None, **kwargs) -> "_ArgumentGroup":
158+
def add_argument_group(self, *args, name: Optional[str] = None, **kwargs) -> "ArgumentGroup":
158159
"""Adds a group to the parser.
159160
160161
All the arguments from `argparse.ArgumentParser.add_argument_group
@@ -173,7 +174,8 @@ def add_argument_group(self, *args, name: Optional[str] = None, **kwargs) -> "_A
173174
parser = self.parser if hasattr(self, "parser") else self
174175
if name is not None and name in parser.groups: # type: ignore[union-attr]
175176
raise ValueError(f"Group with name {name} already exists.")
176-
group = _ArgumentGroup(parser, *args, logger=parser._logger, **kwargs)
177+
group_class = getattr(parser, "_group_class", ArgumentGroup)
178+
group = group_class(parser, *args, logger=parser._logger, **kwargs)
177179
group.parser = parser
178180
parser._action_groups.append(group) # type: ignore[union-attr]
179181
if name is not None:
@@ -208,7 +210,7 @@ def set_defaults(self, *args: Dict[str, Any], **kwargs: Any) -> None:
208210
self.set_defaults(kwargs)
209211

210212

211-
class _ArgumentGroup(ActionsContainer, argparse._ArgumentGroup):
213+
class ArgumentGroup(ActionsContainer, argparse._ArgumentGroup):
212214
"""Extension of argparse._ArgumentGroup to support additional functionalities."""
213215

214216
dest: Optional[str] = None
@@ -219,7 +221,8 @@ class ArgumentParser(ParserDeprecations, ActionsContainer, ArgumentLinking, argp
219221
"""Parser for command line, configuration files and environment variables."""
220222

221223
formatter_class: Type[DefaultHelpFormatter]
222-
groups: Optional[Dict[str, "_ArgumentGroup"]] = None
224+
groups: Optional[Dict[str, ArgumentGroup]] = None
225+
_group_class: Type[ArgumentGroup]
223226
_subcommands_action: Optional[_ActionSubCommands] = None
224227
_instantiators: Optional[InstantiatorsDictType] = None
225228

@@ -258,6 +261,7 @@ def __init__(
258261
default_meta: Set the default value on whether to include metadata in config objects.
259262
"""
260263
super().__init__(*args, formatter_class=formatter_class, logger=logger, **kwargs)
264+
self._group_class = get_argument_group_class(self)
261265
if self.groups is None:
262266
self.groups = {}
263267
self.exit_on_error = exit_on_error
@@ -1183,7 +1187,7 @@ def instantiate_classes(
11831187
Returns:
11841188
A configuration object with all subclasses and class groups instantiated.
11851189
"""
1186-
components: List[Union[ActionTypeHint, _ActionConfigLoad, _ArgumentGroup]] = []
1190+
components: List[Union[ActionTypeHint, _ActionConfigLoad, ArgumentGroup]] = []
11871191
for action in filter_default_actions(self._actions):
11881192
if isinstance(action, ActionTypeHint) or (
11891193
isinstance(action, _ActionConfigLoad) and is_dataclass_like(action.basetype)
@@ -1439,7 +1443,8 @@ def default_config_files(self, default_config_files: Optional[Sequence[Union[str
14391443
if len(self._default_config_files) > 0:
14401444
if not hasattr(self, "_default_config_files_group"):
14411445
group_title = "default config file locations"
1442-
group = _ArgumentGroup(self, title=group_title)
1446+
group_class = getattr(self, "_group_class", ArgumentGroup)
1447+
group = group_class(self, title=group_title)
14431448
self._action_groups = [group] + self._action_groups # type: ignore[operator]
14441449
self._default_config_files_group = group
14451450
elif hasattr(self, "_default_config_files_group"):

jsonargparse/_link_arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from ._namespace import Namespace, split_key_leaf
2222
from ._parameter_resolvers import get_signature_parameters
23-
from ._type_checking import ArgumentParser, _ArgumentGroup
23+
from ._type_checking import ArgumentGroup, ArgumentParser
2424

2525
__all__ = ["ArgumentLinking"]
2626

@@ -47,7 +47,7 @@ def find_subclass_action_or_class_group(
4747
parser: "ArgumentParser",
4848
key: str,
4949
exclude: Optional[Union[Type[ArgparseAction], Tuple[Type[ArgparseAction], ...]]] = None,
50-
) -> Optional[Union[ArgparseAction, "_ArgumentGroup"]]:
50+
) -> Optional[Union[ArgparseAction, "ArgumentGroup"]]:
5151
from ._typehints import ActionTypeHint
5252

5353
action = _find_parent_action(parser, key, exclude=exclude)

jsonargparse/_type_checking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TYPE_CHECKING
22

33
__all__ = [
4-
"_ArgumentGroup",
4+
"ArgumentGroup",
55
"ActionsContainer",
66
"ArgumentParser",
77
"ruyamlCommentedMap",
@@ -10,6 +10,6 @@
1010
if TYPE_CHECKING: # pragma: no cover
1111
from ruyaml.comments import CommentedMap as ruyamlCommentedMap
1212

13-
from ._core import ActionsContainer, ArgumentParser, _ArgumentGroup
13+
from ._core import ActionsContainer, ArgumentGroup, ArgumentParser
1414
else:
1515
globals().update({k: None for k in __all__})

jsonargparse/_util.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,31 @@ def resolve_relative_path(path: str) -> str:
458458
return "/".join(resolved)
459459

460460

461+
def get_argument_group_class(parser):
462+
import ast
463+
464+
from ._core import ActionsContainer, ArgumentGroup
465+
466+
if parser.__class__.add_argument != ActionsContainer.add_argument:
467+
try:
468+
add_argument = parser.__class__.add_argument
469+
source = inspect.getsource(add_argument)
470+
source = "class _ArgumentGroupAutoSubclass(ArgumentGroup):\n" + source
471+
class_ast = ast.parse(source)
472+
code = compile(class_ast, filename="<ast>", mode="exec")
473+
namespace = {**add_argument.__globals__, "ArgumentGroup": ArgumentGroup}
474+
exec(code, namespace)
475+
group_class = namespace["_ArgumentGroupAutoSubclass"]
476+
group_class.__module__ = parser.__class__.__module__
477+
add_argument.__globals__[group_class.__name__] = group_class
478+
return group_class
479+
except Exception as ex:
480+
parser.logger.debug(
481+
f"Failed to create ArgumentGroup subclass based on {parser.__class__.__name__}: {ex}", exc_info=ex
482+
)
483+
return ArgumentGroup
484+
485+
461486
class PathError(TypeError):
462487
"""Exception raised for errors in the Path class."""
463488

jsonargparse_tests/test_core.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
skip_if_not_posix,
4444
skip_if_responses_unavailable,
4545
skip_if_running_as_root,
46+
source_unavailable,
4647
)
4748

4849

@@ -1068,3 +1069,33 @@ def test_default_meta_property():
10681069
def test_pickle_parser(example_parser):
10691070
parser = pickle.loads(pickle.dumps(example_parser))
10701071
assert example_parser.get_defaults() == parser.get_defaults()
1072+
1073+
1074+
def replace_underscores(args):
1075+
return [arg.replace("_", "-") for arg in args]
1076+
1077+
1078+
class UnderscoresToDashesParser(ArgumentParser):
1079+
def add_argument(self, *args, **kwargs):
1080+
args = replace_underscores(args)
1081+
return super().add_argument(*args, **kwargs)
1082+
1083+
1084+
def test_get_argument_group_class_failure(logger):
1085+
with source_unavailable(), capture_logs(logger) as logs:
1086+
UnderscoresToDashesParser(logger=logger)
1087+
assert "Failed to create ArgumentGroup subclass based on" in logs.getvalue()
1088+
1089+
1090+
def test_get_argument_group_class_underscores_to_dashes():
1091+
parser = UnderscoresToDashesParser()
1092+
parser.add_argument("--int_value", type=int, default=1)
1093+
group = parser.add_argument_group("group")
1094+
group.add_argument("--group_float_value", type=float, default=2.0)
1095+
help_str = get_parser_help(parser)
1096+
assert "--int-value" in help_str
1097+
assert "--group-float-value" in help_str
1098+
cfg = parser.parse_args(["--int-value=2", "--group-float-value=3.0"])
1099+
assert cfg == Namespace(int_value=2, group_float_value=3.0)
1100+
clone_parser = pickle.loads(pickle.dumps(parser))
1101+
assert clone_parser.get_defaults() == parser.get_defaults()

0 commit comments

Comments
 (0)