Skip to content

Commit

Permalink
more reversion to previous code
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Aug 8, 2022
1 parent f8d4749 commit f7f1ab0
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 31 deletions.
14 changes: 5 additions & 9 deletions models/files/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ modelInfo <- list(label = "Random Forest",
class = c("numeric", "character", "numeric"),
label = c("#Randomly Selected Predictors",
"Splitting Rule",
"Minimal Node Size"
)),
"Minimal Node Size")),
grid = function(x, y, len = NULL, search = "grid") {
if(search == "grid") {
srule <-
Expand All @@ -24,9 +23,9 @@ modelInfo <- list(label = "Random Forest",
else
"variance"
out <- expand.grid(mtry =
caret::var_seq(p = ncol(x),
classification = is.factor(y),
len = len),
caret::var_seq(p = ncol(x),
classification = is.factor(y),
len = len),
min.node.size = ifelse( is.factor(y), 1, 5),
splitrule = c(srule, "extratrees"))
} else {
Expand All @@ -38,8 +37,7 @@ modelInfo <- list(label = "Random Forest",
data.frame(
min.node.size= sample(1:(min(20,nrow(x))), size = len, replace = TRUE),
mtry = sample(1:ncol(x), size = len, replace = TRUE),
splitrule = sample(srules, size = len, replace = TRUE),
num.trees = 500
splitrule = sample(srules, size = len, replace = TRUE)
)
}
},
Expand All @@ -52,7 +50,6 @@ modelInfo <- list(label = "Random Forest",
mtry = min(param$mtry, ncol(x)),
min.node.size = param$min.node.size,
splitrule = as.character(param$splitrule),
num.trees = param$num.trees,
write.forest = TRUE,
probability = classProbs,
case.weights = wts,
Expand All @@ -63,7 +60,6 @@ modelInfo <- list(label = "Random Forest",
mtry = min(param$mtry, ncol(x)),
min.node.size = param$min.node.size,
splitrule = as.character(param$splitrule),
num.trees = param$num.trees,
write.forest = TRUE,
probability = classProbs,
...)
Expand Down
34 changes: 12 additions & 22 deletions models/files/rf.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,20 @@ modelInfo <- list(label = "Random Forest",
library = "randomForest",
loop = NULL,
type = c("Classification", "Regression"),
parameters = data.frame(parameter = c("mtry", "nodesize"),
class = c("numeric", "numeric"),
label = c("#Randomly Selected Predictors",
"Minimum Node Size")),
parameters = data.frame(parameter = "mtry",
class = "numeric",
label = "#Randomly Selected Predictors"),
grid = function(x, y, len = NULL, search = "grid") {
if(search == "grid") {
out <- expand.grid(mtry = caret::var_seq(p = ncol(x),
out <- data.frame(mtry = caret::var_seq(p = ncol(x),
classification = is.factor(y),
len = len),
nodesize = ifelse( is.factor(y), 1, 5)
)
len = len))
} else {
out <- data.frame(mtry = unique(sample(1:ncol(x), size = len, replace = TRUE)),
nodesize = sample(1:(min(20,nrow(x))), size = len, replace = TRUE)
)
out <- data.frame(mtry = unique(sample(1:ncol(x), size = len, replace = TRUE)))
}
},
fit = function(x, y, wts, param, lev, last, classProbs, ...)
randomForest::randomForest(x, y,
mtry = min(param$mtry, ncol(x)),
nodesize = param$nodesize,
...),
randomForest::randomForest(x, y, mtry = param$mtry, ...),
predict = function(modelFit, newdata, submodels = NULL)
if(!is.null(newdata)) predict(modelFit, newdata) else predict(modelFit),
prob = function(modelFit, newdata, submodels = NULL)
Expand All @@ -41,20 +33,17 @@ modelInfo <- list(label = "Random Forest",
varImp <- randomForest::importance(object, ...)
if(object$type == "regression") {
if("%IncMSE" %in% colnames(varImp)) {
varImp <- as.data.frame(varImp[,"%IncMSE", drop = FALSE])
colnames(varImp) <- "Overall"
varImp <- data.frame(Overall = varImp[,"%IncMSE"])
} else {
varImp <- as.data.frame(varImp[,1, drop = FALSE])
colnames(varImp) <- "Overall"
varImp <- data.frame(Overall = varImp[,1])
}
}
else {
retainNames <- levels(object$y)
if(all(retainNames %in% colnames(varImp))) {
varImp <- varImp[, retainNames, drop = FALSE]
varImp <- varImp[, retainNames]
} else {
varImp <- as.data.frame(varImp[,1, drop = FALSE])
colnames(varImp) <- "Overall"
varImp <- data.frame(Overall = varImp[,1])
}
}

Expand All @@ -76,3 +65,4 @@ modelInfo <- list(label = "Random Forest",
names(out) <- if(x$type == "regression") c("RMSE", "Rsquared") else c("Accuracy", "Kappa")
out
})

Binary file modified pkg/caret/inst/models/models.RData
Binary file not shown.

0 comments on commit f7f1ab0

Please sign in to comment.