Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NormalizedKernel #274

Merged
merged 11 commits into from
Apr 15, 2021
Merged

Add NormalizedKernel #274

merged 11 commits into from
Apr 15, 2021

Conversation

rossviljoen
Copy link
Contributor

This PR adds a composite kernel to normalize a kernel as follows:

k'(x, y) = k(x, y) / sqrt(k(x, x), k(y, y))

As discussed in #259.

src/kernels/normalizedkernel.jl Outdated Show resolved Hide resolved
src/kernels/normalizedkernel.jl Outdated Show resolved Hide resolved
src/kernels/normalizedkernel.jl Show resolved Hide resolved
src/kernels/normalizedkernel.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

Thanks for your PR! It looks very good 👍 I left some comments, mostly I am worried about the additional allocations. Could we try to make sure that broadcasting operations are fused as much as possible?

@rossviljoen
Copy link
Contributor Author

rossviljoen commented Apr 9, 2021

Thanks for the comments! I think I fused the broadcasts as much as possible?

I also optimised the kernelmatrix_diag functions (with one input) to just return the one vector, but is there a better way to get the right return shape/type of the inner kernel than something like map(one, kernelmatrix_diag(κ.kernel, x))?

@devmotion
Copy link
Member

devmotion commented Apr 9, 2021

Hm, in general it is not guaranteed that map(one, ...) returns a vector with the same element type as kernelmatrix_diag(κ, x, x), e.g., if the kernel does not return floating point numbers. So it is a bit unsatisfying even if probably such a situation is quite rare. Maybe it would be better (and more efficient, I think) to define

function kernelmatrix_diag::NormalizedKernel, x::AbstractVector)
    first_x = first(x)
    return Fill(κ(first_x, first_x), length(x))
end

Then the number type of kernelmatrix_diag(κ, x) and kernelmatrix_diag(κ, x, x) would be the same, it seems, and it would not even allocate a full vector. However, of course, this would imply that the exact types of kernelmatrix_diag(κ, x) and kernelmatrix_diag(κ, x, x) would still be different.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me if tests pass. Thank you for this great PR! Can you update the version number in Project.toml to 0.9.2 so we can make a new release once the PR is merged?

@rossviljoen
Copy link
Contributor Author

rossviljoen commented Apr 10, 2021

Thank you!

There is a failing test for the Zygote gradient of kernelmatrix_diag, seemingly because:

julia> k = NormalizedKernel(SqExponentialKernel())
julia> A = [1 2; 3 4]
julia> Zygote.gradient(x -> sum(kernelmatrix_diag(k, x)), A)

ERROR: MethodError: no method matching (::KernelFunctions.var"#ColVecs_pullback#175")(::Vector{Union{Nothing, Vector{Float64}}})

I'm not that familiar with Zygote, so I'm not sure if the problem is with ColVecs_pullback or elsewhere?
Incidentally, the gradient for a vector is fine:

julia> B = [1, 2]
julia> Zygote.gradient(x -> sum(kernelmatrix_diag(k, x)), B)
([0.0, 0.0],)

@theogf
Copy link
Member

theogf commented Apr 10, 2021

Hey! kernelmatrix_diag expert here! I think the issue might be from using broadcast over each sample. Typically using broadcast over ColVecs or RowVecs lead to problems with Zygote...

@theogf
Copy link
Member

theogf commented Apr 10, 2021

Also we did not have the FillArrays dependency before, there is no way to do without it?

@devmotion
Copy link
Member

Does it work with fill instead of Fill?

@rossviljoen
Copy link
Contributor Author

The gradient doesn't work with fill either, no.

@devmotion
Copy link
Member

Also we did not have the FillArrays dependency before, there is no way to do without it?

It seemed fine to me since it only depends on standard libraries and is used by many packages in the ecosystem (such as AbstractGPs 😉). In principle, it should be more efficient to use Fill instead of fill, I was just worried if it breaks AD but it doesn't seem to be the case.


function kernelmatrix_diag(κ::NormalizedKernel, x::AbstractVector)
first_x = first(x)
return Fill(κ(first_x, first_x), length(x))
Copy link
Member

@theogf theogf Apr 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we replace Fill by Ones here?

Suggested change
return Fill(κ(first_x, first_x), length(x))
return Ones{typeof(κ(first_x, first_x))}(length(x))

Copy link
Member

@theogf theogf Apr 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or can some kernel return a negative value...
In which case both propositions are wrong :p

Copy link
Contributor Author

@rossviljoen rossviljoen Apr 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look like this fixes the AD at least, cheers :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose k(x, x) could be zero as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then we're in trouble cause we would be dividing by 0 and you would probably have bigger problems generally.


function kernelmatrix(κ::NormalizedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix(κ.kernel, x, y) ./
sqrt.(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it generally more efficient to ocmpute the sqrt first? This is probably a performance detail though (and I don't know about machine accuracy)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this all get fused anyway though?

rossviljoen and others added 2 commits April 10, 2021 21:41
Co-authored-by: Théo Galy-Fajou <theo.galyfajou@gmail.com>
@devmotion
Copy link
Member

Are you satisfied with the PR @theogf? To me it seems it could be merged after updating the version number again (version 0.9.2 already exists) and adding a compatibility bound for FillArrays.

@theogf
Copy link
Member

theogf commented Apr 15, 2021

Sure LGTM!

@devmotion devmotion merged commit 376407d into JuliaGaussianProcesses:master Apr 15, 2021
@devmotion
Copy link
Member

Thank you @rossviljoen!

@rossviljoen rossviljoen deleted the normalization branch August 16, 2021 21:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants