@@ -47,6 +47,8 @@ object SparkNLP {
47
47
* @param cluster_tmp_dir
48
48
* The location to save logs from annotators during training (By default, it will be in the
49
49
* users home directory under `annotator_logs`.)
50
+ * @param params
51
+ * Custom parameters to set for the Spark configuration (Default: `Map.empty`)
50
52
* @return
51
53
* SparkSession
52
54
*/
@@ -57,9 +59,13 @@ object SparkNLP {
57
59
memory : String = " 16G" ,
58
60
cache_folder : String = " " ,
59
61
log_folder : String = " " ,
60
- cluster_tmp_dir : String = " " ): SparkSession = {
62
+ cluster_tmp_dir : String = " " ,
63
+ params : Map [String , String ] = Map .empty): SparkSession = {
61
64
62
- val build = SparkSession
65
+ if (SparkSession .getActiveSession.isDefined)
66
+ println(" Warning: Spark Session already created, some configs may not be applied." )
67
+
68
+ val builder = SparkSession
63
69
.builder()
64
70
.appName(" Spark NLP" )
65
71
.master(" local[*]" )
@@ -68,26 +74,33 @@ object SparkNLP {
68
74
.config(" spark.kryoserializer.buffer.max" , " 2000M" )
69
75
.config(" spark.driver.maxResultSize" , " 0" )
70
76
71
- if (apple_silicon) {
72
- build.config(" spark.jars.packages" , MavenSparkSilicon )
73
- } else if (aarch64) {
74
- build.config(" spark.jars.packages" , MavenSparkAarch64 )
75
- } else if (gpu) {
76
- build.config(" spark.jars.packages" , MavenGpuSpark3 )
77
- } else {
78
- build.config(" spark.jars.packages" , MavenSpark3 )
77
+ val sparkNlpJar =
78
+ if (apple_silicon) MavenSparkSilicon
79
+ else if (aarch64) MavenSparkAarch64
80
+ else if (gpu) MavenGpuSpark3
81
+ else MavenSpark3
82
+
83
+ if (! params.contains(" spark.jars.packages" )) {
84
+ builder.config(" spark.jars.packages" , sparkNlpJar)
85
+ }
86
+
87
+ params.foreach {
88
+ case (key, value) if key == " spark.jars.packages" =>
89
+ builder.config(key, sparkNlpJar + " ," + value)
90
+ case (key, value) =>
91
+ builder.config(key, value)
79
92
}
80
93
81
94
if (cache_folder.nonEmpty)
82
- build .config(" spark.jsl.settings.pretrained.cache_folder" , cache_folder)
95
+ builder .config(" spark.jsl.settings.pretrained.cache_folder" , cache_folder)
83
96
84
97
if (log_folder.nonEmpty)
85
- build .config(" spark.jsl.settings.annotator.log_folder" , log_folder)
98
+ builder .config(" spark.jsl.settings.annotator.log_folder" , log_folder)
86
99
87
100
if (cluster_tmp_dir.nonEmpty)
88
- build .config(" spark.jsl.settings.storage.cluster_tmp_dir" , cluster_tmp_dir)
101
+ builder .config(" spark.jsl.settings.storage.cluster_tmp_dir" , cluster_tmp_dir)
89
102
90
- build .getOrCreate()
103
+ builder .getOrCreate()
91
104
}
92
105
93
106
def version (): String = {
0 commit comments