Skip to content

Commit

Permalink
add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu committed Nov 16, 2021
1 parent d95be42 commit 5bba9df
Show file tree
Hide file tree
Showing 3 changed files with 283 additions and 58 deletions.
332 changes: 275 additions & 57 deletions docs/source/tutorials/pytorch_encoder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Building full models
====================


This is the last example in the series, and goes one level up again, so that we consider building a whole Tranformer/xFormer model. Please note that this is just an example, because building the whole model from explicit parts is always an option, from pure PyTorch building blocks or adding some xFormers primitives.
Now let's build a full Tranformer/xFormer model. Please note that this is just an example, because building the whole model from explicit parts is always an option, from pure PyTorch building blocks or adding some xFormers primitives.

PyTorch Transformer
-------------------
Expand Down Expand Up @@ -155,73 +155,69 @@ There's also an added flexibility with xFormers in that attention mechanisms can
# Note that a sequence of different encoder blocks can be used, same for decoders
{
"reversible": False, # Optionally make these layers reversible, to save memory
"block_config": {
"block_type": "encoder",
"num_layers": 3, # Optional, this means that this config will repeat N times
"dim_model": EMB,
"layer_norm_style": "pre", # Optional, pre/post
"position_encoding_config": {
"name": "vocab", # whatever position encodinhg makes sense
"seq_len": 1024,
"vocab_size": VOCAB,
},
"multi_head_config": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "linformer", # whatever attention mechanism
"dropout": 0,
"causal": False,
"seq_len": SEQ,
},
},
"feedforward_config": {
"name": "MLP",
"block_type": "encoder",
"num_layers": 3, # Optional, this means that this config will repeat N times
"dim_model": EMB,
"layer_norm_style": "pre", # Optional, pre/post
"position_encoding_config": {
"name": "vocab", # whatever position encodinhg makes sense
"seq_len": 1024,
"vocab_size": VOCAB,
},
"multi_head_config": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "linformer", # whatever attention mechanism
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
"causal": False,
"seq_len": SEQ,
},
},
"feedforward_config": {
"name": "MLP",
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
},
},
{
"reversible": False, # Optionally make these layers reversible, to save memory
"block_config": {
"block_type": "decoder",
"num_layers": 3, # Optional, this means that this config will repeat N times
"dim_model": EMB,
"layer_norm_style": "pre", # Optional, pre/post
"position_encoding_config": {
"name": "vocab", # whatever position encodinhg makes sense
"block_type": "decoder",
"num_layers": 3, # Optional, this means that this config will repeat N times
"dim_model": EMB,
"layer_norm_style": "pre", # Optional, pre/post
"position_encoding_config": {
"name": "vocab", # whatever position encodinhg makes sense
"seq_len": SEQ,
"vocab_size": VOCAB,
},
"multi_head_config_masked": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "nystrom", # whatever attention mechanism
"dropout": 0,
"causal": True,
"seq_len": SEQ,
"vocab_size": VOCAB,
},
"multi_head_config_masked": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "nystrom", # whatever attention mechanism
"dropout": 0,
"causal": True,
"seq_len": SEQ,
},
},
"multi_head_config_cross": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "favor", # whatever attention mechanism
"dropout": 0,
"causal": True,
"seq_len": SEQ,
},
},
"feedforward_config": {
"name": "MLP",
},
"multi_head_config_cross": {
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "favor", # whatever attention mechanism
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
"causal": True,
"seq_len": SEQ,
},
},
"feedforward_config": {
"name": "MLP",
"dropout": 0,
"activation": "relu",
"hidden_layer_multiplier": 4,
},
},
]
Expand Down Expand Up @@ -255,3 +251,225 @@ Current results are as follows, on a nVidia V100 (PyTorch 1.9, Triton 1.1, xForm
| --------- | ----------------- | ------------------ | ------------------ |
| xformers | 89 | 1182 | 2709 |
| pytorch | 155 | 1950 | 4117 |
Build an `xFormer` model with Hydra
-----------------------------------

Alternatively, you can use [Hydra](https://hydra.cc/) to build an xFormer model.
We've included an example [here](https://github.com/facebookresearch/xformers/tree/main/examples/build_model).
The example replicates the model from the above example and demonstrates one way to use Hydra to minimize config duplication.



.. code-block:: yaml
defaults:
- /stack@xformer.stack_configs:
- encoder_local
- encoder_random
- decoder_nystrom_favor
- _self_
xformer:
_target_: xformers.factory.model_factory.xFormer
Building a model this way makes it possible for you to leverage many features Hydra has to offer.
For example, you can override the model architecture from the commandline:

.. code-block:: bash
python examples/build_model/my_model.py 'stack@xformer.stack_configs=[encoder_local]'
Built a model with 1 stack: dict_keys(['encoder_local'])
xFormer(
(encoders): ModuleList(
(0): xFormerEncoderBlock(
(pose_encoding): VocabEmbedding(
(dropout): Dropout(p=0, inplace=False)
(position_embeddings): Embedding(1024, 384)
(word_embeddings): Embedding(64, 384)
)
(mha): MultiHeadDispatch(
(attention): LocalAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(feedforward): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
(wrap_att): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MultiHeadDispatch(
(attention): LocalAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
)
)
(wrap_ff): PostNorm(
(norm): FusedLayerNorm()
(sublayer): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
)
)
)
)
)
(decoders): ModuleList()
)
You can also launch multiple runs of your application with different architectures:

