Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
" .address(\"localhost\", 8898, \"my_api\") \\\n",
" .option(\"name\", \"my_api\") \\\n",
" .load() \\\n",
" .parseRequest(test.schema)\n",
" .parseRequest(\"my_api\", test.schema)\n",
"\n",
"serving_outputs = model.transform(serving_inputs) \\\n",
" .makeReply(\"scored_labels\")\n",
Expand Down
6 changes: 3 additions & 3 deletions src/io/http/src/main/python/ServingImplicits.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def _writeContServer(self):

setattr(pyspark.sql.streaming.DataStreamWriter, 'continuousServer', _writeContServer)

def _parseRequest(self,schema,
idCol="id",requestCol="request"):
def _parseRequest(self, apiName, schema,
idCol="id", requestCol="request", parsingCheck = "none"):
ctx = SparkContext.getOrCreate()
jvm = ctx._jvm
extended = jvm.com.microsoft.ml.spark.DataFrameServingExtensions(self._jdf)
dt = jvm.org.apache.spark.sql.types.DataType
jResult = extended.parseRequest(dt.fromJson(schema.json()), idCol, requestCol)
jResult = extended.parseRequest(apiName, dt.fromJson(schema.json()), idCol, requestCol, parsingCheck)
sql_ctx = pyspark.SQLContext.getOrCreate(ctx)
return DataFrame(jResult, sql_ctx)

