Skip to content

Allow parsing multiple nested sub commands #130

@janvainer

Description

@janvainer

What I am trying to do:

I am training an encoder-decoder neural network and I have multiple different encoders and decoders.
Each encoder and decoder has a separate dataclass config.

I want to be able to specify from the commandline which encoder to use by specifying its config name. Once I specify that, I also want to be able to adjust the selected config fields. So far I was able to accomplish this subparsers:

I have something like

@dataclass
class Model:
    encoder: Union[RNNEncoder, ConvEncoder] = RNNEncoder() 
    decoder: RNNDecoder = RNNDecoder()

And I simply parse it just like in the subparsers example. I then specify the encoder keyword and then it allows me to specify args specific to the selected architecture.

The problem is, if I want to do the same also for decoder, subparsers are not available anymore and I get

error: cannot have multiple subparser arguments

Ideal state

How can I solve my problem with simple_parsing? My ideal syntax would probably be to use the Union type in the dataclasses and call it like so:

python script.py --some_unrelated_args \
    encoder rnnencoder --args_related_to_rnnencoder \
    decoder convdecoder --args_related_to_convdecoder

Solutions & workarounds

There are quite many suggestions in this SO thread.
The issue can be solved with multiple subparsers and user-side argv parsing. The following code allows me to call my script like so:

python script.py \
    --glob1.xx 7 \
    --glob2.xx 5 \
    encoder convencoder \
      --y 2 \
    decoder convdecoder \
      --n 7

And get

Namespace(decoder=Decoder(decoder=ConvDecoder(n=7)), encoder=Encoder(encoder=ConvEncoder(y=2)), glob1=Global1(xx=7, yy='hello'), glob2=Global2(xx=5, yy='hello'))

The code:

import sys
import itertools
from functools import partial

from dataclasses import dataclass

from typing import Union
from simple_parsing import ArgumentParser
from argparse import Namespace


@dataclass
class RNNEncoder:
    x: int = 1


@dataclass
class ConvEncoder:
    y: int = 2


@dataclass
class RNNDecoder:
    m: int = 3


@dataclass
class ConvDecoder:
    n: int = 4


@dataclass
class Encoder:
    encoder: Union[RNNEncoder, ConvEncoder] = RNNEncoder()


@dataclass
class Decoder:
    decoder: Union[RNNDecoder, ConvDecoder] = RNNDecoder()


@dataclass
class Global1:
    xx: int = 5
    yy: str = "hello"


@dataclass
class Global2:
    xx: int = 5
    yy: str = "hello"


parser = ArgumentParser()

parser.add_arguments(Global1, dest="glob1")
parser.add_arguments(Global2, dest="glob2")

sub = parser.add_subparsers()
encoder = sub.add_parser("encoder")
encoder.add_arguments(Encoder, dest="encoder")
decoder = sub.add_parser("decoder")
decoder.add_arguments(Decoder, dest="decoder")


def groupargs(arg, commands, currentarg=[None]):
    if(arg in commands.keys()):currentarg[0]=arg
    return currentarg[0]


rest = 'tmp.py encoder convencoder --y 7 decoder convdecoder --n 6'.split() # or sys.argv

argv = rest # sys.argv
commandlines = [(cmd, list(args)) for cmd,args in itertools.groupby(argv, partial(groupargs, commands=sub.choices))]
commandlines[0][1].pop(0)

namespaces = dict()
for cmd, cmdline in commandlines:
    n, r = parser.parse_known_args(cmdline)
    assert len(r) == 0, f"Provided unknown args {r} for command {cmd}"
    if cmd is None:
        namespaces["global"] = n
    else:
        namespaces[cmd] = getattr(n, cmd)

args = Namespace(
    **vars(namespaces.pop("global")),
    **namespaces,
)
print(args)

The result looks correct and leverages simple_parsing argument name resulotions etc, so its quite convenient.
Some caveats and ugliness:

  1. Each Union-like option must be in a separate subparser and needs a separate dataclass
  2. Attributes from the Union-like classes are accessed like args.encoder.encoder, but it would be nicer to have args.encoder and get the final class straight away

What would be better:

  1. Allow multiple Union fields in simple-parsing and handle the subparsers internally like described above
  2. Directly fill in the user-selected class to the main class to allow args.glob1.encoder calls

The configs could look like this:

@dataclass
class Global:
    letter: Union[A, B] = A()
    number: Union[One, Two] = One()
    some_other_arg: int = 5

parser = ArgumentParser()
parser.add_arguments(Global, dest="glob")
args = parser.parse_args()

and we would get the same output like in the example.

Other stuff I tried:

I also tried to use Enum to store the configs for encoder types and it allows me to select correct config, but I cannot adjust the selcet config params. I tried to use choice but it did not accept dataclass as an argument.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions