Skip to content

Experimental support for sub-classing ArgumentParser to customize add_argument #661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ The semantic versioning only considers the public API as described in
paths are considered internals and can change in minor and patch releases.


v4.37.0 (2025-01-??)
--------------------

Added
^^^^^
- Experimental support for sub-classing ``ArgumentParser`` to customize
``add_argument`` (`#661
<https://github.com/omni-us/jsonargparse/pull/661>`__).


v4.36.0 (2025-01-17)
--------------------

Expand Down
27 changes: 27 additions & 0 deletions DOCUMENTATION.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,33 @@ that don't have attribute docstrings. To enable, do as follows:
set_docstring_parse_options(style=DocstringStyle.GOOGLE)
set_docstring_parse_options(attribute_docstrings=False)

Customization of arguments
--------------------------

Since the arguments are added automatically based on the function signatures,
the developer has limited control over their behavior. To customize some of the
arguments, you can create a subclass and override the
:py:meth:`.ArgumentParser.add_argument` method. For example, by default,
``bool`` arguments require a ``true|false`` value from the command line. To
change this behavior and use :class:`.ActionYesNo` instead, through a CLI based
on :func:`.auto_cli`, you can:

.. testcode::

from jsonargparse import ArgumentParser, auto_cli

class CustomArgumentParser(ArgumentParser):
def add_argument(self, *args, **kwargs):
if "type" in kwargs and kwargs["type"] == bool:
kwargs.pop("type")
kwargs["action"] = ActionYesNo
return super().add_argument(*args, **kwargs)

def main_function(flag: bool = False):
...

if __name__ == "__main__":
auto_cli(main_function, parser_class=CustomArgumentParser)

Classes from functions
----------------------
Expand Down
19 changes: 12 additions & 7 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
Path,
argument_error,
change_to_path_dir,
get_argument_group_class,
get_private_kwargs,
identity,
return_parser_if_captured,
Expand All @@ -102,7 +103,7 @@
class ActionsContainer(SignatureArguments, argparse._ActionsContainer):
"""Extension of argparse._ActionsContainer to support additional functionalities."""

_action_groups: Sequence["_ArgumentGroup"] # type: ignore[assignment]
_action_groups: Sequence["ArgumentGroup"] # type: ignore[assignment]

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -154,7 +155,7 @@ def add_argument(self, *args, enable_path: bool = False, **kwargs):
action.required = False
return action

