1
1
# ' Spark NLP DistilBertForTokenClassification
2
2
# '
3
- # ' Spark ML transformer that
3
+ # ' DistilBertForTokenClassification can load Bert Models with a token classification head on top
4
+ # ' (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.
4
5
# ' See \url{https://nlp.johnsnowlabs.com/docs/en/transformers#distilbertfortokenclassification}
5
6
# '
6
7
# ' @template roxlate-nlp-algo
10
11
# ' @param max_sentence_length Max sentence length to process (Default: 128)
11
12
# '
12
13
# ' @export
13
- nlp_distilbert_for_token_classification <- function (x , input_cols , output_col ,
14
- batch_size = NULL , case_sensitive = NULL , max_sentence_length = NULL ,
15
- uid = random_string(" distilbert_for_token_classification_" )) {
16
- UseMethod(" nlp_distilbert_for_token_classification" )
17
- }
18
-
19
- # ' @export
20
- nlp_distilbert_for_token_classification.spark_connection <- function (x , input_cols , output_col ,
21
- batch_size = NULL , case_sensitive = NULL , max_sentence_length = NULL ,
22
- uid = random_string(" distilbert_for_token_classification_" )) {
14
+ nlp_distilbert_token_classification_pretrained <- function (sc , input_cols , output_col ,
15
+ batch_size = NULL , case_sensitive = NULL ,
16
+ max_sentence_length = NULL ,
17
+ name = NULL , lang = NULL , remote_loc = NULL ) {
23
18
args <- list (
24
19
input_cols = input_cols ,
25
20
output_col = output_col ,
26
21
batch_size = batch_size ,
27
22
case_sensitive = case_sensitive ,
28
- max_sentence_length = max_sentence_length ,
29
- uid = uid
23
+ max_sentence_length = max_sentence_length
30
24
) %> %
31
- validator_nlp_distilbert_for_token_classification()
32
-
33
- jobj <- sparklyr :: spark_pipeline_stage(
34
- x , " com.johnsnowlabs.nlp.annotators.classifier.dl.DistilBertForTokenClassification" ,
35
- input_cols = args [[" input_cols" ]],
36
- output_col = args [[" output_col" ]],
37
- uid = args [[" uid" ]]
38
- ) %> %
39
- sparklyr :: jobj_set_param(" setBatchSize" , args [[" batch_size" ]]) %> %
40
- sparklyr :: jobj_set_param(" setCaseSensitive" , args [[" case_sensitive" ]]) %> %
41
- sparklyr :: jobj_set_param(" setMaxSentenceLength" , args [[" max_sentence_length" ]])
42
-
43
- new_nlp_distilbert_for_token_classification(jobj )
25
+ validator_nlp_distilbert_token_classification()
26
+
27
+ model_class <- " com.johnsnowlabs.nlp.annotators.classifier.dl.DistilBertForTokenClassification"
28
+ model <- pretrained_model(sc , model_class , name , lang , remote_loc )
29
+ spark_jobj(model ) %> %
30
+ sparklyr :: jobj_set_param(" setInputCols" , args [[" input_cols" ]]) %> %
31
+ sparklyr :: jobj_set_param(" setOutputCol" , args [[" output_col" ]]) %> %
32
+ sparklyr :: jobj_set_param(" setCaseSensitive" , args [[" case_sensitive" ]]) %> %
33
+ sparklyr :: jobj_set_param(" setBatchSize" , args [[" batch_size" ]]) %> %
34
+ sparklyr :: jobj_set_param(" setMaxSentenceLength" , args [[" max_sentence_length" ]])
35
+
36
+ new_ml_transformer(model )
44
37
}
45
38
46
- # ' @export
47
- nlp_distilbert_for_token_classification.ml_pipeline <- function (x , input_cols , output_col ,
48
- batch_size = NULL , case_sensitive = NULL , max_sentence_length = NULL ,
49
- uid = random_string(" distilbert_for_token_classification_" )) {
50
-
51
- stage <- nlp_distilbert_for_token_classification.spark_connection(
52
- x = sparklyr :: spark_connection(x ),
53
- input_cols = input_cols ,
54
- output_col = output_col ,
55
- batch_size = batch_size ,
56
- case_sensitive = case_sensitive ,
57
- max_sentence_length = max_sentence_length ,
58
- uid = uid
59
- )
60
-
61
- sparklyr :: ml_add_stage(x , stage )
62
- }
63
-
64
- # ' @export
65
- nlp_distilbert_for_token_classification.tbl_spark <- function (x , input_cols , output_col ,
66
- batch_size = NULL , case_sensitive = NULL , max_sentence_length = NULL ,
67
- uid = random_string(" distilbert_for_token_classification_" )) {
68
- stage <- nlp_distilbert_for_token_classification.spark_connection(
69
- x = sparklyr :: spark_connection(x ),
70
- input_cols = input_cols ,
71
- output_col = output_col ,
72
- batch_size = batch_size ,
73
- case_sensitive = case_sensitive ,
74
- max_sentence_length = max_sentence_length ,
75
- uid = uid
76
- )
77
-
78
- stage %> % sparklyr :: ml_transform(x )
79
- }
80
39
# ' @import forge
81
- validator_nlp_distilbert_for_token_classification <- function (args ) {
40
+ validator_nlp_distilbert_token_classification <- function (args ) {
82
41
args [[" input_cols" ]] <- cast_string_list(args [[" input_cols" ]])
83
42
args [[" output_col" ]] <- cast_string(args [[" output_col" ]])
84
43
args [[" batch_size" ]] <- cast_nullable_integer(args [[" batch_size" ]])
@@ -87,9 +46,6 @@ validator_nlp_distilbert_for_token_classification <- function(args) {
87
46
args
88
47
}
89
48
90
- nlp_float_params.nlp_distilbert_for_token_classification <- function (x ) {
91
- return (c())
92
- }
93
49
new_nlp_distilbert_for_token_classification <- function (jobj ) {
94
50
sparklyr :: new_ml_transformer(jobj , class = " nlp_distilbert_for_token_classification" )
95
51
}
0 commit comments