Skip to content

Commit

Permalink
use Hydra to build xformer model (facebookresearch#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu authored Dec 1, 2021
1 parent 97c826d commit c1b0325
Show file tree
Hide file tree
Showing 32 changed files with 1,113 additions and 762 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ my_runs.md
# examples demo files
examples/input.txt
examples/lightning_logs
examples/data

# Hydra default output dir
multirun
outputs
335 changes: 278 additions & 57 deletions docs/source/tutorials/pytorch_encoder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Building full models
====================


This is the last example in the series, and goes one level up again, so that we consider building a whole Tranformer/xFormer model. Please note that this is just an example, because building the whole model from explicit parts is always an option, from pure PyTorch building blocks or adding some xFormers primitives.
Now let's build a full Tranformers/xFormer model. Please note that this is just an example, because building the whole model from explicit parts is always an option, from pure PyTorch building blocks or adding some xFormers primitives.

PyTorch Transformer
-------------------
Expand Down Expand Up @@ -155,73 +155,69 @@ There's also an added flexibility with xFormers in that attention mechanisms can
# Note that a sequence of different encoder blocks can be used, same for decoders
{
"reversible": False, # Optionally make these layers reversible, to save memory
"block_config": {
"block_type": "encoder",
"num_layers": 3, # Optional, this means that this config will repeat N times
"dim_model": EMB,
"layer_norm_style": "pre", # Optional, pre/post
"position_encoding_config": {
"name": "vocab", # whatever position encodinhg makes sense
"seq_len": 1024,
"vocab_size": VOCAB,
},
"multi_head_config": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "linformer", # whatever attention mechanism
"dropout": 0,
"causal": False,
"seq_len": SEQ,
},
},
"feedforward_config": {
"name": "MLP",
"block_type": "encoder",
"num_layers": 3, # Optional, this means that this config will repeat N times
"dim_model": EMB,
"layer_norm_style": "pre", # Optional, pre/post
"position_encoding_config": {
"name": "vocab", # whatever position encodinhg makes sense
"seq_len": 1024,
"vocab_size": VOCAB,
},
"multi_head_config": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "linformer", # whatever attention mechanism
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
"causal": False,
"seq_len": SEQ,
},
},
"feedforward_config": {
"name": "MLP",
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
},
},
{
"reversible": False, # Optionally make these layers reversible, to save memory
"block_config": {
"block_type": "decoder",
"num_layers": 3, # Optional, this means that this config will repeat N times
"dim_model": EMB,
"layer_norm_style": "pre", # Optional, pre/post
"position_encoding_config": {
"name": "vocab", # whatever position encodinhg makes sense
"block_type": "decoder",
"num_layers": 3, # Optional, this means that this config will repeat N times
"dim_model": EMB,
"layer_norm_style": "pre", # Optional, pre/post
"position_encoding_config": {
"name": "vocab", # whatever position encodinhg makes sense
"seq_len": SEQ,
"vocab_size": VOCAB,
},
"multi_head_config_masked": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "nystrom", # whatever attention mechanism
"dropout": 0,
"causal": True,
"seq_len": SEQ,
"vocab_size": VOCAB,
},
"multi_head_config_masked": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "nystrom", # whatever attention mechanism
"dropout": 0,
"causal": True,
"seq_len": SEQ,
},
},
"multi_head_config_cross": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "favor", # whatever attention mechanism
"dropout": 0,
"causal": True,
"seq_len": SEQ,
},
},
"feedforward_config": {
"name": "MLP",
},
"multi_head_config_cross": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "favor", # whatever attention mechanism
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
"causal": True,
"seq_len": SEQ,
},
},
"feedforward_config": {
"name": "MLP",
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
},
},
]
Expand Down Expand Up @@ -255,3 +251,228 @@ Current results are as follows, on a nVidia V100 (PyTorch 1.9, Triton 1.1, xForm
| --------- | ----------------- | ------------------ | ------------------ |
| xformers | 89 | 1182 | 2709 |
| pytorch | 155 | 1950 | 4117 |
Build an `xFormer` model with Hydra
-----------------------------------

Alternatively, you can use [Hydra](https://hydra.cc/) to build an xFormer model.
We've included an example [here](https://github.com/facebookresearch/xformers/tree/main/examples/build_model).
The example replicates the model from the above example and demonstrates one way to use Hydra to minimize config duplication.
The example is built on top of some more advanced Hydra features. If you are new to Hydra, you can start these docs:
[basic tutorials](https://hydra.cc/docs/tutorials/intro/), [extending configs](https://hydra.cc/docs/patterns/extending_configs/),
[Hydra packages](https://hydra.cc/docs/advanced/overriding_packages/) and
[instantiation API](https://hydra.cc/docs/advanced/instantiate_objects/overview/)


.. code-block:: yaml
defaults:
- /stack@xformer.stack_configs:
- encoder_local
- encoder_random
- decoder_nystrom_favor
- _self_
xformer:
_target_: xformers.factory.model_factory.xFormer
Building a model this way makes it possible for you to leverage many features Hydra has to offer.
For example, you can override the model architecture from the commandline:

.. code-block:: bash
python examples/build_model/my_model.py 'stack@xformer.stack_configs=[encoder_local]'
Built a model with 1 stack: dict_keys(['encoder_local'])
xFormer(
(encoders): ModuleList(
(0): xFormerEncoderBlock(
(pose_encoding): VocabEmbedding(
(dropout): Dropout(p=0, inplace=False)
(position_embeddings): Embedding(1024, 384)
(word_embeddings): Embedding(64, 384)
)
(mha): MultiHeadDispatch(
(attention): LocalAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(feedforward): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
(wrap_att): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MultiHeadDispatch(
(attention): LocalAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
)
)
(wrap_ff): PostNorm(
(norm): FusedLayerNorm()
(sublayer): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
)
)
)
)
)
(decoders): ModuleList()
)
You can also launch multiple runs of your application with different architectures:

.. code-block:: bash
$ python my_model.py --multirun 'stack@xformer.stack_configs=[encoder_local], [encoder_random]'
[HYDRA] Launching 2 jobs locally
[HYDRA] #0 : stack@xformer.stack_configs=[encoder_local]
Built a model with 1 stack: dict_keys(['encoder_local'])
xFormer(
(encoders): ModuleList(
(0): xFormerEncoderBlock(
(pose_encoding): VocabEmbedding(
(dropout): Dropout(p=0, inplace=False)
(position_embeddings): Embedding(1024, 384)
(word_embeddings): Embedding(64, 384)
)
(mha): MultiHeadDispatch(
(attention): LocalAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(feedforward): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
(wrap_att): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MultiHeadDispatch(
(attention): LocalAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
)
)
(wrap_ff): PostNorm(
(norm): FusedLayerNorm()
(sublayer): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
)
)
)
)
)
(decoders): ModuleList()
)
[HYDRA] #1 : stack@xformer.stack_configs=[encoder_random]
Built a model with 1 stack: dict_keys(['encoder_random'])
xFormer(
(encoders): ModuleList(
(0): xFormerEncoderBlock(
(pose_encoding): VocabEmbedding(
(dropout): Dropout(p=0, inplace=False)
(position_embeddings): Embedding(1024, 384)
(word_embeddings): Embedding(64, 384)
)
(mha): MultiHeadDispatch(
(attention): RandomAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(feedforward): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
(wrap_att): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MultiHeadDispatch(
(attention): RandomAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
)
)
(wrap_ff): PostNorm(
(norm): FusedLayerNorm()
(sublayer): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
)
)
)
)
)
(decoders): ModuleList()
)
7 changes: 7 additions & 0 deletions examples/build_model/conf/attention/favor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
# validate schema using xFormer's config dataclass
- /xformers/attention/favor_schema@_here_

name: favor
dropout: 0

6 changes: 6 additions & 0 deletions examples/build_model/conf/attention/local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults:
# validate schema using xFormer's config dataclass
- /xformers/attention/local_schema@_here_

name: local
dropout: 0
4 changes: 4 additions & 0 deletions examples/build_model/conf/attention/nystrom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: nystrom
dropout: 0
causal: True
seq_len: ${seq}
Loading

0 comments on commit c1b0325

Please sign in to comment.