Skip to content

Commit 4419a70

Browse files
authored
[SPARKNLP-1037] Adding addFile changes to to replace broadcast in all ONNX based annotators (#14236)
* [SPARKNLP-1011] Adding changes to transfer ONNX files on executors through Spark files feature * [SPARKNLP-1011] Adding missing copyright comment * [SPARKNLP-1011] Adding changes to add prefix for models with onnx_data file * [SPARKNLP-1037] Adding changes to transfer ONNX files on executors via addFile * [SPARKNLP-1037] Adding unique suffix to avoid duplication in spark files
1 parent fcd4e9c commit 4419a70

File tree

49 files changed

+262
-221
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+262
-221
lines changed

src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWithoutPastWrappers
2323
import com.johnsnowlabs.ml.onnx.TensorResources.implicits._
2424
import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
2525
import com.johnsnowlabs.nlp.Annotation
26-
27-
import scala.collection.JavaConverters._
2826
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
2927
import org.tensorflow.{Session, Tensor}
3028

29+
import scala.collection.JavaConverters._
30+
3131
private[johnsnowlabs] class M2M100(
3232
val onnxWrappers: EncoderDecoderWithoutPastWrappers,
3333
val spp: SentencePieceWrapper,

src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,10 @@ private[johnsnowlabs] class Whisper(
297297
case TensorFlow.name =>
298298
val session =
299299
tensorflowWrapper.get
300-
.getTFSessionWithSignature(configProtoBytes, savedSignatures = signatures)
300+
.getTFSessionWithSignature(
301+
configProtoBytes,
302+
savedSignatures = signatures,
303+
initAllTables = false)
301304

302305
val encodedBatchFeatures: Tensor =
303306
encode(featuresBatch, Some(session), None).asInstanceOf[Tensor]

src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala

Lines changed: 66 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
package com.johnsnowlabs.ml.onnx
1818

19-
import ai.onnxruntime.OrtSession.SessionOptions
20-
import com.johnsnowlabs.util.FileHelper
2119
import org.apache.commons.io.FileUtils
2220
import org.apache.hadoop.fs.{FileSystem, Path}
2321
import org.apache.spark.sql.SparkSession
@@ -32,11 +30,10 @@ trait WriteOnnxModel {
3230
path: String,
3331
spark: SparkSession,
3432
onnxWrappersWithNames: Seq[(OnnxWrapper, String)],
35-
suffix: String,
36-
dataFileSuffix: String = "_data"): Unit = {
33+
suffix: String): Unit = {
3734

3835
val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
39-
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
36+
val fileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
4037

4138
// 1. Create tmp folder
4239
val tmpFolder = Files
@@ -51,15 +48,16 @@ trait WriteOnnxModel {
5148
onnxWrapper.saveToFile(onnxFile)
5249

5350
// 3. Copy to dest folder
54-
fs.copyFromLocalFile(new Path(onnxFile), new Path(path))
51+
fileSystem.copyFromLocalFile(new Path(onnxFile), new Path(path))
5552

5653
// 4. check if there is a onnx_data file
57-
if (onnxWrapper.onnxModelPath.isDefined) {
58-
val onnxDataFile = new Path(onnxWrapper.onnxModelPath.get + dataFileSuffix)
59-
if (fs.exists(onnxDataFile)) {
60-
fs.copyFromLocalFile(onnxDataFile, new Path(path))
54+
if (onnxWrapper.dataFileDirectory.isDefined) {
55+
val onnxDataFile = new Path(onnxWrapper.dataFileDirectory.get)
56+
if (fileSystem.exists(onnxDataFile)) {
57+
fileSystem.copyFromLocalFile(onnxDataFile, new Path(path))
6158
}
6259
}
60+
6361
}
6462

6563
// 4. Remove tmp folder
@@ -74,7 +72,6 @@ trait WriteOnnxModel {
7472
fileName: String): Unit = {
7573
writeOnnxModels(path, spark, Seq((onnxWrapper, fileName)), suffix)
7674
}
77-
7875
}
7976

8077
trait ReadOnnxModel {
@@ -86,38 +83,61 @@ trait ReadOnnxModel {
8683
suffix: String,
8784
zipped: Boolean = true,
8885
useBundle: Boolean = false,
89-
sessionOptions: Option[SessionOptions] = None,
90-
dataFileSuffix: String = "_data"): OnnxWrapper = {
86+
modelName: Option[String] = None,
87+
tmpFolder: Option[String] = None,
88+
dataFilePostfix: Option[String] = None): OnnxWrapper = {
89+
90+
// 1. Copy to local tmp dir
91+
val localModelFile = if (modelName.isDefined) modelName.get else onnxFile
92+
val srcPath = new Path(path, localModelFile)
93+
val fileSystem = getFileSystem(path, spark)
94+
val localTmpFolder = if (tmpFolder.isDefined) tmpFolder.get else createTmpDirectory(suffix)
95+
fileSystem.copyToLocalFile(srcPath, new Path(localTmpFolder))
96+
97+
// 2. Copy onnx_data file if exists
98+
val fsPath = new Path(path, localModelFile).toString
99+
100+
val onnxDataFile: Option[String] = if (modelName.isDefined && dataFilePostfix.isDefined) {
101+
Some(fsPath.replaceAll(modelName.get, s"${suffix}_${modelName.get}${dataFilePostfix.get}"))
102+
} else None
103+
104+
if (onnxDataFile.isDefined) {
105+
val onnxDataFilePath = new Path(onnxDataFile.get)
106+
if (fileSystem.exists(onnxDataFilePath)) {
107+
fileSystem.copyToLocalFile(onnxDataFilePath, new Path(localTmpFolder))
108+
}
109+
}
110+
111+
// 3. Read ONNX state
112+
val onnxFileTmpPath = new Path(localTmpFolder, localModelFile).toString
113+
val onnxWrapper =
114+
OnnxWrapper.read(
115+
spark,
116+
onnxFileTmpPath,
117+
zipped = zipped,
118+
useBundle = useBundle,
119+
modelName = if (modelName.isDefined) modelName.get else onnxFile,
120+
onnxFileSuffix = Some(suffix))
121+
122+
onnxWrapper
123+
124+
}
91125

126+
private def getFileSystem(path: String, sparkSession: SparkSession): FileSystem = {
92127
val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
93-
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
128+
val fileSystem = FileSystem.get(uri, sparkSession.sparkContext.hadoopConfiguration)
129+
fileSystem
130+
}
131+
132+
private def createTmpDirectory(suffix: String): String = {
94133

95134
// 1. Create tmp directory
96135
val tmpFolder = Files
97-
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
136+
.createTempDirectory(s"${UUID.randomUUID().toString.takeRight(12)}_$suffix")
98137
.toAbsolutePath
99138
.toString
100139

101-
// 2. Copy to local dir
102-
fs.copyToLocalFile(new Path(path, onnxFile), new Path(tmpFolder))
103-
104-
val localPath = new Path(tmpFolder, onnxFile).toString
105-
106-
val fsPath = new Path(path, onnxFile)
107-
108-
// 3. Copy onnx_data file if exists
109-
val onnxDataFile = new Path(fsPath + dataFileSuffix)
110-
111-
if (fs.exists(onnxDataFile)) {
112-
fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder))
113-
}
114-
// 4. Read ONNX state
115-
val onnxWrapper = OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle)
116-
117-
// 5. Remove tmp folder
118-
FileHelper.delete(tmpFolder)
119-
120-
onnxWrapper
140+
tmpFolder
121141
}
122142

123143
def readOnnxModels(
@@ -127,43 +147,23 @@ trait ReadOnnxModel {
127147
suffix: String,
128148
zipped: Boolean = true,
129149
useBundle: Boolean = false,
130-
dataFileSuffix: String = "_data"): Map[String, OnnxWrapper] = {
150+
dataFilePostfix: String = "_data"): Map[String, OnnxWrapper] = {
131151

132-
val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
133-
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
134-
135-
// 1. Create tmp directory
136-
val tmpFolder = Files
137-
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
138-
.toAbsolutePath
139-
.toString
152+
val tmpFolder = Some(createTmpDirectory(suffix))
140153

141154
val wrappers = (modelNames map { modelName: String =>
142-
// 2. Copy to local dir
143-
val localModelFile = modelName
144-
fs.copyToLocalFile(new Path(path, localModelFile), new Path(tmpFolder))
145-
146-
val localPath = new Path(tmpFolder, localModelFile).toString
147-
148-
val fsPath = new Path(path, localModelFile).toString
149-
150-
// 3. Copy onnx_data file if exists
151-
val onnxDataFile = new Path(fsPath + dataFileSuffix)
152-
153-
if (fs.exists(onnxDataFile)) {
154-
fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder))
155-
}
156-
157-
// 4. Read ONNX state
158-
val onnxWrapper =
159-
OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle, modelName = modelName)
160-
155+
val onnxWrapper = readOnnxModel(
156+
path,
157+
spark,
158+
suffix,
159+
zipped,
160+
useBundle,
161+
Some(modelName),
162+
tmpFolder,
163+
Option(dataFilePostfix))
161164
(modelName, onnxWrapper)
162165
}).toMap
163166

164-
// 4. Remove tmp folder
165-
FileHelper.delete(tmpFolder)
166-
167167
wrappers
168168
}
169169

src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@ import ai.onnxruntime.OrtSession.SessionOptions.{ExecutionMode, OptLevel}
2121
import ai.onnxruntime.providers.OrtCUDAProviderOptions
2222
import ai.onnxruntime.{OrtEnvironment, OrtSession}
2323
import com.johnsnowlabs.util.{ConfigHelper, FileHelper, ZipArchiveUtil}
24-
import org.apache.commons.io.FileUtils
24+
import org.apache.spark.SparkFiles
25+
import org.apache.spark.sql.SparkSession
2526
import org.slf4j.{Logger, LoggerFactory}
26-
import org.apache.hadoop.fs.{FileSystem, Path}
27+
2728
import java.io._
2829
import java.nio.file.{Files, Paths}
2930
import java.util.UUID
3031
import scala.util.{Failure, Success, Try}
3132

32-
class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String] = None)
33+
class OnnxWrapper(var modelFileName: Option[String] = None, var dataFileDirectory: Option[String])
3334
extends Serializable {
3435

3536
/** For Deserialization */
@@ -43,10 +44,15 @@ class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String]
4344

