-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRBF_block.R
89 lines (78 loc) · 3.35 KB
/
RBF_block.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
## Copyright 2019 Andrew Zammit Mangion
##
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
##
## http://www.apache.org/licenses/LICENSE-2.0
##
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
#' @title Radial Basis Function Warpings
#' @description Sets up a composition of radial basis functions (RBFs) for used in a deep compositional spatial model. The function
#' sets up RBFs on a prescribed domain on a grid at a certain resolution.
#' It returns a list containing all the functions in the single-resolution RBF unit. See Value for more details.
#' @param res the resolution
#' @param lims the limits of one side of the square 2D domain on which to set up the RBFs
#' @return \code{RBF_block} returns a list containing a list for each RBF in the block with the following components:
#' \describe{
#' \item{"f"}{An encapsulated function that takes an input and evaluates the RBF over some input using \code{TensorFlow}}
#' \item{"fR"}{Same as \code{f} but uses \code{R}}
#' \item{"fMC"}{Same as \code{f} but does it in parallel for several inputs index by the first dimension of the tensor}
#' \item{"r"}{The number of basis functions (one for each layer)}
#' \item{"trans"}{The transformation applied to the weights before estimation}
#' \item{"fix_weights"}{Flag indicating whether the weights are fixed or not (FALSE for RBFs)}
#' \item{"name"}{Name of layer}
#' }
#' @export
#' @examples
#' layer <- RBF_block(res = 1L)
RBF_block <- function(res = 1L, lims = c(-0.5, 0.5)) {
## Parameters appearing in sigmoid (grad, loc)
r <- (3^res)^2
cx1d <- seq(lims[1], lims[2], length.out = sqrt(r))
cxgrid <- expand.grid(s1 = cx1d, s2 = cx1d) %>% as.matrix()
a <- 2*(3^res - 1)^2
theta <- cbind(cxgrid, a)
theta_tf <- tf$constant(theta, dtype = "float32")
RBF_list <- list()
trans <- function(transeta) {
tf$exp(-transeta) %>%
tf$add(tf$constant(1, dtype = "float32")) %>%
tf$reciprocal() %>%
tf$multiply(tf$constant(1 + exp(3/2)/2, dtype = "float32")) %>%
tf$add(tf$constant(-1, dtype = "float32"))
}
for(count in 1:r) {
ff <- function(count) {
j <- count
f = function(s_tf, eta_tf) {
PHI_tf <- RBF_tf(s_tf, theta_tf[j, , drop = FALSE])
swarped <- tf$multiply(PHI_tf, eta_tf)
sout_tf <- tf$add(swarped, s_tf)
}
fMC = function(s_tf, eta_tf) {
PHI_tf <- RBF_tf(s_tf, theta_tf[j, , drop = FALSE])
swarped <- tf$multiply(PHI_tf, eta_tf)
sout_tf <- tf$add(swarped, s_tf)
}
fR = function(s, eta) {
PHI <- RBF(s, theta[j, , drop = FALSE])
swarped <- PHI*eta
sout <- swarped + s
}
list(f = f, fMC = fMC, fR = fR)
}
RBF_list[[count]] <- list(f = ff(count)$f,
fMC = ff(count)$fMC,
fR = ff(count)$fR,
r = 1L,
trans = trans,
fix_weights = FALSE,
name = "RBF")
}
RBF_list
}