Expand Down
70 changes: 65 additions & 5 deletions src/io/http/src/main/scala/ServingImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,81 @@ case class DataStreamWriterExtensions[T](dsw: DataStreamWriter[T]) {

case class DataFrameServingExtensions(df: DataFrame) {

def parseRequest(schema: DataType,
private def jsonParsingError(schema: DataType)(body: String): String = {
s"JSON Parsing error, expected schema:\n ${schema.simpleString}\n recieved:\n $body"
}

private def fullJsonParsingSuccess(a: Any): Boolean = {
a match {
case s: Seq[_] => s.forall(fullJsonParsingSuccess)
case a: Row => a.toSeq.forall(fullJsonParsingSuccess)
case null => false
case _ => true
}
}

/**
*
* @param apiName
* @param schema
* @param idCol
* @param requestCol
* @param parsingCheck "none": to accept all requests,
* "full": to ensure all fields and subfields are non-null,
* "partial": to ensure the root structure was parsed correctly
* @return
*/
def parseRequest(apiName: String,
schema: DataType,
idCol: String = "id",
requestCol: String = "request"): DataFrame = {
requestCol: String = "request",
parsingCheck: String = "none"): DataFrame = {
assert(df.schema(idCol).dataType == HTTPSourceV2.ID_SCHEMA &&
df.schema(requestCol).dataType == HTTPRequestData.schema)
schema match {
case BinaryType =>
df.select(col(idCol), col(requestCol).getItem("entity").getItem("content").alias("bytes"))
case _ =>
df.withColumn("variables", from_json(HTTPSchema.request_to_string(col(requestCol)), schema))
.select(idCol,"variables.*")
val parsedDf = df
.withColumn("body", HTTPSchema.request_to_string(col(requestCol)))
.withColumn("parsed", from_json(col("body"), schema))
if (parsingCheck.toLowerCase == "none") {
parsedDf.select(idCol, "parsed.*")
} else {
val successCol = parsingCheck.toLowerCase match {
case "full" => udf({x: Any => !fullJsonParsingSuccess(x)}, BooleanType)(col("parsed"))
case "partial" => col("parsed").isNull
case _ => throw new IllegalArgumentException(
s"Need to use either full, partial, or none. Received $parsingCheck")
}

val df1 = parsedDf
.withColumn("didReply",
when(successCol,
ServingUDFs.sendReplyUDF(
lit(apiName),
ServingUDFs.makeReplyUDF(
udf(jsonParsingError(schema) _, StringType)(col("body")),
StringType,
code = lit(400),
reason = lit("JSON Parsing Failure")),
col("id")
)
)
.otherwise(lit(null)))
.filter(col("didReply").isNull)

df1.withColumn("parsed", udf({ x: Row =>
println(x)
x
}, df1.schema("parsed").dataType)(col("parsed")))
.select(idCol, "parsed.*")

}
}
}

def makeReply(replyCol: String, name: String = "reply"): DataFrame ={
def makeReply(replyCol: String, name: String = "reply"): DataFrame = {
df.withColumn(name, ServingUDFs.makeReplyUDF(col(replyCol), df.schema(replyCol).dataType))
}

Expand Down
2 changes: 1 addition & 1 deletion src/io/http/src/test/scala/ContinuousHTTPSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ContinuousHTTPSuite extends TestBase with HTTPTestUtils {
.address(host, port, apiPath)
.option("name", apiName)
.load()
.parseRequest(BinaryType)
.parseRequest(apiName, BinaryType)
.withColumn("length", length(col("bytes")))
.makeReply("length")
.writeStream
Expand Down
9 changes: 5 additions & 4 deletions src/io/http/src/test/scala/DistributedHTTPSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trait HTTPTestUtils extends WithFreeUrl {
def sendStringRequest(client: CloseableHttpClient,
url: String = url,
payload: String = "foo",
accept400: Boolean = false): (String, Double) = {
targetCode: Int = 200): (String, Double) = {
val post = new HttpPost(url)
val e = new StringEntity(payload)
post.setEntity(e)
Expand All @@ -50,7 +50,8 @@ trait HTTPTestUtils extends WithFreeUrl {
val res = client.execute(post)
val t1 = System.nanoTime()

val out = if (accept400 && res.getStatusLine.getStatusCode == 400) {
assert(targetCode == res.getStatusLine.getStatusCode)
val out = if (targetCode == res.getStatusLine.getStatusCode && !targetCode.toString.startsWith("2")) {
null
} else {
new BasicResponseHandler().handleResponse(res)
Expand Down Expand Up @@ -242,7 +243,7 @@ class DistributedHTTPSuite extends TestBase with HTTPTestUtils {
.option("name", "foo")
.option("maxPartitions", 5)
.load()
.parseRequest(BinaryType)
.parseRequest(apiName, BinaryType)
.withColumn("length", length(col("bytes")))
.makeReply("length")
.writeStream
Expand Down Expand Up @@ -275,7 +276,7 @@ class DistributedHTTPSuite extends TestBase with HTTPTestUtils {
.option("maxPartitions", 5)
.option("name", "foo")
.load()
.parseRequest(BinaryType)
.parseRequest(apiName, BinaryType)
.withColumn("length", length(col("bytes")))
.makeReply("length")
.writeStream
Expand Down
61 changes: 57 additions & 4 deletions src/io/http/src/test/scala/HTTPv2Suite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,62 @@ class HTTPv2Suite extends TestBase with HTTPTestUtils {
}
}

test("can reply to bad requests immediately partial") {
val server = baseWrite(baseDF()
.parseRequest(apiName, new StructType().add("value", IntegerType), parsingCheck = "partial")
.makeReply("value"))
.start()

using(server) {
waitForServer(server)
val r1 = (1 to 10).map(i =>
sendStringRequest(client, payload = """{"value": 1}""")
)

val r2 = (1 to 10).map(i =>
sendStringRequest(client, payload = """{"valu111e": 1}""")
)

val r3 = (1 to 10).map(i =>
sendStringRequest(client, payload = """jskdfjkdhdjfdjkh""", targetCode = 400)
)
assertLatency(r1, 60)
assertLatency(r2, 60)
assertLatency(r3, 60)
r3.foreach(p => assert(Option(p._1).isEmpty))
}
}

test("can reply to bad requests immediately") {
val server = baseWrite(baseDF()
.parseRequest(apiName, new StructType().add("value", IntegerType), parsingCheck = "full")
.makeReply("value"))
.start()

using(server) {
waitForServer(server)
val r1 = (1 to 10).map(i =>
sendStringRequest(client, payload = """{"value": 1}""")
)

val r2 = (1 to 10).map(i =>
sendStringRequest(client, payload = """{"valu111e": 1}""", targetCode = 400)
)

val r3 = (1 to 10).map(i =>
sendStringRequest(client, payload = """jskdfjkdhdjfdjkh""", targetCode = 400)
)
assertLatency(r1, 60)
assertLatency(r2, 60)
assertLatency(r3, 60)

(r2 ++ r3).foreach(p => assert(Option(p._1).isEmpty))
}
}

test("can reply from the middle of the pipeline") {
val server = baseWrite(baseDF()
.parseRequest(new StructType().add("value", IntegerType))
.parseRequest(apiName, new StructType().add("value", IntegerType))
.withColumn("didReply",
when(col("value").isNull,
ServingUDFs.sendReplyUDF(
Expand All @@ -195,11 +248,11 @@ class HTTPv2Suite extends TestBase with HTTPTestUtils {
)

val r2 = (1 to 100).map(i =>
sendStringRequest(client, payload = """{"valu111e": 1}""", accept400 = true)
sendStringRequest(client, payload = """{"valu111e": 1}""", targetCode = 400)
)

val r3 = (1 to 100).map(i =>
sendStringRequest(client, payload = """jskdfjkdhdjfdjkh""", accept400 = true)
sendStringRequest(client, payload = """jskdfjkdhdjfdjkh""", targetCode = 400)
)
assertLatency(r1, 60)
assertLatency(r2, 60)
Expand Down Expand Up @@ -246,7 +299,7 @@ class HTTPv2Suite extends TestBase with HTTPTestUtils {
.map(i => (i, i.toString + "_foo"))
.toDF("key", "value").cache()

val df1 = baseDF(1).parseRequest(new StructType().add("data", IntegerType))
val df1 = baseDF(1).parseRequest(apiName, new StructType().add("data", IntegerType))

df1.printSchema()

Expand Down