Skip to content

Commit

Permalink
Add AdaGrad optimizer in R
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyeqinghan committed Jul 26, 2016
1 parent e969d45 commit 095d742
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
80 changes: 80 additions & 0 deletions R-package/R/optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,83 @@ mx.opt.adam <- function(learning.rate=0.001,
return(list(create.state=create.state, update=update))
}

#' Create an AdaGrad optimizer with respective parameters.
#' AdaGrad optimizer of Duchi et al., 2011,
#'
#' This code follows the version in http://arxiv.org/pdf/1212.5701v1.pdf Eq(5)
#' by Matthew D. Zeiler, 2012. AdaGrad will help the network to converge faster
#' in some cases.
#'
#' @param learning.rate float, default=0.05
#' Step size.
#' @param epsilon float, default=1e-8
#' @param wd float, default=0.0
#' L2 regularization coefficient add to all the weights.
#' @param rescale.grad float, default=1.0
#' rescaling factor of gradient.
#' @param clip_gradient float, optional
#' clip gradient in range [-clip_gradient, clip_gradient].
#' @param lr_scheduler function, optional
#' The learning rate scheduler.
#'
mx.opt.adagrad <- function(learning.rate=0.05,
epsilon=1e-8,
wd=0,
rescale.grad=1,
clip_gradient = NULL,
lr_scheduler = NULL) {
# use lr as short for learing rate.
lr <- learning.rate
count <- 0
num_update <- 0

adagrad <- new.env()
adagrad$lr <- lr
adagrad$count <- 0
adagrad$num_update <- 0

create.state <- function(index, weight) {
return (mx.nd.zeros(dim(weight), ctx(weight))) #history
}

update <- function(index, weight, grad, state) {
if (!is.null(lr_scheduler)){
lr_scheduler(adagrad) ## changing lr
lr <- adagrad$lr
## update count
indexKey <- paste0('ik', index)
if (!exists(envir = adagrad, x = indexKey)){
assign(x = indexKey, value = 0, envir = adagrad)
} else {
indexValue <- get(envir = adagrad, x = indexKey)
assign(x = indexKey, value = indexValue + 1, envir = adagrad)
adagrad$num_update <- max(adagrad$num_update, get(envir = adagrad, x = indexKey))
}
}

grad <- grad * rescale.grad
if (!is.null(clip_gradient)){
if(clip_gradient >= 0){
grad_ctx <- ctx(grad)
grad <- as.array(grad)
grad <- pmax(grad, -1 * clip_gradient)
grad <- pmin(grad, clip_gradient)
grad <- mx.nd.array(grad, grad_ctx)
} else {
stop("Error: clip_gradient should be positive number.")
}
}

history <- state
history <- history + (grad * grad)
weight <- weight - lr * (grad / mx.nd.sqrt(history + epsilon) + wd * weight)
state <- history

return(list(weight=weight, state=state))
}
return(list(create.state=create.state, update=update))
}

#' Create an optimizer by name and parameters
#'
#' @param name The name of the optimizer
Expand All @@ -268,6 +345,9 @@ mx.opt.create <- function(name, ...) {
else if (name == "adam") {
return (mx.opt.adam(...))
}
else if (name == "adagrad") {
return (mx.opt.adagrad(...))
}
stop(paste("Unknown optimizer ", name))
}

Expand Down
34 changes: 34 additions & 0 deletions R-package/man/mx.opt.adagrad.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 095d742

Please sign in to comment.