Skip to content

Commit a77f42e

Browse files
committed
trustL1 update for new dMod stucture
1 parent 7fcadfd commit a77f42e

12 files changed

+490
-1485
lines changed

R/objClass.R

Lines changed: 67 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -310,156 +310,124 @@ normL2 <- function(data, x, errmodel = NULL, times = NULL,
310310
#' and its derivatives with respect to p and sigma. Sigma parameters being
311311
#' passed to the function are ALWAYS assumed to be on a log scale, i.e. internally
312312
#' sigma parameters are converted by \code{exp()}.
313-
#' @examples
314-
#' mu <- c(A = 0, B = 0)
315-
#' sigma <- c(A = 0.1, B = 1)
316-
#' myfn <- constraintL2(mu, sigma)
317-
#' myfn(pars = c(A = 1, B = -1))
318-
#'
319-
#' # Introduce sigma parameter but fix them (sigma parameters
320-
#' # are assumed to be passed on log scale)
321-
#' mu <- c(A = 0, B = 0)
322-
#' sigma <- paste("sigma", names(mu), sep = "_")
323-
#' myfn <- constraintL2(mu, sigma)
324-
#' pars <- c(A = .8, B = -.3, sigma_A = -1, sigma_B = 1)
325-
#' myfn(pars = pars[c(1, 3)], fixed = pars[c(2, 4)])
326-
#'
327-
#' # Assume same sigma parameter for both A and B
328-
#' # sigma is assumed to be passed on log scale
329-
#' mu <- c(A = 0, B = 0)
330-
#' myfn <- constraintL2(mu, sigma = "sigma")
331-
#' pars <- c(A = .8, B = -.3, sigma = 0)
332-
#' myfn(pars = pars)
333-
#'
334313
#' @export
335314
constraintL2 <- function(mu, sigma = 1, attr.name = "prior", condition = NULL) {
336315

316+
estimateSigma <- is.character(sigma)
337317

338-
# Aktuell zu kompliziert aufgesetzt. Man sollte immer die komplette Hessematrix/Gradient
339-
# auswerten und dann die Elemente streichen, die in fixed sind!
340-
341-
342-
estimateSigma <- ifelse(is.character(sigma), TRUE, FALSE)
343-
if (length(sigma) > 1 & length(sigma) < length(mu))
344-
stop("sigma must either have length 1 or at least length equal to length of mu.")
345-
346-
## Augment sigma if length = 1
347318
if (length(sigma) == 1)
348319
sigma <- structure(rep(sigma, length(mu)), names = names(mu))
349320
if (is.null(names(sigma)))
350321
names(sigma) <- names(mu)
351-
if (!is.null(names(sigma)) & !all(names(mu) %in% names(sigma)))
352-
stop("Names of sigma and names of mu do not match.")
353-
354-
## Bring sigma in correct order (no matter if character or numeric)
355322
sigma <- sigma[names(mu)]
356323

357324
controls <- list(mu = mu, sigma = sigma, attr.name = attr.name)
358325

359-
myfn <- function(..., fixed = NULL, deriv = TRUE, conditions = condition, env = NULL) {
326+
myfn <- function(..., fixed = NULL, deriv = TRUE, deriv2 = FALSE, conditions = condition, env = NULL) {
360327

361-
arglist <- list(...)
362-
arglist <- arglist[match.fnargs(arglist, "pars")]
363-
pouter <- arglist[[1]]
364-
365-
# Import from controls
328+
pouter <- list(...)[[match.fnargs(list(...), "pars")]]
366329
mu <- controls$mu
367330
sigma <- controls$sigma
368-
attr.name <- controls$attr.name
369-
nmu <- length(mu)
370331

371-
# pouter can be a list (if result from a parameter transformation)
372-
# In this case match with conditions and evaluate only those
373-
# If there is no overlap, return NULL
374-
# If pouter is not a list, evaluate the constraint function
375-
# for this pouter.
332+
# Handle list input (multiple conditions)
376333
if (is.list(pouter) && !is.null(conditions)) {
377334
available <- intersect(names(pouter), conditions)
378-
defined <- ifelse(is.null(condition), TRUE, condition %in% conditions)
379-
380-
if (length(available) == 0 | !defined) return()
381-
pouter <- pouter[intersect(available, condition)]
335+
if (length(available) == 0) return()
336+
pouter <- pouter[available]
382337
}
383338
if (!is.list(pouter)) pouter <- list(pouter)
384339

385-
386340
outlist <- lapply(pouter, function(p) {
387341

342+
# Get deriv attributes before any manipulation
343+
dP <- attr(p, "deriv", exact = TRUE)
344+
dP2 <- attr(p, "deriv2", exact = TRUE)
388345

389-
pars <- c(p, fixed)[names(mu)]
390-
p1 <- setdiff(intersect(names(mu), names(p)), names(fixed))
346+
# Combine and extract available parameters
347+
all_pars <- c(as.numeric(p), fixed)
348+
names(all_pars) <- c(names(p), names(fixed))
349+
avail <- intersect(names(mu), names(all_pars))
391350

392-
# if estimate sigma, produce numeric sigma vector from the parameters provided in p and fixed
393-
if (estimateSigma) {
394-
sigmapars <- sigma
395-
sigma <- exp(c(p, fixed)[sigma])
396-
names(sigma) <- names(mu)
397-
Jsigma <- do.call(cbind, lapply(unique(sigmapars), function(s) {
398-
(sigmapars == s)*sigma
399-
}))
400-
colnames(Jsigma) <- unique(sigmapars)
401-
rownames(Jsigma) <- names(sigma)
402-
p2 <- setdiff(intersect(unique(sigmapars), names(p)), names(fixed))
403-
}
351+
if (length(avail) == 0)
352+
return(objlist(value = 0,
353+
gradient = if(deriv) setNames(rep(0, length(p)), names(p)) else NULL,
354+
hessian = if(deriv) matrix(0, length(p), length(p), dimnames = list(names(p), names(p))) else NULL))
404355

405-
# Compute constraint value and derivatives
406-
val <- sum((pars - mu)^2/sigma^2) + estimateSigma * sum(log(sigma^2))
407-
val.p <- 2*(pars - mu)/sigma^2
408-
val.sigma <- -2*(pars-mu)^2/sigma^3 + 2/sigma
409-
val.p.p <- diag(2/sigma^2, nmu, nmu); colnames(val.p.p) <- rownames(val.p.p) <- names(mu)
410-
val.p.sigma <- diag(-4*(pars-mu)/sigma^3, nmu, nmu); colnames(val.p.sigma) <- rownames(val.p.sigma) <- names(mu)
411-
val.sigma.sigma <- diag(6*(pars-mu)^2/sigma^4 - 2/sigma^2, nmu, nmu); colnames(val.sigma.sigma) <- rownames(val.sigma.sigma) <- names(mu)
356+
pars <- all_pars[avail]
357+
mu_a <- mu[avail]
358+
sig_a <- sigma[avail]
359+
n_a <- length(avail)
412360

413-
# Multiply with Jacobian of sigma vector if estimate sigma
361+
p1 <- intersect(setdiff(names(mu), names(fixed)), names(p))
362+
p2 <- character(0)
363+
364+
# Handle sigma estimation
414365
if (estimateSigma) {
415-
val.sigma.sigma <- t(Jsigma) %*% val.sigma.sigma %*% Jsigma + diag((t(val.sigma) %*% Jsigma)[1,], ncol(Jsigma), ncol(Jsigma))
416-
val.sigma <- (val.sigma %*% Jsigma)[1,]
417-
val.p.sigma <- (val.p.sigma %*% Jsigma)
366+
sig_a <- exp(all_pars[sig_a])
367+
names(sig_a) <- avail
368+
p2 <- intersect(setdiff(unique(sigma[avail]), names(fixed)), names(p))
418369
}
419370

371+
# Compute value
372+
res <- pars - mu_a
373+
val <- sum(res^2 / sig_a^2) + estimateSigma * sum(2 * log(sig_a))
420374

421375
gr <- hs <- NULL
422376
if (deriv) {
423-
# Produce output gradient and hessian
424-
gr <- rep(0, length(p)); names(gr) <- names(p)
377+
gr <- setNames(rep(0, length(p)), names(p))
425378
hs <- matrix(0, length(p), length(p), dimnames = list(names(p), names(p)))
426379

427-
# Set values in gradient and hessian
428-
gr[p1] <- val.p[p1]
429-
hs[p1, p1] <- val.p.p[p1, p1]
430-
if (estimateSigma) {
431-
gr[p2] <- val.sigma[p2]
432-
hs[p1, p2] <- val.p.sigma[p1, p2]
433-
hs[p2, p1] <- t(val.p.sigma)[p2, p1]
434-
hs[p2, p2] <- val.sigma.sigma[p2, p2]
380+
p1_a <- intersect(p1, avail)
381+
if (length(p1_a) > 0) {
382+
gr[p1_a] <- 2 * res[p1_a] / sig_a[p1_a]^2
383+
diag(hs)[p1_a] <- 2 / sig_a[p1_a]^2
435384
}
436385

437-
# Multiply with derivatives of incoming parameter
438-
dP <- attr(p, "deriv")
386+
if (estimateSigma && length(p2) > 0) {
387+
# Aggregate sigma derivatives by sigma parameter name
388+
for (sp in p2) {
389+
idx <- which(sigma[avail] == sp)
390+
gr[sp] <- sum(-2 * res[idx]^2 / sig_a[idx]^2 + 2)
391+
hs[sp, sp] <- sum(4 * res[idx]^2 / sig_a[idx]^2)
392+
}
393+
# Cross terms p1 x p2
394+
for (sp in p2) {
395+
idx <- names(sigma[avail])[sigma[avail] == sp]
396+
common <- intersect(idx, p1_a)
397+
if (length(common) > 0) {
398+
hs[common, sp] <- -4 * res[common] / sig_a[common]^2
399+
hs[sp, common] <- hs[common, sp]
400+
}
401+
}
402+
}
403+
404+
# Chain rule
439405
if (!is.null(dP)) {
440-
gr <- as.vector(gr %*% dP); names(gr) <- colnames(dP)
441-
hs <- t(dP) %*% hs %*% dP; colnames(hs) <- colnames(dP); rownames(hs) <- colnames(dP)
406+
gr_inner <- gr
407+
gr <- as.vector(gr_inner %*% dP)
408+
names(gr) <- colnames(dP)
409+
hs <- t(dP) %*% hs %*% dP
410+
if (!is.null(dP2)) {
411+
hs <- hs + colSums(gr_inner * matrix(dP2, nrow = length(gr_inner)))
412+
dim(hs) <- c(ncol(dP), ncol(dP))
413+
}
414+
dimnames(hs) <- list(colnames(dP), colnames(dP))
442415
}
443416
}
444417

445418
objlist(value = val, gradient = gr, hessian = hs)
446-
447-
448419
})
449420

450421
out <- Reduce("+", outlist)
451422
attr(out, controls$attr.name) <- out$value
452423
attr(out, "env") <- env
453-
return(out)
454-
455-
424+
out
456425
}
426+
457427
class(myfn) <- c("objfn", "fn")
458428
attr(myfn, "conditions") <- condition
459429
attr(myfn, "parameters") <- names(mu)
460-
return(myfn)
461-
462-
430+
myfn
463431
}
464432

465433

0 commit comments

Comments
 (0)