Skip to content

Commit b44f8d7

Browse files
committed
Fixing XlmRoBerta and Longformer embeddings annotators
1 parent eebdb4b commit b44f8d7

9 files changed

+87
-323
lines changed

NAMESPACE

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ S3method(nlp_lemmatizer,spark_connection)
9191
S3method(nlp_lemmatizer,tbl_spark)
9292
S3method(nlp_light_pipeline,ml_pipeline_model)
9393
S3method(nlp_light_pipeline,nlp_pretrained_pipeline)
94-
S3method(nlp_longformer_embeddings,ml_pipeline)
95-
S3method(nlp_longformer_embeddings,spark_connection)
96-
S3method(nlp_longformer_embeddings,tbl_spark)
9794
S3method(nlp_marian_transformer,ml_pipeline)
9895
S3method(nlp_marian_transformer,spark_connection)
9996
S3method(nlp_marian_transformer,tbl_spark)
@@ -152,9 +149,6 @@ S3method(nlp_relation_extraction,tbl_spark)
152149
S3method(nlp_relation_extraction_dl,ml_pipeline)
153150
S3method(nlp_relation_extraction_dl,spark_connection)
154151
S3method(nlp_relation_extraction_dl,tbl_spark)
155-
S3method(nlp_roberta_embeddings,ml_pipeline)
156-
S3method(nlp_roberta_embeddings,spark_connection)
157-
S3method(nlp_roberta_embeddings,tbl_spark)
158152
S3method(nlp_sentence_detector,ml_pipeline)
159153
S3method(nlp_sentence_detector,spark_connection)
160154
S3method(nlp_sentence_detector,tbl_spark)
@@ -206,9 +200,6 @@ S3method(nlp_vivekn_sentiment_detector,tbl_spark)
206200
S3method(nlp_word_embeddings,ml_pipeline)
207201
S3method(nlp_word_embeddings,spark_connection)
208202
S3method(nlp_word_embeddings,tbl_spark)
209-
S3method(nlp_xlm_roberta_embeddings,ml_pipeline)
210-
S3method(nlp_xlm_roberta_embeddings,spark_connection)
211-
S3method(nlp_xlm_roberta_embeddings,tbl_spark)
212203
S3method(nlp_yake_model,ml_pipeline)
213204
S3method(nlp_yake_model,spark_connection)
214205
S3method(nlp_yake_model,tbl_spark)
@@ -266,7 +257,7 @@ export(nlp_language_detector_dl_pretrained)
266257
export(nlp_lemmatizer)
267258
export(nlp_lemmatizer_pretrained)
268259
export(nlp_light_pipeline)
269-
export(nlp_longformer_embeddings)
260+
export(nlp_longformer_embeddings_pretrained)
270261
export(nlp_marian_transformer)
271262
export(nlp_marian_transformer_pretrained)
272263
export(nlp_medical_ner)
@@ -298,7 +289,7 @@ export(nlp_relation_extraction)
298289
export(nlp_relation_extraction_dl)
299290
export(nlp_relation_extraction_dl_pretrained)
300291
export(nlp_relation_extraction_pretrained)
301-
export(nlp_roberta_embeddings)
292+
export(nlp_roberta_embeddings_pretrained)
302293
export(nlp_sentence_detector)
303294
export(nlp_sentence_detector_dl)
304295
export(nlp_sentence_detector_dl_pretrained)
@@ -331,7 +322,7 @@ export(nlp_vivekn_sentiment_pretrained)
331322
export(nlp_word_embeddings)
332323
export(nlp_word_embeddings_model)
333324
export(nlp_word_embeddings_pretrained)
334-
export(nlp_xlm_roberta_embeddings)
325+
export(nlp_xlm_roberta_embeddings_pretrained)
335326
export(nlp_xlnet_embeddings_pretrained)
336327
export(nlp_yake_model)
337328
export(set_nlp_version)

R/longformer-embeddings.R

