-
-
Notifications
You must be signed in to change notification settings - Fork 8
feat/pipeop-transformer-layer #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
94 commits
Select commit
Hold shift + click to select a range
50bc8e9
double quotes
cxzhang4 8aa4470
style
cxzhang4 4b5fafe
copied in old attic code to test file, still need to try
cxzhang4 258ea42
idrk
cxzhang4 a8a8787
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 b81e9dd
changed d_token in test
cxzhang4 e7288f8
small cleanup
cxzhang4 da059e2
cleanup
cxzhang4 866a7ec
factored out d_token
cxzhang4 e6e67f6
idk
cxzhang4 9d37216
test passes for now
cxzhang4 31ef199
intermrediate docs
cxzhang4 ebcb3c0
TODO: implement custom checks for parameters that are nn_modules or n…
cxzhang4 946dba0
more TODOs
cxzhang4 00c91eb
docs
cxzhang4 da21ff2
change title of nn_ft_transformer_layer module
cxzhang4 7d65f09
removed is_first_layer param
cxzhang4 3872ce0
some comments
cxzhang4 ce4809b
a comment
cxzhang4 41724f0
added back is_first_layer param
cxzhang4 1c2ee1e
added back comment on prenormalization condition
cxzhang4 6b4c34a
comment on parameters
cxzhang4 24645c4
Merge branch 'main' into feat/pipeop-transformer-layer
sebffischer 8b671ee
some changes
sebffischer f812cb5
some notes
sebffischer 14c98c5
some more changes
sebffischer 8d5f641
...
sebffischer c479d9c
factored out last_layer_query_idx from layer
cxzhang4 0c020cf
query_idx should be -1L (last dim) for last transformer layer
cxzhang4 0f17330
deleted file with old name (Layer, not Block)
cxzhang4 916648d
formatting
cxzhang4 40d5e44
check_nn_module_generator
cxzhang4 44f003b
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 726e8af
got rid of some defaults
cxzhang4 5b12f99
idk
cxzhang4 c69af36
reduce blocks in learner when there are multipleg
cxzhang4 8f80d18
delete TODO
cxzhang4 d0ec4fc
fix test
cxzhang4 f6a6326
some comments
cxzhang4 a11b217
small changes
0416ead
Merge branch 'main' into feat/pipeop-transformer-layer
3e609bd
added custom error messages
cxzhang4 852d69c
x_residual
cxzhang4 ec4e8ab
set block dependent default vals
cxzhang4 d5c3cc4
some comments
cxzhang4 e73100a
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 92ced45
intermediate
cxzhang4 7a23834
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 31571ff
remove first_prenormalization
cxzhang4 b5172a6
address TODOs
cxzhang4 17d25cb
looks ok 2 me, still has compression
cxzhang4 d50e1e0
removed kv compression
cxzhang4 c9ad1e9
update docs for learner
cxzhang4 5bf9ada
added block-dependent defaults, removed required tags from learner pa…
cxzhang4 7707922
added activation-dependent ffn_d_hidden
cxzhang4 b62f95c
a comment
cxzhang4 abb2094
man
cxzhang4 173b72c
defaults look ok
cxzhang4 85b1d71
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 5ecef42
removed comment
cxzhang4 b977f56
removed test for mixed input:
cxzhang4 552c6d1
removed browser())
cxzhang4 b98ab8f
removed print statemetn from test:
cxzhang4 ac110dd
add test for only categorical input
cxzhang4 358ef1c
simple script for benchmarking ft transformer
cxzhang4 0ac5b92
init -> default, shouldn't have duplicated anyway
cxzhang4 c2449c9
added query_idx = NULL in the constructor in the module generator
cxzhang4 064ede4
allow vectors for cardinalities param
cxzhang4 d27858d
TODO: debug switch to torch attention implmntn
cxzhang4 d1a87b4
hopefully refactored to use torch multihead_attention
cxzhang4 0c80578
removed benchmark file from attic
cxzhang4 12cb3cb
stuff
cxzhang4 711acbb
remove buggy heuristic for setting params based on n_blocks
cxzhang4 169033e
cleanup
cxzhang4 5b17cfa
Merge branch 'main' into feat/pipeop-transformer-layer
sebffischer f8eb599
docs
cxzhang4 5c4bcdc
docs
cxzhang4 716bee4
docs
cxzhang4 f98391c
docs
cxzhang4 3ed1d11
rename step -> step_valid in callback
sebffischer d10c4c9
fix(tb-callback): log train loss every epoch (#405)
sebffischer 6adf000
fix pkgdown: mlr_pipeops() call in pipeop, remove wrong parameter as …
cxzhang4 193b85c
docs
cxzhang4 970cfab
Merge branch 'main' into feat/pipeop-transformer-layer
cxzhang4 f7d7db6
tried checking out tests/testthat folder from main
cxzhang4 9267761
notes: put ffn_d_hidden_multiplier param back
cxzhang4 b5d7876
some logic adding d_hidden_multiplier param back
cxzhang4 42faf80
docs
cxzhang4 4e8b11d
stuff
cxzhang4 d0e804f
Docs
cxzhang4 ec01aa7
man
cxzhang4 1033d28
Apply suggestions from code review
sebffischer 0856ff2
some last fixes
sebffischer fb60e38
...
sebffischer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
#' @title FT-Transformer | ||
#' @templateVar name ft_transformer | ||
#' @templateVar task_types classif, regr | ||
#' @templateVar param_vals n_blocks = 2, d_token = 32, ffn_d_hidden_multiplier = 4/3 | ||
#' @template params_learner | ||
#' @template learner | ||
#' @template learner_example | ||
#' | ||
#' @description | ||
#' Feature-Tokenizer Transformer for tabular data that can either work on [`lazy_tensor`] inputs | ||
#' or on standard tabular features. | ||
#' | ||
#' Some differences from the paper implementation: no attention compression, no option to have prenormalization in the first layer. | ||
#' | ||
#' If training is unstable, consider a combination of standardizing features (e.g. using `po("scale")`), using an adaptive optimizer (e.g. Adam), reducing the learning rate, | ||
#' and using a learning rate scheduler (see [`CallbackSetLRScheduler`] for options). | ||
#' | ||
#' @section Parameters: | ||
#' Parameters from [`LearnerTorch`] and [`PipeOpTorchFTTransformerBlock`], as well as: | ||
#' * `n_blocks` :: `integer(1)`\cr | ||
#' The number of transformer blocks. | ||
#' * `d_token` :: `integer(1)`\cr | ||
#' The dimension of the embedding. | ||
#' * `cardinalities` :: `integer(1)`\cr | ||
#' The number of categories for each categorical feature. This only needs to be specified when working with [`lazy_tensor`] inputs. | ||
#' * `init_token` :: `character(1)`\cr | ||
#' The initialization method for the embedding weights. Either "uniform" or "normal". "Uniform" by default. | ||
#' * `ingress_tokens` :: named `list()` or `NULL`\cr | ||
#' A list of `TorchIngressToken`s. Only required when using lazy tensor features. | ||
#' The names are either "num.input" or "categ.input", and the values are lazy tensor ingress tokens constructed by, e.g. `ingress_ltnsr(<num_feat_name>)`. | ||
#' | ||
#' @references | ||
#' `r format_bib("gorishniy2021revisiting")` | ||
#' @export | ||
LearnerTorchFTTransformer = R6Class("LearnerTorchFTTransformer", | ||
inherit = LearnerTorch, | ||
public = list( | ||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) { | ||
private$.block = PipeOpTorchFTTransformerBlock$new() | ||
cxzhang4 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
check_ingress_tokens = crate(function(ingress_tokens, task) { | ||
if (is.null(ingress_tokens)) { | ||
return(TRUE) | ||
} | ||
msg = check_list(ingress_tokens, types = "TorchIngressToken", min.len = 1L, names = "unique") | ||
if (!isTRUE(msg)) { | ||
return(msg) | ||
} | ||
check_permutation(names(ingress_tokens), c("num.input", "categ.input")) | ||
}) | ||
|
||
private$.param_set_base = ps( | ||
n_blocks = p_int(lower = 0L, default = 3L, tags = "train"), | ||
d_token = p_int(lower = 1L, default = 192L, tags = "train"), | ||
cardinalities = p_uty(custom_check = function(input) check_integerish(input, null.ok = TRUE), tags = "train"), | ||
init_token = p_fct(init = "uniform", levels = c("uniform", "normal"), tags = "train"), | ||
ingress_tokens = p_uty(tags = "train", custom_check = check_ingress_tokens) | ||
) | ||
param_set = alist(private$.block$param_set, private$.param_set_base) | ||
|
||
super$initialize( | ||
task_type = task_type, | ||
id = paste0(task_type, ".ft_transformer"), | ||
label = "FT-Transformer", | ||
param_set = param_set, | ||
optimizer = optimizer, | ||
callbacks = callbacks, | ||
loss = loss, | ||
man = "mlr3torch::mlr_learners.ft_transformer", | ||
feature_types = c("numeric", "integer", "logical", "factor", "ordered", "lazy_tensor"), | ||
# Because the CLS token does resizing that depends dynamically on the input shape, | ||
# specifically, the batch size | ||
jittable = FALSE | ||
) | ||
} | ||
), | ||
private = list( | ||
.block = NULL, | ||
.ingress_tokens = function(task, param_vals) { | ||
if ("lazy_tensor" %in% task$feature_types$type) { | ||
if (!all(task$feature_types$type == "lazy_tensor")) { | ||
stopf("Learner '%s' received an input task '%s' that is mixing lazy_tensors with other feature types.", self$id, task$id) # nolint | ||
} | ||
if (task$n_features > 2L) { | ||
stopf("Learner '%s' received an input task '%s' that has more than two lazy tensors.", self$id, task$id) # nolint | ||
} | ||
if (is.null(param_vals$ingress_tokens)) { | ||
stopf("Learner '%s' received an input task '%s' with lazy tensors, but no parameter 'ingress_tokens' was specified.", self$id, task$id) # nolint | ||
} | ||
|
||
ingress_tokens = param_vals$ingress_tokens | ||
row = task$head(1L) | ||
for (i in seq_along(ingress_tokens)) { | ||
feat = ingress_tokens[[i]]$features(task) | ||
if (!length(feat) == 1L) { | ||
stopf("Learner '%s' received an input task '%s' with lazy tensors, but the ingress token '%s' does not select exactly one feature.", self$id, task$id, names(ingress_tokens)[[i]]) # nolint | ||
} | ||
if (is.null(ingress_tokens[[i]]$shape)) { | ||
ingress_tokens[[i]]$shape = lazy_shape(row[[feat]]) | ||
} | ||
if (is.null(ingress_tokens[[i]]$shape)) { | ||
stopf("Learner '%s' received an input task '%s' with lazy tensors, but neither the ingress token for '%s', nor the 'lazy_tensor' specify the shape, which makes it impossible to build the network.", self$id, task$id, feat) # nolint | ||
} | ||
} | ||
return(ingress_tokens) | ||
} | ||
num_features = n_num_features(task) | ||
categ_features = n_categ_features(task) | ||
output = list() | ||
if (num_features > 0L) { | ||
output$num.input = ingress_num(shape = c(NA, num_features)) | ||
} | ||
if (categ_features > 0L) { | ||
output$categ.input = ingress_categ(shape = c(NA, categ_features)) | ||
} | ||
output | ||
}, | ||
.network = function(task, param_vals) { | ||
its = private$.ingress_tokens(task, param_vals) | ||
mds = list() | ||
|
||
path_num = if (!is.null(its$num.input)) { | ||
mds$tokenizer_num.input = ModelDescriptor( | ||
po("nop", id = "num"), | ||
its["num.input"], | ||
task$clone(deep = TRUE)$select(its[["num.input"]]$features(task)), | ||
pointer = c("num", "output"), | ||
pointer_shape = its[["num.input"]]$shape | ||
) | ||
nn("tokenizer_num", | ||
d_token = param_vals$d_token, | ||
bias = TRUE, | ||
initialization = param_vals$init_token | ||
) | ||
} | ||
path_categ = if (!is.null(its$categ.input)) { | ||
mds$tokenizer_categ.input = ModelDescriptor( | ||
po("nop", id = "categ"), | ||
its["categ.input"], | ||
task$clone(deep = TRUE)$select(its[["categ.input"]]$features(task)), | ||
pointer = c("categ", "output"), | ||
pointer_shape = its[["categ.input"]]$shape | ||
) | ||
nn("tokenizer_categ", | ||
d_token = param_vals$d_token, | ||
bias = TRUE, | ||
initialization = param_vals$init_token, | ||
param_vals = discard(param_vals["cardinalities"], is.null) | ||
) | ||
} | ||
|
||
input_paths = discard(list(path_num, path_categ), is.null) | ||
|
||
graph_tokenizer = if (length(input_paths) == 1L) { | ||
input_paths[[1L]] | ||
} else { | ||
gunion(input_paths) %>>% | ||
nn("merge_cat", param_vals = list(dim = 2)) | ||
} | ||
|
||
blocks = map(seq_len(param_vals$n_blocks), function(i) { | ||
block = private$.block$clone(deep = TRUE) | ||
block$id = sprintf("block_%i", i) | ||
|
||
if (i == 1) { | ||
block$param_set$values$is_first_layer = TRUE | ||
} else { | ||
block$param_set$values$is_first_layer = FALSE | ||
} | ||
if (i == param_vals$n_blocks) { | ||
block$param_set$values$query_idx = -1L | ||
} else { | ||
block$param_set$values$query_idx = NULL | ||
} | ||
block | ||
}) | ||
|
||
if (length(blocks) > 1L) { | ||
blocks = Reduce(`%>>%`, blocks) | ||
} | ||
|
||
graph = graph_tokenizer %>>% | ||
nn("ft_cls", initialization = "uniform") %>>% | ||
blocks %>>% | ||
nn("fn", fn = function(x) x[, -1]) %>>% | ||
nn("layer_norm", dims = 1) %>>% | ||
nn("relu") %>>% | ||
nn("head") | ||
|
||
model_descriptor_to_module(graph$train(mds, FALSE)[[1L]]) | ||
} | ||
) | ||
) | ||
|
||
register_learner("regr.ft_transformer", LearnerTorchFTTransformer) | ||
register_learner("classif.ft_transformer", LearnerTorchFTTransformer) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.