Skip to content

Commit

Permalink
[KYUUBI #3935] Support use Trino client to submit SQL
Browse files Browse the repository at this point in the history
### _Why are the changes needed?_

Close #3935

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4232 from iodone/kyuubi-3935.

Closes #3935

936ea1f [odone] address
e7bd01a [odone] support trino client connect kyuubi trino server
9ea8b6a [odone] [WIP] trion request/response implementation

Authored-by: odone <odone.zhang@gmail.com>
Signed-off-by: ulyssesyou <ulyssesyou@apache.org>
(cherry picked from commit 41f0805)
Signed-off-by: ulyssesyou <ulyssesyou@apache.org>
  • Loading branch information
iodone authored and ulysses-you committed Feb 13, 2023
1 parent 354df96 commit 2e5a5e7
Show file tree
Hide file tree
Showing 18 changed files with 661 additions and 134 deletions.
2 changes: 1 addition & 1 deletion dev/dependencyList
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jackson-annotations/2.14.2//jackson-annotations-2.14.2.jar
jackson-core/2.14.2//jackson-core-2.14.2.jar
jackson-databind/2.14.2//jackson-databind-2.14.2.jar
jackson-dataformat-yaml/2.14.2//jackson-dataformat-yaml-2.14.2.jar
jackson-datatype-jdk8/2.12.3//jackson-datatype-jdk8-2.12.3.jar
jackson-datatype-jdk8/2.14.2//jackson-datatype-jdk8-2.14.2.jar
jackson-datatype-jsr310/2.14.2//jackson-datatype-jsr310-2.14.2.jar
jackson-jaxrs-base/2.14.2//jackson-jaxrs-base-2.14.2.jar
jackson-jaxrs-json-provider/2.14.2//jackson-jaxrs-json-provider-2.14.2.jar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,14 @@ abstract class AbstractBackendService(name: String)
queryId
}

