Skip to content

Commit e8fe6de

Browse files
committed
x-transformer decoder args in config
1 parent 902e31f commit e8fe6de

File tree

4 files changed

+11
-1
lines changed

4 files changed

+11
-1
lines changed

models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def get_model(args):
129129
dim=args.dim,
130130
depth=args.num_layers,
131131
heads=args.heads,
132-
cross_attend=True
132+
**args.decoder_args
133133
)),
134134
pad_value=args.pad_token
135135
).to(args.device)

settings/config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ batchsize: 10
22
bos_token: 1
33
channels: 1
44
debug: false
5+
decoder_args:
6+
cross_attend: true
57
device: cuda
68
dim: 256
79
encoder_depth: 4

settings/default.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ backbone_layers:
3737
- 9
3838
encoder_depth: 4
3939
num_layers: 4
40+
decoder_args:
41+
cross_attend: true
42+
ff_glu: true
43+
attn_on_attn: false
44+
use_scalenorm: true
45+
rel_pos_bias: false
4046
heads: 8
4147
num_tokens: 8000
4248
max_seq_len: 1024

utils/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def parse_args(args, **kwargs):
5454
args.wandb = not kwargs.debug and not args.debug
5555
args.device = 'cuda' if torch.cuda.is_available() and not kwargs.no_cuda else 'cpu'
5656
args.max_dimensions = [args.max_width, args.max_height]
57+
if 'decoder_args' not in args or args.decoder_args is None:
58+
args.decoder_args = {}
5759
if 'model_path' in args:
5860
args.out_path = os.path.join(args.model_path, args.name)
5961
os.makedirs(args.out_path, exist_ok=True)

0 commit comments

Comments
 (0)