Skip to content

Commit 4abca45

Browse files
committed
Fixing DistilBertForTokenClassification annotator
1 parent b44f8d7 commit 4abca45

5 files changed

+48
-104
lines changed

NAMESPACE

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ S3method(nlp_date_normalizer,tbl_spark)
5656
S3method(nlp_dependency_parser,ml_pipeline)
5757
S3method(nlp_dependency_parser,spark_connection)
5858
S3method(nlp_dependency_parser,tbl_spark)
59-
S3method(nlp_distilbert_for_token_classification,ml_pipeline)
60-
S3method(nlp_distilbert_for_token_classification,spark_connection)
61-
S3method(nlp_distilbert_for_token_classification,tbl_spark)
6259
S3method(nlp_doc2chunk,ml_pipeline)
6360
S3method(nlp_doc2chunk,spark_connection)
6461
S3method(nlp_doc2chunk,tbl_spark)
@@ -241,7 +238,7 @@ export(nlp_date_normalizer)
241238
export(nlp_dependency_parser)
242239
export(nlp_dependency_parser_pretrained)
243240
export(nlp_distilbert_embeddings_pretrained)
244-
export(nlp_distilbert_for_token_classification)
241+
export(nlp_distilbert_token_classification_pretrained)
245242
export(nlp_doc2chunk)
246243
export(nlp_document_assembler)
247244
export(nlp_document_logreg_classifier)
Lines changed: 20 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#' Spark NLP DistilBertForTokenClassification
22
#'
3-
#' Spark ML transformer that
3+
#' DistilBertForTokenClassification can load Bert Models with a token classification head on top
4+
#' (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.
45
#' See \url{https://nlp.johnsnowlabs.com/docs/en/transformers#distilbertfortokenclassification}
56
#'
67
#' @template roxlate-nlp-algo
@@ -10,75 +11,33 @@
1011
#' @param max_sentence_length Max sentence length to process (Default: 128)
1112
#'
1213
#' @export
13-
nlp_distilbert_for_token_classification <- function(x, input_cols, output_col,
14-
batch_size = NULL, case_sensitive = NULL, max_sentence_length = NULL,
15-
uid = random_string("distilbert_for_token_classification_")) {
16-
UseMethod("nlp_distilbert_for_token_classification")
17-
}
18-
19-
#' @export
20-
nlp_distilbert_for_token_classification.spark_connection <- function(x, input_cols, output_col,
21-
batch_size = NULL, case_sensitive = NULL, max_sentence_length = NULL,
22-
uid = random_string("distilbert_for_token_classification_")) {
14+
nlp_distilbert_token_classification_pretrained <- function(sc, input_cols, output_col,
15+
batch_size = NULL, case_sensitive = NULL,
16+
max_sentence_length = NULL,
17+
name = NULL, lang = NULL, remote_loc = NULL) {
2318
args <- list(
2419
input_cols = input_cols,
2520
output_col = output_col,
2621
batch_size = batch_size,
2722
case_sensitive = case_sensitive,
28-
max_sentence_length = max_sentence_length,
29-
uid = uid
23+
max_sentence_length = max_sentence_length
3024
) %>%
31-
validator_nlp_distilbert_for_token_classification()
32-
33-
jobj <- sparklyr::spark_pipeline_stage(
34-
x, "com.johnsnowlabs.nlp.annotators.classifier.dl.DistilBertForTokenClassification",
35-
input_cols = args[["input_cols"]],
36-
output_col = args[["output_col"]],
37-
uid = args[["uid"]]
38-
) %>%
39-
sparklyr::jobj_set_param("setBatchSize", args[["batch_size"]]) %>%
40-
sparklyr::jobj_set_param("setCaseSensitive", args[["case_sensitive"]]) %>%
41-
sparklyr::jobj_set_param("setMaxSentenceLength", args[["max_sentence_length"]])
42-
43-
new_nlp_distilbert_for_token_classification(jobj)
25+
validator_nlp_distilbert_token_classification()
26+
27+
model_class <- "com.johnsnowlabs.nlp.annotators.classifier.dl.DistilBertForTokenClassification"
28+
model <- pretrained_model(sc, model_class, name, lang, remote_loc)
29+
spark_jobj(model) %>%
30+
sparklyr::jobj_set_param("setInputCols", args[["input_cols"]]) %>%
31+
sparklyr::jobj_set_param("setOutputCol", args[["output_col"]]) %>%
32+
sparklyr::jobj_set_param("setCaseSensitive", args[["case_sensitive"]]) %>%
33+
sparklyr::jobj_set_param("setBatchSize", args[["batch_size"]]) %>%
34+
sparklyr::jobj_set_param("setMaxSentenceLength", args[["max_sentence_length"]])
35+
36+
new_ml_transformer(model)
4437
}
4538

46-
#' @export
47-
nlp_distilbert_for_token_classification.ml_pipeline <- function(x, input_cols, output_col,
48-
batch_size = NULL, case_sensitive = NULL, max_sentence_length = NULL,
49-
uid = random_string("distilbert_for_token_classification_")) {
50-
51-
stage <- nlp_distilbert_for_token_classification.spark_connection(
52-
x = sparklyr::spark_connection(x),
53-
input_cols = input_cols,
54-
output_col = output_col,
55-
batch_size = batch_size,
56-
case_sensitive = case_sensitive,
57-
max_sentence_length = max_sentence_length,
58-
uid = uid
59-
)
60-
61-
sparklyr::ml_add_stage(x, stage)
62-
}
63-
64-
#' @export
65-
nlp_distilbert_for_token_classification.tbl_spark <- function(x, input_cols, output_col,
66-
batch_size = NULL, case_sensitive = NULL, max_sentence_length = NULL,
67-
uid = random_string("distilbert_for_token_classification_")) {
68-
stage <- nlp_distilbert_for_token_classification.spark_connection(
69-
x = sparklyr::spark_connection(x),
70-
input_cols = input_cols,
71-
output_col = output_col,
72-
batch_size = batch_size,
73-
case_sensitive = case_sensitive,
74-
max_sentence_length = max_sentence_length,
75-
uid = uid
76-
)
77-
78-
stage %>% sparklyr::ml_transform(x)
79-
}
8039
#' @import forge
81-
validator_nlp_distilbert_for_token_classification <- function(args) {
40+
validator_nlp_distilbert_token_classification <- function(args) {
8241
args[["input_cols"]] <- cast_string_list(args[["input_cols"]])
8342
args[["output_col"]] <- cast_string(args[["output_col"]])
8443
args[["batch_size"]] <- cast_nullable_integer(args[["batch_size"]])
@@ -87,9 +46,6 @@ validator_nlp_distilbert_for_token_classification <- function(args) {
8746
args
8847
}
8948

90-
nlp_float_params.nlp_distilbert_for_token_classification <- function(x) {
91-
return(c())
92-
}
9349
new_nlp_distilbert_for_token_classification <- function(jobj) {
9450
sparklyr::new_ml_transformer(jobj, class = "nlp_distilbert_for_token_classification")
9551
}

R/utils.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
#'
88
#' @export
99
nlp_set_input_cols <- function(jobj, input_cols) {
10-
invoke(spark_jobj(jobj), "setInputCols", cast_string_list(input_cols))
10+
newobj <- sparklyr:::ml_set_param(jobj, "inputCols", output_col)
11+
return(newobj)
12+
#invoke(spark_jobj(jobj), "setInputCols", cast_string_list(input_cols))
1113
}
1214

1315
#' Set the output column name
@@ -19,7 +21,9 @@ nlp_set_input_cols <- function(jobj, input_cols) {
1921
#'
2022
#' @export
2123
nlp_set_output_col <- function(jobj, output_col) {
22-
invoke(spark_jobj(jobj), "setOutputCol", cast_string(output_col))
24+
newobj <- sparklyr:::ml_set_param(jobj, "outputCol", output_col)
25+
return(newobj)
26+
# invoke(spark_jobj(jobj), "setOutputCol", cast_string(output_col))
2327
}
2428

2529
#' Spark NLP version

man/nlp_distilbert_for_token_classification.Rd renamed to man/nlp_distilbert_token_classification_pretrained.Rd

Lines changed: 11 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/testthat-distilbert-for-token-classification.R

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,18 @@ teardown({
2222
rm(test_data, envir = .GlobalEnv)
2323
})
2424

25-
test_that("distilbert_for_token_classification param setting", {
26-
# TODO: edit these to make them legal values for the parameters
27-
test_args <- list(
28-
input_cols = c("string1", "string2"),
29-
output_col = "string1",
30-
batch_size = 100,
31-
case_sensitive = FALSE,
32-
max_sentence_length = 200
33-
)
3425

35-
test_param_setting(sc, nlp_distilbert_for_token_classification, test_args)
26+
test_that("nlp_distilbert_token_classification pretrained", {
27+
model <- nlp_distilbert_token_classification_pretrained(sc, input_cols = c("sentence", "token"), output_col = "distilbert")
28+
transformed_data <- ml_transform(model, test_data)
29+
expect_true("distilbert" %in% colnames(transformed_data))
3630
})
3731

38-
test_that("nlp_distilbert_for_token_classification spark_connection", {
39-
test_annotator <- nlp_distilbert_for_token_classification(sc, input_cols = c("token", "document"), output_col = "label")
40-
transformed_data <- ml_transform(test_annotator, test_data)
32+
test_that("nlp_distilbert_token_classification load", {
33+
model_files <- list.files("~/cache_pretrained/")
34+
bert_model_file <- max(Filter(function(s) startsWith(s, "distilbert_base_token"), model_files))
35+
model <- ml_load(sc, paste0("~/cache_pretrained/", bert_model_file))
36+
model <- nlp_set_output_col(model, "label")
37+
transformed_data <- ml_transform(model, test_data)
4138
expect_true("label" %in% colnames(transformed_data))
42-
expect_true(inherits(test_annotator, "nlp_distilbert_for_token_classification"))
4339
})
44-
45-
test_that("nlp_distilbert_for_token_classification ml_pipeline", {
46-
test_annotator <- nlp_distilbert_for_token_classification(pipeline, input_cols = c("token", "document"), output_col = "label")
47-
transformed_data <- ml_fit_and_transform(test_annotator, test_data)
48-
expect_true("label" %in% colnames(transformed_data))
49-
})
50-
51-
test_that("nlp_distilbert_for_token_classification tbl_spark", {
52-
transformed_data <- nlp_distilbert_for_token_classification(test_data, input_cols = c("token", "document"), output_col = "label")
53-
expect_true("label" %in% colnames(transformed_data))
54-
})
55-

0 commit comments

Comments
 (0)