.. code-block:: bash
$ python my_model.py --multirun 'stack@xformer.stack_configs=[encoder_local], [encoder_random]'
[HYDRA] Launching 2 jobs locally
[HYDRA] #0 : stack@xformer.stack_configs=[encoder_local]
Built a model with 1 stack: dict_keys(['encoder_local'])
xFormer(
(encoders): ModuleList(
(0): xFormerEncoderBlock(
(pose_encoding): VocabEmbedding(
(dropout): Dropout(p=0, inplace=False)
(position_embeddings): Embedding(1024, 384)
(word_embeddings): Embedding(64, 384)
)
(mha): MultiHeadDispatch(
(attention): LocalAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(feedforward): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
(wrap_att): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MultiHeadDispatch(
(attention): LocalAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
)
)
(wrap_ff): PostNorm(
(norm): FusedLayerNorm()
(sublayer): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
)
)
)
)
)
(decoders): ModuleList()
)
[HYDRA] #1 : stack@xformer.stack_configs=[encoder_random]
Built a model with 1 stack: dict_keys(['encoder_random'])
xFormer(
(encoders): ModuleList(
(0): xFormerEncoderBlock(
(pose_encoding): VocabEmbedding(
(dropout): Dropout(p=0, inplace=False)
(position_embeddings): Embedding(1024, 384)
(word_embeddings): Embedding(64, 384)
)
(mha): MultiHeadDispatch(
(attention): RandomAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
(feedforward): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
(wrap_att): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MultiHeadDispatch(
(attention): RandomAttention(
(attn_drop): Dropout(p=0.0, inplace=False)
)
(in_proj_container): InProjContainer()
(resid_drop): Dropout(p=0, inplace=False)
(proj): Linear(in_features=384, out_features=384, bias=True)
)
)
)
(wrap_ff): PostNorm(
(norm): FusedLayerNorm()
(sublayer): Residual(
(layer): PreNorm(
(norm): FusedLayerNorm()
(sublayer): MLP(
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): ReLU()
(2): Dropout(p=0, inplace=False)
(3): Linear(in_features=1536, out_features=384, bias=True)
(4): Dropout(p=0, inplace=False)
)
)
)
)
)
)
)
(decoders): ModuleList()
)
3 changes: 3 additions & 0 deletions examples/build_model/my_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
@hydra.main(config_path="conf", config_name="config")
def my_app(cfg: DictConfig) -> None:
model = hydra.utils.instantiate(cfg.xformer, _convert_="all")
print(
f"Built a model with {len(cfg.xformer.stack_configs)} stack: {cfg.xformer.stack_configs.keys()}"
)
print(model)


Expand Down
6 changes: 5 additions & 1 deletion xformers/factory/hydra_helper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# register components configs into Hydra ConfigStore
# component config classes could be used to validate configs
import logging

from hydra.core.config_store import ConfigStore
from omegaconf.errors import ValidationError

from xformers.components.attention import ATTENTION_REGISTRY
from xformers.components.feedforward import FEEDFORWARD_REGISTRY
from xformers.components.positional_embedding import POSITION_EMBEDDING_REGISTRY

log = logging.getLogger(__name__)


def import_xformer_config_schema():
"""
Expand All @@ -24,4 +28,4 @@ def import_xformer_config_schema():
try:
cs.store(name=f"{kk}_schema", node=v[kk].config, group=f"xformers/{k}")
except ValidationError as e:
print(f"Error registering {kk}_schema, error: {e}")
log.debug(f"Error registering {kk}_schema, error: {e}")

0 comments on commit 5bba9df

Please sign in to comment.