Skip to content

SPARKNLP-862 Adding ONNX support for E5 Embeddings #13927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 67 additions & 6 deletions src/main/scala/com/johnsnowlabs/ml/ai/E5.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@

package com.johnsnowlabs.ml.ai

import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

import scala.collection.JavaConverters._

/** E5 Sentence embeddings model
* @param tensorflow
* @param tensorflowWrapper
* tensorflow wrapper
* @param configProtoBytes
* config proto bytes
Expand All @@ -36,7 +39,8 @@ import scala.collection.JavaConverters._
* signatures
*/
private[johnsnowlabs] class E5(
val tensorflow: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
configProtoBytes: Option[Array[Byte]] = None,
sentenceStartTokenId: Int,
sentenceEndTokenId: Int,
Expand All @@ -46,7 +50,11 @@ private[johnsnowlabs] class E5(
private val _tfInstructorSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
private val paddingTokenId = 0
private val eosTokenId = 1

val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

/** Get sentence embeddings for a batch of sentences
* @param batch
Expand All @@ -55,6 +63,16 @@ private[johnsnowlabs] class E5(
* sentence embeddings
*/
private def getSentenceEmbedding(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val embeddings = detectedEngine match {
case ONNX.name =>
getSentenceEmbeddingFromOnnx(batch)
case _ =>
getSentenceEmbeddingFromTF(batch)
}
embeddings
}

private def getSentenceEmbeddingFromTF(batch: Seq[Array[Int]]): Array[Array[Float]] = {
// get max sentence length
val sequencesLength = batch.map(x => x.length).toArray
val maxSentenceLength = sequencesLength.max
Expand Down Expand Up @@ -90,7 +108,7 @@ private[johnsnowlabs] class E5(
tensorEncoder.createIntBufferTensor(shape, encoderAttentionMaskBuffers)

// run model
val runner = tensorflow
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
initAllTables = false,
Expand Down Expand Up @@ -129,6 +147,49 @@ private[johnsnowlabs] class E5(
sentenceEmbeddingsFloatsArray
}

private def getSentenceEmbeddingFromOnnx(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max

val (runner, env) = onnxWrapper.get.getSession()
val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

// TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("last_hidden_state")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()

val dim = embeddings.length / batchLength
// group embeddings
val sentenceEmbeddingsFloatsArray = embeddings.grouped(dim).toArray
sentenceEmbeddingsFloatsArray
} finally if (results != null) results.close()
}
}

/** Predict sentence embeddings for a batch of sentences
* @param sentences
* sentences
Expand Down
76 changes: 57 additions & 19 deletions src/main/scala/com/johnsnowlabs/nlp/embeddings/E5Embeddings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.ml.ai.E5
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadTextAsset,
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
Expand Down Expand Up @@ -144,6 +145,7 @@ class E5Embeddings(override val uid: String)
extends AnnotatorModel[E5Embeddings]
with HasBatchedAnnotate[E5Embeddings]
with WriteTensorflowModel
with WriteOnnxModel
with HasEmbeddingsProperties
with HasStorageRef
with HasCaseSensitiveProperties
Expand Down Expand Up @@ -228,12 +230,14 @@ class E5Embeddings(override val uid: String)
/** @group setParam */
def setModelIfNotSet(
spark: SparkSession,
tensorflowWrapper: TensorflowWrapper): E5Embeddings = {
tensorflowWrapper: Option[TensorflowWrapper],
onnxWrapper: Option[OnnxWrapper]): E5Embeddings = {
if (_model.isEmpty) {
_model = Some(
spark.sparkContext.broadcast(
new E5(
tensorflowWrapper,
onnxWrapper,
configProtoBytes = getConfigProtoBytes,
sentenceStartTokenId = sentenceStartTokenId,
sentenceEndTokenId = sentenceEndTokenId,
Expand Down Expand Up @@ -335,14 +339,29 @@ class E5Embeddings(override val uid: String)

override def onWrite(path: String, spark: SparkSession): Unit = {
super.onWrite(path, spark)
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflow,
"_e5",
E5Embeddings.tfFile,
configProtoBytes = getConfigProtoBytes,
savedSignatures = getSignatures)
val suffix = "_e5"

getEngine match {
case TensorFlow.name =>
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper.get,
suffix,
E5Embeddings.tfFile,
configProtoBytes = getConfigProtoBytes,
savedSignatures = getSignatures)
case ONNX.name =>
writeOnnxModel(
path,
spark,
getModelIfNotSet.onnxWrapper.get,
suffix,
E5Embeddings.onnxFile)
case _ =>
throw new Exception(notSupportedEngineError)
}

}

/** @group getParam */
Expand Down Expand Up @@ -379,19 +398,33 @@ trait ReadablePretrainedE5Model
super.pretrained(name, lang, remoteLoc)
}

trait ReadE5DLModel extends ReadTensorflowModel {
trait ReadE5DLModel extends ReadTensorflowModel with ReadOnnxModel {
this: ParamsAndFeaturesReadable[E5Embeddings] =>

override val tfFile: String = "e5_tensorflow"
override val onnxFile: String = "e5_onnx"

def readModel(instance: E5Embeddings, path: String, spark: SparkSession): Unit = {

val tf = readTensorflowModel(
path,
spark,
"_e5_tf",
savedSignatures = instance.getSignatures,
initAllTables = false)
instance.setModelIfNotSet(spark, tf)
instance.getEngine match {
case TensorFlow.name => {
val tf = readTensorflowModel(
path,
spark,
"_e5_tf",
savedSignatures = instance.getSignatures,
initAllTables = false)
instance.setModelIfNotSet(spark, Some(tf), None)
}
case ONNX.name => {
val onnxWrapper =
readOnnxModel(path, spark, "_e5_onnx", zipped = true, useBundle = false, None)
instance.setModelIfNotSet(spark, None, Some(onnxWrapper))
}
case _ =>
throw new Exception(notSupportedEngineError)
}

}

addReader(readModel)
Expand Down Expand Up @@ -423,7 +456,12 @@ trait ReadE5DLModel extends ReadTensorflowModel {
*/
annotatorModel
.setSignatures(_signatures)
.setModelIfNotSet(spark, wrapper)
.setModelIfNotSet(spark, Some(wrapper), None)

case ONNX.name =>
val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true)
annotatorModel
.setModelIfNotSet(spark, None, Some(onnxWrapper))

case _ =>
throw new Exception(notSupportedEngineError)
Expand Down