Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement luz_callback_validation_check #56

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mattwarkentin
Copy link
Contributor

Related to #5 (comment)

Hi @dfalbel,

This is a first attempt at implementing the validation check callback. It may still need some work so I am submitting this as a draft PR for now. By design, this check only runs a few batches and computes the loss. It does not strictly follow the standard validation loop because it does not call the validation-related callbacks.

We may also want to compute the validation metrics in this check. I will await your thoughts before more changes are made.

if (is.null(ctx$valid_data)) return()
if (self$batches <= 0) return()

ctx$model$eval()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to extract this out into a function:

luz/R/callbacks.R

Lines 298 to 300 in 6e0bb77

ctx$model$eval()
ctx$training <- FALSE
ctx$loss <- list()

And reuse here so we make sure that the same changes are always set?

input <- list(batch[[1]])
target <- batch[[2]]
pred <- do.call(ctx$model, input)
self$loss <- ctx$model$loss(pred, target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general we would want to do the full validation step because the errors could be in any of the callbacks etc. But we would need to take care of the side effects that this might cause.

We would need to call valid_one_step() and then make sure we can reset the state. Not sure yet what would be the best way to do it though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this problem. It feels to me that the safest way would be to on_fit_begin() we call fit again and we add a callback that breaks the training loop after batches steps for both training and validation. This way, no side effects would interfere in the actual training loop but we still run the full loop which would detect the other possible bugs.

I think this is possible if the first thing we do in the ctx object is to save a list with all arguments that were passed to
fit, before we do any kind of manipulation (like we do for callbacks).
To avoid the infinite recursion we could check ctx$callbacks to check if the callback that breaks the loop is present.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I've kind of gone in circles here. We want to call the validation callbacks so it is a complete check of the validation loop, but I was worried about any changes in state this might have. I did consider using valid_one_batch() but at the time decided against it for the above reasons.

Copy link
Contributor Author

@mattwarkentin mattwarkentin Jul 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat related question, when ctx$call_callbacks("on_..._...") is called, if there are multiple callbacks with available methods for the breakpoint, what is the order they are called in? Default callbacks first, user-supplied callbacks second?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are called in that order: default callbacks then user callbacks.
I think that if we call fit again, there would be no interference, the only difference is that it would also test the training loop. But we could also skip it anyway...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did actually think about calling fit() again inside on_fit_begin() but I decided against. But you're right, it would be a good way to check both the training and validation loops before committing to a full fit.

Copy link
Member

@dfalbel dfalbel Jul 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking again, there could still be some side effects, eg: the callbacks passed by the user can have side effects outside of the R session (maybe writing to a file or something like this). So maybe we want to call fit again, only with the default callbacks + the one that breaks the training loop.

This is not completely ideal, because still there would be callbacks that could fail in the 'real' pass. Bu sounds like enough, I guess.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with only calling the default callbacks. My original reason for avoiding callbacks was for loggers and other things written to disk. But if we only run default callbacks we can avoid this issue. The function docs can just point out that user callbacks aren't validated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants