Skip to content

Commit 98c8ffa

Browse files
authored
SPARKNLP-743: Add parameter to SparkNLP.start (#13510)
- Added params parameter which can supply custom configurations to the SparkSession
1 parent e8f4bed commit 98c8ffa

File tree

4 files changed

+60
-18
lines changed

4 files changed

+60
-18
lines changed

python/sparknlp/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,13 @@ def start(gpu=False,
107107
for WordEmbeddings. By default, this locations is the location of
108108
`hadoop.tmp.dir` set via Hadoop configuration for Apache Spark. NOTE: `S3` is
109109
not supported and it must be local, HDFS, or DBFS.
110+
params : dict, optional
111+
Custom parameters to set for the Spark configuration, by default None.
110112
cluster_tmp_dir : str, optional
111113
The location to save logs from annotators during training. If not set, it will
112114
be in the users home directory under `annotator_logs`.
113115
real_time_output : bool, optional
114-
Whether to output in real time, by default False
116+
Whether to read and print JVM output in real time, by default False
115117
output_level : int, optional
116118
Output level for logs, by default 1
117119

src/main/scala/com/johnsnowlabs/nlp/SparkNLP.scala

+27-14
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ object SparkNLP {
4747
* @param cluster_tmp_dir
4848
* The location to save logs from annotators during training (By default, it will be in the
4949
* users home directory under `annotator_logs`.)
50+
* @param params
51+
* Custom parameters to set for the Spark configuration (Default: `Map.empty`)
5052
* @return
5153
* SparkSession
5254
*/
@@ -57,9 +59,13 @@ object SparkNLP {
5759
memory: String = "16G",
5860
cache_folder: String = "",
5961
log_folder: String = "",
60-
cluster_tmp_dir: String = ""): SparkSession = {
62+
cluster_tmp_dir: String = "",
63+
params: Map[String, String] = Map.empty): SparkSession = {
6164

62-
val build = SparkSession
65+
if (SparkSession.getActiveSession.isDefined)
66+
println("Warning: Spark Session already created, some configs may not be applied.")
67+
68+
val builder = SparkSession
6369
.builder()
6470
.appName("Spark NLP")
6571
.master("local[*]")
@@ -68,26 +74,33 @@ object SparkNLP {
6874
.config("spark.kryoserializer.buffer.max", "2000M")
6975
.config("spark.driver.maxResultSize", "0")
7076

71-
if (apple_silicon) {
72-
build.config("spark.jars.packages", MavenSparkSilicon)
73-
} else if (aarch64) {
74-
build.config("spark.jars.packages", MavenSparkAarch64)
75-
} else if (gpu) {
76-
build.config("spark.jars.packages", MavenGpuSpark3)
77-
} else {
78-
build.config("spark.jars.packages", MavenSpark3)
77+
val sparkNlpJar =
78+
if (apple_silicon) MavenSparkSilicon
79+
else if (aarch64) MavenSparkAarch64
80+
else if (gpu) MavenGpuSpark3
81+
else MavenSpark3
82+
83+
if (!params.contains("spark.jars.packages")) {
84+
builder.config("spark.jars.packages", sparkNlpJar)
85+
}
86+
87+
params.foreach {
88+
case (key, value) if key == "spark.jars.packages" =>
89+
builder.config(key, sparkNlpJar + "," + value)
90+
case (key, value) =>
91+
builder.config(key, value)
7992
}
8093

8194
if (cache_folder.nonEmpty)
82-
build.config("spark.jsl.settings.pretrained.cache_folder", cache_folder)
95+
builder.config("spark.jsl.settings.pretrained.cache_folder", cache_folder)
8396

8497
if (log_folder.nonEmpty)
85-
build.config("spark.jsl.settings.annotator.log_folder", log_folder)
98+
builder.config("spark.jsl.settings.annotator.log_folder", log_folder)
8699

87100
if (cluster_tmp_dir.nonEmpty)
88-
build.config("spark.jsl.settings.storage.cluster_tmp_dir", cluster_tmp_dir)
101+
builder.config("spark.jsl.settings.storage.cluster_tmp_dir", cluster_tmp_dir)
89102

90-
build.getOrCreate()
103+
builder.getOrCreate()
91104
}
92105

93106
def version(): String = {

src/test/java/com/johnsnowlabs/nlp/GeneralAnnotationsTest.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.spark.sql.Encoders;
2828
import org.apache.spark.sql.Row;
2929
import org.apache.spark.sql.SparkSession;
30+
import scala.collection.immutable.HashMap;
3031

3132
import java.util.LinkedList;
3233

@@ -46,14 +47,13 @@ public static void main(String[] args) {
4647
Pipeline pipeline = new Pipeline();
4748
pipeline.setStages(new PipelineStage[]{document, tokenizer});
4849

49-
SparkSession spark = com.johnsnowlabs.nlp.SparkNLP.start(
50-
false,
50+
SparkSession spark = com.johnsnowlabs.nlp.SparkNLP.start(false,
5151
false,
5252
false,
5353
"16G",
5454
"",
5555
"",
56-
"");
56+
"", new HashMap<>());
5757

5858
LinkedList<String> text = new java.util.LinkedList<>();
5959

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package com.johnsnowlabs.nlp
2+
3+
import com.johnsnowlabs.tags.SlowTest
4+
import com.johnsnowlabs.util.ConfigHelper.{awsJavaSdkVersion, hadoopAwsVersion}
5+
import org.scalatest.flatspec.AnyFlatSpec
6+
7+
class SparkNLPTestSpec extends AnyFlatSpec {
8+
9+
behavior of "SparkNLPTestSpec"
10+
11+
it should "start with extra parameters" taggedAs SlowTest ignore {
12+
val extraParams: Map[String, String] = Map(
13+
"spark.jars.packages" -> ("org.apache.hadoop:hadoop-aws:" + hadoopAwsVersion + ",com.amazonaws:aws-java-sdk:" + awsJavaSdkVersion),
14+
"spark.hadoop.fs.s3a.path.style.access" -> "true")
15+
16+
val spark = SparkNLP.start(params = extraParams)
17+
18+
assert(spark.conf.get("spark.hadoop.fs.s3a.path.style.access") == "true")
19+
20+
Seq(
21+
"com.johnsnowlabs.nlp:spark-nlp",
22+
"org.apache.hadoop:hadoop-aws",
23+
"com.amazonaws:aws-java-sdk").foreach { pkg =>
24+
assert(spark.conf.get("spark.jars.packages").contains(pkg))
25+
}
26+
}
27+
}

0 commit comments

Comments
 (0)