diff --git a/dev/dependencyList b/dev/dependencyList index 2035b95dea8..9813bb56288 100644 --- a/dev/dependencyList +++ b/dev/dependencyList @@ -69,7 +69,7 @@ jackson-annotations/2.14.2//jackson-annotations-2.14.2.jar jackson-core/2.14.2//jackson-core-2.14.2.jar jackson-databind/2.14.2//jackson-databind-2.14.2.jar jackson-dataformat-yaml/2.14.2//jackson-dataformat-yaml-2.14.2.jar -jackson-datatype-jdk8/2.12.3//jackson-datatype-jdk8-2.12.3.jar +jackson-datatype-jdk8/2.14.2//jackson-datatype-jdk8-2.14.2.jar jackson-datatype-jsr310/2.14.2//jackson-datatype-jsr310-2.14.2.jar jackson-jaxrs-base/2.14.2//jackson-jaxrs-base-2.14.2.jar jackson-jaxrs-json-provider/2.14.2//jackson-jaxrs-json-provider-2.14.2.jar diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/AbstractBackendService.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/AbstractBackendService.scala index e7c2d836573..b9e254508dd 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/AbstractBackendService.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/AbstractBackendService.scala @@ -156,11 +156,14 @@ abstract class AbstractBackendService(name: String) queryId } - override def getOperationStatus(operationHandle: OperationHandle): OperationStatus = { + override def getOperationStatus( + operationHandle: OperationHandle, + maxWait: Option[Long]): OperationStatus = { val operation = sessionManager.operationManager.getOperation(operationHandle) if (operation.shouldRunAsync) { try { - operation.getBackgroundHandle.get(timeout, TimeUnit.MILLISECONDS) + val waitTime = maxWait.getOrElse(timeout) + operation.getBackgroundHandle.get(waitTime, TimeUnit.MILLISECONDS) } catch { case e: TimeoutException => debug(s"$operationHandle: Long polling timed out, ${e.getMessage}") diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/BackendService.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/BackendService.scala index e1841156664..968a94197d2 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/BackendService.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/BackendService.scala @@ -91,7 +91,9 @@ trait BackendService { foreignTable: String): OperationHandle def getQueryId(operationHandle: OperationHandle): String - def getOperationStatus(operationHandle: OperationHandle): OperationStatus + def getOperationStatus( + operationHandle: OperationHandle, + maxWait: Option[Long] = None): OperationStatus def cancelOperation(operationHandle: OperationHandle): Unit def closeOperation(operationHandle: OperationHandle): Unit def getResultSetMetadata(operationHandle: OperationHandle): TGetResultSetMetadataResp diff --git a/kyuubi-server/pom.xml b/kyuubi-server/pom.xml index 4dd89e0e62c..478d08e40c8 100644 --- a/kyuubi-server/pom.xml +++ b/kyuubi-server/pom.xml @@ -221,6 +221,16 @@ jersey-media-multipart + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.datatype + jackson-datatype-jdk8 + + com.zaxxer HikariCP diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/BackendServiceMetric.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/BackendServiceMetric.scala index d8b66416375..68bf11d7f99 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/BackendServiceMetric.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/BackendServiceMetric.scala @@ -152,9 +152,11 @@ trait BackendServiceMetric extends BackendService { } } - abstract override def getOperationStatus(operationHandle: OperationHandle): OperationStatus = { + abstract override def getOperationStatus( + operationHandle: OperationHandle, + maxWait: Option[Long] = None): OperationStatus = { MetricsSystem.timerTracing(MetricsConstants.BS_GET_OPERATION_STATUS) { - super.getOperationStatus(operationHandle) + super.getOperationStatus(operationHandle, maxWait) } } diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiTrinoOperationTranslator.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiTrinoOperationTranslator.scala index 6ec9fc1c80e..5eba9c32777 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiTrinoOperationTranslator.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiTrinoOperationTranslator.scala @@ -19,10 +19,9 @@ package org.apache.kyuubi.server.trino.api import scala.collection.JavaConverters._ -import org.apache.hive.service.rpc.thrift.TProtocolVersion - import org.apache.kyuubi.operation.OperationHandle import org.apache.kyuubi.service.BackendService +import org.apache.kyuubi.session.SessionHandle import org.apache.kyuubi.sql.parser.trino.KyuubiTrinoFeParser import org.apache.kyuubi.sql.plan.PassThroughNode import org.apache.kyuubi.sql.plan.trino.{GetCatalogs, GetColumns, GetSchemas, GetTables, GetTableTypes, GetTypeInfo} @@ -32,17 +31,10 @@ class KyuubiTrinoOperationTranslator(backendService: BackendService) { def transform( statement: String, - user: String, - ipAddress: String, + sessionHandle: SessionHandle, configs: Map[String, String], runAsync: Boolean, queryTimeout: Long): OperationHandle = { - val sessionHandle = backendService.openSession( - TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11, - user, - "", - ipAddress, - configs) parser.parsePlan(statement) match { case GetSchemas(catalogName, schemaPattern) => backendService.getSchemas(sessionHandle, catalogName, schemaPattern) diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/Query.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/Query.scala new file mode 100644 index 00000000000..c8c9e2fd6c1 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/Query.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.trino.api + +import java.net.URI +import java.security.SecureRandom +import java.util.Objects.requireNonNull +import java.util.UUID +import java.util.concurrent.atomic.AtomicLong +import javax.ws.rs.WebApplicationException +import javax.ws.rs.core.{Response, UriInfo} + +import Slug.Context.{EXECUTING_QUERY, QUEUED_QUERY} +import com.google.common.hash.Hashing +import io.trino.client.QueryResults +import org.apache.hive.service.rpc.thrift.TProtocolVersion + +import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle} +import org.apache.kyuubi.operation.OperationState.{FINISHED, INITIALIZED, OperationState, PENDING} +import org.apache.kyuubi.service.BackendService +import org.apache.kyuubi.session.SessionHandle + +case class Query( + queryId: QueryId, + context: TrinoContext, + be: BackendService) { + + private val QUEUED_QUERY_PATH = "/v1/statement/queued/" + private val EXECUTING_QUERY_PATH = "/v1/statement/executing" + + private val slug: Slug = Slug.createNewWithUUID(queryId.getQueryId) + private val lastToken = new AtomicLong + + private val defaultMaxRows = 1000 + private val defaultFetchOrientation = FetchOrientation.withName("FETCH_NEXT") + + def getQueryResults(token: Long, uriInfo: UriInfo, maxWait: Long = 0): QueryResults = { + val status = + be.getOperationStatus(queryId.operationHandle, Some(maxWait)) + val nextUri = if (status.exception.isEmpty) { + getNextUri(token + 1, uriInfo, toSlugContext(status.state)) + } else null + val queryHtmlUri = uriInfo.getRequestUriBuilder + .replacePath("ui/query.html").replaceQuery(queryId.getQueryId).build() + + status.state match { + case FINISHED => + val metaData = be.getResultSetMetadata(queryId.operationHandle) + val resultSet = be.fetchResults( + queryId.operationHandle, + defaultFetchOrientation, + defaultMaxRows, + false) + TrinoContext.createQueryResults( + queryId.getQueryId, + nextUri, + queryHtmlUri, + status, + Option(metaData), + Option(resultSet)) + case _ => + TrinoContext.createQueryResults( + queryId.getQueryId, + nextUri, + queryHtmlUri, + status) + } + } + + def getLastToken: Long = this.lastToken.get() + + def getSlug: Slug = this.slug + + def cancel: Unit = clear + + private def clear = { + be.closeOperation(queryId.operationHandle) + context.session.get("sessionId").foreach { id => + be.closeSession(SessionHandle.fromUUID(id)) + } + } + + private def setToken(token: Long): Unit = { + val lastToken = this.lastToken.get + if (token != lastToken && token != lastToken + 1) { + throw new WebApplicationException(Response.Status.GONE) + } + this.lastToken.compareAndSet(lastToken, token) + } + + private def getNextUri(token: Long, uriInfo: UriInfo, slugContext: Slug.Context.Context): URI = { + val path = slugContext match { + case QUEUED_QUERY => QUEUED_QUERY_PATH + case EXECUTING_QUERY => EXECUTING_QUERY_PATH + } + + uriInfo.getBaseUriBuilder.replacePath(path) + .path(queryId.getQueryId) + .path(slug.makeSlug(slugContext, token)) + .path(String.valueOf(token)) + .replaceQuery("") + .build() + } + + private def toSlugContext(state: OperationState): Slug.Context.Context = { + state match { + case INITIALIZED | PENDING => Slug.Context.QUEUED_QUERY + case _ => Slug.Context.EXECUTING_QUERY + } + } + +} + +object Query { + + def apply( + statement: String, + context: TrinoContext, + translator: KyuubiTrinoOperationTranslator, + backendService: BackendService, + queryTimeout: Long = 0): Query = { + + val sessionHandle = createSession(context, backendService) + val operationHandle = translator.transform( + statement, + sessionHandle, + context.session, + true, + queryTimeout) + val newSessionProperties = + context.session + ("sessionId" -> sessionHandle.identifier.toString) + val updatedContext = context.copy(session = newSessionProperties) + Query(QueryId(operationHandle), updatedContext, backendService) + } + + def apply(id: String, context: TrinoContext, backendService: BackendService): Query = { + Query(QueryId(id), context, backendService) + } + + private def createSession( + context: TrinoContext, + backendService: BackendService): SessionHandle = { + backendService.openSession( + TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11, + context.user, + "", + context.remoteUserAddress.getOrElse(""), + context.session) + } + +} + +case class QueryId(operationHandle: OperationHandle) { + def getQueryId: String = operationHandle.identifier.toString +} + +object QueryId { + def apply(id: String): QueryId = QueryId(OperationHandle(id)) +} + +object Slug { + + object Context extends Enumeration { + type Context = Value + val QUEUED_QUERY, EXECUTING_QUERY = Value + } + + private val RANDOM = new SecureRandom + + def createNew: Slug = { + val randomBytes = new Array[Byte](16) + RANDOM.nextBytes(randomBytes) + new Slug(randomBytes) + } + + def createNewWithUUID(uuid: String): Slug = { + val uuidBytes = UUID.fromString(uuid).toString.getBytes("UTF-8") + new Slug(uuidBytes) + } +} + +case class Slug(slugKey: Array[Byte]) { + val hmac = Hashing.hmacSha1(requireNonNull(slugKey, "slugKey is null")) + + def makeSlug(context: Slug.Context.Context, token: Long): String = { + "y" + hmac.newHasher.putInt(context.id).putLong(token).hash.toString + } + + def isValid(context: Slug.Context.Context, slug: String, token: Long): Boolean = + makeSlug(context, token) == slug +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala index 4a0736ddb89..9e77139040a 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala @@ -28,6 +28,7 @@ import io.trino.client.{ClientStandardTypes, ClientTypeSignature, Column, QueryE import io.trino.client.ProtocolHeaders.TRINO_HEADERS import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, TTypeId} +import org.apache.kyuubi.operation.OperationState.FINISHED import org.apache.kyuubi.operation.OperationStatus /** @@ -58,6 +59,7 @@ case class TrinoContext( source: Option[String] = None, catalog: Option[String] = None, schema: Option[String] = None, + remoteUserAddress: Option[String] = None, language: Option[String] = None, traceToken: Option[String] = None, clientInfo: Option[String] = None, @@ -72,10 +74,11 @@ object TrinoContext { private val GENERIC_INTERNAL_ERROR_NAME = "GENERIC_INTERNAL_ERROR_NAME" private val GENERIC_INTERNAL_ERROR_TYPE = "INTERNAL_ERROR" - def apply(headers: HttpHeaders): TrinoContext = { - apply(headers.getRequestHeaders.asScala.toMap.map { + def apply(headers: HttpHeaders, remoteAddress: Option[String]): TrinoContext = { + val context = apply(headers.getRequestHeaders.asScala.toMap.map { case (k, v) => (k, v.asScala.toList) }) + context.copy(remoteUserAddress = remoteAddress) } def apply(headers: Map[String, List[String]]): TrinoContext = { @@ -134,7 +137,6 @@ object TrinoContext { } } - // TODO: Building response with TrinoContext and other information def buildTrinoResponse(qr: QueryResults, trinoContext: TrinoContext): Response = { val responseBuilder = Response.ok(qr) @@ -156,8 +158,6 @@ object TrinoContext { responseBuilder.header(TRINO_HEADERS.responseDeallocatedPrepare, urlEncode(v)) } - responseBuilder.header(TRINO_HEADERS.responseClearSession, s"responseClearSession") - responseBuilder.header(TRINO_HEADERS.responseClearTransactionId, "false") responseBuilder.build() } @@ -192,11 +192,16 @@ object TrinoContext { case None => null } + val updatedNextUri = queryStatus.state match { + case FINISHED if rowList == null || rowList.isEmpty || rowList.get(0).isEmpty => null + case _ => nextUri + } + new QueryResults( queryId, queryHtmlUri, nextUri, - nextUri, + updatedNextUri, columnList, rowList, StatementStats.builder.setState(queryStatus.state.name()).setQueued(false) diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiScalaObjectMapper.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoScalaObjectMapper.scala similarity index 73% rename from kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiScalaObjectMapper.scala rename to kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoScalaObjectMapper.scala index 915b109b7b9..f6055927ac2 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/KyuubiScalaObjectMapper.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoScalaObjectMapper.scala @@ -19,11 +19,14 @@ package org.apache.kyuubi.server.trino.api import javax.ws.rs.ext.ContextResolver -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.module.scala.DefaultScalaModule +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module -class KyuubiScalaObjectMapper extends ContextResolver[ObjectMapper] { - private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) +class TrinoScalaObjectMapper extends ContextResolver[ObjectMapper] { + + private lazy val mapper = new ObjectMapper() + .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) + .registerModule(new Jdk8Module) override def getContext(aClass: Class[_]): ObjectMapper = mapper } diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoServerConfig.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoServerConfig.scala index d1f7de336ba..298e60c9cac 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoServerConfig.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoServerConfig.scala @@ -21,6 +21,6 @@ import org.glassfish.jersey.server.ResourceConfig class TrinoServerConfig extends ResourceConfig { packages("org.apache.kyuubi.server.trino.api.v1") - register(classOf[KyuubiScalaObjectMapper]) + register(classOf[TrinoScalaObjectMapper]) register(classOf[RestExceptionMapper]) } diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala index 122b39ccd8c..ab783f8acce 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala @@ -18,16 +18,24 @@ package org.apache.kyuubi.server.trino.api.v1 import javax.ws.rs._ -import javax.ws.rs.core.{Context, HttpHeaders, MediaType} +import javax.ws.rs.core.{Context, HttpHeaders, MediaType, Response, UriInfo} +import javax.ws.rs.core.MediaType.TEXT_PLAIN_TYPE +import javax.ws.rs.core.Response.Status.{BAD_REQUEST, NOT_FOUND} +import scala.util.Try +import scala.util.control.NonFatal + +import io.airlift.units.Duration import io.swagger.v3.oas.annotations.media.{Content, Schema} import io.swagger.v3.oas.annotations.responses.ApiResponse import io.swagger.v3.oas.annotations.tags.Tag import io.trino.client.QueryResults import org.apache.kyuubi.Logging -import org.apache.kyuubi.server.trino.api.{ApiRequestContext, KyuubiTrinoOperationTranslator} +import org.apache.kyuubi.server.trino.api.{ApiRequestContext, KyuubiTrinoOperationTranslator, Query, QueryId, Slug, TrinoContext} +import org.apache.kyuubi.server.trino.api.Slug.Context.{EXECUTING_QUERY, QUEUED_QUERY} import org.apache.kyuubi.server.trino.api.v1.dto.Ok +import org.apache.kyuubi.service.BackendService @Tag(name = "Statement") @Produces(Array(MediaType.APPLICATION_JSON)) @@ -50,11 +58,32 @@ private[v1] class StatementResource extends ApiRequestContext with Logging { schema = new Schema(implementation = classOf[QueryResults]))), description = "Create a query") - @GET + @POST @Path("/") @Consumes(Array(MediaType.TEXT_PLAIN)) - def query(statement: String, @Context headers: HttpHeaders): QueryResults = { - throw new UnsupportedOperationException + def query( + statement: String, + @Context headers: HttpHeaders, + @Context uriInfo: UriInfo): Response = { + if (statement == null || statement.isEmpty) { + throw badRequest(BAD_REQUEST, "SQL statement is empty") + } + + val remoteAddr = Option(httpRequest.getRemoteAddr) + val trinoContext = TrinoContext(headers, remoteAddr) + + try { + val query = Query(statement, trinoContext, translator, fe.be) + val qr = query.getQueryResults(query.getLastToken, uriInfo) + TrinoContext.buildTrinoResponse(qr, query.context) + } catch { + case e: Exception => + val errorMsg = + s"Error submitting sql" + e.printStackTrace() + error(errorMsg, e) + throw badRequest(BAD_REQUEST, errorMsg) + } } @ApiResponse( @@ -65,11 +94,31 @@ private[v1] class StatementResource extends ApiRequestContext with Logging { @GET @Path("/queued/{queryId}/{slug}/{token}") def getQueuedStatementStatus( - @Context headers: HttpHeaders, @PathParam("queryId") queryId: String, @PathParam("slug") slug: String, - @PathParam("token") token: Long): QueryResults = { - throw new UnsupportedOperationException + @PathParam("token") token: Long, + @QueryParam("maxWait") maxWait: Duration, + @Context headers: HttpHeaders, + @Context uriInfo: UriInfo): Response = { + + val remoteAddr = Option(httpRequest.getRemoteAddr) + val trinoContext = TrinoContext(headers, remoteAddr) + val waitTime = if (maxWait == null) 0 else maxWait.toMillis + getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, QUEUED_QUERY) + .flatMap(query => + Try(TrinoContext.buildTrinoResponse( + query.getQueryResults( + token, + uriInfo, + waitTime), + query.context))) + .recover { + case NonFatal(e) => + val errorMsg = + s"Error executing for query id $queryId" + error(errorMsg, e) + throw badRequest(NOT_FOUND, "Query not found") + }.get } @ApiResponse( @@ -80,11 +129,28 @@ private[v1] class StatementResource extends ApiRequestContext with Logging { @GET @Path("/executing/{queryId}/{slug}/{token}") def getExecutingStatementStatus( - @Context headers: HttpHeaders, @PathParam("queryId") queryId: String, @PathParam("slug") slug: String, - @PathParam("token") token: Long): QueryResults = { - throw new UnsupportedOperationException + @PathParam("token") token: Long, + @QueryParam("maxWait") maxWait: Duration, + @Context headers: HttpHeaders, + @Context uriInfo: UriInfo): Response = { + + val remoteAddr = Option(httpRequest.getRemoteAddr) + val trinoContext = TrinoContext(headers, remoteAddr) + val waitTime = if (maxWait == null) 0 else maxWait.toMillis + getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, EXECUTING_QUERY) + .flatMap(query => + Try(TrinoContext.buildTrinoResponse( + query.getQueryResults(token, uriInfo, waitTime), + query.context))) + .recover { + case NonFatal(e) => + val errorMsg = + s"Error executing for query id $queryId" + error(errorMsg, e) + throw badRequest(NOT_FOUND, "Query not found") + }.get } @ApiResponse( @@ -95,11 +161,23 @@ private[v1] class StatementResource extends ApiRequestContext with Logging { @DELETE @Path("/queued/{queryId}/{slug}/{token}") def cancelQueuedStatement( - @Context headers: HttpHeaders, @PathParam("queryId") queryId: String, @PathParam("slug") slug: String, - @PathParam("token") token: Long): QueryResults = { - throw new UnsupportedOperationException + @PathParam("token") token: Long, + @Context headers: HttpHeaders): Response = { + + val remoteAddr = Option(httpRequest.getRemoteAddr) + val trinoContext = TrinoContext(headers, remoteAddr) + getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, QUEUED_QUERY) + .flatMap(query => Try(query.cancel)) + .recover { + case NonFatal(e) => + val errorMsg = + s"Error executing for query id $queryId" + error(errorMsg, e) + throw badRequest(NOT_FOUND, "Query not found") + }.get + Response.noContent.build } @ApiResponse( @@ -110,11 +188,44 @@ private[v1] class StatementResource extends ApiRequestContext with Logging { @DELETE @Path("/executing/{queryId}/{slug}/{token}") def cancelExecutingStatementStatus( - @Context headers: HttpHeaders, @PathParam("queryId") queryId: String, @PathParam("slug") slug: String, - @PathParam("token") token: Long): QueryResults = { - throw new UnsupportedOperationException + @PathParam("token") token: Long, + @Context headers: HttpHeaders): Response = { + + val remoteAddr = Option(httpRequest.getRemoteAddr) + val trinoContext = TrinoContext(headers, remoteAddr) + getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, EXECUTING_QUERY) + .flatMap(query => Try(query.cancel)) + .recover { + case NonFatal(e) => + val errorMsg = + s"Error executing for query id $queryId" + error(errorMsg, e) + throw badRequest(NOT_FOUND, "Query not found") + }.get + + Response.noContent.build } + private def getQuery( + be: BackendService, + context: TrinoContext, + queryId: QueryId, + slug: String, + token: Long, + slugContext: Slug.Context.Context): Try[Query] = { + + Try(be.sessionManager.operationManager.getOperation(queryId.operationHandle)).map { _ => + Query(queryId, context, be) + }.filter(_.getSlug.isValid(slugContext, slug, token)) + } + + private def badRequest(status: Response.Status, message: String) = + new WebApplicationException( + Response.status(status) + .`type`(TEXT_PLAIN_TYPE) + .entity(message) + .build) + } diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/RestFrontendTestHelper.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/RestFrontendTestHelper.scala index b22783771ec..fafdcf4a7b1 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/RestFrontendTestHelper.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/RestFrontendTestHelper.scala @@ -37,7 +37,7 @@ import org.apache.kyuubi.service.AbstractFrontendService object RestFrontendTestHelper { - private class RestApiBaseSuite extends JerseyTest { + class RestApiBaseSuite extends JerseyTest { override def configure: Application = new ResourceConfig(getClass) .register(classOf[MultiPartFeature]) @@ -58,7 +58,7 @@ trait RestFrontendTestHelper extends WithKyuubiServer { override protected val frontendProtocols: Seq[FrontendProtocol] = FrontendProtocols.REST :: Nil - private val restApiBaseSuite = new RestApiBaseSuite + protected val restApiBaseSuite: JerseyTest = new RestApiBaseSuite override def beforeAll(): Unit = { super.beforeAll() diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala deleted file mode 100644 index c0b3949f4ee..00000000000 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.kyuubi - -import java.net.URI -import java.time.ZoneId -import java.util.{Locale, Optional} -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import io.airlift.units.Duration -import io.trino.client.{ClientSelectedRole, ClientSession, StatementClient, StatementClientFactory} -import okhttp3.OkHttpClient - -trait TrinoClientTestHelper extends RestFrontendTestHelper { - - override def afterAll(): Unit = { - super.afterAll() - } - - private val httpClient = new OkHttpClient.Builder().build() - - protected val clientSession = createClientSession(baseUri: URI) - - def getTrinoStatementClient(sql: String): StatementClient = { - StatementClientFactory.newStatementClient(httpClient, clientSession, sql) - } - - def createClientSession(connectUrl: URI): ClientSession = { - new ClientSession( - connectUrl, - "kyuubi_test", - Optional.of("test_user"), - "kyuubi", - Optional.of("test_token_tracing"), - Set[String]().asJava, - "test_client_info", - "test_catalog", - "test_schema", - "test_path", - ZoneId.systemDefault(), - Locale.getDefault, - Map[String, String]( - "test_resource_key0" -> "test_resource_value0", - "test_resource_key1" -> "test_resource_value1").asJava, - Map[String, String]( - "test_property_key0" -> "test_property_value0", - "test_property_key1" -> "test_propert_value1").asJava, - Map[String, String]( - "test_statement_key0" -> "select 1", - "test_statement_key1" -> "select 2").asJava, - Map[String, ClientSelectedRole]( - "test_role_key0" -> ClientSelectedRole.valueOf("ROLE"), - "test_role_key2" -> ClientSelectedRole.valueOf("ALL")).asJava, - Map[String, String]( - "test_credentials_key0" -> "test_credentials_value0", - "test_credentials_key1" -> "test_credentials_value1").asJava, - "test_transaction_id", - new Duration(2, TimeUnit.MINUTES), - true) - - } - -} diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoRestFrontendTestHelper.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoRestFrontendTestHelper.scala new file mode 100644 index 00000000000..1ff00e64fa2 --- /dev/null +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoRestFrontendTestHelper.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi + +import org.glassfish.jersey.client.ClientConfig +import org.glassfish.jersey.test.JerseyTest + +import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols +import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols.FrontendProtocol +import org.apache.kyuubi.server.trino.api.TrinoScalaObjectMapper + +trait TrinoRestFrontendTestHelper extends RestFrontendTestHelper { + + private class TrinoRestBaseSuite extends RestFrontendTestHelper.RestApiBaseSuite { + override def configureClient(config: ClientConfig): Unit = { + config.register(classOf[TrinoScalaObjectMapper]) + } + } + + override protected val frontendProtocols: Seq[FrontendProtocol] = + FrontendProtocols.TRINO :: Nil + + override protected val restApiBaseSuite: JerseyTest = new TrinoRestBaseSuite + +} diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoClientApiSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoClientApiSuite.scala new file mode 100644 index 00000000000..c88b5c9409f --- /dev/null +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoClientApiSuite.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.trino.api + +import java.net.URI +import java.time.ZoneId +import java.util.{Collections, Locale, Optional} +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ + +import com.google.common.base.Verify +import io.airlift.units.Duration +import io.trino.client.{ClientSession, StatementClient, StatementClientFactory} +import okhttp3.OkHttpClient + +import org.apache.kyuubi.{KyuubiFunSuite, KyuubiSQLException, TrinoRestFrontendTestHelper} + +class TrinoClientApiSuite extends KyuubiFunSuite with TrinoRestFrontendTestHelper { + + private val httpClient = + new OkHttpClient.Builder() + .readTimeout(5, TimeUnit.MINUTES) + .build() + private lazy val clientSession = + new AtomicReference[ClientSession](createTestClientSession(baseUri)) + + test("submit query with trino client api") { + val trino = getTrinoStatementClient("select 1") + val result = execute(trino) + val sessionId = trino.getSetSessionProperties.asScala.get("sessionId") + assert(result == List(List(1))) + + updateClientSession(trino) + + val trino1 = getTrinoStatementClient("select 2") + val result1 = execute(trino1) + val sessionId1 = trino1.getSetSessionProperties.asScala.get("sessionId") + assert(result1 == List(List(2))) + assert(sessionId != sessionId1) + + trino.close() + } + + private def updateClientSession(trino: StatementClient): Unit = { + val session = clientSession.get + + var builder = ClientSession.builder(session) + // update catalog and schema + if (trino.getSetCatalog.isPresent || trino.getSetSchema.isPresent) { + builder = builder + .withCatalog(trino.getSetCatalog.orElse(session.getCatalog)) + .withSchema(trino.getSetSchema.orElse(session.getSchema)) + } + + // update path if present + if (trino.getSetPath.isPresent) { + builder = builder.withPath(trino.getSetPath.get) + } + + // update session properties if present + if (!trino.getSetSessionProperties.isEmpty || !trino.getResetSessionProperties.isEmpty) { + val properties = session.getProperties.asScala.clone() + properties ++= trino.getSetSessionProperties.asScala + properties --= trino.getResetSessionProperties.asScala + builder = builder.withProperties(properties.asJava) + } + clientSession.set(builder.build()) + } + + private def execute(trino: StatementClient): List[List[Any]] = { + @tailrec + def getData(trino: StatementClient): (Boolean, List[List[Any]]) = { + if (trino.isRunning) { + val data = trino.currentData().getData() + trino.advance() + if (data != null) { + (true, data.asScala.toList.map(_.asScala.toList)) + } else { + getData(trino) + } + } else { + Verify.verify(trino.isFinished) + val finalStatus = trino.finalStatusInfo() + if (finalStatus.getError() != null) { + throw KyuubiSQLException( + s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}") + } + (false, List[List[Any]]()) + } + } + Iterator.continually(getData(trino)).takeWhile(_._1).flatMap(_._2).toList + } + + private def getTrinoStatementClient(sql: String): StatementClient = { + StatementClientFactory.newStatementClient(httpClient, clientSession.get, sql) + } + + private def createTestClientSession(connectUrl: URI): ClientSession = { + new ClientSession( + connectUrl, + "kyuubi_test", + Optional.of("test_user"), + "kyuubi", + Optional.of("test_token_tracing"), + Set[String]().asJava, + "test_client_info", + "test_catalog", + "test_schema", + null, + ZoneId.systemDefault(), + Locale.getDefault, + Collections.emptyMap(), + Map[String, String]( + "test_property_key0" -> "test_property_value0", + "test_property_key1" -> "test_propert_value1").asJava, + Map[String, String]( + "test_statement_key0" -> "select 1", + "test_statement_key1" -> "select 2").asJava, + Collections.emptyMap(), + Collections.emptyMap(), + null, + new Duration(2, TimeUnit.MINUTES), + true) + + } + +} diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala index 8d7b2bf2ccf..87c8eda968a 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala @@ -85,7 +85,7 @@ class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper { val metadataResp = fe.be.getResultSetMetadata(opHandle) val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false) - val status = fe.be.getOperationStatus(opHandle) + val status = fe.be.getOperationStatus(opHandle, Some(0)) val uri = new URI("sfdsfsdfdsf") val results = TrinoContext @@ -112,7 +112,7 @@ class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper { val metadataResp = fe.be.getResultSetMetadata(opHandle) val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false) - val status = fe.be.getOperationStatus(opHandle) + val status = fe.be.getOperationStatus(opHandle, Some(0)) val uri = new URI("sfdsfsdfdsf") val results = TrinoContext diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala index b60c7c67aa2..adbf389c931 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/v1/StatementResourceSuite.scala @@ -17,15 +17,27 @@ package org.apache.kyuubi.server.trino.api.v1 -import org.apache.kyuubi.{KyuubiFunSuite, RestFrontendTestHelper} -import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols -import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols.FrontendProtocol +import javax.ws.rs.client.Entity +import javax.ws.rs.core.{MediaType, Response} + +import scala.collection.JavaConverters._ + +import io.trino.client.{QueryError, QueryResults} +import io.trino.client.ProtocolHeaders.TRINO_HEADERS + +import org.apache.kyuubi.{KyuubiFunSuite, KyuubiSQLException, TrinoRestFrontendTestHelper} +import org.apache.kyuubi.operation.{OperationHandle, OperationState} +import org.apache.kyuubi.server.trino.api.TrinoContext import org.apache.kyuubi.server.trino.api.v1.dto.Ok +import org.apache.kyuubi.session.SessionHandle -class StatementResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper { +class StatementResourceSuite extends KyuubiFunSuite with TrinoRestFrontendTestHelper { - override protected val frontendProtocols: Seq[FrontendProtocol] = - FrontendProtocols.TRINO :: Nil + case class TrinoResponse( + response: Option[Response] = None, + queryError: Option[QueryError] = None, + data: List[List[Any]] = List[List[Any]](), + isEnd: Boolean = false) test("statement test") { val response = webTarget.path("v1/statement/test").request().get() @@ -33,4 +45,74 @@ class StatementResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper assert(result == new Ok("trino server is running")) } + test("statement submit for query error") { + + val response = webTarget.path("v1/statement") + .request().post(Entity.entity("select a", MediaType.TEXT_PLAIN_TYPE)) + + val trinoResponseIter = Iterator.iterate(TrinoResponse(response = Option(response)))(getData) + val isErr = trinoResponseIter.takeWhile(_.isEnd == false).exists { t => + t.queryError != None && t.response == None + } + assert(isErr == true) + } + + test("statement submit and get result") { + val response = webTarget.path("v1/statement") + .request().post(Entity.entity("select 1", MediaType.TEXT_PLAIN_TYPE)) + + val trinoResponseIter = Iterator.iterate(TrinoResponse(response = Option(response)))(getData) + val dataSet = trinoResponseIter + .takeWhile(_.isEnd == false) + .map(_.data) + .flatten.toList + assert(dataSet == List(List(1))) + } + + test("query cancel") { + val response = webTarget.path("v1/statement") + .request().post(Entity.entity("select 1", MediaType.TEXT_PLAIN_TYPE)) + val qr = response.readEntity(classOf[QueryResults]) + val sessionManager = fe.be.sessionManager + val sessionHandle = + response.getStringHeaders.get(TRINO_HEADERS.responseSetSession).asScala + .map(_.split("=")) + .find { + case Array("sessionId", _) => true + } + .map { + case Array(_, value) => SessionHandle.fromUUID(TrinoContext.urlDecode(value)) + }.get + sessionManager.getSession(sessionHandle) + val operationHandle = OperationHandle(qr.getId) + val operation = sessionManager.operationManager.getOperation(operationHandle) + assert(response.getStatus == 200) + val path = qr.getNextUri.getPath + val nextResponse = webTarget.path(path).request().header( + TRINO_HEADERS.requestSession(), + s"sessionId=${TrinoContext.urlEncode(sessionHandle.identifier.toString)}").delete() + assert(nextResponse.getStatus == 204) + assert(operation.getStatus.state == OperationState.CLOSED) + val exception = intercept[KyuubiSQLException](sessionManager.getSession(sessionHandle)) + assert(exception.getMessage === s"Invalid $sessionHandle") + + } + + private def getData(current: TrinoResponse): TrinoResponse = { + current.response.map { response => + assert(response.getStatus == 200) + val qr = response.readEntity(classOf[QueryResults]) + val nextData = Option(qr.getData) + .map(_.asScala.toList.map(_.asScala.toList)) + .getOrElse(List[List[Any]]()) + val nextResponse = Option(qr.getNextUri).map { + uri => + val path = uri.getPath + val headers = response.getHeaders + webTarget.path(path).request().headers(headers).get() + } + TrinoResponse(nextResponse, Option(qr.getError), nextData) + }.getOrElse(TrinoResponse(isEnd = true)) + } + } diff --git a/pom.xml b/pom.xml index 9afc0cf1de5..2c2d9babe3a 100644 --- a/pom.xml +++ b/pom.xml @@ -777,6 +777,12 @@ ${jackson.version} + + com.fasterxml.jackson.datatype + jackson-datatype-jdk8 + ${jackson.version} + + com.fasterxml.jackson.jaxrs jackson-jaxrs-base