Skip to content

Commit d78d59b

Browse files
committed
Merge remote-tracking branch 'origin/master' into survival
2 parents 00631d3 + e215f8a commit d78d59b

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
* helpLearner, helpLearnerParam: open the help for a learner or get a
1111
description of its parameters
1212

13+
## measures - general
14+
* measure "arsq" now has ID "arsq"
15+
1316
## measures - new
1417
* measureBER, measureRMSLE, measureF1
1518

R/Task_operators.R

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,10 +402,7 @@ changeData = function(task, data, costs, weights) {
402402
weights = task$weights
403403
task$env = new.env(parent = emptyenv())
404404
task$env$data = data
405-
if (is.null(weights))
406-
task["weights"] = list(NULL)
407-
else
408-
task$weights = weights
405+
task["weights"] = list(weights) # so also 'NULL' gets set
409406
td = task$task.desc
410407
# FIXME: this is bad style but I see no other way right now
411408
task$task.desc = switch(td$type,
@@ -414,7 +411,7 @@ changeData = function(task, data, costs, weights) {
414411
"cluster" = makeClusterTaskDesc(td$id, data, task$weights, task$blocking),
415412
"surv" = makeSurvTaskDesc(td$id, data, td$target, task$weights, task$blocking),
416413
"costsens" = makeCostSensTaskDesc(td$id, data, td$target, task$blocking, costs),
417-
"multilabel" = makeMultilabelTaskDesc(td$id, data, td$target, td$weights, task$blocking)
414+
"multilabel" = makeMultilabelTaskDesc(td$id, data, td$target, task$weights, task$blocking)
418415
)
419416

420417
return(task)

R/measures.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ measureEXPVAR = function(truth, response) {
277277
#' @export arsq
278278
#' @rdname measures
279279
#' @format none
280-
arsq = makeMeasure(id = "adjrsq", minimize = FALSE, best = 1, worst = 0,
280+
arsq = makeMeasure(id = "arsq", minimize = FALSE, best = 1, worst = 0,
281281
properties = c("regr", "req.pred", "req.truth"),
282282
name = "Adjusted coefficient of determination",
283283
note = "Defined as: 1 - (1 - rsq) * (p / (n - p - 1L)). Adjusted R-squared is only defined for normal linear regression.",

tests/testthat/test_base_weights.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,17 @@ test_that("weights", {
2727

2828
expect_error(train(lrn, rger.task, weights = 1:2))
2929
})
30+
31+
test_that("weights remain after subset", {
32+
tasks = list(binaryclass.task, multiclass.task, multilabel.task, regr.task, surv.task, noclass.task)
33+
for (t in tasks) {
34+
expect_false(getTaskDesc(t)$has.weights)
35+
ws = seq_len(getTaskDesc(t)$size)
36+
wtask = changeData(t, weights = ws)
37+
expect_true(getTaskDesc(wtask)$has.weights)
38+
expect_equal(wtask$weights, ws)
39+
expect_equal(subsetTask(wtask, 1:10)$weights, 1:10)
40+
expect_true(getTaskDesc(subsetTask(wtask, 1:10))$has.weights)
41+
}
42+
})
43+

0 commit comments

Comments
 (0)