-
Notifications
You must be signed in to change notification settings - Fork 17
/
elm.Rmd
186 lines (119 loc) · 4.18 KB
/
elm.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# Extreme Learning Machine
A very simple implementation of an extreme learning machine for regression, which can be seen as a quick way to estimate a 'good enough' neural net, one that can be nearly as performant but with a lot less computational overhead. See <span class="pack" style = "">elmNN</span> and <span class="pack" style = "">ELMR</span> for some R package implementations. I add comparison to generalized additive models (elm/neural networks and GAMs are adaptive basis function models).
- http://www.extreme-learning-machines.org
- G.B. Huang, Q.Y. Zhu and C.K. Siew, Extreme Learning Machine: Theory and Applications.
## Data Setup
One variable, complex function.
```{r elm-setup1}
library(tidyverse)
library(mgcv)
set.seed(123)
n = 5000
x = runif(n)
# x = rnorm(n)
mu = sin(2*(4*x-2)) + 2* exp(-(16^2) * ((x-.5)^2))
y = rnorm(n, mu, .3)
d = data.frame(x, y)
qplot(x, y, color = I('#ff55001A'))
```
Motorcycle accident data.
```{r elm-setup2}
data('mcycle', package = 'MASS')
times = matrix(mcycle$times, ncol = 1)
accel = mcycle$accel
```
## Function
```{r elm-func}
elm <- function(X, y, n_hidden = NULL, active_fun = tanh) {
# X: an N observations x p features matrix
# y: the target
# n_hidden: the number of hidden nodes
# active_fun: activation function
pp1 = ncol(X) + 1
w0 = matrix(rnorm(pp1*n_hidden), pp1, n_hidden) # random weights
h = active_fun(cbind(1, scale(X)) %*% w0) # compute hidden layer
B = MASS::ginv(h) %*% y # find weights for hidden layer
fit = h %*% B # fitted values
list(
fit = fit,
loss = crossprod(y - fit),
B = B,
w0 = w0
)
}
```
## Estimation
```{r elm-est1}
X_mat = as.matrix(x, ncol = 1)
fit_elm = elm(X_mat, y, n_hidden = 100)
str(fit_elm)
ggplot(aes(x, y), data = d) +
geom_point(color = '#ff55001A') +
geom_line(aes(y = fit_elm$fit), color = '#00aaff')
cor(fit_elm$fit[,1], y)^2
```
```{r elm-est2}
fit_elm_mcycle = elm(times, accel, n_hidden = 100)
cor(fit_elm_mcycle$fit[,1], accel)^2
```
## Comparison
We'll compare to a generalized additive model with gaussian process approximation.
```{r elm-compare1}
fit_gam = gam(y ~ s(x, bs = 'gp', k = 20), data = d)
summary(fit_gam)$r.sq
d %>%
mutate(fit_elm = fit_elm$fit,
fit_gam = fitted(fit_gam)) %>%
ggplot() +
geom_point(aes(x, y), color = '#ff55001A') +
geom_line(aes(x, y = fit_elm), color = '#1e90ff') +
geom_line(aes(x, y = fit_gam), color = '#990024')
```
```{r elm-compare2}
fit_gam_mcycle = gam(accel ~ s(times), data = mcycle)
summary(fit_gam_mcycle)$r.sq
mcycle %>%
ggplot(aes(times, accel)) +
geom_point(color = '#ff55001A') +
geom_line(aes(y = fit_elm_mcycle$fit), color = '#1e90ff') +
geom_line(aes(y = fitted(fit_gam_mcycle)), color = '#990024')
```
## Supplemental Example
Yet another example with additional covariates.
```{r elm-setup-3}
d = gamSim(eg = 7, n = 10000)
X = as.matrix(d[, 2:5])
y = d[, 1]
n_nodes = c(10, 25, 100, 250, 500, 1000)
```
The following estimation over multiple models will take several seconds.
```{r elm-est3}
fit_elm = purrr::map(n_nodes, function(n) elm(X, y, n_hidden = n))
```
Now find the best fitting model.
```{r elm-best}
# estimate
best_loss = which.min(map_dbl(fit_elm, function(x) x$loss))
fit_best = fit_elm[[best_loss]]
```
A quick check of the fit.
```{r elm-fit}
# str(fit_best)
# qplot(fit_best$fit[, 1], y, alpha = .2)
cor(fit_best$fit[, 1], y)^2
```
And compare again to <span class="pack" style = "">mgcv</span>. In this case, we're comparing fit on test data of the same form.
```{r elm-compare3}
fit_gam = gam(y ~ s(x0) + s(x1) + s(x2) + s(x3), data = d)
gam.check(fit_gam)
summary(fit_gam)$r.sq
test_data0 = gamSim(eg = 7) # default n = 400
test_data = cbind(1, scale(test_data0[, 2:5]))
# remember to use your specific activation function here
elm_prediction = tanh(test_data %*% fit_best$w0) %*% fit_best$B
gam_prediction = predict(fit_gam, newdata = test_data0)
cor(data.frame(elm_prediction, gam_prediction), test_data0$y)^2
```
## Source
Original code available at:
https://github.com/m-clark/Miscellaneous-R-Code/blob/master/ModelFitting/elm.R