Skip to content

Commit 818de8d

Browse files
authored
Sparknlp 967 add onnx support to xlm roberta classifiers (#14130)
* fixing typo + adding support for ONNX to XLM-Roberta * adding conversion notebooks
1 parent 8585b7e commit 818de8d

14 files changed

+7116
-219
lines changed

examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_XlmRoBertaForQuestionAnswering.ipynb

+2,433
Large diffs are not rendered by default.

examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_XlmRoBertaForSequenceClassification.ipynb

+2,173
Large diffs are not rendered by default.

examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_XlmRoBertaForTokenClassification.ipynb

+2,144
Large diffs are not rendered by default.

src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ private[johnsnowlabs] class AlbertClassification(
108108
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
109109

110110
val rawScores = detectedEngine match {
111-
case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true)
111+
case ONNX.name => getRawScoresWithOnnx(batch, maxSentenceLength, sequence = true)
112112
case _ => getRawScoresWithTF(batch, maxSentenceLength)
113113
}
114114

@@ -128,7 +128,7 @@ private[johnsnowlabs] class AlbertClassification(
128128
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
129129

130130
val rawScores = detectedEngine match {
131-
case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true)
131+
case ONNX.name => getRawScoresWithOnnx(batch, maxSentenceLength, sequence = true)
132132
case _ => getRawScoresWithTF(batch, maxSentenceLength)
133133
}
134134

@@ -203,7 +203,7 @@ private[johnsnowlabs] class AlbertClassification(
203203
rawScores
204204
}
205205

206-
private def getRowScoresWithOnnx(
206+
private def getRawScoresWithOnnx(
207207
batch: Seq[Array[Int]],
208208
maxSentenceLength: Int,
209209
sequence: Boolean): Array[Float] = {

src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ private[johnsnowlabs] class BertClassification(
149149

150150
val rawScores = detectedEngine match {
151151
case ONNX.name =>
152-
getRowScoresWithOnnx(batch, maxSentenceLength)
152+
getRawScoresWithOnnx(batch, maxSentenceLength)
153153
case _ => getRawScoresWithTF(batch, maxSentenceLength)
154154
}
155155

@@ -218,7 +218,7 @@ private[johnsnowlabs] class BertClassification(
218218
rawScores
219219
}
220220

221-
private def getRowScoresWithOnnx(
221+
private def getRawScoresWithOnnx(
222222
batch: Seq[Array[Int]],
223223
maxSentenceLength: Int): Array[Float] = {
224224

@@ -265,7 +265,7 @@ private[johnsnowlabs] class BertClassification(
265265
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
266266
val rawScores = detectedEngine match {
267267
case ONNX.name =>
268-
getRowScoresWithOnnx(batch, maxSentenceLength)
268+
getRawScoresWithOnnx(batch, maxSentenceLength)
269269
case _ => getRawScoresWithTF(batch, maxSentenceLength)
270270
}
271271

src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ private[johnsnowlabs] class CamemBertClassification(
123123
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
124124

125125
val rawScores = detectedEngine match {
126-
case ONNX.name => getRowScoresWithOnnx(batch)
126+
case ONNX.name => getRawScoresWithOnnx(batch)
127127
case _ => getRawScoresWithTF(batch, maxSentenceLength)
128128
}
129129

@@ -189,7 +189,7 @@ private[johnsnowlabs] class CamemBertClassification(
189189
rawScores
190190
}
191191

192-
private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {
192+
private def getRawScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {
193193

194194
// [nb of encoded sentences , maxSentenceLength]
195195
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)
@@ -227,7 +227,7 @@ private[johnsnowlabs] class CamemBertClassification(
227227
val batchLength = batch.length
228228

229229
val rawScores = detectedEngine match {
230-
case ONNX.name => getRowScoresWithOnnx(batch)
230+
case ONNX.name => getRawScoresWithOnnx(batch)
231231
case _ => getRawScoresWithTF(batch, maxSentenceLength)
232232
}
233233

src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ private[johnsnowlabs] class DeBertaClassification(
109109
val batchLength = batch.length
110110

111111
val rawScores = detectedEngine match {
112-
case ONNX.name => getRowScoresWithOnnx(batch)
112+
case ONNX.name => getRawScoresWithOnnx(batch)
113113
case _ => getRawScoresWithTF(batch)
114114
}
115115

@@ -182,7 +182,7 @@ private[johnsnowlabs] class DeBertaClassification(
182182
rawScores
183183
}
184184

185-
private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {
185+
private def getRawScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {
186186

187187
// [nb of encoded sentences , maxSentenceLength]
188188
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)
@@ -219,7 +219,7 @@ private[johnsnowlabs] class DeBertaClassification(
219219
val batchLength = batch.length
220220

221221
val rawScores = detectedEngine match {
222-
case ONNX.name => getRowScoresWithOnnx(batch)
222+
case ONNX.name => getRawScoresWithOnnx(batch)
223223
case _ => getRawScoresWithTF(batch)
224224
}
225225

src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ private[johnsnowlabs] class DistilBertClassification(
148148
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
149149

150150
val rawScores = detectedEngine match {
151-
case ONNX.name => getRowScoresWithOnnx(batch)
151+
case ONNX.name => getRawScoresWithOnnx(batch)
152152
case _ => getRawScoresWithTF(batch, maxSentenceLength)
153153
}
154154

@@ -211,7 +211,7 @@ private[johnsnowlabs] class DistilBertClassification(
211211
rawScores
212212
}
213213

214-
private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {
214+
private def getRawScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {
215215

216216
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)
217217

@@ -247,7 +247,7 @@ private[johnsnowlabs] class DistilBertClassification(
247247
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
248248

249249
val rawScores = detectedEngine match {
250-
case ONNX.name => getRowScoresWithOnnx(batch)
250+
case ONNX.name => getRawScoresWithOnnx(batch)
251251
case _ => getRawScoresWithTF(batch, maxSentenceLength)
252252
}
253253

src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ private[johnsnowlabs] class RoBertaClassification(
141141
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
142142

143143
val rawScores = detectedEngine match {
144-
case ONNX.name => getRowScoresWithOnnx(batch)
144+
case ONNX.name => getRawScoresWithOnnx(batch)
145145
case _ => getRawScoresWithTF(batch, maxSentenceLength)
146146
}
147147

@@ -207,7 +207,7 @@ private[johnsnowlabs] class RoBertaClassification(
207207
rawScores
208208
}
209209

210-
private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {
210+
private def getRawScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {
211211

212212
// [nb of encoded sentences , maxSentenceLength]
213213
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)
@@ -244,7 +244,7 @@ private[johnsnowlabs] class RoBertaClassification(
244244
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
245245

246246
val rawScores = detectedEngine match {
247-
case ONNX.name => getRowScoresWithOnnx(batch)
247+
case ONNX.name => getRawScoresWithOnnx(batch)
248248
case _ => getRawScoresWithTF(batch, maxSentenceLength)
249249
}
250250

0 commit comments

Comments
 (0)