@@ -10,19 +10,19 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient
10
10
import com .microsoft .azure .synapse .ml .io .http ._
11
11
import com .microsoft .azure .synapse .ml .logging .SynapseMLLogging
12
12
import com .microsoft .azure .synapse .ml .logging .common .PlatformDetails
13
- import com .microsoft .azure .synapse .ml .param .{ GlobalKey , GlobalParams , HasGlobalParams , ServiceParam }
14
- import com .microsoft .azure .synapse .ml .stages .{ DropColumns , Lambda }
13
+ import com .microsoft .azure .synapse .ml .param .{GlobalKey , GlobalParams , HasGlobalParams , ServiceParam , TypedArrayParam }
14
+ import com .microsoft .azure .synapse .ml .stages .{DropColumns , Lambda }
15
15
import org .apache .commons .lang .StringUtils
16
16
import org .apache .http .NameValuePair
17
- import org .apache .http .client .methods .{ HttpEntityEnclosingRequestBase , HttpPost , HttpRequestBase }
17
+ import org .apache .http .client .methods .{HttpEntityEnclosingRequestBase , HttpPost , HttpRequestBase }
18
18
import org .apache .http .client .utils .URLEncodedUtils
19
19
import org .apache .http .entity .AbstractHttpEntity
20
20
import org .apache .http .impl .client .CloseableHttpClient
21
21
import org .apache .spark .ml .param ._
22
- import org .apache .spark .ml .{ ComplexParamsWritable , NamespaceInjections , PipelineModel , Transformer }
23
- import org .apache .spark .sql .functions .{ col , lit , struct }
22
+ import org .apache .spark .ml .{ComplexParamsWritable , NamespaceInjections , PipelineModel , Transformer }
23
+ import org .apache .spark .sql .functions .{col , lit , struct }
24
24
import org .apache .spark .sql .types ._
25
- import org .apache .spark .sql .{ DataFrame , Dataset , Row }
25
+ import org .apache .spark .sql .{DataFrame , Dataset , Row }
26
26
import spray .json .DefaultJsonProtocol ._
27
27
28
28
import java .net .URI
@@ -206,13 +206,37 @@ trait HasCustomHeaders extends HasServiceParams {
206
206
207
207
// For Pyspark compatability accept Java HashMap as input to parameter
208
208
// py4J only natively supports conversions from Python Dict to Java HashMap
209
- def setCustomHeaders (v : java.util.HashMap [String ,String ]): this .type = {
209
+ def setCustomHeaders (v : java.util.HashMap [String , String ]): this .type = {
210
210
setCustomHeaders(v.asScala.toMap)
211
211
}
212
+ }
213
+
214
+ trait HasTelemHeaders extends HasServiceParams {
215
+
216
+ private [ml] val telemHeaders = new ServiceParam [Map [String , String ]](
217
+ this , " telemHeaders" , " Map of Custom Header Key-Value Tuples."
218
+ )
219
+
220
+ private [ml] def setTelemHeaders (v : Map [String , String ]): this .type = {
221
+ setScalarParam(telemHeaders, v)
222
+ }
223
+
224
+ // For Pyspark compatability accept Java HashMap as input to parameter
225
+ // py4J only natively supports conversions from Python Dict to Java HashMap
226
+ private [ml] def setTelemHeaders (v : java.util.HashMap [String , String ]): this .type = {
227
+ setTelemHeaders(v.asScala.toMap)
228
+ }
229
+
230
+ setDefault(telemHeaders -> Left (Map (" x-ai-telemetry-properties" ->
231
+ s """ {
232
+ |"OriginatingService": "SynapseML",
233
+ |"ClientArtifactType": "Spark",
234
+ |"OperationName": " ${this .getClass.getName}"
235
+ |} """ .stripMargin.replaceAll(" \n " , " " ))))
212
236
213
- def getCustomHeaders : Map [String , String ] = getScalarParam(customHeaders)
214
237
}
215
238
239
+
216
240
trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
217
241
def setCustomServiceName (v : String ): this .type = {
218
242
setUrl(s " https:// $v.cognitiveservices.azure.com/ " + urlPath.stripPrefix(" /" ))
@@ -281,7 +305,7 @@ object URLEncodingUtils {
281
305
}
282
306
283
307
trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
284
- with HasCustomHeaders with SynapseMLLogging {
308
+ with HasCustomHeaders with HasTelemHeaders with SynapseMLLogging {
285
309
286
310
val customUrlRoot : Param [String ] = new Param [String ](
287
311
this , " customUrlRoot" , " The custom URL root for the service. " +
@@ -334,7 +358,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
334
358
335
359
protected def getCustomAuthHeader (row : Row ): Option [String ] = {
336
360
val providedCustomAuthHeader = getValueOpt(row, CustomAuthHeader )
337
- if (providedCustomAuthHeader .isEmpty && PlatformDetails .runningOnFabric()) {
361
+ if (providedCustomAuthHeader.isEmpty && PlatformDetails .runningOnFabric()) {
338
362
logInfo(" Using Default AAD Token On Fabric" )
339
363
Option (FabricClient .getCognitiveMWCTokenAuthHeader)
340
364
} else {
@@ -362,6 +386,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
362
386
val contentTypeValue = contentType(row)
363
387
val customAuthHeaderOpt = getCustomAuthHeader(row)
364
388
val customHeadersOpt = getCustomHeaders(row)
389
+ val telemHeadersOpt = getValueOpt(row, telemHeaders)
365
390
366
391
if (subscriptionKeyOpt.nonEmpty) {
367
392
headers += (subscriptionKeyHeaderName -> getValue(row, subscriptionKey))
@@ -376,6 +401,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
376
401
headers += (" x-ms-workload-resource-moniker" -> UUID .randomUUID().toString)
377
402
}
378
403
}
404
+
379
405
if (customHeadersOpt.nonEmpty) {
380
406
customHeadersOpt.foreach { m =>
381
407
m.foreach { case (headerName, headerValue) =>
@@ -384,10 +410,17 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
384
410
}
385
411
}
386
412
413
+ if (telemHeadersOpt.nonEmpty) {
414
+ telemHeadersOpt.foreach { m =>
415
+ m.foreach { case (headerName, headerValue) =>
416
+ headers += (headerName -> headerValue)
417
+ }
418
+ }
419
+ }
420
+
387
421
if (addContentType && ! StringUtils .isEmpty(contentTypeValue)) {
388
422
headers += (" Content-Type" -> contentTypeValue)
389
423
}
390
-
391
424
new scala.collection.immutable.TreeMap [String , String ]() ++ headers
392
425
}
393
426
@@ -514,7 +547,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform
514
547
errorCol -> (this .uid + " _error" )
515
548
)
516
549
517
- if (PlatformDetails .runningOnFabric()) {
550
+ if (PlatformDetails .runningOnFabric()) {
518
551
setDefaultInternalEndpoint(FabricClient .MLWorkloadEndpointML )
519
552
}
520
553
0 commit comments