Skip to content

Commit

Permalink
take 2: use Hydra to build xformer model
Browse files Browse the repository at this point in the history
  • Loading branch information
jieru-hu committed Nov 11, 2021
1 parent c92931b commit 019f21f
Show file tree
Hide file tree
Showing 20 changed files with 264 additions and 112 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ my_runs.md
# examples demo files
examples/input.txt
examples/lightning_logs
examples/data

# Hydra default output dir
multirun
outputs
7 changes: 7 additions & 0 deletions examples/build_model/conf/attention/favor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
# validate schema using xFormer's config dataclass
- /xformers/attention/favor_schema@_here_

name: favor
dropout: 0

6 changes: 6 additions & 0 deletions examples/build_model/conf/attention/local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults:
# validate schema using xFormer's config dataclass
- /xformers/attention/local_schema@_here_

name: local
dropout: 0
4 changes: 4 additions & 0 deletions examples/build_model/conf/attention/nystrom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: nystrom
dropout: 0
causal: True
seq_len: ${seq}
8 changes: 8 additions & 0 deletions examples/build_model/conf/attention/random.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults:
- /xformers/attention/random_schema@_here_

name: random
dropout: 0
r: 0.01
constant_masking: True
force_sparsity: False
13 changes: 13 additions & 0 deletions examples/build_model/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
emb: 384
seq: 1024
vocab: 64

defaults:
- /stack@xformer.stack_configs:
- encoder_local
- encoder_random
- decoder_nystrom_favor
- _self_

xformer:
_target_: xformers.factory.model_factory.xFormer
27 changes: 27 additions & 0 deletions examples/build_model/conf/stack/base_decoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# base encoder settings that can be extended and overriden
# we leave out the attention part for other config to override

_target_: xformers.factory.block_factory.xFormerDecoderConfig
reversible: False # Optionally make these layers reversible to save memory
num_layers: 3 # Optional this means that this config will repeat N times
block_type: decoder
dim_model: ${emb}
layer_norm_style: pre # Optional pre/post
position_encoding_config:
name: vocab # whatever position encodinhg makes sense
seq_len: ${seq}
vocab_size: ${vocab}
dropout: 0
multi_head_config_masked:
num_heads: 4
residual_dropout: 0
attention: ???
multi_head_config_cross:
num_heads: 4
residual_dropout: 0
attention: ???
feedforward_config:
name: MLP
dropout: 0
activation: relu
hidden_layer_multiplier: 4
23 changes: 23 additions & 0 deletions examples/build_model/conf/stack/base_encoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# base encoder settings that can be extended and overriden
# we leave out the attention part for other config to override

_target_: xformers.factory.block_factory.xFormerEncoderConfig
reversible: False
num_layers: 4
user_triton: True
dim_model: ${emb}
layer_norm_style: pre
position_encoding_config:
name: vocab
seq_len: 1024
vocab_size: ${vocab}
dropout: 0
multi_head_config:
num_heads: 4
residual_dropout: 0
attention: ???
feedforward_config:
name: MLP
dropout: 0
activation: relu
hidden_layer_multiplier: 4
16 changes: 16 additions & 0 deletions examples/build_model/conf/stack/decoder_nystrom_favor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
defaults:
# move configs from base_decoder to decoder_nystrom_favor package
# resulting config would look like
#
# decoder_nystrom_favor:
# _target_: xformers.factory.block_factory.xFormerDecoderConfig
# reversible: False
# ...
#
# this helps with organizing the configs at a model level
# the package name is arbitrary but should be unique within the stacks groups
# to avoid conficts.
- base_decoder@decoder_nystrom_favor
# override the attentions :)
- /attention@decoder_nystrom_favor.multi_head_config_masked.attention: nystrom
- /attention@decoder_nystrom_favor.multi_head_config_cross.attention: favor
8 changes: 8 additions & 0 deletions examples/build_model/conf/stack/encoder_local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults:
# move configs from base_encoder to under encoder_local package
# this helps with merging the configs at a model level
# the package name is arbitrary but should be unique within the stacks config groups
# to avoid conflicts.
- base_encoder@encoder_local
# override the attention
- /attention@encoder_local.multi_head_config.attention: local
8 changes: 8 additions & 0 deletions examples/build_model/conf/stack/encoder_random.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults:
# move configs from base_encoder to under encoder_local package
# this helps with merging the configs at a model level
# the package name is arbitrary but should be unique within the stacks config groups
# to avoid conflicts.
- base_encoder@encoder_random
# override the attention
- /attention@encoder_random.multi_head_config.attention: random
17 changes: 17 additions & 0 deletions examples/build_model/my_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import hydra
from omegaconf import DictConfig