def add_argument_group(self, *args, name: Optional[str] = None, **kwargs) -> "_ArgumentGroup":
def add_argument_group(self, *args, name: Optional[str] = None, **kwargs) -> "ArgumentGroup":
"""Adds a group to the parser.

All the arguments from `argparse.ArgumentParser.add_argument_group
Expand All @@ -173,7 +174,8 @@ def add_argument_group(self, *args, name: Optional[str] = None, **kwargs) -> "_A
parser = self.parser if hasattr(self, "parser") else self
if name is not None and name in parser.groups: # type: ignore[union-attr]
raise ValueError(f"Group with name {name} already exists.")
group = _ArgumentGroup(parser, *args, logger=parser._logger, **kwargs)
group_class = getattr(parser, "_group_class", ArgumentGroup)
group = group_class(parser, *args, logger=parser._logger, **kwargs)
group.parser = parser
parser._action_groups.append(group) # type: ignore[union-attr]
if name is not None:
Expand Down Expand Up @@ -208,7 +210,7 @@ def set_defaults(self, *args: Dict[str, Any], **kwargs: Any) -> None:
self.set_defaults(kwargs)


class _ArgumentGroup(ActionsContainer, argparse._ArgumentGroup):
class ArgumentGroup(ActionsContainer, argparse._ArgumentGroup):
"""Extension of argparse._ArgumentGroup to support additional functionalities."""

dest: Optional[str] = None
Expand All @@ -219,7 +221,8 @@ class ArgumentParser(ParserDeprecations, ActionsContainer, ArgumentLinking, argp
"""Parser for command line, configuration files and environment variables."""

formatter_class: Type[DefaultHelpFormatter]
groups: Optional[Dict[str, "_ArgumentGroup"]] = None
groups: Optional[Dict[str, ArgumentGroup]] = None
_group_class: Type[ArgumentGroup]
_subcommands_action: Optional[_ActionSubCommands] = None
_instantiators: Optional[InstantiatorsDictType] = None

Expand Down Expand Up @@ -258,6 +261,7 @@ def __init__(
default_meta: Set the default value on whether to include metadata in config objects.
"""
super().__init__(*args, formatter_class=formatter_class, logger=logger, **kwargs)
self._group_class = get_argument_group_class(self)
if self.groups is None:
self.groups = {}
self.exit_on_error = exit_on_error
Expand Down Expand Up @@ -1183,7 +1187,7 @@ def instantiate_classes(
Returns:
A configuration object with all subclasses and class groups instantiated.
"""
components: List[Union[ActionTypeHint, _ActionConfigLoad, _ArgumentGroup]] = []
components: List[Union[ActionTypeHint, _ActionConfigLoad, ArgumentGroup]] = []
for action in filter_default_actions(self._actions):
if isinstance(action, ActionTypeHint) or (
isinstance(action, _ActionConfigLoad) and is_dataclass_like(action.basetype)
Expand Down Expand Up @@ -1439,7 +1443,8 @@ def default_config_files(self, default_config_files: Optional[Sequence[Union[str
if len(self._default_config_files) > 0:
if not hasattr(self, "_default_config_files_group"):
group_title = "default config file locations"
group = _ArgumentGroup(self, title=group_title)
group_class = getattr(self, "_group_class", ArgumentGroup)
group = group_class(self, title=group_title)
self._action_groups = [group] + self._action_groups # type: ignore[operator]
self._default_config_files_group = group
elif hasattr(self, "_default_config_files_group"):
Expand Down
4 changes: 2 additions & 2 deletions jsonargparse/_link_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from ._namespace import Namespace, split_key_leaf
from ._parameter_resolvers import get_signature_parameters
from ._type_checking import ArgumentParser, _ArgumentGroup
from ._type_checking import ArgumentGroup, ArgumentParser

__all__ = ["ArgumentLinking"]

Expand All @@ -47,7 +47,7 @@ def find_subclass_action_or_class_group(
parser: "ArgumentParser",
key: str,
exclude: Optional[Union[Type[ArgparseAction], Tuple[Type[ArgparseAction], ...]]] = None,
) -> Optional[Union[ArgparseAction, "_ArgumentGroup"]]:
) -> Optional[Union[ArgparseAction, "ArgumentGroup"]]:
from ._typehints import ActionTypeHint

action = _find_parent_action(parser, key, exclude=exclude)
Expand Down
4 changes: 2 additions & 2 deletions jsonargparse/_type_checking.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING

__all__ = [
"_ArgumentGroup",
"ArgumentGroup",
"ActionsContainer",
"ArgumentParser",
"ruyamlCommentedMap",
Expand All @@ -10,6 +10,6 @@
if TYPE_CHECKING: # pragma: no cover
from ruyaml.comments import CommentedMap as ruyamlCommentedMap

from ._core import ActionsContainer, ArgumentParser, _ArgumentGroup
from ._core import ActionsContainer, ArgumentGroup, ArgumentParser
else:
globals().update({k: None for k in __all__})
25 changes: 25 additions & 0 deletions jsonargparse/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,31 @@ def resolve_relative_path(path: str) -> str:
return "/".join(resolved)


def get_argument_group_class(parser):
import ast

from ._core import ActionsContainer, ArgumentGroup

if parser.__class__.add_argument != ActionsContainer.add_argument:
try:
add_argument = parser.__class__.add_argument
source = inspect.getsource(add_argument)
source = "class _ArgumentGroupAutoSubclass(ArgumentGroup):\n" + source
class_ast = ast.parse(source)
code = compile(class_ast, filename="<ast>", mode="exec")
namespace = {**add_argument.__globals__, "ArgumentGroup": ArgumentGroup}
exec(code, namespace)
group_class = namespace["_ArgumentGroupAutoSubclass"]
group_class.__module__ = parser.__class__.__module__
add_argument.__globals__[group_class.__name__] = group_class
return group_class
except Exception as ex:
parser.logger.debug(
f"Failed to create ArgumentGroup subclass based on {parser.__class__.__name__}: {ex}", exc_info=ex
)
return ArgumentGroup


class PathError(TypeError):
"""Exception raised for errors in the Path class."""

Expand Down
31 changes: 31 additions & 0 deletions jsonargparse_tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
skip_if_not_posix,
skip_if_responses_unavailable,
skip_if_running_as_root,
source_unavailable,
)


Expand Down Expand Up @@ -1068,3 +1069,33 @@ def test_default_meta_property():
def test_pickle_parser(example_parser):
parser = pickle.loads(pickle.dumps(example_parser))
assert example_parser.get_defaults() == parser.get_defaults()


def replace_underscores(args):
return [arg.replace("_", "-") for arg in args]


class UnderscoresToDashesParser(ArgumentParser):
def add_argument(self, *args, **kwargs):
args = replace_underscores(args)
return super().add_argument(*args, **kwargs)


def test_get_argument_group_class_failure(logger):
with source_unavailable(), capture_logs(logger) as logs:
UnderscoresToDashesParser(logger=logger)
assert "Failed to create ArgumentGroup subclass based on" in logs.getvalue()


def test_get_argument_group_class_underscores_to_dashes():
parser = UnderscoresToDashesParser()
parser.add_argument("--int_value", type=int, default=1)
group = parser.add_argument_group("group")
group.add_argument("--group_float_value", type=float, default=2.0)
help_str = get_parser_help(parser)
assert "--int-value" in help_str
assert "--group-float-value" in help_str
cfg = parser.parse_args(["--int-value=2", "--group-float-value=3.0"])
assert cfg == Namespace(int_value=2, group_float_value=3.0)
clone_parser = pickle.loads(pickle.dumps(parser))
assert clone_parser.get_defaults() == parser.get_defaults()
Loading