Lines changed: 20 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,85 +15,35 @@
1515
#' @param storage_ref Unique identifier for storage (Default: this.uid)
1616
#'
1717
#' @export
18-
nlp_longformer_embeddings <- function(x, input_cols, output_col,
19-
batch_size = NULL, case_sensitive = NULL, dimension = NULL,
20-
max_sentence_length = NULL, storage_ref = NULL,
21-
uid = random_string("longformer_embeddings_")) {
22-
UseMethod("nlp_longformer_embeddings")
23-
}
24-
25-
#' @export
26-
nlp_longformer_embeddings.spark_connection <- function(x, input_cols, output_col,
27-
batch_size = NULL, case_sensitive = NULL, dimension = NULL,
28-
max_sentence_length = NULL, storage_ref = NULL,
29-
uid = random_string("longformer_embeddings_")) {
18+
nlp_longformer_embeddings_pretrained <- function(sc, input_cols, output_col,
19+
batch_size = NULL, case_sensitive = NULL, dimension = NULL,
20+
max_sentence_length = NULL, storage_ref = NULL,
21+
name = NULL, lang = NULL, remote_loc = NULL) {
3022
args <- list(
3123
input_cols = input_cols,
3224
output_col = output_col,
3325
batch_size = batch_size,
3426
case_sensitive = case_sensitive,
3527
dimension = dimension,
3628
max_sentence_length = max_sentence_length,
37-
storage_ref = storage_ref,
38-
uid = uid
29+
storage_ref = storage_ref
3930
) %>%
40-
validator_nlp_longformer_embeddings()
41-
42-
jobj <- sparklyr::spark_pipeline_stage(
43-
x, "com.johnsnowlabs.nlp.embeddings.LongformerEmbeddings",
44-
input_cols = args[["input_cols"]],
45-
output_col = args[["output_col"]],
46-
uid = args[["uid"]]
47-
) %>%
48-
sparklyr::jobj_set_param("setBatchSize", args[["batch_size"]]) %>%
49-
sparklyr::jobj_set_param("setCaseSensitive", args[["case_sensitive"]]) %>%
50-
sparklyr::jobj_set_param("setDimension", args[["dimension"]]) %>%
51-
sparklyr::jobj_set_param("setMaxSentenceLength", args[["max_sentence_length"]]) %>%
52-
sparklyr::jobj_set_param("setStorageRef", args[["storage_ref"]])
53-
54-
new_nlp_longformer_embeddings(jobj)
55-
}
56-
57-
#' @export
58-
nlp_longformer_embeddings.ml_pipeline <- function(x, input_cols, output_col,
59-
batch_size = NULL, case_sensitive = NULL, dimension = NULL,
60-
max_sentence_length = NULL, storage_ref = NULL,
61-
uid = random_string("longformer_embeddings_")) {
62-
63-
stage <- nlp_longformer_embeddings.spark_connection(
64-
x = sparklyr::spark_connection(x),
65-
input_cols = input_cols,
66-
output_col = output_col,
67-
batch_size = batch_size,
68-
case_sensitive = case_sensitive,
69-
dimension = dimension,
70-
max_sentence_length = max_sentence_length,
71-
storage_ref = storage_ref,
72-
uid = uid
73-
)
74-
75-
sparklyr::ml_add_stage(x, stage)
31+
validator_nlp_longformer_embeddings()
32+
33+
model_class <- "com.johnsnowlabs.nlp.embeddings.LongformerEmbeddings"
34+
model <- pretrained_model(sc, model_class, name, lang, remote_loc)
35+
spark_jobj(model) %>%
36+
sparklyr::jobj_set_param("setInputCols", args[["input_cols"]]) %>%
37+
sparklyr::jobj_set_param("setOutputCol", args[["output_col"]]) %>%
38+
sparklyr::jobj_set_param("setCaseSensitive", args[["case_sensitive"]]) %>%
39+
sparklyr::jobj_set_param("setBatchSize", args[["batch_size"]]) %>%
40+
sparklyr::jobj_set_param("setDimension", args[["dimension"]]) %>%
41+
sparklyr::jobj_set_param("setMaxSentenceLength", args[["max_sentence_length"]]) %>%
42+
sparklyr::jobj_set_param("setStorageRef", args[["storage_ref"]])
43+
44+
new_ml_transformer(model)
7645
}
7746

78-
#' @export
79-
nlp_longformer_embeddings.tbl_spark <- function(x, input_cols, output_col,
80-
batch_size = NULL, case_sensitive = NULL, dimension = NULL,
81-
max_sentence_length = NULL, storage_ref = NULL,
82-
uid = random_string("longformer_embeddings_")) {
83-
stage <- nlp_longformer_embeddings.spark_connection(
84-
x = sparklyr::spark_connection(x),
85-
input_cols = input_cols,
86-
output_col = output_col,
87-
batch_size = batch_size,
88-
case_sensitive = case_sensitive,
89-
dimension = dimension,
90-
max_sentence_length = max_sentence_length,
91-
storage_ref = storage_ref,
92-
uid = uid
93-
)
94-
95-
stage %>% sparklyr::ml_transform(x)
96-
}
9747
#' @import forge
9848
validator_nlp_longformer_embeddings <- function(args) {
9949
args[["input_cols"]] <- cast_string_list(args[["input_cols"]])
@@ -106,9 +56,7 @@ validator_nlp_longformer_embeddings <- function(args) {
10656
args
10757
}
10858

109-
nlp_float_params.nlp_longformer_embeddings <- function(x) {
110-
return(c())
111-
}
11259
new_nlp_longformer_embeddings <- function(jobj) {
11360
sparklyr::new_ml_transformer(jobj, class = "nlp_longformer_embeddings")
11461
}
62+

R/roberta-embeddings.R

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -56,99 +56,3 @@ validator_nlp_roberta_embeddings <- function(args) {
5656
new_nlp_roberta_embeddings <- function(jobj) {
5757
sparklyr::new_ml_transformer(jobj, class = "nlp_roberta_embeddings")
5858
}
59-
#'
60-
#'
61-
#'
62-
#' nlp_roberta_embeddings <- function(x, input_cols, output_col,
63-
#' batch_size = NULL, case_sensitive = NULL, dimension = NULL, max_sentence_length = NULL, storage_ref = NULL,
64-
#' uid = random_string("roberta_embeddings_")) {
65-
#' UseMethod("nlp_roberta_embeddings")
66-
#' }
67-
#'
68-
#' #' @export
69-
#' nlp_roberta_embeddings.spark_connection <- function(x, input_cols, output_col,
70-
#' batch_size = NULL, case_sensitive = NULL, dimension = NULL, max_sentence_length = NULL, storage_ref = NULL,
71-
#' uid = random_string("roberta_embeddings_")) {
72-
#' args <- list(
73-
#' input_cols = input_cols,
74-
#' output_col = output_col,
75-
#' batch_size = batch_size,
76-
#' case_sensitive = case_sensitive,
77-
#' dimension = dimension,
78-
#' max_sentence_length = max_sentence_length,
79-
#' storage_ref = storage_ref,
80-
#' uid = uid
81-
#' ) %>%
82-
#' validator_nlp_roberta_embeddings()
83-
#'
84-
#' jobj <- sparklyr::spark_pipeline_stage(
85-
#' x, "com.johnsnowlabs.nlp.embeddings.RoBertaEmbeddings",
86-
#' input_cols = args[["input_cols"]],
87-
#' output_col = args[["output_col"]],
88-
#' uid = args[["uid"]]
89-
#' ) %>%
90-
#' sparklyr::jobj_set_param("setBatchSize", args[["batch_size"]]) %>%
91-
#' sparklyr::jobj_set_param("setCaseSensitive", args[["case_sensitive"]]) %>%
92-
#' sparklyr::jobj_set_param("setDimension", args[["dimension"]]) %>%
93-
#' sparklyr::jobj_set_param("setMaxSentenceLength", args[["max_sentence_length"]]) %>%
94-
#' sparklyr::jobj_set_param("setStorageRef", args[["storage_ref"]])
95-
#'
96-
#' new_nlp_roberta_embeddings(jobj)
97-
#' }
98-
#'
99-
#' #' @export
100-
#' nlp_roberta_embeddings.ml_pipeline <- function(x, input_cols, output_col,
101-
#' batch_size = NULL, case_sensitive = NULL, dimension = NULL, max_sentence_length = NULL, storage_ref = NULL,
102-
#' uid = random_string("roberta_embeddings_")) {
103-
#'
104-
#' stage <- nlp_roberta_embeddings.spark_connection(
105-
#' x = sparklyr::spark_connection(x),
106-
#' input_cols = input_cols,
107-
#' output_col = output_col,
108-
#' batch_size = batch_size,
109-
#' case_sensitive = case_sensitive,
110-
#' dimension = dimension,
111-
#' max_sentence_length = max_sentence_length,
112-
#' storage_ref = storage_ref,
113-
#' uid = uid
114-
#' )
115-
#'
116-
#' sparklyr::ml_add_stage(x, stage)
117-
#' }
118-
#'
119-
#' #' @export
120-
#' nlp_roberta_embeddings.tbl_spark <- function(x, input_cols, output_col,
121-
#' batch_size = NULL, case_sensitive = NULL, dimension = NULL, max_sentence_length = NULL, storage_ref = NULL,
122-
#' uid = random_string("roberta_embeddings_")) {
123-
#' stage <- nlp_roberta_embeddings.spark_connection(
124-
#' x = sparklyr::spark_connection(x),
125-
#' input_cols = input_cols,
126-
#' output_col = output_col,
127-
#' batch_size = batch_size,
128-
#' case_sensitive = case_sensitive,
129-
#' dimension = dimension,
130-
#' max_sentence_length = max_sentence_length,
131-
#' storage_ref = storage_ref,
132-
#' uid = uid
133-
#' )
134-
#'
135-
#' stage %>% sparklyr::ml_transform(x)
136-
#' }
137-
#' #' @import forge
138-
#' validator_nlp_roberta_embeddings <- function(args) {
139-
#' args[["input_cols"]] <- cast_string_list(args[["input_cols"]])
140-
#' args[["output_col"]] <- cast_string(args[["output_col"]])
141-
#' args[["batch_size"]] <- cast_nullable_integer(args[["batch_size"]])
142-
#' args[["case_sensitive"]] <- cast_nullable_logical(args[["case_sensitive"]])
143-
#' args[["dimension"]] <- cast_nullable_integer(args[["dimension"]])
144-
#' args[["max_sentence_length"]] <- cast_nullable_integer(args[["max_sentence_length"]])
145-
#' args[["storage_ref"]] <- cast_nullable_string(args[["storage_ref"]])
146-
#' args
147-
#' }
148-
#'
149-
#' nlp_float_params.nlp_roberta_embeddings <- function(x) {
150-
#' return(c())
151-
#' }
152-
#' new_nlp_roberta_embeddings <- function(jobj) {
153-
#' sparklyr::new_ml_transformer(jobj, class = "nlp_roberta_embeddings")
154-
#' }

R/xlm-roberta-embeddings.R

Lines changed: 20 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -12,81 +12,35 @@
1212
#' @param storage_ref Unique identifier for storage (Default: this.uid)
1313
#'
1414
#' @export
15-
nlp_xlm_roberta_embeddings <- function(x, input_cols, output_col,
16-
batch_size = NULL, case_sensitive = NULL, dimension = NULL, max_sentence_length = NULL, storage_ref = NULL,
17-
uid = random_string("xlm_roberta_embeddings_")) {
18-
UseMethod("nlp_xlm_roberta_embeddings")
19-
}
20-
21-
#' @export
22-
nlp_xlm_roberta_embeddings.spark_connection <- function(x, input_cols, output_col,
23-
batch_size = NULL, case_sensitive = NULL, dimension = NULL, max_sentence_length = NULL, storage_ref = NULL,
24-
uid = random_string("xlm_roberta_embeddings_")) {
15+
nlp_xlm_roberta_embeddings_pretrained <- function(sc, input_cols, output_col,
16+
batch_size = NULL, case_sensitive = NULL, dimension = NULL,
17+
max_sentence_length = NULL, storage_ref = NULL,
18+
name = NULL, lang = NULL, remote_loc = NULL) {
2519
args <- list(
2620
input_cols = input_cols,
2721
output_col = output_col,
2822
batch_size = batch_size,
2923
case_sensitive = case_sensitive,
3024
dimension = dimension,
3125
max_sentence_length = max_sentence_length,
32-
storage_ref = storage_ref,
33-
uid = uid
26+
storage_ref = storage_ref
3427
) %>%
35-
validator_nlp_xlm_roberta_embeddings()
36-
37-
jobj <- sparklyr::spark_pipeline_stage(
38-
x, "com.johnsnowlabs.nlp.embeddings.XlmRoBertaEmbeddings",
39-
input_cols = args[["input_cols"]],
40-
output_col = args[["output_col"]],
41-
uid = args[["uid"]]
42-
) %>%
43-
sparklyr::jobj_set_param("setBatchSize", args[["batch_size"]]) %>%
44-
sparklyr::jobj_set_param("setCaseSensitive", args[["case_sensitive"]]) %>%
45-
sparklyr::jobj_set_param("setDimension", args[["dimension"]]) %>%
46-
sparklyr::jobj_set_param("setMaxSentenceLength", args[["max_sentence_length"]]) %>%
47-
sparklyr::jobj_set_param("setStorageRef", args[["storage_ref"]])
48-
49-
new_nlp_xlm_roberta_embeddings(jobj)
50-
}
51-
52-
#' @export
53-
nlp_xlm_roberta_embeddings.ml_pipeline <- function(x, input_cols, output_col,
54-
batch_size = NULL, case_sensitive = NULL, dimension = NULL, max_sentence_length = NULL, storage_ref = NULL,
55-
uid = random_string("xlm_roberta_embeddings_")) {
56-
57-
stage <- nlp_xlm_roberta_embeddings.spark_connection(
58-
x = sparklyr::spark_connection(x),
59-
input_cols = input_cols,
60-
output_col = output_col,
61-
batch_size = batch_size,
62-
case_sensitive = case_sensitive,
63-
dimension = dimension,
64-
max_sentence_length = max_sentence_length,
65-
storage_ref = storage_ref,
66-
uid = uid
67-
)
68-
69-
sparklyr::ml_add_stage(x, stage)
28+
validator_nlp_xlm_roberta_embeddings()
29+
30+
model_class <- "com.johnsnowlabs.nlp.embeddings.XlmRoBertaEmbeddings"
31+
model <- pretrained_model(sc, model_class, name, lang, remote_loc)
32+
spark_jobj(model) %>%
33+
sparklyr::jobj_set_param("setInputCols", args[["input_cols"]]) %>%
34+
sparklyr::jobj_set_param("setOutputCol", args[["output_col"]]) %>%
35+
sparklyr::jobj_set_param("setCaseSensitive", args[["case_sensitive"]]) %>%
36+
sparklyr::jobj_set_param("setBatchSize", args[["batch_size"]]) %>%
37+
sparklyr::jobj_set_param("setDimension", args[["dimension"]]) %>%
38+
sparklyr::jobj_set_param("setMaxSentenceLength", args[["max_sentence_length"]]) %>%
39+
sparklyr::jobj_set_param("setStorageRef", args[["storage_ref"]])
40+
41+
new_ml_transformer(model)
7042
}
7143

72-
#' @export
73-
nlp_xlm_roberta_embeddings.tbl_spark <- function(x, input_cols, output_col,
74-
batch_size = NULL, case_sensitive = NULL, dimension = NULL, max_sentence_length = NULL, storage_ref = NULL,
75-
uid = random_string("xlm_roberta_embeddings_")) {
76-
stage <- nlp_xlm_roberta_embeddings.spark_connection(
77-
x = sparklyr::spark_connection(x),
78-
input_cols = input_cols,
79-
output_col = output_col,
80-
batch_size = batch_size,
81-
case_sensitive = case_sensitive,
82-
dimension = dimension,
83-
max_sentence_length = max_sentence_length,
84-
storage_ref = storage_ref,
85-
uid = uid
86-
)
87-
88-
stage %>% sparklyr::ml_transform(x)
89-
}
9044
#' @import forge
9145
validator_nlp_xlm_roberta_embeddings <- function(args) {
9246
args[["input_cols"]] <- cast_string_list(args[["input_cols"]])
@@ -99,9 +53,7 @@ validator_nlp_xlm_roberta_embeddings <- function(args) {
9953
args
10054
}
10155

102-
nlp_float_params.nlp_xlm_roberta_embeddings <- function(x) {
103-
return(c())
104-
}
10556
new_nlp_xlm_roberta_embeddings <- function(jobj) {
10657
sparklyr::new_ml_transformer(jobj, class = "nlp_xlm_roberta_embeddings")
10758
}
59+

0 commit comments

Comments
 (0)