-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NormalizedKernel #274
Add NormalizedKernel #274
Conversation
Thanks for your PR! It looks very good 👍 I left some comments, mostly I am worried about the additional allocations. Could we try to make sure that broadcasting operations are fused as much as possible? |
Thanks for the comments! I think I fused the broadcasts as much as possible? I also optimised the |
Hm, in general it is not guaranteed that function kernelmatrix_diag(κ::NormalizedKernel, x::AbstractVector)
first_x = first(x)
return Fill(κ(first_x, first_x), length(x))
end Then the number type of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me if tests pass. Thank you for this great PR! Can you update the version number in Project.toml to 0.9.2 so we can make a new release once the PR is merged?
Thank you! There is a failing test for the Zygote gradient of
I'm not that familiar with Zygote, so I'm not sure if the problem is with
|
Hey! kernelmatrix_diag expert here! I think the issue might be from using broadcast over each sample. Typically using broadcast over ColVecs or RowVecs lead to problems with Zygote... |
Also we did not have the FillArrays dependency before, there is no way to do without it? |
Does it work with |
The gradient doesn't work with |
It seemed fine to me since it only depends on standard libraries and is used by many packages in the ecosystem (such as AbstractGPs 😉). In principle, it should be more efficient to use |
src/kernels/normalizedkernel.jl
Outdated
|
||
function kernelmatrix_diag(κ::NormalizedKernel, x::AbstractVector) | ||
first_x = first(x) | ||
return Fill(κ(first_x, first_x), length(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we replace Fill
by Ones
here?
return Fill(κ(first_x, first_x), length(x)) | |
return Ones{typeof(κ(first_x, first_x))}(length(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or can some kernel return a negative value...
In which case both propositions are wrong :p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look like this fixes the AD at least, cheers :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose k(x, x) could be zero as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then we're in trouble cause we would be dividing by 0 and you would probably have bigger problems generally.
|
||
function kernelmatrix(κ::NormalizedKernel, x::AbstractVector, y::AbstractVector) | ||
return kernelmatrix(κ.kernel, x, y) ./ | ||
sqrt.( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it generally more efficient to ocmpute the sqrt
first? This is probably a performance detail though (and I don't know about machine accuracy)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this all get fused anyway though?
Co-authored-by: Théo Galy-Fajou <theo.galyfajou@gmail.com>
Are you satisfied with the PR @theogf? To me it seems it could be merged after updating the version number again (version 0.9.2 already exists) and adding a compatibility bound for FillArrays. |
Sure LGTM! |
Thank you @rossviljoen! |
This PR adds a composite kernel to normalize a kernel as follows:
k'(x, y) = k(x, y) / sqrt(k(x, x), k(y, y))
As discussed in #259.