Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions notebooks/samples/SparkServing - Deploying a Classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,13 @@
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql.functions import col, from_json\n",
"from pyspark.sql.types import *\n",
"import uuid\n",
"from mmlspark import request_to_string, string_to_response\n",
"\n",
"serving_inputs = spark.readStream.server() \\\n",
" .address(\"localhost\", 8898, \"my_api\") \\\n",
" .load()\\\n",
" .option(\"name\", \"my_api\") \\\n",
" .load() \\\n",
" .parseRequest(test.schema)\n",
"\n",
"serving_outputs = model.transform(serving_inputs) \\\n",
Expand Down
4 changes: 2 additions & 2 deletions src/core/test/base/src/main/scala/SparkSessionFactory.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ object SparkSessionFactory {
}
def currentDir(): String = System.getProperty("user.dir")

def getSession(name: String, logLevel: String = "WARN"): SparkSession = {
def getSession(name: String, logLevel: String = "WARN", numRetries: Int): SparkSession = {
val conf = new SparkConf()
.setAppName(name)
.setMaster("local[*]")
.setMaster(if (numRetries == 1){"local[*]"}else{s"local[*, $numRetries]"})
.set("spark.logConf", "true")
.set("spark.sql.warehouse.dir", SparkSessionFactory.localWarehousePath)
val sess = SparkSession.builder()
Expand Down
7 changes: 4 additions & 3 deletions src/core/test/base/src/main/scala/TestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ trait LinuxOnly extends TestBase {

abstract class TestBase extends FunSuite with BeforeAndAfterEachTestData with BeforeAndAfterAll {

println(s"\n>>>-------------------- $this --------------------<<<")

// "This Is A Bad Thing" according to my research. However, this is
// just for tests so maybe ok. A better design would be to break the
// session stuff into TestSparkSession as a trait and have test suites
// that need it "with TestSparkSession" instead, but that's a lot of
// changes right now and maybe not desired.

protected val numRetries = 1
protected val logLevel = "WARN"
private var sessionInitialized = false
protected lazy val session: SparkSession = {
info(s"Creating a spark session for suite $this")
sessionInitialized = true
SparkSessionFactory
.getSession(s"$this", logLevel = "WARN")
.getSession(s"$this", logLevel = logLevel, numRetries)
}

protected lazy val sc: SparkContext = session.sparkContext
Expand Down
4 changes: 2 additions & 2 deletions src/io/http/src/main/scala/DistributedHTTPSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class DistributedHTTPSource(name: String,
s.updateCurrentBatch(currentOffset.offset)
s.getRequests(startOrdinal, endOrdinal)
.map{ case (id, request) =>
Row.fromSeq(Seq(Row(id, null), toRow(request)))
Row.fromSeq(Seq(Row(null, id, null), toRow(request)))
}.toIterator
}(RowEncoder(HTTPSourceV2.SCHEMA))
}
Expand Down Expand Up @@ -434,7 +434,7 @@ class DistributedHTTPSink(val options: Map[String, String])

val irToResponseData = HTTPResponseData.makeFromInternalRowConverter
data.queryExecution.toRdd.map { ir =>
(ir.getStruct(idColIndex, 2).getString(0), irToResponseData(ir.getStruct(replyColIndex, 4)))
(ir.getStruct(idColIndex, 3).getString(1), irToResponseData(ir.getStruct(replyColIndex, 4)))
}.foreach { case (id, value) =>
server.get.respond(batchId, id, value)
}
Expand Down
79 changes: 55 additions & 24 deletions src/io/http/src/main/scala/HTTPSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@ package com.microsoft.ml.spark

import java.net.{SocketException, URI}

import com.microsoft.ml.spark.schema.SparkBindings
import com.microsoft.ml.spark.StreamUtilities.using
import com.microsoft.ml.spark.schema.SparkBindings
import com.sun.net.httpserver.HttpExchange
import org.apache.commons.io.IOUtils
import org.apache.http._
import org.apache.http.client.methods._
import org.apache.http.entity.{ByteArrayEntity, ContentType, StringEntity}
import org.apache.http.entity.{ByteArrayEntity, StringEntity}
import org.apache.http.message.BasicHeader
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.internal.{Logging => SLogging}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, struct, typedLit, udf}
import org.apache.spark.sql.functions.{col, struct, typedLit, udf, lit}
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.sql.{Column, Row}

import collection.JavaConverters._
import collection.JavaConversions._
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._

case class HeaderData(name: String, value: String) {

Expand Down Expand Up @@ -106,16 +106,29 @@ case class HTTPResponseData(headers: Array[HeaderData],
if (headersToAdd.nonEmpty) {
headersToAdd.foreach(h => responseHeaders.add(h.name, h.value))
}
request.sendResponseHeaders(statusLine.statusCode,
entity.flatMap(_.contentLength).getOrElse(0L))
entity.foreach(entity =>using(request.getResponseBody) {
_.write(entity.content)
}.get)
try {
request.sendResponseHeaders(statusLine.statusCode,
entity.flatMap(_.contentLength).getOrElse(0L))
} catch {
case e: java.io.IOException =>
HTTPResponseData.warn(s"Could not write headers: ${e.getMessage}")
}

try {
entity.foreach(entity => using(request.getResponseBody) {
_.write(entity.content)
}.get)
} catch {
case e: java.io.IOException =>
HTTPResponseData.warn(s"Could not send bytes: ${e.getMessage}")
}
}

}

object HTTPResponseData extends SparkBindings[HTTPResponseData]
object HTTPResponseData extends SparkBindings[HTTPResponseData] with SLogging {
def warn(msg: => String): Unit = logWarning(msg)
}

case class ProtocolVersionData(protocol: String, major: Int, minor: Int) {

Expand Down Expand Up @@ -265,29 +278,47 @@ object HTTPSchema {

def request_to_string(c: Column): Column = request_to_string_udf(c)

def stringToResponse(x: String): HTTPResponseData = {
def stringToResponse(x: String, code: Int, reason: String): HTTPResponseData = {
HTTPResponseData(
Array(),
Some(stringToEntity(x)),
StatusLineData(null, 200, "Success"),
StatusLineData(null, code, reason),
"en")
}

private val string_to_response_udf: UserDefinedFunction =
udf(stringToResponse _, HTTPResponseData.schema)

def string_to_response(c: Column): Column = string_to_response_udf(c)
def string_to_response(str: Column, code: Column = lit(200), reason: Column = lit("Success")): Column =
string_to_response_udf(str, code, reason)

def emptyResponse(code: Int, reason: String): HTTPResponseData = {
HTTPResponseData(
Array(),
None,
StatusLineData(null, code, reason),
"en")
}

private val empty_response_udf: UserDefinedFunction =
udf(emptyResponse _, HTTPResponseData.schema)

def empty_response(code: Column = lit(200), reason: Column = lit("Success")): Column =
empty_response_udf(code, reason)

def binaryToResponse(x: Array[Byte], code: Int, reason: String): HTTPResponseData = {
HTTPResponseData(
Array(),
Some(binaryToEntity(x)),
StatusLineData(null, code, reason),
"en")
}

private val binary_to_response_udf: UserDefinedFunction =
udf({ x: Array[Byte] =>
HTTPResponseData(
Array(),
Some(binaryToEntity(x)),
StatusLineData(null, 200, "Success"),
"en")
}, HTTPResponseData.schema)

def binary_to_response(c: Column): Column = binary_to_response_udf(c)
udf(binaryToResponse _, HTTPResponseData.schema)

def binary_to_response(ba: Column, code: Column = lit(200), reason: Column = lit("Success")): Column =
binary_to_response_udf(ba, code, reason)

def to_http_request(urlCol: Column, headersCol: Column, methodCol: Column, jsonEntityCol: Column): Column = {
val pvd: Option[ProtocolVersionData] = None
Expand Down
106 changes: 106 additions & 0 deletions src/io/http/src/main/scala/HTTPSinkV2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package org.apache.spark.sql.execution.streaming.continuous

import com.microsoft.ml.spark._
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types._

import scala.collection.mutable

class HTTPSinkProviderV2 extends DataSourceV2
with StreamWriteSupport
with DataSourceRegister {

override def createStreamWriter(queryId: String,
schema: StructType,
mode: OutputMode,
options: DataSourceOptions): StreamWriter = {
new HTTPWriter(schema, options)
}

def shortName(): String = "HTTPv2"
}

/** Common methods used to create writes for the the console sink */
class HTTPWriter(schema: StructType, options: DataSourceOptions)
extends StreamWriter with Logging {

protected val idCol: String = options.get("idCol").orElse("id")
protected val replyCol: String = options.get("replyCol").orElse("reply")
protected val name: String = options.get("name").get

val idColIndex: Int = schema.fieldIndex(idCol)
val replyColIndex: Int = schema.fieldIndex(replyCol)

assert(SparkSession.getActiveSession.isDefined)

def createWriterFactory(): DataWriterFactory[InternalRow] =
HTTPWriterFactory(idColIndex, replyColIndex, name)

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
HTTPSourceStateHolder.cleanUp(name)
}

}

private[streaming] case class HTTPWriterFactory(idColIndex: Int,
replyColIndex: Int,
name: String)
extends DataWriterFactory[InternalRow] {
def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = {
new HTTPDataWriter(partitionId, idColIndex, replyColIndex, name, epochId)
}
}

private[streaming] class HTTPDataWriter(partitionId: Int,
val idColIndex: Int,
val replyColIndex: Int,
val name: String,
epochId: Long)
extends DataWriter[InternalRow] with Logging {
logInfo(s"Creating writer on PID:$partitionId")
HTTPSourceStateHolder.getServer(name).commit(epochId - 1, partitionId)

private val ids: mutable.ListBuffer[(String, Int)] = new mutable.ListBuffer[(String, Int)]()

private val fromRow = HTTPResponseData.makeFromInternalRowConverter

override def write(row: InternalRow): Unit = {
val id = row.getStruct(idColIndex, 2)
val mid = id.getString(0)
val rid = id.getString(1)
val pid = id.getInt(2)
val reply = fromRow(row.getStruct(replyColIndex, 4))
HTTPSourceStateHolder.getServer(name).replyTo(mid, rid, reply)
ids.append((rid, pid))
}

override def commit(): HTTPCommitMessage = {
val msg = HTTPCommitMessage(ids.toArray)
ids.foreach { case (rid, pid) =>
HTTPSourceStateHolder.getServer(name).commit(rid)
}
ids.clear()
msg
}

override def abort(): Unit = {
if (TaskContext.get().getKillReason().contains("Stage cancelled")) {
HTTPSourceStateHolder.cleanUp(name)
}
}
}

private[streaming] case class HTTPCommitMessage(ids: Array[(String, Int)]) extends WriterCommitMessage
10 changes: 6 additions & 4 deletions src/io/http/src/main/scala/HTTPSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ class HTTPSource(name: String, host: String, port: Int, sqlContext: SQLContext)
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
requests.slice(sliceStart, sliceEnd).map{ case(id, request) =>
val row = new GenericInternalRow(2)
val idRow = new GenericInternalRow(2)
idRow.update(0, UTF8String.fromString(id.toString))
idRow.update(1, null)
val idRow = new GenericInternalRow(3)
idRow.update(0, null)
idRow.update(1, UTF8String.fromString(id.toString))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UTF8String.fromString(id.toString) looks expensive. what type is id?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its a long, for some reason Spark's internal row construction needs this UTF8 string instead of a regular string. The HTTP Source is mainly good for debugging though and will be deprecated for v2 stuff soon though

idRow.update(2, null)
row.update(0, idRow)
row.update(1, hrdToIr(HTTPRequestData.fromHTTPExchange(request)))
row.asInstanceOf[InternalRow]
Expand Down Expand Up @@ -199,8 +200,9 @@ class HTTPSink(val options: Map[String, String]) extends Sink with Logging {
assert(idType == HTTPSourceV2.ID_SCHEMA, s"id col is $idType, need ${HTTPSourceV2.ID_SCHEMA}")

val irToResponseData = HTTPResponseData.makeFromInternalRowConverter

val replies = data.queryExecution.toRdd.map { ir =>
(ir.getStruct(idColIndex, 2).getString(0), irToResponseData(ir.getStruct(replyColIndex, 4)))
(ir.getStruct(idColIndex, 3).getString(1), irToResponseData(ir.getStruct(replyColIndex, 4)))
// 4 is the Number of fields of HTTPResponseData,
// there does not seem to be a way to get this w/o reflection
}.collect()
Expand Down
Loading