Skip to content

Commit 68dfd60

Browse files
Update internal functions to reflect deprecated keras functions (#639)
* Add keras_predict_classes to replace use of keras::predict_classes * use keras_predict_classes * update mlp_keras reference calculations in tests * move from predict_proba() to predict() * Create conditional tensorflow checking * use newer version of tensorflow in GHA * Add old tensorflow version GHA * Set seeds for tensorflow * only set tensorflow seed when you need it * conditionally set seed in tensorflow by tensorflow version * do conditional check innside keras_predict_* functions as well * Add missing set_seed to keras logistic reg test * seperate out old-tensorflow GHA * add last missing tensorflow set_seed * you need tensorflow AND R seed... * use keras version as switch * Conditionally transform predictions depending on tensorflow version * skip test if tensorflow version can't be found * refactor keras_predict_classes to avoid post function * Add keras_set_seed function * remove remotes * rename keras_set_seed * add news * rename to set_tf_seed in all tests * adjust "R CMD Check" and "old tensorflow" GHA * don't use keras_predict_proba anymore Co-authored-by: Max Kuhn <mxkuhn@gmail.com>
1 parent 8f26aaa commit 68dfd60

17 files changed

+229
-48
lines changed

.github/workflows/R-CMD-check.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ jobs:
5959

6060
- name: Install Miniconda
6161
# conda can fail at downgrading python, so we specify python version in advance
62-
env:
63-
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
62+
env:
63+
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
6464
run: reticulate::install_miniconda() # creates r-reticulate conda env by default
6565
shell: Rscript {0}
6666

6767
- name: Install TensorFlow
6868
run: |
69-
tensorflow::install_tensorflow(version='1.15', conda_python_version = NULL)
69+
tensorflow::install_tensorflow(version='2.7', conda_python_version = NULL)
7070
shell: Rscript {0}
7171

7272
- uses: r-lib/actions/check-r-package@v2

.github/workflows/old-tensorflow.yaml

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Workflow derived from https://github.com/r-lib/actions/tree/master/examples
2+
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
3+
#
4+
# NOTE: This workflow is overkill for most R packages and
5+
# check-standard.yaml is likely a better choice.
6+
# usethis::use_github_action("check-standard") will install it.
7+
on:
8+
push:
9+
branches: [main, master]
10+
pull_request:
11+
branches: [main, master]
12+
workflow_dispatch:
13+
14+
name: old-tensorflow
15+
16+
jobs:
17+
old-tensorflow:
18+
runs-on: ${{ matrix.config.os }}
19+
20+
name: ${{ matrix.config.os }} (${{ matrix.config.r }})
21+
22+
strategy:
23+
fail-fast: false
24+
matrix:
25+
config:
26+
- {os: windows-latest, r: 'release'}
27+
# Use older ubuntu to maximise backward compatibility
28+
- {os: ubuntu-18.04, r: 'devel', http-user-agent: 'release'}
29+
env:
30+
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
31+
R_KEEP_PKG_SOURCE: yes
32+
CXX14: g++
33+
CXX14STD: -std=c++1y
34+
CXX14FLAGS: -Wall -g -02
35+
36+
steps:
37+
- uses: actions/checkout@v2
38+
39+
- uses: r-lib/actions/setup-pandoc@v2
40+
41+
- uses: r-lib/actions/setup-r@v2
42+
with:
43+
r-version: ${{ matrix.config.r }}
44+
http-user-agent: ${{ matrix.config.http-user-agent }}
45+
use-public-rspm: true
46+
47+
- uses: r-lib/actions/setup-r-dependencies@v2
48+
with:
49+
extra-packages: rcmdcheck
50+
51+
- name: Install dev reticulate
52+
run: pak::pkg_install('rstudio/reticulate')
53+
shell: Rscript {0}
54+
55+
- name: Install Miniconda
56+
# conda can fail at downgrading python, so we specify python version in advance
57+
env:
58+
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
59+
run: reticulate::install_miniconda() # creates r-reticulate conda env by default
60+
shell: Rscript {0}
61+
62+
- name: Install TensorFlow
63+
run: |
64+
tensorflow::install_tensorflow(version='1.15', conda_python_version = NULL)
65+
shell: Rscript {0}
66+
67+
- uses: r-lib/actions/check-r-package@v2
68+
69+
- name: Show testthat output
70+
if: always()
71+
run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true
72+
shell: bash
73+
74+
- name: Upload check results
75+
if: failure()
76+
uses: actions/upload-artifact@main
77+
with:
78+
name: ${{ runner.os }}-r${{ matrix.config.r }}-results
79+
path: check

.github/workflows/pkgdown.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
pak::pkg_install('rstudio/reticulate')
3535
reticulate::install_miniconda()
3636
reticulate::conda_create('r-reticulate', packages = c('python==3.6.9'))
37-
tensorflow::install_tensorflow(version='1.14.0')
37+
tensorflow::install_tensorflow(version='2.7.0')
3838
shell: Rscript {0}
3939

4040
- name: Install package

.github/workflows/test-coverage.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
pak::pkg_install('rstudio/reticulate')
3232
reticulate::install_miniconda()
3333
reticulate::conda_create('r-reticulate', packages = c('python==3.6.9'))
34-
tensorflow::install_tensorflow(version='1.14.0')
34+
tensorflow::install_tensorflow(version='2.7.0')
3535
shell: Rscript {0}
3636

3737
- name: Test coverage

DESCRIPTION

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Suggests:
4040
covr,
4141
dials (>= 0.0.10.9001),
4242
earth,
43+
tensorflow,
4344
ggplot2,
4445
keras,
4546
kernlab,
@@ -86,6 +87,3 @@ Encoding: UTF-8
8687
LazyData: true
8788
Roxygen: list(markdown = TRUE)
8889
RoxygenNote: 7.1.2
89-
Remotes:
90-
tidymodels/dials,
91-
tidymodels/hardhat

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ export(glance)
206206
export(has_multi_predict)
207207
export(is_varying)
208208
export(keras_mlp)
209+
export(keras_predict_classes)
209210
export(knit_engine_docs)
210211
export(linear_reg)
211212
export(list_md_problems)
@@ -271,6 +272,7 @@ export(set_model_engine)
271272
export(set_model_mode)
272273
export(set_new_model)
273274
export(set_pred)
275+
export(set_tf_seed)
274276
export(show_call)
275277
export(show_engines)
276278
export(show_fit)

NEWS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@
3535
* Argument `interval` was added for prediction: For types "survival" and "quantile", estimates for the confidence or prediction interval can be added if available (#615).
3636

3737
* `set_dependency()` now allows developers to create package requirements that are specific to the model's mode (#604).
38-
38+
*
3939
* `varying()` is soft-deprecated in favor of `tune()`.
4040

4141
* `varying_args()` is soft-deprecated in favor of `tune_args()`.
4242

43+
* parsnip is now more robust working with keras and tensorflow for a larger range of versions (#596).
44+
4345
# parsnip 0.1.7
4446

4547
## Model Specification Changes

R/logistic_reg_data.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,13 +438,11 @@ set_pred(
438438
type = "class",
439439
value = list(
440440
pre = NULL,
441-
post = function(x, object) {
442-
object$lvl[x + 1]
443-
},
444-
func = c(pkg = "keras", fun = "predict_classes"),
441+
post = NULL,
442+
func = c(pkg = "parsnip", fun = "keras_predict_classes"),
445443
args =
446444
list(
447-
object = quote(object$fit),
445+
object = quote(object),
448446
x = quote(as.matrix(new_data))
449447
)
450448
)
@@ -462,7 +460,7 @@ set_pred(
462460
x <- as_tibble(x)
463461
x
464462
},
465-
func = c(pkg = "keras", fun = "predict_proba"),
463+
func = c(fun = "predict"),
466464
args =
467465
list(
468466
object = quote(object$fit),

R/mlp.R

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,5 +436,40 @@ reformat_torch_num <- function(results, object) {
436436
results
437437
}
438438

439+
#' Wrapper for keras class predictions
440+
#' @param object A keras model fit
441+
#' @param x A data set.
442+
#' @export
443+
#' @keywords internal
444+
keras_predict_classes <- function(object, x) {
445+
if (utils::packageVersion("keras") >= package_version("2.6")) {
446+
preds <- predict(object$fit, x)
447+
if (tensorflow::tf_version() <= package_version("2.0.0")) {
448+
# -1 to assign with keras' zero indexing
449+
index <- apply(preds, 1, which.max) - 1
450+
} else {
451+
index <- preds %>% keras::k_argmax() %>% as.integer()
452+
}
453+
} else {
454+
index <- keras::predict_classes(object$fit, x)
455+
}
456+
object$lvl[index + 1]
457+
}
439458

440-
459+
#' Set seed in R and TensorFlow at the same time
460+
#'
461+
#' Some Keras models requires seeds to be set in both R and TensorFlow to
462+
#' achieve reproducible results. This function sets these seeds at the same
463+
#' time using version appropriate functions.
464+
#'
465+
#' @param seed 1 integer value.
466+
#' @export
467+
#' @keywords internal
468+
set_tf_seed <- function(seed) {
469+
set.seed(seed)
470+
if (tensorflow::tf_version() >= package_version("2.0")) {
471+
tensorflow::tf$random$set_seed(seed)
472+
} else {
473+
tensorflow::tf$random$set_random_seed(seed)
474+
}
475+
}

R/mlp_data.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,11 @@ set_pred(
146146
type = "class",
147147
value = list(
148148
pre = NULL,
149-
post = function(x, object) {
150-
object$lvl[x + 1]
151-
},
152-
func = c(pkg = "keras", fun = "predict_classes"),
149+
post = NULL,
150+
func = c(pkg = "parsnip", fun = "keras_predict_classes"),
153151
args =
154152
list(
155-
object = quote(object$fit),
153+
object = quote(object),
156154
x = quote(as.matrix(new_data))
157155
)
158156
)
@@ -170,7 +168,7 @@ set_pred(
170168
x <- as_tibble(x)
171169
x
172170
},
173-
func = c(pkg = "keras", fun = "predict_proba"),
171+
func = c(fun = "predict"),
174172
args =
175173
list(
176174
object = quote(object$fit),

R/multinom_reg_data.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,10 @@ set_pred(
234234
type = "class",
235235
value = list(
236236
pre = NULL,
237-
post = function(x, object) {
238-
object$lvl[x + 1]
239-
},
240-
func = c(pkg = "keras", fun = "predict_classes"),
237+
post = NULL,
238+
func = c(pkg = "parsnip", fun = "keras_predict_classes"),
241239
args =
242-
list(object = quote(object$fit),
240+
list(object = quote(object),
243241
x = quote(as.matrix(new_data)))
244242
)
245243
)
@@ -256,7 +254,7 @@ set_pred(
256254
x <- as_tibble(x)
257255
x
258256
},
259-
func = c(pkg = "keras", fun = "predict_proba"),
257+
func = c(fun = "predict"),
260258
args =
261259
list(object = quote(object$fit),
262260
x = quote(as.matrix(new_data)))

man/keras_predict_classes.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/set_tf_seed.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_linear_reg_keras.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE)
2727
test_that('model fitting', {
2828
skip_on_cran()
2929
skip_if_not_installed("keras")
30+
skip_if(is.null(tensorflow::tf_version()))
31+
32+
set_tf_seed(257)
3033

31-
set.seed(257)
3234
expect_error(
3335
fit1 <-
3436
fit_xy(
@@ -40,7 +42,8 @@ test_that('model fitting', {
4042
regexp = NA
4143
)
4244

43-
set.seed(257)
45+
set_tf_seed(257)
46+
4447
expect_error(
4548
fit2 <-
4649
fit_xy(
@@ -94,6 +97,7 @@ test_that('model fitting', {
9497
test_that('regression prediction', {
9598
skip_on_cran()
9699
skip_if_not_installed("keras")
100+
skip_if(is.null(tensorflow::tf_version()))
97101

98102
library(keras)
99103

0 commit comments

Comments
 (0)