@@ -310,156 +310,124 @@ normL2 <- function(data, x, errmodel = NULL, times = NULL,
310310# ' and its derivatives with respect to p and sigma. Sigma parameters being
311311# ' passed to the function are ALWAYS assumed to be on a log scale, i.e. internally
312312# ' sigma parameters are converted by \code{exp()}.
313- # ' @examples
314- # ' mu <- c(A = 0, B = 0)
315- # ' sigma <- c(A = 0.1, B = 1)
316- # ' myfn <- constraintL2(mu, sigma)
317- # ' myfn(pars = c(A = 1, B = -1))
318- # '
319- # ' # Introduce sigma parameter but fix them (sigma parameters
320- # ' # are assumed to be passed on log scale)
321- # ' mu <- c(A = 0, B = 0)
322- # ' sigma <- paste("sigma", names(mu), sep = "_")
323- # ' myfn <- constraintL2(mu, sigma)
324- # ' pars <- c(A = .8, B = -.3, sigma_A = -1, sigma_B = 1)
325- # ' myfn(pars = pars[c(1, 3)], fixed = pars[c(2, 4)])
326- # '
327- # ' # Assume same sigma parameter for both A and B
328- # ' # sigma is assumed to be passed on log scale
329- # ' mu <- c(A = 0, B = 0)
330- # ' myfn <- constraintL2(mu, sigma = "sigma")
331- # ' pars <- c(A = .8, B = -.3, sigma = 0)
332- # ' myfn(pars = pars)
333- # '
334313# ' @export
335314constraintL2 <- function (mu , sigma = 1 , attr.name = " prior" , condition = NULL ) {
336315
316+ estimateSigma <- is.character(sigma )
337317
338- # Aktuell zu kompliziert aufgesetzt. Man sollte immer die komplette Hessematrix/Gradient
339- # auswerten und dann die Elemente streichen, die in fixed sind!
340-
341-
342- estimateSigma <- ifelse(is.character(sigma ), TRUE , FALSE )
343- if (length(sigma ) > 1 & length(sigma ) < length(mu ))
344- stop(" sigma must either have length 1 or at least length equal to length of mu." )
345-
346- # # Augment sigma if length = 1
347318 if (length(sigma ) == 1 )
348319 sigma <- structure(rep(sigma , length(mu )), names = names(mu ))
349320 if (is.null(names(sigma )))
350321 names(sigma ) <- names(mu )
351- if (! is.null(names(sigma )) & ! all(names(mu ) %in% names(sigma )))
352- stop(" Names of sigma and names of mu do not match." )
353-
354- # # Bring sigma in correct order (no matter if character or numeric)
355322 sigma <- sigma [names(mu )]
356323
357324 controls <- list (mu = mu , sigma = sigma , attr.name = attr.name )
358325
359- myfn <- function (... , fixed = NULL , deriv = TRUE , conditions = condition , env = NULL ) {
326+ myfn <- function (... , fixed = NULL , deriv = TRUE , deriv2 = FALSE , conditions = condition , env = NULL ) {
360327
361- arglist <- list (... )
362- arglist <- arglist [match.fnargs(arglist , " pars" )]
363- pouter <- arglist [[1 ]]
364-
365- # Import from controls
328+ pouter <- list (... )[[match.fnargs(list (... ), " pars" )]]
366329 mu <- controls $ mu
367330 sigma <- controls $ sigma
368- attr.name <- controls $ attr.name
369- nmu <- length(mu )
370331
371- # pouter can be a list (if result from a parameter transformation)
372- # In this case match with conditions and evaluate only those
373- # If there is no overlap, return NULL
374- # If pouter is not a list, evaluate the constraint function
375- # for this pouter.
332+ # Handle list input (multiple conditions)
376333 if (is.list(pouter ) && ! is.null(conditions )) {
377334 available <- intersect(names(pouter ), conditions )
378- defined <- ifelse(is.null(condition ), TRUE , condition %in% conditions )
379-
380- if (length(available ) == 0 | ! defined ) return ()
381- pouter <- pouter [intersect(available , condition )]
335+ if (length(available ) == 0 ) return ()
336+ pouter <- pouter [available ]
382337 }
383338 if (! is.list(pouter )) pouter <- list (pouter )
384339
385-
386340 outlist <- lapply(pouter , function (p ) {
387341
342+ # Get deriv attributes before any manipulation
343+ dP <- attr(p , " deriv" , exact = TRUE )
344+ dP2 <- attr(p , " deriv2" , exact = TRUE )
388345
389- pars <- c(p , fixed )[names(mu )]
390- p1 <- setdiff(intersect(names(mu ), names(p )), names(fixed ))
346+ # Combine and extract available parameters
347+ all_pars <- c(as.numeric(p ), fixed )
348+ names(all_pars ) <- c(names(p ), names(fixed ))
349+ avail <- intersect(names(mu ), names(all_pars ))
391350
392- # if estimate sigma, produce numeric sigma vector from the parameters provided in p and fixed
393- if (estimateSigma ) {
394- sigmapars <- sigma
395- sigma <- exp(c(p , fixed )[sigma ])
396- names(sigma ) <- names(mu )
397- Jsigma <- do.call(cbind , lapply(unique(sigmapars ), function (s ) {
398- (sigmapars == s )* sigma
399- }))
400- colnames(Jsigma ) <- unique(sigmapars )
401- rownames(Jsigma ) <- names(sigma )
402- p2 <- setdiff(intersect(unique(sigmapars ), names(p )), names(fixed ))
403- }
351+ if (length(avail ) == 0 )
352+ return (objlist(value = 0 ,
353+ gradient = if (deriv ) setNames(rep(0 , length(p )), names(p )) else NULL ,
354+ hessian = if (deriv ) matrix (0 , length(p ), length(p ), dimnames = list (names(p ), names(p ))) else NULL ))
404355
405- # Compute constraint value and derivatives
406- val <- sum((pars - mu )^ 2 / sigma ^ 2 ) + estimateSigma * sum(log(sigma ^ 2 ))
407- val.p <- 2 * (pars - mu )/ sigma ^ 2
408- val.sigma <- - 2 * (pars - mu )^ 2 / sigma ^ 3 + 2 / sigma
409- val.p.p <- diag(2 / sigma ^ 2 , nmu , nmu ); colnames(val.p.p ) <- rownames(val.p.p ) <- names(mu )
410- val.p.sigma <- diag(- 4 * (pars - mu )/ sigma ^ 3 , nmu , nmu ); colnames(val.p.sigma ) <- rownames(val.p.sigma ) <- names(mu )
411- val.sigma.sigma <- diag(6 * (pars - mu )^ 2 / sigma ^ 4 - 2 / sigma ^ 2 , nmu , nmu ); colnames(val.sigma.sigma ) <- rownames(val.sigma.sigma ) <- names(mu )
356+ pars <- all_pars [avail ]
357+ mu_a <- mu [avail ]
358+ sig_a <- sigma [avail ]
359+ n_a <- length(avail )
412360
413- # Multiply with Jacobian of sigma vector if estimate sigma
361+ p1 <- intersect(setdiff(names(mu ), names(fixed )), names(p ))
362+ p2 <- character (0 )
363+
364+ # Handle sigma estimation
414365 if (estimateSigma ) {
415- val.sigma.sigma <- t( Jsigma ) %*% val.sigma.sigma %*% Jsigma + diag((t( val.sigma ) %*% Jsigma )[ 1 ,], ncol( Jsigma ), ncol( Jsigma ) )
416- val.sigma <- ( val.sigma %*% Jsigma )[ 1 ,]
417- val.p.sigma <- ( val.p. sigma %*% Jsigma )
366+ sig_a <- exp( all_pars [ sig_a ] )
367+ names( sig_a ) <- avail
368+ p2 <- intersect(setdiff(unique( sigma [ avail ]), names( fixed )), names( p ) )
418369 }
419370
371+ # Compute value
372+ res <- pars - mu_a
373+ val <- sum(res ^ 2 / sig_a ^ 2 ) + estimateSigma * sum(2 * log(sig_a ))
420374
421375 gr <- hs <- NULL
422376 if (deriv ) {
423- # Produce output gradient and hessian
424- gr <- rep(0 , length(p )); names(gr ) <- names(p )
377+ gr <- setNames(rep(0 , length(p )), names(p ))
425378 hs <- matrix (0 , length(p ), length(p ), dimnames = list (names(p ), names(p )))
426379
427- # Set values in gradient and hessian
428- gr [p1 ] <- val.p [p1 ]
429- hs [p1 , p1 ] <- val.p.p [p1 , p1 ]
430- if (estimateSigma ) {
431- gr [p2 ] <- val.sigma [p2 ]
432- hs [p1 , p2 ] <- val.p.sigma [p1 , p2 ]
433- hs [p2 , p1 ] <- t(val.p.sigma )[p2 , p1 ]
434- hs [p2 , p2 ] <- val.sigma.sigma [p2 , p2 ]
380+ p1_a <- intersect(p1 , avail )
381+ if (length(p1_a ) > 0 ) {
382+ gr [p1_a ] <- 2 * res [p1_a ] / sig_a [p1_a ]^ 2
383+ diag(hs )[p1_a ] <- 2 / sig_a [p1_a ]^ 2
435384 }
436385
437- # Multiply with derivatives of incoming parameter
438- dP <- attr(p , " deriv" )
386+ if (estimateSigma && length(p2 ) > 0 ) {
387+ # Aggregate sigma derivatives by sigma parameter name
388+ for (sp in p2 ) {
389+ idx <- which(sigma [avail ] == sp )
390+ gr [sp ] <- sum(- 2 * res [idx ]^ 2 / sig_a [idx ]^ 2 + 2 )
391+ hs [sp , sp ] <- sum(4 * res [idx ]^ 2 / sig_a [idx ]^ 2 )
392+ }
393+ # Cross terms p1 x p2
394+ for (sp in p2 ) {
395+ idx <- names(sigma [avail ])[sigma [avail ] == sp ]
396+ common <- intersect(idx , p1_a )
397+ if (length(common ) > 0 ) {
398+ hs [common , sp ] <- - 4 * res [common ] / sig_a [common ]^ 2
399+ hs [sp , common ] <- hs [common , sp ]
400+ }
401+ }
402+ }
403+
404+ # Chain rule
439405 if (! is.null(dP )) {
440- gr <- as.vector(gr %*% dP ); names(gr ) <- colnames(dP )
441- hs <- t(dP ) %*% hs %*% dP ; colnames(hs ) <- colnames(dP ); rownames(hs ) <- colnames(dP )
406+ gr_inner <- gr
407+ gr <- as.vector(gr_inner %*% dP )
408+ names(gr ) <- colnames(dP )
409+ hs <- t(dP ) %*% hs %*% dP
410+ if (! is.null(dP2 )) {
411+ hs <- hs + colSums(gr_inner * matrix (dP2 , nrow = length(gr_inner )))
412+ dim(hs ) <- c(ncol(dP ), ncol(dP ))
413+ }
414+ dimnames(hs ) <- list (colnames(dP ), colnames(dP ))
442415 }
443416 }
444417
445418 objlist(value = val , gradient = gr , hessian = hs )
446-
447-
448419 })
449420
450421 out <- Reduce(" +" , outlist )
451422 attr(out , controls $ attr.name ) <- out $ value
452423 attr(out , " env" ) <- env
453- return (out )
454-
455-
424+ out
456425 }
426+
457427 class(myfn ) <- c(" objfn" , " fn" )
458428 attr(myfn , " conditions" ) <- condition
459429 attr(myfn , " parameters" ) <- names(mu )
460- return (myfn )
461-
462-
430+ myfn
463431}
464432
465433
0 commit comments