4445
def getSession(onnxSessionOptions: Map[String, String]): (OrtSession, OrtEnvironment) =
4546
this.synchronized {
46-
// TODO: After testing it works remove the Map.empty
4747
if (ortSession == null && ortEnv == null) {
48+
val modelFilePath = if (modelFileName.isDefined) {
49+
SparkFiles.get(modelFileName.get)
50+
} else {
51+
throw new UnsupportedOperationException("modelFileName not defined")
52+
}
53+
4854
val (session, env) =
49-
OnnxWrapper.withSafeOnnxModelLoader(onnxModel, onnxSessionOptions, onnxModelPath)
55+
OnnxWrapper.withSafeOnnxModelLoader(onnxSessionOptions, Some(modelFilePath))
5056
ortEnv = env
5157
ortSession = session
5258
}
@@ -60,17 +66,11 @@ class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String]
6066
.toAbsolutePath
6167
.toString
6268

63-
// 2. Save onnx model
64-
val fileName = Paths.get(file).getFileName.toString
65-
val onnxFile = Paths
66-
.get(tmpFolder, fileName)
67-
.toString
68-
69-
FileUtils.writeByteArrayToFile(new File(onnxFile), onnxModel)
70-
// 4. Zip folder
71-
if (zip) ZipArchiveUtil.zip(tmpFolder, file)
69+
val tmpModelFilePath = SparkFiles.get(modelFileName.get)
70+
// 2. Zip folder
71+
if (zip) ZipArchiveUtil.zip(tmpModelFilePath, file)
7272