override def getOperationStatus(operationHandle: OperationHandle): OperationStatus = {
override def getOperationStatus(
operationHandle: OperationHandle,
maxWait: Option[Long]): OperationStatus = {
val operation = sessionManager.operationManager.getOperation(operationHandle)
if (operation.shouldRunAsync) {
try {
operation.getBackgroundHandle.get(timeout, TimeUnit.MILLISECONDS)
val waitTime = maxWait.getOrElse(timeout)
operation.getBackgroundHandle.get(waitTime, TimeUnit.MILLISECONDS)
} catch {
case e: TimeoutException =>
debug(s"$operationHandle: Long polling timed out, ${e.getMessage}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ trait BackendService {
foreignTable: String): OperationHandle
def getQueryId(operationHandle: OperationHandle): String

def getOperationStatus(operationHandle: OperationHandle): OperationStatus
def getOperationStatus(
operationHandle: OperationHandle,
maxWait: Option[Long] = None): OperationStatus
def cancelOperation(operationHandle: OperationHandle): Unit
def closeOperation(operationHandle: OperationHandle): Unit
def getResultSetMetadata(operationHandle: OperationHandle): TGetResultSetMetadataResp
Expand Down
10 changes: 10 additions & 0 deletions kyuubi-server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@
<artifactId>jersey-media-multipart</artifactId>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jdk8</artifactId>
</dependency>

<dependency>
<groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,11 @@ trait BackendServiceMetric extends BackendService {
}
}

abstract override def getOperationStatus(operationHandle: OperationHandle): OperationStatus = {
abstract override def getOperationStatus(
operationHandle: OperationHandle,
maxWait: Option[Long] = None): OperationStatus = {
MetricsSystem.timerTracing(MetricsConstants.BS_GET_OPERATION_STATUS) {
super.getOperationStatus(operationHandle)
super.getOperationStatus(operationHandle, maxWait)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ package org.apache.kyuubi.server.trino.api

import scala.collection.JavaConverters._

import org.apache.hive.service.rpc.thrift.TProtocolVersion

import org.apache.kyuubi.operation.OperationHandle
import org.apache.kyuubi.service.BackendService
import org.apache.kyuubi.session.SessionHandle
import org.apache.kyuubi.sql.parser.trino.KyuubiTrinoFeParser
import org.apache.kyuubi.sql.plan.PassThroughNode
import org.apache.kyuubi.sql.plan.trino.{GetCatalogs, GetColumns, GetSchemas, GetTables, GetTableTypes, GetTypeInfo}
Expand All @@ -32,17 +31,10 @@ class KyuubiTrinoOperationTranslator(backendService: BackendService) {

def transform(
statement: String,
user: String,
ipAddress: String,
sessionHandle: SessionHandle,
configs: Map[String, String],
runAsync: Boolean,
queryTimeout: Long): OperationHandle = {
val sessionHandle = backendService.openSession(
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
user,
"",
ipAddress,
configs)
parser.parsePlan(statement) match {
case GetSchemas(catalogName, schemaPattern) =>
backendService.getSchemas(sessionHandle, catalogName, schemaPattern)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.server.trino.api

import java.net.URI
import java.security.SecureRandom
import java.util.Objects.requireNonNull
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import javax.ws.rs.WebApplicationException
import javax.ws.rs.core.{Response, UriInfo}

import Slug.Context.{EXECUTING_QUERY, QUEUED_QUERY}
import com.google.common.hash.Hashing
import io.trino.client.QueryResults
import org.apache.hive.service.rpc.thrift.TProtocolVersion

import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle}
import org.apache.kyuubi.operation.OperationState.{FINISHED, INITIALIZED, OperationState, PENDING}
import org.apache.kyuubi.service.BackendService
import org.apache.kyuubi.session.SessionHandle

case class Query(
queryId: QueryId,
context: TrinoContext,
be: BackendService) {

private val QUEUED_QUERY_PATH = "/v1/statement/queued/"
private val EXECUTING_QUERY_PATH = "/v1/statement/executing"

private val slug: Slug = Slug.createNewWithUUID(queryId.getQueryId)
private val lastToken = new AtomicLong

private val defaultMaxRows = 1000
private val defaultFetchOrientation = FetchOrientation.withName("FETCH_NEXT")

def getQueryResults(token: Long, uriInfo: UriInfo, maxWait: Long = 0): QueryResults = {
val status =
be.getOperationStatus(queryId.operationHandle, Some(maxWait))
val nextUri = if (status.exception.isEmpty) {
getNextUri(token + 1, uriInfo, toSlugContext(status.state))
} else null
val queryHtmlUri = uriInfo.getRequestUriBuilder
.replacePath("ui/query.html").replaceQuery(queryId.getQueryId).build()

status.state match {
case FINISHED =>
val metaData = be.getResultSetMetadata(queryId.operationHandle)
val resultSet = be.fetchResults(
queryId.operationHandle,
defaultFetchOrientation,
defaultMaxRows,
false)
TrinoContext.createQueryResults(
queryId.getQueryId,
nextUri,
queryHtmlUri,
status,
Option(metaData),
Option(resultSet))
case _ =>
TrinoContext.createQueryResults(
queryId.getQueryId,
nextUri,
queryHtmlUri,
status)
}
}

def getLastToken: Long = this.lastToken.get()

def getSlug: Slug = this.slug

def cancel: Unit = clear

private def clear = {
be.closeOperation(queryId.operationHandle)
context.session.get("sessionId").foreach { id =>
be.closeSession(SessionHandle.fromUUID(id))
}
}

private def setToken(token: Long): Unit = {
val lastToken = this.lastToken.get
if (token != lastToken && token != lastToken + 1) {
throw new WebApplicationException(Response.Status.GONE)
}
this.lastToken.compareAndSet(lastToken, token)
}

private def getNextUri(token: Long, uriInfo: UriInfo, slugContext: Slug.Context.Context): URI = {
val path = slugContext match {
case QUEUED_QUERY => QUEUED_QUERY_PATH
case EXECUTING_QUERY => EXECUTING_QUERY_PATH
}

uriInfo.getBaseUriBuilder.replacePath(path)
.path(queryId.getQueryId)
.path(slug.makeSlug(slugContext, token))
.path(String.valueOf(token))
.replaceQuery("")
.build()
}

private def toSlugContext(state: OperationState): Slug.Context.Context = {
state match {
case INITIALIZED | PENDING => Slug.Context.QUEUED_QUERY
case _ => Slug.Context.EXECUTING_QUERY
}
}

}

object Query {

def apply(
statement: String,
context: TrinoContext,
translator: KyuubiTrinoOperationTranslator,
backendService: BackendService,
queryTimeout: Long = 0): Query = {

val sessionHandle = createSession(context, backendService)
val operationHandle = translator.transform(
statement,
sessionHandle,
context.session,
true,
queryTimeout)
val newSessionProperties =
context.session + ("sessionId" -> sessionHandle.identifier.toString)
val updatedContext = context.copy(session = newSessionProperties)
Query(QueryId(operationHandle), updatedContext, backendService)
}

def apply(id: String, context: TrinoContext, backendService: BackendService): Query = {
Query(QueryId(id), context, backendService)
}

private def createSession(
context: TrinoContext,
backendService: BackendService): SessionHandle = {
backendService.openSession(
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
context.user,
"",
context.remoteUserAddress.getOrElse(""),
context.session)
}

}

case class QueryId(operationHandle: OperationHandle) {
def getQueryId: String = operationHandle.identifier.toString
}

object QueryId {
def apply(id: String): QueryId = QueryId(OperationHandle(id))
}

object Slug {

object Context extends Enumeration {
type Context = Value
val QUEUED_QUERY, EXECUTING_QUERY = Value
}

private val RANDOM = new SecureRandom

def createNew: Slug = {
val randomBytes = new Array[Byte](16)
RANDOM.nextBytes(randomBytes)
new Slug(randomBytes)
}

def createNewWithUUID(uuid: String): Slug = {
val uuidBytes = UUID.fromString(uuid).toString.getBytes("UTF-8")
new Slug(uuidBytes)
}
}

case class Slug(slugKey: Array[Byte]) {
val hmac = Hashing.hmacSha1(requireNonNull(slugKey, "slugKey is null"))

def makeSlug(context: Slug.Context.Context, token: Long): String = {
"y" + hmac.newHasher.putInt(context.id).putLong(token).hash.toString
}

def isValid(context: Slug.Context.Context, slug: String, token: Long): Boolean =
makeSlug(context, token) == slug
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import io.trino.client.{ClientStandardTypes, ClientTypeSignature, Column, QueryE
import io.trino.client.ProtocolHeaders.TRINO_HEADERS
import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, TTypeId}

import org.apache.kyuubi.operation.OperationState.FINISHED
import org.apache.kyuubi.operation.OperationStatus

/**
Expand Down Expand Up @@ -58,6 +59,7 @@ case class TrinoContext(
source: Option[String] = None,
catalog: Option[String] = None,
schema: Option[String] = None,
remoteUserAddress: Option[String] = None,
language: Option[String] = None,
traceToken: Option[String] = None,
clientInfo: Option[String] = None,
Expand All @@ -72,10 +74,11 @@ object TrinoContext {
private val GENERIC_INTERNAL_ERROR_NAME = "GENERIC_INTERNAL_ERROR_NAME"
private val GENERIC_INTERNAL_ERROR_TYPE = "INTERNAL_ERROR"

def apply(headers: HttpHeaders): TrinoContext = {
apply(headers.getRequestHeaders.asScala.toMap.map {
def apply(headers: HttpHeaders, remoteAddress: Option[String]): TrinoContext = {
val context = apply(headers.getRequestHeaders.asScala.toMap.map {
case (k, v) => (k, v.asScala.toList)
})
context.copy(remoteUserAddress = remoteAddress)
}

def apply(headers: Map[String, List[String]]): TrinoContext = {
Expand Down Expand Up @@ -134,7 +137,6 @@ object TrinoContext {
}
}

// TODO: Building response with TrinoContext and other information
def buildTrinoResponse(qr: QueryResults, trinoContext: TrinoContext): Response = {
val responseBuilder = Response.ok(qr)

Expand All @@ -156,8 +158,6 @@ object TrinoContext {
responseBuilder.header(TRINO_HEADERS.responseDeallocatedPrepare, urlEncode(v))
}

responseBuilder.header(TRINO_HEADERS.responseClearSession, s"responseClearSession")
responseBuilder.header(TRINO_HEADERS.responseClearTransactionId, "false")
responseBuilder.build()
}

Expand Down Expand Up @@ -192,11 +192,16 @@ object TrinoContext {
case None => null
}

val updatedNextUri = queryStatus.state match {
case FINISHED if rowList == null || rowList.isEmpty || rowList.get(0).isEmpty => null
case _ => nextUri
}

new QueryResults(
queryId,
queryHtmlUri,
nextUri,
nextUri,
updatedNextUri,
columnList,
rowList,
StatementStats.builder.setState(queryStatus.state.name()).setQueued(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ package org.apache.kyuubi.server.trino.api

import javax.ws.rs.ext.ContextResolver

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module

class KyuubiScalaObjectMapper extends ContextResolver[ObjectMapper] {
private val mapper = new ObjectMapper().registerModule(DefaultScalaModule)
class TrinoScalaObjectMapper extends ContextResolver[ObjectMapper] {

private lazy val mapper = new ObjectMapper()
.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
.registerModule(new Jdk8Module)

override def getContext(aClass: Class[_]): ObjectMapper = mapper
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ import org.glassfish.jersey.server.ResourceConfig

class TrinoServerConfig extends ResourceConfig {
packages("org.apache.kyuubi.server.trino.api.v1")
register(classOf[KyuubiScalaObjectMapper])
register(classOf[TrinoScalaObjectMapper])
register(classOf[RestExceptionMapper])
}
Loading

0 comments on commit 2e5a5e7

Please sign in to comment.