-
Notifications
You must be signed in to change notification settings - Fork 58
Description
What I am trying to do:
I am training an encoder-decoder neural network and I have multiple different encoders and decoders.
Each encoder and decoder has a separate dataclass config.
I want to be able to specify from the commandline which encoder to use by specifying its config name. Once I specify that, I also want to be able to adjust the selected config fields. So far I was able to accomplish this subparsers:
I have something like
@dataclass
class Model:
encoder: Union[RNNEncoder, ConvEncoder] = RNNEncoder()
decoder: RNNDecoder = RNNDecoder()
And I simply parse it just like in the subparsers example. I then specify the encoder keyword and then it allows me to specify args specific to the selected architecture.
The problem is, if I want to do the same also for decoder, subparsers are not available anymore and I get
error: cannot have multiple subparser arguments
Ideal state
How can I solve my problem with simple_parsing? My ideal syntax would probably be to use the Union type in the dataclasses and call it like so:
python script.py --some_unrelated_args \
encoder rnnencoder --args_related_to_rnnencoder \
decoder convdecoder --args_related_to_convdecoder
Solutions & workarounds
There are quite many suggestions in this SO thread.
The issue can be solved with multiple subparsers and user-side argv parsing. The following code allows me to call my script like so:
python script.py \
--glob1.xx 7 \
--glob2.xx 5 \
encoder convencoder \
--y 2 \
decoder convdecoder \
--n 7
And get
Namespace(decoder=Decoder(decoder=ConvDecoder(n=7)), encoder=Encoder(encoder=ConvEncoder(y=2)), glob1=Global1(xx=7, yy='hello'), glob2=Global2(xx=5, yy='hello'))
The code:
import sys
import itertools
from functools import partial
from dataclasses import dataclass
from typing import Union
from simple_parsing import ArgumentParser
from argparse import Namespace
@dataclass
class RNNEncoder:
x: int = 1
@dataclass
class ConvEncoder:
y: int = 2
@dataclass
class RNNDecoder:
m: int = 3
@dataclass
class ConvDecoder:
n: int = 4
@dataclass
class Encoder:
encoder: Union[RNNEncoder, ConvEncoder] = RNNEncoder()
@dataclass
class Decoder:
decoder: Union[RNNDecoder, ConvDecoder] = RNNDecoder()
@dataclass
class Global1:
xx: int = 5
yy: str = "hello"
@dataclass
class Global2:
xx: int = 5
yy: str = "hello"
parser = ArgumentParser()
parser.add_arguments(Global1, dest="glob1")
parser.add_arguments(Global2, dest="glob2")
sub = parser.add_subparsers()
encoder = sub.add_parser("encoder")
encoder.add_arguments(Encoder, dest="encoder")
decoder = sub.add_parser("decoder")
decoder.add_arguments(Decoder, dest="decoder")
def groupargs(arg, commands, currentarg=[None]):
if(arg in commands.keys()):currentarg[0]=arg
return currentarg[0]
rest = 'tmp.py encoder convencoder --y 7 decoder convdecoder --n 6'.split() # or sys.argv
argv = rest # sys.argv
commandlines = [(cmd, list(args)) for cmd,args in itertools.groupby(argv, partial(groupargs, commands=sub.choices))]
commandlines[0][1].pop(0)
namespaces = dict()
for cmd, cmdline in commandlines:
n, r = parser.parse_known_args(cmdline)
assert len(r) == 0, f"Provided unknown args {r} for command {cmd}"
if cmd is None:
namespaces["global"] = n
else:
namespaces[cmd] = getattr(n, cmd)
args = Namespace(
**vars(namespaces.pop("global")),
**namespaces,
)
print(args)
The result looks correct and leverages simple_parsing argument name resulotions etc, so its quite convenient.
Some caveats and ugliness:
- Each Union-like option must be in a separate subparser and needs a separate dataclass
- Attributes from the Union-like classes are accessed like
args.encoder.encoder, but it would be nicer to haveargs.encoderand get the final class straight away
What would be better:
- Allow multiple Union fields in simple-parsing and handle the subparsers internally like described above
- Directly fill in the user-selected class to the main class to allow
args.glob1.encodercalls
The configs could look like this:
@dataclass
class Global:
letter: Union[A, B] = A()
number: Union[One, Two] = One()
some_other_arg: int = 5
parser = ArgumentParser()
parser.add_arguments(Global, dest="glob")
args = parser.parse_args()
and we would get the same output like in the example.
Other stuff I tried:
I also tried to use Enum to store the configs for encoder types and it allows me to select correct config, but I cannot adjust the selcet config params. I tried to use choice but it did not accept dataclass as an argument.