73-
// 5. Remove tmp directory
73+
// 3. Remove tmp directory
7474
FileHelper.delete(tmpFolder)
7575
}
7676

@@ -82,7 +82,6 @@ object OnnxWrapper {
8282

8383
// TODO: make sure this.synchronized is needed or it's not a bottleneck
8484
private def withSafeOnnxModelLoader(
85-
onnxModel: Array[Byte],
8685
sessionOptions: Map[String, String],
8786
onnxModelPath: Option[String] = None): (OrtSession, OrtEnvironment) =
8887
this.synchronized {
@@ -96,19 +95,18 @@ object OnnxWrapper {
9695
val session = env.createSession(onnxModelPath.get, sessionOptionsObject)
9796
(session, env)
9897
} else {
99-
val session = env.createSession(onnxModel, sessionOptionsObject)
100-
(session, env)
98+
throw new UnsupportedOperationException("onnxModelPath not defined")
10199
}
102100
}
103101

104-
// TODO: the parts related to onnx_data should be refactored once we support addFile()
105102
def read(
103+
sparkSession: SparkSession,
106104
modelPath: String,
107105
zipped: Boolean = true,
108106
useBundle: Boolean = false,
109107
modelName: String = "model",
110-
dataFileSuffix: String = "_data"): OnnxWrapper = {
111-
108+
dataFileSuffix: Option[String] = Some("_data"),
109+
onnxFileSuffix: Option[String] = None): OnnxWrapper = {
112110
// 1. Create tmp folder
113111
val tmpFolder = Files
114112
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_onnx")
@@ -118,11 +116,10 @@ object OnnxWrapper {
118116
// 2. Unpack archive
119117
val folder =
120118
if (zipped)
121-
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder))
119+
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), onnxFileSuffix)
122120
else
123121
modelPath
124122