from xformers.factory.hydra_helper import import_xformer_config_schema


@hydra.main(config_path="conf", config_name="config")
def my_app(cfg: DictConfig) -> None:
model = hydra.utils.instantiate(cfg.xformer, _convert_="all")
print(model)


if __name__ == "__main__":
# optional - only needed when you want to use xformer config dataclass
# to validate config values.
import_xformer_config_schema()
my_app()
51 changes: 25 additions & 26 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,34 +49,33 @@ def __init__(
# A list of the encoder or decoder blocks which constitute the Transformer.
xformer_config = [
{
"block_config": {
"block_type": "encoder",
"num_layers": self.hparams.n_layer,
"dim_model": self.hparams.n_embd,
"layer_norm_style": "pre",
"position_encoding_config": {
"name": "vocab",

"block_type": "encoder",
"num_layers": self.hparams.n_layer,
"dim_model": self.hparams.n_embd,
"layer_norm_style": "pre",
"position_encoding_config": {
"name": "vocab",
"seq_len": self.hparams.block_size,
"vocab_size": self.hparams.vocab_size,
},
"multi_head_config": {
"num_heads": self.hparams.n_head,
"residual_dropout": self.hparams.resid_pdrop,
"use_rotary_embeddings": True,
"attention": {
"name": self.hparams.attention,
"dropout": self.hparams.attn_pdrop,
"causal": True,
"seq_len": self.hparams.block_size,
"vocab_size": self.hparams.vocab_size,
},
"multi_head_config": {
"num_heads": self.hparams.n_head,
"residual_dropout": self.hparams.resid_pdrop,
"use_rotary_embeddings": True,
"attention": {
"name": self.hparams.attention,
"dropout": self.hparams.attn_pdrop,
"causal": True,
"seq_len": self.hparams.block_size,
},
},
"feedforward_config": {
"name": "MLP",
"dropout": self.hparams.mlp_pdrop,
"activation": "gelu",
"hidden_layer_multiplier": self.hparams.hidden_layer_multiplier,
},
}
},
"feedforward_config": {
"name": "MLP",
"dropout": self.hparams.mlp_pdrop,
"activation": "gelu",
"hidden_layer_multiplier": self.hparams.hidden_layer_multiplier,
},
}
]

Expand Down
43 changes: 21 additions & 22 deletions examples/microViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,29 @@ def __init__(
# A list of the encoder or decoder blocks which constitute the Transformer.
xformer_config = [
{
"block_config": {
"block_type": "encoder",
"num_layers": n_layer,
"dim_model": dim,
"seq_len": num_patches,
"layer_norm_style": "pre",
"multi_head_config": {
"num_heads": n_head,
"residual_dropout": resid_pdrop,
"use_rotary_embeddings": True,
"attention": {
"name": attention,
"dropout": attn_pdrop,
"causal": False,
},
"block_type": "encoder",
"num_layers": n_layer,
"dim_model": dim,
"seq_len": num_patches,
"layer_norm_style": "pre",
"multi_head_config": {
"num_heads": n_head,
"residual_dropout": resid_pdrop,
"use_rotary_embeddings": True,
"attention": {
"name": attention,
"dropout": attn_pdrop,
"causal": False,
},
"feedforward_config": {
"name": "MLP",
"dropout": mlp_pdrop,
"activation": "gelu",
"hidden_layer_multiplier": hidden_layer_multiplier,
},
}
},
"feedforward_config": {
"name": "MLP",
"dropout": mlp_pdrop,
"activation": "gelu",
"hidden_layer_multiplier": hidden_layer_multiplier,
},
}

]

config = xFormerConfig(xformer_config)
Expand Down
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
hydra-core>1.1
lightning-bolts
5 changes: 4 additions & 1 deletion xformers/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def build_multi_head_attention(
"num_heads"
]

if "dim_features" not in multi_head_config["attention"]:
if (
"dim_features" not in multi_head_config["attention"]
or multi_head_config["attention"]["dim_features"] is None
):
multi_head_config["attention"]["dim_features"] = (
multi_head_config["dim_model"] // multi_head_config["num_heads"]
)
Expand Down
6 changes: 3 additions & 3 deletions xformers/components/attention/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

@dataclass
class LocalAttentionConfig(AttentionConfig):
causal: Optional[bool]
window_size: Optional[int]
force_sparsity: Optional[bool]
causal: Optional[bool] = None
window_size: Optional[int] = None
force_sparsity: Optional[bool] = None


@register_attention("local", LocalAttentionConfig)
Expand Down
8 changes: 6 additions & 2 deletions xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class xFormerBlockConfig:
layer_norm_style: LayerNormStyle
layer_position: LayerPosition
use_triton: bool
reversible: bool
num_layers: int

def __init__(
self,
Expand All @@ -108,10 +110,14 @@ def __init__(
position_encoding_config: Optional[Dict[str, Any]],
block_type: BlockType,
layer_norm_style: LayerNormStyle = LayerNormStyle("post"),
reversible: bool = False,
num_layers: int = 1,
):
self.dim_model = dim_model
self.block_type = block_type
self.layer_norm_style = layer_norm_style
self.reversible = reversible
self.num_layers = num_layers

# Fill in possible gaps in the config for subparts of the block
self.feedforward_config = generate_matching_config(
Expand Down Expand Up @@ -256,7 +262,6 @@ def __init__(self, config: xFormerEncoderConfig, **kwargs):
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
self.wrap_att = ln_factory(self.mha)
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(self.feedforward)

if (
config.layer_norm_style == LayerNormStyle.Pre
and config.layer_position.is_last()
Expand Down Expand Up @@ -334,7 +339,6 @@ def __init__(self, config: xFormerDecoderConfig, **kwargs):
self.wrap_att = ln_factory(self.mha)
self.wrap_cross = ln_factory(self.cross_mha)
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(self.feedforward)

if (
config.layer_norm_style == LayerNormStyle.Pre
and config.layer_position.is_last()
Expand Down
27 changes: 27 additions & 0 deletions xformers/factory/hydra_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# register components configs into Hydra ConfigStore
# component config classes could be used to validate configs
from hydra.core.config_store import ConfigStore
from omegaconf.errors import ValidationError

from xformers.components.attention import ATTENTION_REGISTRY
from xformers.components.feedforward import FEEDFORWARD_REGISTRY
from xformers.components.positional_embedding import POSITION_EMBEDDING_REGISTRY


def import_xformer_config_schema():
"""
Best effort - OmegaConf supports limited typing, so we may fail to import
certain config classes. For example, pytorch typing are not supported.
"""
cs = ConfigStore.instance()

for k, v in {
"ff": FEEDFORWARD_REGISTRY,
"pe": POSITION_EMBEDDING_REGISTRY,
"attention": ATTENTION_REGISTRY,
}.items():
for kk in v.keys():
try:
cs.store(name=f"{kk}_schema", node=v[kk].config, group=f"xformers/{k}")
except ValidationError as e:
print(f"Error registering {kk}_schema, error: {e}")
Loading

0 comments on commit 019f21f

Please sign in to comment.