-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathslpDIVA.R
317 lines (262 loc) · 11.6 KB
/
slpDIVA.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# # # .diva.backprop
#' backpropagate error and update weights
#'
#' @param out_wts Matrix of output weights
#' @param in_wts Matrix of inputs weights
#' @param out_activation Array of output unit activations
#' @param current_target Target feature values for reconstruction
#' error calculation
#' @param hid_activation Hidden unit activation
#' @param hid_activation_raw Raw hidden unit activation
#' @param ins_w_bias Inputs with bias unit added
#' @param learning_rate Learning rate for weight updates through
#' backpropagation
#' @return List of updated in weights and out weights
.diva.backprop <- function(out_wts, in_wts, out_activation, current_target,
hid_activation, hid_activation_raw, ins_w_bias,
learning_rate){
# # # calc error on output units
out_delta <- 2 * (out_activation - current_target)
# # # calc error on hidden units
hid_delta <- out_delta %*% t(out_wts)
hid_delta <- hid_delta[,2:ncol(hid_delta)] *
.diva.sigmoid_grad(hid_activation_raw)
# # # calc weight changes
out_delta <- learning_rate * (t(hid_activation) %*% out_delta)
hid_delta <- learning_rate * (t(ins_w_bias) %*% hid_delta)
# # # adjust wts
out_wts <- out_wts - out_delta
in_wts <- in_wts - hid_delta
return(list(out_wts = out_wts,
in_wts = in_wts))
}
# # # .diva.forward_pass
#' Conducts forward pass
#'
#' @param in_wts Matrix of weights from input to hidden layer
#' @param out_wts Array of weights from hidden layer to output
#' channels
#' @param inputs Matrix of input features
#' @param continuous Boolean value to indicate if inputs are
#' continuous
#' @return List of output unit activation, hidden unit activation, raw
#' hidden unit activation and inputs with bias
#'
.diva.forward_pass <- function(in_wts, out_wts, inputs, continuous) {
# # # init needed vars
num_feats <- ncol(out_wts)
num_cats <- dim(out_wts)[3]
num_stims <- nrow(inputs)
if (is.null(num_stims)) {num_stims <- 1}
# # # add bias to ins
bias_units <- matrix(rep(1, num_stims), ncol = 1,
nrow = num_stims)
ins_w_bias <- cbind(bias_units,
matrix(inputs, nrow = num_stims,
ncol = num_feats, byrow = TRUE))
# # # ins to hids propagation
hid_activation_raw <- ins_w_bias %*% in_wts
hid_activation <- .diva.sigmoid(hid_activation_raw)
# # # add bias unit to hid activation
hid_activation <- cbind(bias_units, hid_activation)
# # # hids to outs propagation
out_activation <- array(rep(0, (num_stims * num_feats * num_cats)),
dim = c(num_stims, num_feats, num_cats))
# # # get output activation
for (category in 1:num_cats) {
out_activation[,,category] <- hid_activation %*%
out_wts[,,category]
}
# # # apply output activation rule
if(continuous == FALSE) out_activation <- .diva.sigmoid(out_activation)
return(list(out_activation = out_activation,
hid_activation = hid_activation,
hid_activation_raw = hid_activation_raw,
ins_w_bias = ins_w_bias))
}
# # # .diva.get_wts
#' # Generate input and output weights for initialization of DIVA
#'
#' @param num_feats Scalar value for the number of features in the
#' input
#' @param num_hids Scalar value for the number of hidden units in the
#' model architecture
#' @param num_cats Scalar value for the number of categories
#' @param wts_range Scalar value for the range of the generated
#' weights
#' @param wts_center Scalar value for the center of the weights
#' @return List with input (input to hidden) weights and output
#' weights (hidden to output channels)
.diva.get_wts <- function(num_feats, num_hids, num_cats, wts_range,
wts_center) {
# # # set bias
bias <- 1
# # # Default Model behavior
# # # generate wts between ins and hids
in_wts <-
(matrix(runif((num_feats + bias) * num_hids), ncol = num_hids) - 0.5) * 2
in_wts <- wts_center + (wts_range * in_wts)
# # # generate wts between hids and outs
out_wts <-
(array(runif((num_hids + bias) * num_feats * num_cats),
dim = c((num_hids + bias), num_feats, num_cats)) - 0.5) * 2
out_wts <- wts_center + (wts_range * out_wts)
return(list(in_wts = in_wts,
out_wts = out_wts))
}
# # # .diva.global_scale
#' Scale model targets to 0 : 1 values appropriate for sigmoid output unit activation
#'
#' @param inputs Matrix of inputs in format -1 : 1 that need to be
#' scaled
#' @return Matrix of inputs scaled to 0 : 1
.diva.global_scale <- function(inputs) { inputs / 2 + 0.5 }
# .diva.response_rule
#' convert output activations to classification
#'
#' @param out_activation Array of output channel activations
#' @param target_activation Array of output unit targets
#' @param beta_val Scalar value for the beta parameter (set in st)
#' @param phi Scalar value for response mapping constant (set in st)
#' @return List of classification probability, focusing weights and
#' sum squared error
.diva.response_rule <- function(out_activation, target_activation, beta_val, phi){
num_feats <- ncol(out_activation)
num_cats <- dim(out_activation)[3]
num_stims <- nrow(target_activation)
if (is.null(num_stims)) {num_stims <- 1}
# # # compute error
ssqerror <- array(as.vector(
apply(out_activation, 3, function(x) {x - target_activation})),
c(num_stims, num_feats, num_cats))
ssqerror <- ssqerror ^ 2
ssqerror[ssqerror < 1e-7] <- 1e-7
# # # if focusing is on:
if (beta_val > 0) {
# # # get list of channel comparisons
pairwise_comps <- combn(1:num_cats, 2)
# # # get differences for each feature between categories
diff_matrix <-
abs(apply(pairwise_comps, 2, function(x) {
out_activation[,,x[1]] - out_activation[,,x[2]]}))
# # # reconstruct activation array and get feature diversity means
diff_array <- array(diff_matrix, dim = c(num_stims, num_feats, num_cats))
feature_diffs <- apply(diff_array, 2, mean)
# # # calculate diversities
diversities <- exp(beta_val * feature_diffs)
diversities[diversities > 1e+7] <- 1e+7
# # # divide diversities by sum of diversities
fweights <- diversities / sum(diversities)
# # # or set fweights equal
} else {
fweights <- rep(1/num_feats, num_feats)
}
# # # apply focus weights; then get sum for each category
ssqerror <- t(apply(ssqerror, 3, function(x) sum(x * fweights)))
# # # calculate inverse sse
ssqerror <- 1 / ssqerror
# # # subtract the max error for numerical stability
ssqerror <- ssqerror - max(ssqerror)
# # # apply response mapping and exponentiate
ssqerror <- exp(phi * ssqerror)
# # # set resp probs equal to 1
ps <- ssqerror / sum(ssqerror)
return(list(ps = ps,
fweights = fweights,
ssqerror = ssqerror))
}
# .diva.sigmoid
# returns sigmoid evaluated element-wise in X
#'
#' Returns sigmoid evaluated element-wise in X
#'
#' @param x Matrix of values to be evaluated with sigmoid function
#' @return Same format of input, evaluated with the sigmoid function
.diva.sigmoid <- function(x) {
g = 1 / (1 + exp(-x))
return(g)
}
# sigmoid gradient
#'
#' Returns gradient of the sigmoid function evaluated at x
#'
#' @param x Values to be evaluated for the sigmoid gradient
#' @return Gradient of the sigmoid function for the input
.diva.sigmoid_grad <- function(x) {
return(g = ((.diva.sigmoid(x)) * (1 - .diva.sigmoid(x))))
}
# slpDIVA
#'
#' Train stateful list processor DIVA
#'
#' @param st A list of the model parameters
#' @param tr A matrix of the input and class labels
#' @param xtdo A boolean value indicating if extended output is
#' desired
#' @return List including a matrix of model classification
#' probabilities and list of model's final state
slpDIVA <- function(st, tr, xtdo = FALSE) {
# # # construct weight matrix history list
wts_history <- list(initial = list(), final = list())
# # # convert targets to 0/1 for binomial input data ONLY
targets <- tr[,(st$colskip + 1):(st$colskip + st$num_feats), drop = FALSE]
if (st$continuous == FALSE) targets <- .diva.global_scale(targets)
# # # init size parameter variables
out <- matrix(rep(NA, st$num_cats * dim(tr)[1]),
ncol = st$num_cats, nrow = dim(tr)[1])
# # # iterate over each trial in the tr matrix
for (trial_num in 1:dim(tr)[1]) {
current_input <- tr[trial_num, (st$colskip + 1):(st$colskip +
st$num_feats)]
current_target <- targets[trial_num,]
# # # determine current class from MLP-style outs
out_unit_loc <- c((st$colskip + st$num_feats + 1),
(st$colskip + st$num_feats + st$num_cats))
out_units <- tr[trial_num, out_unit_loc[1]:out_unit_loc[2]]
current_class <- which.max(out_units)
# # # if ctrl is set to 1 generate new weights
if (tr[trial_num, 'ctrl'] == 1) {
# # # save existing weights
wts_history$final[[length(wts_history$final) + 1]] <-
list(in_wts = st$in_wts, out_wts = st$out_wts)
# # # generate new weights
wts <- .diva.get_wts(st$num_feats, st$num_hids,
st$num_cats, st$wts_range,
st$wts_center)
st$in_wts <- wts$in_wts
st$out_wts <- wts$out_wts
# # # save new weights
wts_history$initial[[length(wts_history$initial) + 1]] <-
list(in_wts = st$in_wts, out_wts = st$out_wts)
}
# # # complete forward pass
fp <- .diva.forward_pass(st$in_wts, st$out_wts, current_input,
st$continuous)
# # # calculate classification probability
response <- .diva.response_rule(fp$out_activation, current_target,
st$beta_val, st$phi)
# # # store classification accuracy
out[trial_num,] = response$ps
# # # adjust weights based on ctrl setting
if (tr[trial_num, 'ctrl'] < 2) {
# # # back propagate error to adjust weights
class_wts <- st$out_wts[,,current_class]
class_activation <- fp$out_activation[,,current_class]
adjusted_wts <- .diva.backprop(class_wts, st$in_wts,
class_activation, current_target,
fp$hid_activation,
fp$hid_activation_raw,
fp$ins_w_bias, st$learning_rate)
# # # set new weights
st$out_wts[,,current_class] <- adjusted_wts$out_wts
st$in_wts <- adjusted_wts$in_wts
}
}
# # # save extended output
if (xtdo == TRUE) {
xtd_output <- list()
xtd_output$wts_history <- wts_history
return(list(out = out, xtd_output = xtd_output))
}
return(list(out = out, st = st))
}