125-
val sessionOptions = new OnnxSession().getSessionOptions
126123
val onnxFile =
127124
if (useBundle) Paths.get(modelPath, s"$modelName.onnx").toString
128125
else Paths.get(folder, new File(folder).list().head).toString
@@ -134,38 +131,23 @@ object OnnxWrapper {
134131
val parentDir = if (zipped) Paths.get(modelPath).getParent.toString else modelPath
135132

136133
val onnxDataFileExist: Boolean = {
137-
onnxDataFile = Paths.get(parentDir, modelName + dataFileSuffix).toFile
138-
onnxDataFile.exists()
134+
if (onnxFileSuffix.isDefined && dataFileSuffix.isDefined) {
135+
val onnxDataFilePath = s"${onnxFileSuffix.get}_$modelName${dataFileSuffix.get}"
136+
onnxDataFile = Paths.get(parentDir, onnxDataFilePath).toFile
137+
onnxDataFile.exists()
138+
} else false
139139
}
140140

141141
if (onnxDataFileExist) {
142-
val onnxDataFileTmp =
143-
Paths.get(tmpFolder, modelName + dataFileSuffix).toFile
144-
FileUtils.copyFile(onnxDataFile, onnxDataFileTmp)
142+
sparkSession.sparkContext.addFile(onnxDataFile.toString)
145143
}
146144

147-
val modelFile = new File(onnxFile)
148-
val modelBytes = FileUtils.readFileToByteArray(modelFile)
149-
var session: OrtSession = null
150-
var env: OrtEnvironment = null
151-
if (onnxDataFileExist) {
152-
val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, Some(onnxFile))
153-
session = _session
154-
env = _env
155-
} else {
156-
val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, None)
157-
session = _session
158-
env = _env
145+
sparkSession.sparkContext.addFile(onnxFile)
159146

160-
}
161-
// 4. Remove tmp folder
162-
FileHelper.delete(tmpFolder)
147+
val onnxFileName = Some(new File(onnxFile).getName)
148+
val dataFileDirectory = if (onnxDataFileExist) Some(onnxDataFile.toString) else None
149+
val onnxWrapper = new OnnxWrapper(onnxFileName, dataFileDirectory)
163150

164-
val onnxWrapper =
165-
if (onnxDataFileExist) new OnnxWrapper(modelBytes, Option(onnxFile))
166-
else new OnnxWrapper(modelBytes)
167-
onnxWrapper.ortSession = session
168-
onnxWrapper.ortEnv = env
169151
onnxWrapper
170152
}
171153

0 commit comments

Comments
 (0)