forked from jmxpearson/bayesrl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model2b.stan
87 lines (76 loc) · 3.09 KB
/
model2b.stan
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// Model 2:
// Model 2, but only updates chosen option each trial
data {
int<lower = 1> N; // number of observations
int<lower = 1> Nsub; // number of subjects
int<lower = 1> Ncue; // number of cues
int<lower = 1> Ntrial; // number of trials per subject
int<lower = 1> Ngroup; // number of experimental groups
int<lower = 1> Ncond; // number of task conditions
int<lower = 0> sub[N]; // subject index
int<lower = 0> chosen[N]; // index of chosen option: 0 => missing
int<lower = 0> unchosen[N]; // index of unchosen option: 0 => missing
int<lower = 0> condition[N]; // delay condition code: 0 => missing
int<lower = 1> trial[N]; // trial number
int<lower = -1, upper = 1> outcome[N]; // outcome: -1 => missing
int<lower = 1> group[Nsub]; // group assignment for each subject
}
parameters {
vector<lower = 0>[Nsub] beta; // softmax parameter
real<lower = 0, upper = 1> alpha[Nsub, Ncond]; // learning rate
real<lower = 0> a[Ngroup, Ncond]; // parameter for group-specific alpha
real<lower = 0> b[Ngroup, Ncond]; // parameter for group-specific alpha
}
transformed parameters {
real<lower=0, upper=1> Q[Nsub, Ntrial, Ncue]; // value function for each target
real Delta[Nsub, Ntrial, Ncue]; // prediction error
for (idx in 1:N) {
if (trial[idx] == 1) {
for (c in 1:Ncue) {
Q[sub[idx], trial[idx], c] <- 0.5;
Delta[sub[idx], trial[idx], c] <- 0;
}
}
if (trial[idx] < Ntrial) { // push forward this trial's values
for (c in 1:Ncue) {
Q[sub[idx], trial[idx] + 1, c] <- Q[sub[idx], trial[idx], c];
Delta[sub[idx], trial[idx], c] <- 0;
}
}
if (outcome[idx] >= 0) {
// prediction error: chosen option
Delta[sub[idx], trial[idx], chosen[idx]] <- outcome[idx] - Q[sub[idx], trial[idx], chosen[idx]];
if (trial[idx] < Ntrial) { // update action values for next trial
// update chosen option
Q[sub[idx], trial[idx] + 1, chosen[idx]] <- Q[sub[idx], trial[idx], chosen[idx]] + alpha[sub[idx], condition[idx]] * Delta[sub[idx], trial[idx], chosen[idx]];
}
}
}
}
model {
beta ~ gamma(1, 0.2);
for (grp in 1:Ngroup) {
for (cond in 1:Ncond) {
a[grp, cond] ~ gamma(1, 1);
b[grp, cond] ~ gamma(1, 1);
}
}
for (idx in 1:Nsub) {
for (cond in 1:Ncond) {
alpha[idx, cond] ~ beta(a[group[idx], cond], b[group[idx], cond]);
}
}
for (idx in 1:N) {
if (chosen[idx] > 0) {
1 ~ bernoulli_logit(beta[sub[idx]] * (Q[sub[idx], trial[idx], chosen[idx]] - Q[sub[idx], trial[idx], unchosen[idx]]));
}
}
}
generated quantities { // generate samples of learning rate from each group
real<lower=0, upper=1> alpha_pred[Ngroup, Ncond];
for (grp in 1:Ngroup) {
for (cond in 1:Ncond) {
alpha_pred[grp, cond] <- beta_rng(a[grp, cond], b[grp, cond]);
}
}
}