diff --git a/dev/dependencyList b/dev/dependencyList
index 2035b95dea8..9813bb56288 100644
--- a/dev/dependencyList
+++ b/dev/dependencyList
@@ -69,7 +69,7 @@ jackson-annotations/2.14.2//jackson-annotations-2.14.2.jar
jackson-core/2.14.2//jackson-core-2.14.2.jar
jackson-databind/2.14.2//jackson-databind-2.14.2.jar
jackson-dataformat-yaml/2.14.2//jackson-dataformat-yaml-2.14.2.jar
-jackson-datatype-jdk8/2.12.3//jackson-datatype-jdk8-2.12.3.jar
+jackson-datatype-jdk8/2.14.2//jackson-datatype-jdk8-2.14.2.jar
jackson-datatype-jsr310/2.14.2//jackson-datatype-jsr310-2.14.2.jar
jackson-jaxrs-base/2.14.2//jackson-jaxrs-base-2.14.2.jar
jackson-jaxrs-json-provider/2.14.2//jackson-jaxrs-json-provider-2.14.2.jar
diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/AbstractBackendService.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/AbstractBackendService.scala
index e7c2d836573..b9e254508dd 100644
--- a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/AbstractBackendService.scala
+++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/AbstractBackendService.scala
@@ -156,11 +156,14 @@ abstract class AbstractBackendService(name: String)
queryId
}
- override def getOperationStatus(operationHandle: OperationHandle): OperationStatus = {
+ override def getOperationStatus(
+ operationHandle: OperationHandle,
+ maxWait: Option[Long]): OperationStatus = {
val operation = sessionManager.operationManager.getOperation(operationHandle)
if (operation.shouldRunAsync) {
try {
- operation.getBackgroundHandle.get(timeout, TimeUnit.MILLISECONDS)
+ val waitTime = maxWait.getOrElse(timeout)
+ operation.getBackgroundHandle.get(waitTime, TimeUnit.MILLISECONDS)
} catch {
case e: TimeoutException =>
debug(s"$operationHandle: Long polling timed out, ${e.getMessage}")
diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/BackendService.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/BackendService.scala
index e1841156664..968a94197d2 100644
--- a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/BackendService.scala
+++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/BackendService.scala
@@ -91,7 +91,9 @@ trait BackendService {
foreignTable: String): OperationHandle
def getQueryId(operationHandle: OperationHandle): String
- def getOperationStatus(operationHandle: OperationHandle): OperationStatus
+ def getOperationStatus(
+ operationHandle: OperationHandle,
+ maxWait: Option[Long] = None): OperationStatus
def cancelOperation(operationHandle: OperationHandle): Unit
def closeOperation(operationHandle: OperationHandle): Unit
def getResultSetMetadata(operationHandle: OperationHandle): TGetResultSetMetadataResp
diff --git a/kyuubi-server/pom.xml b/kyuubi-server/pom.xml
index 4dd89e0e62c..478d08e40c8 100644
--- a/kyuubi-server/pom.xml
+++ b/kyuubi-server/pom.xml
@@ -221,6 +221,16 @@
jersey-media-multipart
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+
+ com.fasterxml.jackson.datatype
+ jackson-datatype-jdk8
+
+
com.zaxxer
HikariCP
diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/BackendServiceMetric.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/BackendServiceMetric.scala
index d8b66416375..68bf11d7f99 100644
--- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/BackendServiceMetric.scala
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/BackendServiceMetric.scala
@@ -152,9 +152,11 @@ trait BackendServiceMetric extends BackendService {
}
}
- abstract override def getOperationStatus(operationHandle: OperationHandle): OperationStatus = {
+ abstract override def getOperationStatus(
+ operationHandle: OperationHandle,
+ maxWait: Option[Long] = None): OperationStatus = {
MetricsSystem.timerTracing(MetricsConstants.BS_GET_OPERATION_STATUS) {
- super.getOperationStatus(operationHandle)
+ super.getOperationStatus(operationHandle, maxWait)
}
}
diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiTrinoOperationTranslator.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiTrinoOperationTranslator.scala
index 6ec9fc1c80e..5eba9c32777 100644
--- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiTrinoOperationTranslator.scala
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiTrinoOperationTranslator.scala
@@ -19,10 +19,9 @@ package org.apache.kyuubi.server.trino.api
import scala.collection.JavaConverters._
-import org.apache.hive.service.rpc.thrift.TProtocolVersion
-
import org.apache.kyuubi.operation.OperationHandle
import org.apache.kyuubi.service.BackendService
+import org.apache.kyuubi.session.SessionHandle
import org.apache.kyuubi.sql.parser.trino.KyuubiTrinoFeParser
import org.apache.kyuubi.sql.plan.PassThroughNode
import org.apache.kyuubi.sql.plan.trino.{GetCatalogs, GetColumns, GetSchemas, GetTables, GetTableTypes, GetTypeInfo}
@@ -32,17 +31,10 @@ class KyuubiTrinoOperationTranslator(backendService: BackendService) {
def transform(
statement: String,
- user: String,
- ipAddress: String,
+ sessionHandle: SessionHandle,
configs: Map[String, String],
runAsync: Boolean,
queryTimeout: Long): OperationHandle = {
- val sessionHandle = backendService.openSession(
- TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
- user,
- "",
- ipAddress,
- configs)
parser.parsePlan(statement) match {
case GetSchemas(catalogName, schemaPattern) =>
backendService.getSchemas(sessionHandle, catalogName, schemaPattern)
diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/Query.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/Query.scala
new file mode 100644
index 00000000000..c8c9e2fd6c1
--- /dev/null
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/Query.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.server.trino.api
+
+import java.net.URI
+import java.security.SecureRandom
+import java.util.Objects.requireNonNull
+import java.util.UUID
+import java.util.concurrent.atomic.AtomicLong
+import javax.ws.rs.WebApplicationException
+import javax.ws.rs.core.{Response, UriInfo}
+
+import Slug.Context.{EXECUTING_QUERY, QUEUED_QUERY}
+import com.google.common.hash.Hashing
+import io.trino.client.QueryResults
+import org.apache.hive.service.rpc.thrift.TProtocolVersion
+
+import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle}
+import org.apache.kyuubi.operation.OperationState.{FINISHED, INITIALIZED, OperationState, PENDING}
+import org.apache.kyuubi.service.BackendService
+import org.apache.kyuubi.session.SessionHandle
+
+case class Query(
+ queryId: QueryId,
+ context: TrinoContext,
+ be: BackendService) {
+
+ private val QUEUED_QUERY_PATH = "/v1/statement/queued/"
+ private val EXECUTING_QUERY_PATH = "/v1/statement/executing"
+
+ private val slug: Slug = Slug.createNewWithUUID(queryId.getQueryId)
+ private val lastToken = new AtomicLong
+
+ private val defaultMaxRows = 1000
+ private val defaultFetchOrientation = FetchOrientation.withName("FETCH_NEXT")
+
+ def getQueryResults(token: Long, uriInfo: UriInfo, maxWait: Long = 0): QueryResults = {
+ val status =
+ be.getOperationStatus(queryId.operationHandle, Some(maxWait))
+ val nextUri = if (status.exception.isEmpty) {
+ getNextUri(token + 1, uriInfo, toSlugContext(status.state))
+ } else null
+ val queryHtmlUri = uriInfo.getRequestUriBuilder
+ .replacePath("ui/query.html").replaceQuery(queryId.getQueryId).build()
+
+ status.state match {
+ case FINISHED =>
+ val metaData = be.getResultSetMetadata(queryId.operationHandle)
+ val resultSet = be.fetchResults(
+ queryId.operationHandle,
+ defaultFetchOrientation,
+ defaultMaxRows,
+ false)
+ TrinoContext.createQueryResults(
+ queryId.getQueryId,
+ nextUri,
+ queryHtmlUri,
+ status,
+ Option(metaData),
+ Option(resultSet))
+ case _ =>
+ TrinoContext.createQueryResults(
+ queryId.getQueryId,
+ nextUri,
+ queryHtmlUri,
+ status)
+ }
+ }
+
+ def getLastToken: Long = this.lastToken.get()
+
+ def getSlug: Slug = this.slug
+
+ def cancel: Unit = clear
+
+ private def clear = {
+ be.closeOperation(queryId.operationHandle)
+ context.session.get("sessionId").foreach { id =>
+ be.closeSession(SessionHandle.fromUUID(id))
+ }
+ }
+
+ private def setToken(token: Long): Unit = {
+ val lastToken = this.lastToken.get
+ if (token != lastToken && token != lastToken + 1) {
+ throw new WebApplicationException(Response.Status.GONE)
+ }
+ this.lastToken.compareAndSet(lastToken, token)
+ }
+
+ private def getNextUri(token: Long, uriInfo: UriInfo, slugContext: Slug.Context.Context): URI = {
+ val path = slugContext match {
+ case QUEUED_QUERY => QUEUED_QUERY_PATH
+ case EXECUTING_QUERY => EXECUTING_QUERY_PATH
+ }
+
+ uriInfo.getBaseUriBuilder.replacePath(path)
+ .path(queryId.getQueryId)
+ .path(slug.makeSlug(slugContext, token))
+ .path(String.valueOf(token))
+ .replaceQuery("")
+ .build()
+ }
+
+ private def toSlugContext(state: OperationState): Slug.Context.Context = {
+ state match {
+ case INITIALIZED | PENDING => Slug.Context.QUEUED_QUERY
+ case _ => Slug.Context.EXECUTING_QUERY
+ }
+ }
+
+}
+
+object Query {
+
+ def apply(
+ statement: String,
+ context: TrinoContext,
+ translator: KyuubiTrinoOperationTranslator,
+ backendService: BackendService,
+ queryTimeout: Long = 0): Query = {
+
+ val sessionHandle = createSession(context, backendService)
+ val operationHandle = translator.transform(
+ statement,
+ sessionHandle,
+ context.session,
+ true,
+ queryTimeout)
+ val newSessionProperties =
+ context.session + ("sessionId" -> sessionHandle.identifier.toString)
+ val updatedContext = context.copy(session = newSessionProperties)
+ Query(QueryId(operationHandle), updatedContext, backendService)
+ }
+
+ def apply(id: String, context: TrinoContext, backendService: BackendService): Query = {
+ Query(QueryId(id), context, backendService)
+ }
+
+ private def createSession(
+ context: TrinoContext,
+ backendService: BackendService): SessionHandle = {
+ backendService.openSession(
+ TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
+ context.user,
+ "",
+ context.remoteUserAddress.getOrElse(""),
+ context.session)
+ }
+
+}
+
+case class QueryId(operationHandle: OperationHandle) {
+ def getQueryId: String = operationHandle.identifier.toString
+}
+
+object QueryId {
+ def apply(id: String): QueryId = QueryId(OperationHandle(id))
+}
+
+object Slug {
+
+ object Context extends Enumeration {
+ type Context = Value
+ val QUEUED_QUERY, EXECUTING_QUERY = Value
+ }
+
+ private val RANDOM = new SecureRandom
+
+ def createNew: Slug = {
+ val randomBytes = new Array[Byte](16)
+ RANDOM.nextBytes(randomBytes)
+ new Slug(randomBytes)
+ }
+
+ def createNewWithUUID(uuid: String): Slug = {
+ val uuidBytes = UUID.fromString(uuid).toString.getBytes("UTF-8")
+ new Slug(uuidBytes)
+ }
+}
+
+case class Slug(slugKey: Array[Byte]) {
+ val hmac = Hashing.hmacSha1(requireNonNull(slugKey, "slugKey is null"))
+
+ def makeSlug(context: Slug.Context.Context, token: Long): String = {
+ "y" + hmac.newHasher.putInt(context.id).putLong(token).hash.toString
+ }
+
+ def isValid(context: Slug.Context.Context, slug: String, token: Long): Boolean =
+ makeSlug(context, token) == slug
+}
diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
index 4a0736ddb89..9e77139040a 100644
--- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
@@ -28,6 +28,7 @@ import io.trino.client.{ClientStandardTypes, ClientTypeSignature, Column, QueryE
import io.trino.client.ProtocolHeaders.TRINO_HEADERS
import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, TTypeId}
+import org.apache.kyuubi.operation.OperationState.FINISHED
import org.apache.kyuubi.operation.OperationStatus
/**
@@ -58,6 +59,7 @@ case class TrinoContext(
source: Option[String] = None,
catalog: Option[String] = None,
schema: Option[String] = None,
+ remoteUserAddress: Option[String] = None,
language: Option[String] = None,
traceToken: Option[String] = None,
clientInfo: Option[String] = None,
@@ -72,10 +74,11 @@ object TrinoContext {
private val GENERIC_INTERNAL_ERROR_NAME = "GENERIC_INTERNAL_ERROR_NAME"
private val GENERIC_INTERNAL_ERROR_TYPE = "INTERNAL_ERROR"
- def apply(headers: HttpHeaders): TrinoContext = {
- apply(headers.getRequestHeaders.asScala.toMap.map {
+ def apply(headers: HttpHeaders, remoteAddress: Option[String]): TrinoContext = {
+ val context = apply(headers.getRequestHeaders.asScala.toMap.map {
case (k, v) => (k, v.asScala.toList)
})
+ context.copy(remoteUserAddress = remoteAddress)
}
def apply(headers: Map[String, List[String]]): TrinoContext = {
@@ -134,7 +137,6 @@ object TrinoContext {
}
}
- // TODO: Building response with TrinoContext and other information
def buildTrinoResponse(qr: QueryResults, trinoContext: TrinoContext): Response = {
val responseBuilder = Response.ok(qr)
@@ -156,8 +158,6 @@ object TrinoContext {
responseBuilder.header(TRINO_HEADERS.responseDeallocatedPrepare, urlEncode(v))
}
- responseBuilder.header(TRINO_HEADERS.responseClearSession, s"responseClearSession")
- responseBuilder.header(TRINO_HEADERS.responseClearTransactionId, "false")
responseBuilder.build()
}
@@ -192,11 +192,16 @@ object TrinoContext {
case None => null
}
+ val updatedNextUri = queryStatus.state match {
+ case FINISHED if rowList == null || rowList.isEmpty || rowList.get(0).isEmpty => null
+ case _ => nextUri
+ }
+
new QueryResults(
queryId,
queryHtmlUri,
nextUri,
- nextUri,
+ updatedNextUri,
columnList,
rowList,
StatementStats.builder.setState(queryStatus.state.name()).setQueued(false)
diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiScalaObjectMapper.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoScalaObjectMapper.scala
similarity index 73%
rename from kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiScalaObjectMapper.scala
rename to kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoScalaObjectMapper.scala
index 915b109b7b9..f6055927ac2 100644
--- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiScalaObjectMapper.scala
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoScalaObjectMapper.scala
@@ -19,11 +19,14 @@ package org.apache.kyuubi.server.trino.api
import javax.ws.rs.ext.ContextResolver
-import com.fasterxml.jackson.databind.ObjectMapper
-import com.fasterxml.jackson.module.scala.DefaultScalaModule
+import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
+import com.fasterxml.jackson.datatype.jdk8.Jdk8Module
-class KyuubiScalaObjectMapper extends ContextResolver[ObjectMapper] {
- private val mapper = new ObjectMapper().registerModule(DefaultScalaModule)
+class TrinoScalaObjectMapper extends ContextResolver[ObjectMapper] {
+
+ private lazy val mapper = new ObjectMapper()
+ .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
+ .registerModule(new Jdk8Module)
override def getContext(aClass: Class[_]): ObjectMapper = mapper
}
diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoServerConfig.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoServerConfig.scala
index d1f7de336ba..298e60c9cac 100644
--- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoServerConfig.scala
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoServerConfig.scala
@@ -21,6 +21,6 @@ import org.glassfish.jersey.server.ResourceConfig
class TrinoServerConfig extends ResourceConfig {
packages("org.apache.kyuubi.server.trino.api.v1")
- register(classOf[KyuubiScalaObjectMapper])
+ register(classOf[TrinoScalaObjectMapper])
register(classOf[RestExceptionMapper])
}
diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
index 122b39ccd8c..ab783f8acce 100644
--- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
@@ -18,16 +18,24 @@
package org.apache.kyuubi.server.trino.api.v1
import javax.ws.rs._
-import javax.ws.rs.core.{Context, HttpHeaders, MediaType}
+import javax.ws.rs.core.{Context, HttpHeaders, MediaType, Response, UriInfo}
+import javax.ws.rs.core.MediaType.TEXT_PLAIN_TYPE
+import javax.ws.rs.core.Response.Status.{BAD_REQUEST, NOT_FOUND}
+import scala.util.Try
+import scala.util.control.NonFatal
+
+import io.airlift.units.Duration
import io.swagger.v3.oas.annotations.media.{Content, Schema}
import io.swagger.v3.oas.annotations.responses.ApiResponse
import io.swagger.v3.oas.annotations.tags.Tag
import io.trino.client.QueryResults
import org.apache.kyuubi.Logging
-import org.apache.kyuubi.server.trino.api.{ApiRequestContext, KyuubiTrinoOperationTranslator}
+import org.apache.kyuubi.server.trino.api.{ApiRequestContext, KyuubiTrinoOperationTranslator, Query, QueryId, Slug, TrinoContext}
+import org.apache.kyuubi.server.trino.api.Slug.Context.{EXECUTING_QUERY, QUEUED_QUERY}
import org.apache.kyuubi.server.trino.api.v1.dto.Ok
+import org.apache.kyuubi.service.BackendService
@Tag(name = "Statement")
@Produces(Array(MediaType.APPLICATION_JSON))
@@ -50,11 +58,32 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
schema = new Schema(implementation = classOf[QueryResults]))),
description =
"Create a query")
- @GET
+ @POST
@Path("/")
@Consumes(Array(MediaType.TEXT_PLAIN))
- def query(statement: String, @Context headers: HttpHeaders): QueryResults = {
- throw new UnsupportedOperationException
+ def query(
+ statement: String,
+ @Context headers: HttpHeaders,
+ @Context uriInfo: UriInfo): Response = {
+ if (statement == null || statement.isEmpty) {
+ throw badRequest(BAD_REQUEST, "SQL statement is empty")
+ }
+
+ val remoteAddr = Option(httpRequest.getRemoteAddr)
+ val trinoContext = TrinoContext(headers, remoteAddr)
+
+ try {
+ val query = Query(statement, trinoContext, translator, fe.be)
+ val qr = query.getQueryResults(query.getLastToken, uriInfo)
+ TrinoContext.buildTrinoResponse(qr, query.context)
+ } catch {
+ case e: Exception =>
+ val errorMsg =
+ s"Error submitting sql"
+ e.printStackTrace()
+ error(errorMsg, e)
+ throw badRequest(BAD_REQUEST, errorMsg)
+ }
}
@ApiResponse(
@@ -65,11 +94,31 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@GET
@Path("/queued/{queryId}/{slug}/{token}")
def getQueuedStatementStatus(
- @Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
- @PathParam("token") token: Long): QueryResults = {
- throw new UnsupportedOperationException
+ @PathParam("token") token: Long,
+ @QueryParam("maxWait") maxWait: Duration,
+ @Context headers: HttpHeaders,
+ @Context uriInfo: UriInfo): Response = {
+
+ val remoteAddr = Option(httpRequest.getRemoteAddr)
+ val trinoContext = TrinoContext(headers, remoteAddr)
+ val waitTime = if (maxWait == null) 0 else maxWait.toMillis
+ getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, QUEUED_QUERY)
+ .flatMap(query =>
+ Try(TrinoContext.buildTrinoResponse(
+ query.getQueryResults(
+ token,
+ uriInfo,
+ waitTime),
+ query.context)))
+ .recover {
+ case NonFatal(e) =>
+ val errorMsg =
+ s"Error executing for query id $queryId"
+ error(errorMsg, e)
+ throw badRequest(NOT_FOUND, "Query not found")
+ }.get
}
@ApiResponse(
@@ -80,11 +129,28 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@GET
@Path("/executing/{queryId}/{slug}/{token}")
def getExecutingStatementStatus(
- @Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
- @PathParam("token") token: Long): QueryResults = {
- throw new UnsupportedOperationException
+ @PathParam("token") token: Long,
+ @QueryParam("maxWait") maxWait: Duration,
+ @Context headers: HttpHeaders,
+ @Context uriInfo: UriInfo): Response = {
+
+ val remoteAddr = Option(httpRequest.getRemoteAddr)
+ val trinoContext = TrinoContext(headers, remoteAddr)
+ val waitTime = if (maxWait == null) 0 else maxWait.toMillis
+ getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, EXECUTING_QUERY)
+ .flatMap(query =>
+ Try(TrinoContext.buildTrinoResponse(
+ query.getQueryResults(token, uriInfo, waitTime),
+ query.context)))
+ .recover {
+ case NonFatal(e) =>
+ val errorMsg =
+ s"Error executing for query id $queryId"
+ error(errorMsg, e)
+ throw badRequest(NOT_FOUND, "Query not found")
+ }.get
}
@ApiResponse(
@@ -95,11 +161,23 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@DELETE
@Path("/queued/{queryId}/{slug}/{token}")
def cancelQueuedStatement(
- @Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
- @PathParam("token") token: Long): QueryResults = {
- throw new UnsupportedOperationException
+ @PathParam("token") token: Long,
+ @Context headers: HttpHeaders): Response = {
+
+ val remoteAddr = Option(httpRequest.getRemoteAddr)
+ val trinoContext = TrinoContext(headers, remoteAddr)
+ getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, QUEUED_QUERY)
+ .flatMap(query => Try(query.cancel))
+ .recover {
+ case NonFatal(e) =>
+ val errorMsg =
+ s"Error executing for query id $queryId"
+ error(errorMsg, e)
+ throw badRequest(NOT_FOUND, "Query not found")
+ }.get
+ Response.noContent.build
}
@ApiResponse(
@@ -110,11 +188,44 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@DELETE
@Path("/executing/{queryId}/{slug}/{token}")
def cancelExecutingStatementStatus(
- @Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
- @PathParam("token") token: Long): QueryResults = {
- throw new UnsupportedOperationException
+ @PathParam("token") token: Long,
+ @Context headers: HttpHeaders): Response = {
+
+ val remoteAddr = Option(httpRequest.getRemoteAddr)
+ val trinoContext = TrinoContext(headers, remoteAddr)
+ getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, EXECUTING_QUERY)
+ .flatMap(query => Try(query.cancel))
+ .recover {
+ case NonFatal(e) =>
+ val errorMsg =
+ s"Error executing for query id $queryId"
+ error(errorMsg, e)
+ throw badRequest(NOT_FOUND, "Query not found")
+ }.get
+
+ Response.noContent.build
}
+ private def getQuery(
+ be: BackendService,
+ context: TrinoContext,
+ queryId: QueryId,
+ slug: String,
+ token: Long,
+ slugContext: Slug.Context.Context): Try[Query] = {
+
+ Try(be.sessionManager.operationManager.getOperation(queryId.operationHandle)).map { _ =>
+ Query(queryId, context, be)
+ }.filter(_.getSlug.isValid(slugContext, slug, token))
+ }
+
+ private def badRequest(status: Response.Status, message: String) =
+ new WebApplicationException(
+ Response.status(status)
+ .`type`(TEXT_PLAIN_TYPE)
+ .entity(message)
+ .build)
+
}
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/RestFrontendTestHelper.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/RestFrontendTestHelper.scala
index b22783771ec..fafdcf4a7b1 100644
--- a/kyuubi-server/src/test/scala/org/apache/kyuubi/RestFrontendTestHelper.scala
+++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/RestFrontendTestHelper.scala
@@ -37,7 +37,7 @@ import org.apache.kyuubi.service.AbstractFrontendService
object RestFrontendTestHelper {
- private class RestApiBaseSuite extends JerseyTest {
+ class RestApiBaseSuite extends JerseyTest {
override def configure: Application = new ResourceConfig(getClass)
.register(classOf[MultiPartFeature])
@@ -58,7 +58,7 @@ trait RestFrontendTestHelper extends WithKyuubiServer {
override protected val frontendProtocols: Seq[FrontendProtocol] =
FrontendProtocols.REST :: Nil
- private val restApiBaseSuite = new RestApiBaseSuite
+ protected val restApiBaseSuite: JerseyTest = new RestApiBaseSuite
override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala
deleted file mode 100644
index c0b3949f4ee..00000000000
--- a/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.kyuubi
-
-import java.net.URI
-import java.time.ZoneId
-import java.util.{Locale, Optional}
-import java.util.concurrent.TimeUnit
-
-import scala.collection.JavaConverters._
-
-import io.airlift.units.Duration
-import io.trino.client.{ClientSelectedRole, ClientSession, StatementClient, StatementClientFactory}
-import okhttp3.OkHttpClient
-
-trait TrinoClientTestHelper extends RestFrontendTestHelper {
-
- override def afterAll(): Unit = {
- super.afterAll()
- }
-
- private val httpClient = new OkHttpClient.Builder().build()
-
- protected val clientSession = createClientSession(baseUri: URI)
-
- def getTrinoStatementClient(sql: String): StatementClient = {
- StatementClientFactory.newStatementClient(httpClient, clientSession, sql)
- }
-
- def createClientSession(connectUrl: URI): ClientSession = {
- new ClientSession(
- connectUrl,
- "kyuubi_test",
- Optional.of("test_user"),
- "kyuubi",
- Optional.of("test_token_tracing"),
- Set[String]().asJava,
- "test_client_info",
- "test_catalog",
- "test_schema",
- "test_path",
- ZoneId.systemDefault(),
- Locale.getDefault,
- Map[String, String](
- "test_resource_key0" -> "test_resource_value0",
- "test_resource_key1" -> "test_resource_value1").asJava,
- Map[String, String](
- "test_property_key0" -> "test_property_value0",
- "test_property_key1" -> "test_propert_value1").asJava,
- Map[String, String](
- "test_statement_key0" -> "select 1",
- "test_statement_key1" -> "select 2").asJava,
- Map[String, ClientSelectedRole](
- "test_role_key0" -> ClientSelectedRole.valueOf("ROLE"),
- "test_role_key2" -> ClientSelectedRole.valueOf("ALL")).asJava,
- Map[String, String](
- "test_credentials_key0" -> "test_credentials_value0",
- "test_credentials_key1" -> "test_credentials_value1").asJava,
- "test_transaction_id",
- new Duration(2, TimeUnit.MINUTES),
- true)
-
- }
-
-}
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoRestFrontendTestHelper.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoRestFrontendTestHelper.scala
new file mode 100644
index 00000000000..1ff00e64fa2
--- /dev/null
+++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoRestFrontendTestHelper.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi
+
+import org.glassfish.jersey.client.ClientConfig
+import org.glassfish.jersey.test.JerseyTest
+
+import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols
+import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols.FrontendProtocol
+import org.apache.kyuubi.server.trino.api.TrinoScalaObjectMapper
+
+trait TrinoRestFrontendTestHelper extends RestFrontendTestHelper {
+
+ private class TrinoRestBaseSuite extends RestFrontendTestHelper.RestApiBaseSuite {
+ override def configureClient(config: ClientConfig): Unit = {
+ config.register(classOf[TrinoScalaObjectMapper])
+ }
+ }
+
+ override protected val frontendProtocols: Seq[FrontendProtocol] =
+ FrontendProtocols.TRINO :: Nil
+
+ override protected val restApiBaseSuite: JerseyTest = new TrinoRestBaseSuite
+
+}
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoClientApiSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoClientApiSuite.scala
new file mode 100644
index 00000000000..c88b5c9409f
--- /dev/null
+++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoClientApiSuite.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.server.trino.api
+
+import java.net.URI
+import java.time.ZoneId
+import java.util.{Collections, Locale, Optional}
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicReference
+
+import scala.annotation.tailrec
+import scala.collection.JavaConverters._
+
+import com.google.common.base.Verify
+import io.airlift.units.Duration
+import io.trino.client.{ClientSession, StatementClient, StatementClientFactory}
+import okhttp3.OkHttpClient
+
+import org.apache.kyuubi.{KyuubiFunSuite, KyuubiSQLException, TrinoRestFrontendTestHelper}
+
+class TrinoClientApiSuite extends KyuubiFunSuite with TrinoRestFrontendTestHelper {
+
+ private val httpClient =
+ new OkHttpClient.Builder()
+ .readTimeout(5, TimeUnit.MINUTES)
+ .build()
+ private lazy val clientSession =
+ new AtomicReference[ClientSession](createTestClientSession(baseUri))
+
+ test("submit query with trino client api") {
+ val trino = getTrinoStatementClient("select 1")
+ val result = execute(trino)
+ val sessionId = trino.getSetSessionProperties.asScala.get("sessionId")
+ assert(result == List(List(1)))
+
+ updateClientSession(trino)
+
+ val trino1 = getTrinoStatementClient("select 2")
+ val result1 = execute(trino1)
+ val sessionId1 = trino1.getSetSessionProperties.asScala.get("sessionId")
+ assert(result1 == List(List(2)))
+ assert(sessionId != sessionId1)
+
+ trino.close()
+ }
+
+ private def updateClientSession(trino: StatementClient): Unit = {
+ val session = clientSession.get
+
+ var builder = ClientSession.builder(session)
+ // update catalog and schema
+ if (trino.getSetCatalog.isPresent || trino.getSetSchema.isPresent) {
+ builder = builder
+ .withCatalog(trino.getSetCatalog.orElse(session.getCatalog))
+ .withSchema(trino.getSetSchema.orElse(session.getSchema))
+ }
+
+ // update path if present
+ if (trino.getSetPath.isPresent) {
+ builder = builder.withPath(trino.getSetPath.get)
+ }
+
+ // update session properties if present
+ if (!trino.getSetSessionProperties.isEmpty || !trino.getResetSessionProperties.isEmpty) {
+ val properties = session.getProperties.asScala.clone()
+ properties ++= trino.getSetSessionProperties.asScala
+ properties --= trino.getResetSessionProperties.asScala
+ builder = builder.withProperties(properties.asJava)
+ }
+ clientSession.set(builder.build())
+ }
+
+ private def execute(trino: StatementClient): List[List[Any]] = {
+ @tailrec
+ def getData(trino: StatementClient): (Boolean, List[List[Any]]) = {
+ if (trino.isRunning) {
+ val data = trino.currentData().getData()
+ trino.advance()
+ if (data != null) {
+ (true, data.asScala.toList.map(_.asScala.toList))
+ } else {
+ getData(trino)
+ }
+ } else {
+ Verify.verify(trino.isFinished)
+ val finalStatus = trino.finalStatusInfo()
+ if (finalStatus.getError() != null) {
+ throw KyuubiSQLException(
+ s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}")
+ }
+ (false, List[List[Any]]())
+ }
+ }
+ Iterator.continually(getData(trino)).takeWhile(_._1).flatMap(_._2).toList
+ }
+
+ private def getTrinoStatementClient(sql: String): StatementClient = {
+ StatementClientFactory.newStatementClient(httpClient, clientSession.get, sql)
+ }
+
+ private def createTestClientSession(connectUrl: URI): ClientSession = {
+ new ClientSession(
+ connectUrl,
+ "kyuubi_test",
+ Optional.of("test_user"),
+ "kyuubi",
+ Optional.of("test_token_tracing"),
+ Set[String]().asJava,
+ "test_client_info",
+ "test_catalog",
+ "test_schema",
+ null,
+ ZoneId.systemDefault(),
+ Locale.getDefault,
+ Collections.emptyMap(),
+ Map[String, String](
+ "test_property_key0" -> "test_property_value0",
+ "test_property_key1" -> "test_propert_value1").asJava,
+ Map[String, String](
+ "test_statement_key0" -> "select 1",
+ "test_statement_key1" -> "select 2").asJava,
+ Collections.emptyMap(),
+ Collections.emptyMap(),
+ null,
+ new Duration(2, TimeUnit.MINUTES),
+ true)
+
+ }
+
+}
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala
index 8d7b2bf2ccf..87c8eda968a 100644
--- a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala
+++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala
@@ -85,7 +85,7 @@ class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper {
val metadataResp = fe.be.getResultSetMetadata(opHandle)
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
- val status = fe.be.getOperationStatus(opHandle)
+ val status = fe.be.getOperationStatus(opHandle, Some(0))
val uri = new URI("sfdsfsdfdsf")
val results = TrinoContext
@@ -112,7 +112,7 @@ class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper {
val metadataResp = fe.be.getResultSetMetadata(opHandle)
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
- val status = fe.be.getOperationStatus(opHandle)
+ val status = fe.be.getOperationStatus(opHandle, Some(0))
val uri = new URI("sfdsfsdfdsf")
val results = TrinoContext
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala
index b60c7c67aa2..adbf389c931 100644
--- a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala
+++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala
@@ -17,15 +17,27 @@
package org.apache.kyuubi.server.trino.api.v1
-import org.apache.kyuubi.{KyuubiFunSuite, RestFrontendTestHelper}
-import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols
-import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols.FrontendProtocol
+import javax.ws.rs.client.Entity
+import javax.ws.rs.core.{MediaType, Response}
+
+import scala.collection.JavaConverters._
+
+import io.trino.client.{QueryError, QueryResults}
+import io.trino.client.ProtocolHeaders.TRINO_HEADERS
+
+import org.apache.kyuubi.{KyuubiFunSuite, KyuubiSQLException, TrinoRestFrontendTestHelper}
+import org.apache.kyuubi.operation.{OperationHandle, OperationState}
+import org.apache.kyuubi.server.trino.api.TrinoContext
import org.apache.kyuubi.server.trino.api.v1.dto.Ok
+import org.apache.kyuubi.session.SessionHandle
-class StatementResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper {
+class StatementResourceSuite extends KyuubiFunSuite with TrinoRestFrontendTestHelper {
- override protected val frontendProtocols: Seq[FrontendProtocol] =
- FrontendProtocols.TRINO :: Nil
+ case class TrinoResponse(
+ response: Option[Response] = None,
+ queryError: Option[QueryError] = None,
+ data: List[List[Any]] = List[List[Any]](),
+ isEnd: Boolean = false)
test("statement test") {
val response = webTarget.path("v1/statement/test").request().get()
@@ -33,4 +45,74 @@ class StatementResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper
assert(result == new Ok("trino server is running"))
}
+ test("statement submit for query error") {
+
+ val response = webTarget.path("v1/statement")
+ .request().post(Entity.entity("select a", MediaType.TEXT_PLAIN_TYPE))
+
+ val trinoResponseIter = Iterator.iterate(TrinoResponse(response = Option(response)))(getData)
+ val isErr = trinoResponseIter.takeWhile(_.isEnd == false).exists { t =>
+ t.queryError != None && t.response == None
+ }
+ assert(isErr == true)
+ }
+
+ test("statement submit and get result") {
+ val response = webTarget.path("v1/statement")
+ .request().post(Entity.entity("select 1", MediaType.TEXT_PLAIN_TYPE))
+
+ val trinoResponseIter = Iterator.iterate(TrinoResponse(response = Option(response)))(getData)
+ val dataSet = trinoResponseIter
+ .takeWhile(_.isEnd == false)
+ .map(_.data)
+ .flatten.toList
+ assert(dataSet == List(List(1)))
+ }
+
+ test("query cancel") {
+ val response = webTarget.path("v1/statement")
+ .request().post(Entity.entity("select 1", MediaType.TEXT_PLAIN_TYPE))
+ val qr = response.readEntity(classOf[QueryResults])
+ val sessionManager = fe.be.sessionManager
+ val sessionHandle =
+ response.getStringHeaders.get(TRINO_HEADERS.responseSetSession).asScala
+ .map(_.split("="))
+ .find {
+ case Array("sessionId", _) => true
+ }
+ .map {
+ case Array(_, value) => SessionHandle.fromUUID(TrinoContext.urlDecode(value))
+ }.get
+ sessionManager.getSession(sessionHandle)
+ val operationHandle = OperationHandle(qr.getId)
+ val operation = sessionManager.operationManager.getOperation(operationHandle)
+ assert(response.getStatus == 200)
+ val path = qr.getNextUri.getPath
+ val nextResponse = webTarget.path(path).request().header(
+ TRINO_HEADERS.requestSession(),
+ s"sessionId=${TrinoContext.urlEncode(sessionHandle.identifier.toString)}").delete()
+ assert(nextResponse.getStatus == 204)
+ assert(operation.getStatus.state == OperationState.CLOSED)
+ val exception = intercept[KyuubiSQLException](sessionManager.getSession(sessionHandle))
+ assert(exception.getMessage === s"Invalid $sessionHandle")
+
+ }
+
+ private def getData(current: TrinoResponse): TrinoResponse = {
+ current.response.map { response =>
+ assert(response.getStatus == 200)
+ val qr = response.readEntity(classOf[QueryResults])
+ val nextData = Option(qr.getData)
+ .map(_.asScala.toList.map(_.asScala.toList))
+ .getOrElse(List[List[Any]]())
+ val nextResponse = Option(qr.getNextUri).map {
+ uri =>
+ val path = uri.getPath
+ val headers = response.getHeaders
+ webTarget.path(path).request().headers(headers).get()
+ }
+ TrinoResponse(nextResponse, Option(qr.getError), nextData)
+ }.getOrElse(TrinoResponse(isEnd = true))
+ }
+
}
diff --git a/pom.xml b/pom.xml
index 9afc0cf1de5..2c2d9babe3a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -777,6 +777,12 @@
${jackson.version}
+
+ com.fasterxml.jackson.datatype
+ jackson-datatype-jdk8
+ ${jackson.version}
+
+
com.fasterxml.jackson.jaxrs
jackson-jaxrs-base