Skip to content

Commit 2ffbd2a

Browse files
committed
Fixing defects with setters
1 parent 4d24680 commit 2ffbd2a

File tree

4 files changed

+21
-2
lines changed

4 files changed

+21
-2
lines changed

R/utils.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
#'
88
#' @export
99
nlp_set_input_cols <- function(jobj, input_cols) {
10-
newobj <- sparklyr:::ml_set_param(jobj, "inputCols", input_col)
10+
input_cols <- forge::cast_string_list(input_cols)
11+
newobj <- sparklyr:::ml_set_param(jobj, "inputCols", input_cols)
1112
return(newobj)
1213
#invoke(spark_jobj(jobj), "setInputCols", cast_string_list(input_cols))
1314
}
@@ -21,6 +22,7 @@ nlp_set_input_cols <- function(jobj, input_cols) {
2122
#'
2223
#' @export
2324
nlp_set_output_col <- function(jobj, output_col) {
25+
output_col <- forge::cast_string(output_col)
2426
newobj <- sparklyr:::ml_set_param(jobj, "outputCol", output_col)
2527
return(newobj)
2628
# invoke(spark_jobj(jobj), "setOutputCol", cast_string(output_col))

tests/testthat/helper-initialize.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ test_classifier_text <- data.frame(description = c("The cats are laying in front
1515
# helper functions from sparklyr tests
1616
# https://github.com/rstudio/sparklyr/blob/master/tests/testthat/helper-initialize.R
1717
testthat_spark_connection <- function() {
18-
version <- Sys.getenv("SPARK_VERSION", unset = "3.1.1")
18+
version <- Sys.getenv("SPARK_VERSION", unset = "3.1.2")
1919

2020
spark_installed <- sparklyr::spark_installed_versions()
2121
if (nrow(spark_installed[spark_installed$spark == version, ]) == 0) {

tests/testthat/testthat-medical-ner.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ test_that("medical_ner param setting", {
101101
# })
102102

103103
test_that("nlp_medical_ner pretrained", {
104+
print(nlp_version())
104105
model <- nlp_medical_ner_pretrained(sc, input_cols = c("sentence", "token", "embeddings"),
105106
output_col = "ner",
106107
name = "ner_clinical", lang = "en", remote_loc = "clinical/models")

tests/testthat/testthat-utils.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ test_that("nlp_set_param", {
3636
expect_equal(newvalue, 0.8)
3737
})
3838

39+
test_that("nlp_set_input_cols", {
40+
model <- nlp_sentence_detector_dl_pretrained(sc, input_cols = c("document"), output_col = "sentence")
41+
42+
model <- nlp_set_input_cols(model, c("new_document"))
43+
newvalue <- ml_param(model, "input_cols")
44+
expect_equal(newvalue, list("new_document"))
45+
})
46+
47+
test_that("nlp_set_output_cols", {
48+
model <- nlp_sentence_detector_dl_pretrained(sc, input_cols = c("document"), output_col = "sentence")
49+
50+
model <- nlp_set_output_col(model, list("new_sentence"))
51+
newvalue <- ml_param(model, "output_col")
52+
expect_equal(newvalue, "new_sentence")
53+
})
54+
3955
test_that("nlp_conll_read_dataset", {
4056
conll_data <- nlp_conll_read_dataset(sc, here::here("tests", "testthat", "data", "eng.testa.conll"))
4157
expect_true("text" %in% colnames(conll_data))

0 commit comments

Comments
 (0)