-
Notifications
You must be signed in to change notification settings - Fork 651
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
take 2: use Hydra to build xformer model
- Loading branch information
Showing
20 changed files
with
264 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
defaults: | ||
# validate schema using xFormer's config dataclass | ||
- /xformers/attention/favor_schema@_here_ | ||
|
||
name: favor | ||
dropout: 0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
defaults: | ||
# validate schema using xFormer's config dataclass | ||
- /xformers/attention/local_schema@_here_ | ||
|
||
name: local | ||
dropout: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
name: nystrom | ||
dropout: 0 | ||
causal: True | ||
seq_len: ${seq} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
defaults: | ||
- /xformers/attention/random_schema@_here_ | ||
|
||
name: random | ||
dropout: 0 | ||
r: 0.01 | ||
constant_masking: True | ||
force_sparsity: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
emb: 384 | ||
seq: 1024 | ||
vocab: 64 | ||
|
||
defaults: | ||
- /stack@xformer.stack_configs: | ||
- encoder_local | ||
- encoder_random | ||
- decoder_nystrom_favor | ||
- _self_ | ||
|
||
xformer: | ||
_target_: xformers.factory.model_factory.xFormer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# base encoder settings that can be extended and overriden | ||
# we leave out the attention part for other config to override | ||
|
||
_target_: xformers.factory.block_factory.xFormerDecoderConfig | ||
reversible: False # Optionally make these layers reversible to save memory | ||
num_layers: 3 # Optional this means that this config will repeat N times | ||
block_type: decoder | ||
dim_model: ${emb} | ||
layer_norm_style: pre # Optional pre/post | ||
position_encoding_config: | ||
name: vocab # whatever position encodinhg makes sense | ||
seq_len: ${seq} | ||
vocab_size: ${vocab} | ||
dropout: 0 | ||
multi_head_config_masked: | ||
num_heads: 4 | ||
residual_dropout: 0 | ||
attention: ??? | ||
multi_head_config_cross: | ||
num_heads: 4 | ||
residual_dropout: 0 | ||
attention: ??? | ||
feedforward_config: | ||
name: MLP | ||
dropout: 0 | ||
activation: relu | ||
hidden_layer_multiplier: 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# base encoder settings that can be extended and overriden | ||
# we leave out the attention part for other config to override | ||
|
||
_target_: xformers.factory.block_factory.xFormerEncoderConfig | ||
reversible: False | ||
num_layers: 4 | ||
user_triton: True | ||
dim_model: ${emb} | ||
layer_norm_style: pre | ||
position_encoding_config: | ||
name: vocab | ||
seq_len: 1024 | ||
vocab_size: ${vocab} | ||
dropout: 0 | ||
multi_head_config: | ||
num_heads: 4 | ||
residual_dropout: 0 | ||
attention: ??? | ||
feedforward_config: | ||
name: MLP | ||
dropout: 0 | ||
activation: relu | ||
hidden_layer_multiplier: 4 |
16 changes: 16 additions & 0 deletions
16
examples/build_model/conf/stack/decoder_nystrom_favor.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
defaults: | ||
# move configs from base_decoder to decoder_nystrom_favor package | ||
# resulting config would look like | ||
# | ||
# decoder_nystrom_favor: | ||
# _target_: xformers.factory.block_factory.xFormerDecoderConfig | ||
# reversible: False | ||
# ... | ||
# | ||
# this helps with organizing the configs at a model level | ||
# the package name is arbitrary but should be unique within the stacks groups | ||
# to avoid conficts. | ||
- base_decoder@decoder_nystrom_favor | ||
# override the attentions :) | ||
- /attention@decoder_nystrom_favor.multi_head_config_masked.attention: nystrom | ||
- /attention@decoder_nystrom_favor.multi_head_config_cross.attention: favor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
defaults: | ||
# move configs from base_encoder to under encoder_local package | ||
# this helps with merging the configs at a model level | ||
# the package name is arbitrary but should be unique within the stacks config groups | ||
# to avoid conflicts. | ||
- base_encoder@encoder_local | ||
# override the attention | ||
- /attention@encoder_local.multi_head_config.attention: local |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
defaults: | ||
# move configs from base_encoder to under encoder_local package | ||
# this helps with merging the configs at a model level | ||
# the package name is arbitrary but should be unique within the stacks config groups | ||
# to avoid conflicts. | ||
- base_encoder@encoder_random | ||
# override the attention | ||
- /attention@encoder_random.multi_head_config.attention: random |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import hydra | ||
from omegaconf import DictConfig | ||
|
||
from xformers.factory.hydra_helper import import_xformer_config_schema | ||
|
||
|
||
@hydra.main(config_path="conf", config_name="config") | ||
def my_app(cfg: DictConfig) -> None: | ||
model = hydra.utils.instantiate(cfg.xformer, _convert_="all") | ||
print(model) | ||
|
||
|
||
if __name__ == "__main__": | ||
# optional - only needed when you want to use xformer config dataclass | ||
# to validate config values. | ||
import_xformer_config_schema() | ||
my_app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
hydra-core>1.1 | ||
lightning-bolts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# register components configs into Hydra ConfigStore | ||
# component config classes could be used to validate configs | ||
from hydra.core.config_store import ConfigStore | ||
from omegaconf.errors import ValidationError | ||
|
||
from xformers.components.attention import ATTENTION_REGISTRY | ||
from xformers.components.feedforward import FEEDFORWARD_REGISTRY | ||
from xformers.components.positional_embedding import POSITION_EMBEDDING_REGISTRY | ||
|
||
|
||
def import_xformer_config_schema(): | ||
""" | ||
Best effort - OmegaConf supports limited typing, so we may fail to import | ||
certain config classes. For example, pytorch typing are not supported. | ||
""" | ||
cs = ConfigStore.instance() | ||
|
||
for k, v in { | ||
"ff": FEEDFORWARD_REGISTRY, | ||
"pe": POSITION_EMBEDDING_REGISTRY, | ||
"attention": ATTENTION_REGISTRY, | ||
}.items(): | ||
for kk in v.keys(): | ||
try: | ||
cs.store(name=f"{kk}_schema", node=v[kk].config, group=f"xformers/{k}") | ||
except ValidationError as e: | ||
print(f"Error registering {kk}_schema, error: {e}") |
Oops, something went wrong.