Skip to content

Commit 85a6687

Browse files
authored
chore: fix propagation of fabric telemetry (#2403)
* chore: fix propegation of fabric telemetry * chore: refine names
1 parent 298c7ed commit 85a6687

File tree

2 files changed

+46
-16
lines changed

2 files changed

+46
-16
lines changed

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient
1010
import com.microsoft.azure.synapse.ml.io.http._
1111
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
1212
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
13-
import com.microsoft.azure.synapse.ml.param.{ GlobalKey, GlobalParams, HasGlobalParams, ServiceParam }
14-
import com.microsoft.azure.synapse.ml.stages.{ DropColumns, Lambda }
13+
import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, HasGlobalParams, ServiceParam, TypedArrayParam}
14+
import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda}
1515
import org.apache.commons.lang.StringUtils
1616
import org.apache.http.NameValuePair
17-
import org.apache.http.client.methods.{ HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase }
17+
import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase}
1818
import org.apache.http.client.utils.URLEncodedUtils
1919
import org.apache.http.entity.AbstractHttpEntity
2020
import org.apache.http.impl.client.CloseableHttpClient
2121
import org.apache.spark.ml.param._
22-
import org.apache.spark.ml.{ ComplexParamsWritable, NamespaceInjections, PipelineModel, Transformer }
23-
import org.apache.spark.sql.functions.{ col, lit, struct }
22+
import org.apache.spark.ml.{ComplexParamsWritable, NamespaceInjections, PipelineModel, Transformer}
23+
import org.apache.spark.sql.functions.{col, lit, struct}
2424
import org.apache.spark.sql.types._
25-
import org.apache.spark.sql.{ DataFrame, Dataset, Row }
25+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
2626
import spray.json.DefaultJsonProtocol._
2727

2828
import java.net.URI
@@ -206,13 +206,37 @@ trait HasCustomHeaders extends HasServiceParams {
206206

207207
// For Pyspark compatability accept Java HashMap as input to parameter
208208
// py4J only natively supports conversions from Python Dict to Java HashMap
209-
def setCustomHeaders(v: java.util.HashMap[String,String]): this.type = {
209+
def setCustomHeaders(v: java.util.HashMap[String, String]): this.type = {
210210
setCustomHeaders(v.asScala.toMap)
211211
}
212+
}
213+
214+
trait HasTelemHeaders extends HasServiceParams {
215+
216+
private[ml] val telemHeaders = new ServiceParam[Map[String, String]](
217+
this, "telemHeaders", "Map of Custom Header Key-Value Tuples."
218+
)
219+
220+
private[ml] def setTelemHeaders(v: Map[String, String]): this.type = {
221+
setScalarParam(telemHeaders, v)
222+
}
223+
224+
// For Pyspark compatability accept Java HashMap as input to parameter
225+
// py4J only natively supports conversions from Python Dict to Java HashMap
226+
private[ml] def setTelemHeaders(v: java.util.HashMap[String, String]): this.type = {
227+
setTelemHeaders(v.asScala.toMap)
228+
}
229+
230+
setDefault(telemHeaders -> Left(Map("x-ai-telemetry-properties"->
231+
s"""{
232+
|"OriginatingService": "SynapseML",
233+
|"ClientArtifactType": "Spark",
234+
|"OperationName": "${this.getClass.getName}"
235+
|}""".stripMargin.replaceAll("\n", ""))))
212236

213-
def getCustomHeaders: Map[String, String] = getScalarParam(customHeaders)
214237
}
215238

239+
216240
trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
217241
def setCustomServiceName(v: String): this.type = {
218242
setUrl(s"https://$v.cognitiveservices.azure.com/" + urlPath.stripPrefix("/"))
@@ -281,7 +305,7 @@ object URLEncodingUtils {
281305
}
282306

283307
trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
284-
with HasCustomHeaders with SynapseMLLogging {
308+
with HasCustomHeaders with HasTelemHeaders with SynapseMLLogging {
285309

286310
val customUrlRoot: Param[String] = new Param[String](
287311
this, "customUrlRoot", "The custom URL root for the service. " +
@@ -334,7 +358,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
334358

335359
protected def getCustomAuthHeader(row: Row): Option[String] = {
336360
val providedCustomAuthHeader = getValueOpt(row, CustomAuthHeader)
337-
if (providedCustomAuthHeader .isEmpty && PlatformDetails.runningOnFabric()) {
361+
if (providedCustomAuthHeader.isEmpty && PlatformDetails.runningOnFabric()) {
338362
logInfo("Using Default AAD Token On Fabric")
339363
Option(FabricClient.getCognitiveMWCTokenAuthHeader)
340364
} else {
@@ -362,6 +386,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
362386
val contentTypeValue = contentType(row)
363387
val customAuthHeaderOpt = getCustomAuthHeader(row)
364388
val customHeadersOpt = getCustomHeaders(row)
389+
val telemHeadersOpt = getValueOpt(row, telemHeaders)
365390

366391
if (subscriptionKeyOpt.nonEmpty) {
367392
headers += (subscriptionKeyHeaderName -> getValue(row, subscriptionKey))
@@ -376,6 +401,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
376401
headers += ("x-ms-workload-resource-moniker" -> UUID.randomUUID().toString)
377402
}
378403
}
404+
379405
if (customHeadersOpt.nonEmpty) {
380406
customHeadersOpt.foreach { m =>
381407
m.foreach { case (headerName, headerValue) =>
@@ -384,10 +410,17 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
384410
}
385411
}
386412

413+
if (telemHeadersOpt.nonEmpty) {
414+
telemHeadersOpt.foreach { m =>
415+
m.foreach { case (headerName, headerValue) =>
416+
headers += (headerName -> headerValue)
417+
}
418+
}
419+
}
420+
387421
if (addContentType && !StringUtils.isEmpty(contentTypeValue)) {
388422
headers += ("Content-Type" -> contentTypeValue)
389423
}
390-
391424
new scala.collection.immutable.TreeMap[String, String]() ++ headers
392425
}
393426

@@ -514,7 +547,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform
514547
errorCol -> (this.uid + "_error")
515548
)
516549

517-
if(PlatformDetails.runningOnFabric()) {
550+
if (PlatformDetails.runningOnFabric()) {
518551
setDefaultInternalEndpoint(FabricClient.MLWorkloadEndpointML)
519552
}
520553

core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPSchema.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,7 @@ case class HTTPRequestData(requestLine: RequestLineData,
199199
request.setProtocolVersion(pv.toHTTPCore))
200200
request.setHeaders(headers.map(_.toHTTPCore) ++
201201
Array(new BasicHeader(
202-
"User-Agent", s"synapseml/${BuildInfo.version}${HeaderValues.PlatformInfo}"),
203-
new BasicHeader(
204-
"x-ai-telemetry-properties", "{\"ClientArtifactType\": \"AIFunctionsSpark\"}"
205-
)))
202+
"User-Agent", s"synapseml/${BuildInfo.version}${HeaderValues.PlatformInfo}")))
206203
request
207204
}
208205
//scalastyle:on cyclomatic.complexity

0 commit comments

Comments
 (0)