Skip to content

Commit

Permalink
address
Browse files Browse the repository at this point in the history
  • Loading branch information
iodone committed Feb 13, 2023
1 parent e7bd01a commit 936ea1f
Show file tree
Hide file tree
Showing 13 changed files with 75 additions and 93 deletions.
2 changes: 1 addition & 1 deletion dev/dependencyList
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jackson-annotations/2.14.2//jackson-annotations-2.14.2.jar
jackson-core/2.14.2//jackson-core-2.14.2.jar
jackson-databind/2.14.2//jackson-databind-2.14.2.jar
jackson-dataformat-yaml/2.14.2//jackson-dataformat-yaml-2.14.2.jar
jackson-datatype-jdk8/2.12.3//jackson-datatype-jdk8-2.12.3.jar
jackson-datatype-jdk8/2.14.2//jackson-datatype-jdk8-2.14.2.jar
jackson-datatype-jsr310/2.14.2//jackson-datatype-jsr310-2.14.2.jar
jackson-jaxrs-base/2.14.2//jackson-jaxrs-base-2.14.2.jar
jackson-jaxrs-json-provider/2.14.2//jackson-jaxrs-json-provider-2.14.2.jar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ abstract class AbstractBackendService(name: String)

override def getOperationStatus(
operationHandle: OperationHandle,
maxWait: Long = timeout): OperationStatus = {
maxWait: Option[Long]): OperationStatus = {
val operation = sessionManager.operationManager.getOperation(operationHandle)
if (operation.shouldRunAsync) {
try {
val waitTime = maxWait
val waitTime = maxWait.getOrElse(timeout)
operation.getBackgroundHandle.get(waitTime, TimeUnit.MILLISECONDS)
} catch {
case e: TimeoutException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ trait BackendService {
foreignTable: String): OperationHandle
def getQueryId(operationHandle: OperationHandle): String

def getOperationStatus(operationHandle: OperationHandle, maxWait: Long = 0): OperationStatus
def getOperationStatus(
operationHandle: OperationHandle,
maxWait: Option[Long] = None): OperationStatus
def cancelOperation(operationHandle: OperationHandle): Unit
def closeOperation(operationHandle: OperationHandle): Unit
def getResultSetMetadata(operationHandle: OperationHandle): TGetResultSetMetadataResp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ trait BackendServiceMetric extends BackendService {

abstract override def getOperationStatus(
operationHandle: OperationHandle,
maxWait: Long): OperationStatus = {
maxWait: Option[Long] = None): OperationStatus = {
MetricsSystem.timerTracing(MetricsConstants.BS_GET_OPERATION_STATUS) {
super.getOperationStatus(operationHandle)
super.getOperationStatus(operationHandle, maxWait)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ case class Query(

def getQueryResults(token: Long, uriInfo: UriInfo, maxWait: Long = 0): QueryResults = {
val status =
be.getOperationStatus(queryId.operationHandle, maxWait)
be.getOperationStatus(queryId.operationHandle, Some(maxWait))
val nextUri = if (status.exception.isEmpty) {
getNextUri(token + 1, uriInfo, toSlugContext(status.state))
} else null
Expand Down Expand Up @@ -155,14 +155,12 @@ object Query {
private def createSession(
context: TrinoContext,
backendService: BackendService): SessionHandle = {
context.session
.get("sessionId")
.fold(backendService.openSession(
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
context.user,
"",
context.remoteUserAddress.getOrElse(""),
context.session))(SessionHandle.fromUUID)
backendService.openSession(
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
context.user,
"",
context.remoteUserAddress.getOrElse(""),
context.session)
}

}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ import scala.collection.JavaConverters._

import io.trino.client.{ClientStandardTypes, ClientTypeSignature, Column, QueryError, QueryResults, StatementStats, Warning}
import io.trino.client.ProtocolHeaders.TRINO_HEADERS
import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet}

import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, TTypeId}

import org.apache.kyuubi.operation.OperationState.FINISHED
import org.apache.kyuubi.operation.OperationStatus

/**
Expand Down Expand Up @@ -193,14 +192,16 @@ object TrinoContext {
case None => null
}

val updatedNextUri =
if (rowList == null || rowList.isEmpty || rowList.get(0).isEmpty) null else nextUri
val updatedNextUri = queryStatus.state match {
case FINISHED if rowList == null || rowList.isEmpty || rowList.get(0).isEmpty => null
case _ => nextUri
}

new QueryResults(
queryId,
queryHtmlUri,
nextUri,
nextUri,
updatedNextUri,
columnList,
rowList,
StatementStats.builder.setState(queryStatus.state.name()).setQueued(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,6 @@ class TrinoScalaObjectMapper extends ContextResolver[ObjectMapper] {

private lazy val mapper = new ObjectMapper()
.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
// .disable(MapperFeature.AUTO_DETECT_CREATORS)
// .disable(MapperFeature.AUTO_DETECT_FIELDS)
// .disable(MapperFeature.AUTO_DETECT_SETTERS)
// .disable(MapperFeature.AUTO_DETECT_GETTERS)
// .disable(MapperFeature.AUTO_DETECT_IS_GETTERS)
// .disable(MapperFeature.USE_GETTERS_AS_SETTERS)
// .disable(MapperFeature.CAN_OVERRIDE_ACCESS_MODIFIERS)
// .disable(MapperFeature.INFER_PROPERTY_MUTATORS)
// .disable(MapperFeature.ALLOW_FINAL_FIELDS_AS_MUTATORS)
.registerModule(new Jdk8Module)

override def getContext(aClass: Class[_]): ObjectMapper = mapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import io.swagger.v3.oas.annotations.tags.Tag
import io.trino.client.QueryResults

import org.apache.kyuubi.Logging
import org.apache.kyuubi.server.trino.api.{ApiRequestContext, KyuubiTrinoOperationTranslator, Query, QueryId, QueryManager, Slug, TrinoContext}
import org.apache.kyuubi.server.trino.api.{ApiRequestContext, KyuubiTrinoOperationTranslator, Query, QueryId, Slug, TrinoContext}
import org.apache.kyuubi.server.trino.api.Slug.Context.{EXECUTING_QUERY, QUEUED_QUERY}
import org.apache.kyuubi.server.trino.api.v1.dto.Ok
import org.apache.kyuubi.service.BackendService
Expand All @@ -42,7 +42,6 @@ import org.apache.kyuubi.service.BackendService
private[v1] class StatementResource extends ApiRequestContext with Logging {

lazy val translator = new KyuubiTrinoOperationTranslator(fe.be)
lazy val queryManager = new QueryManager()

@ApiResponse(
responseCode = "200",
Expand Down Expand Up @@ -117,7 +116,6 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
case NonFatal(e) =>
val errorMsg =
s"Error executing for query id $queryId"
e.printStackTrace()
error(errorMsg, e)
throw badRequest(NOT_FOUND, "Query not found")
}.get
Expand Down Expand Up @@ -150,7 +148,6 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
case NonFatal(e) =>
val errorMsg =
s"Error executing for query id $queryId"
e.printStackTrace()
error(errorMsg, e)
throw badRequest(NOT_FOUND, "Query not found")
}.get
Expand Down Expand Up @@ -225,7 +222,7 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
}

private def badRequest(status: Response.Status, message: String) =
throw new WebApplicationException(
new WebApplicationException(
Response.status(status)
.`type`(TEXT_PLAIN_TYPE)
.entity(message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TrinoClientApiSuite extends KyuubiFunSuite with TrinoRestFrontendTestHelpe
val result1 = execute(trino1)
val sessionId1 = trino1.getSetSessionProperties.asScala.get("sessionId")
assert(result1 == List(List(2)))
assert(sessionId == sessionId1)
assert(sessionId != sessionId1)

trino.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper {

val metadataResp = fe.be.getResultSetMetadata(opHandle)
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
val status = fe.be.getOperationStatus(opHandle)
val status = fe.be.getOperationStatus(opHandle, Some(0))

val uri = new URI("sfdsfsdfdsf")
val results = TrinoContext
Expand All @@ -112,7 +112,7 @@ class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper {

val metadataResp = fe.be.getResultSetMetadata(opHandle)
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
val status = fe.be.getOperationStatus(opHandle)
val status = fe.be.getOperationStatus(opHandle, Some(0))

val uri = new URI("sfdsfsdfdsf")
val results = TrinoContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,51 @@ import javax.ws.rs.core.{MediaType, Response}

import scala.collection.JavaConverters._

import io.trino.client.{QueryError, QueryResults}
import io.trino.client.ProtocolHeaders.TRINO_HEADERS
import io.trino.client.QueryResults

import org.apache.kyuubi.{KyuubiSQLException, TrinoRestFrontendTestHelper}
import org.apache.kyuubi.{KyuubiFunSuite, KyuubiSQLException, TrinoRestFrontendTestHelper}
import org.apache.kyuubi.operation.{OperationHandle, OperationState}
import org.apache.kyuubi.server.trino.api.TrinoContext
import org.apache.kyuubi.server.trino.api.v1.dto.Ok
import org.apache.kyuubi.session.SessionHandle

class StatementResourceSuite extends TrinoRestFrontendTestHelper {
class StatementResourceSuite extends KyuubiFunSuite with TrinoRestFrontendTestHelper {

case class TrinoResponse(
response: Option[Response] = None,
queryError: Option[QueryError] = None,
data: List[List[Any]] = List[List[Any]](),
isEnd: Boolean = false)

test("statement test") {
val response = webTarget.path("v1/statement/test").request().get()
val result = response.readEntity(classOf[Ok])
assert(result == new Ok("trino server is running"))
}

test("statement submit for query error") {

val response = webTarget.path("v1/statement")
.request().post(Entity.entity("select a", MediaType.TEXT_PLAIN_TYPE))

val trinoResponseIter = Iterator.iterate(TrinoResponse(response = Option(response)))(getData)
val isErr = trinoResponseIter.takeWhile(_.isEnd == false).exists { t =>
t.queryError != None && t.response == None
}
assert(isErr == true)
}

test("statement submit and get result") {
val response = webTarget.path("v1/statement")
.request().post(Entity.entity("select 1", MediaType.TEXT_PLAIN_TYPE))
checkResult(response)

val trinoResponseIter = Iterator.iterate(TrinoResponse(response = Option(response)))(getData)
val dataSet = trinoResponseIter
.takeWhile(_.isEnd == false)
.map(_.data)
.flatten.toList
assert(dataSet == List(List(1)))
}

test("query cancel") {
Expand Down Expand Up @@ -74,20 +98,21 @@ class StatementResourceSuite extends TrinoRestFrontendTestHelper {

}

private def checkResult(response: Response): Unit = {
assert(response.getStatus == 200)
val qr = response.readEntity(classOf[QueryResults])
if (qr.getData.iterator().hasNext) {
val resultSet = qr.getData.iterator()
assert(resultSet.next.asScala == List(1))
}
if (qr.getNextUri != null) {
val path = qr.getNextUri.getPath
val headers = response.getHeaders
val nextResponse = webTarget.path(path).request().headers(headers).get()
checkResult(nextResponse)
}

private def getData(current: TrinoResponse): TrinoResponse = {
current.response.map { response =>
assert(response.getStatus == 200)
val qr = response.readEntity(classOf[QueryResults])
val nextData = Option(qr.getData)
.map(_.asScala.toList.map(_.asScala.toList))
.getOrElse(List[List[Any]]())
val nextResponse = Option(qr.getNextUri).map {
uri =>
val path = uri.getPath
val headers = response.getHeaders
webTarget.path(path).request().headers(headers).get()
}
TrinoResponse(nextResponse, Option(qr.getError), nextData)
}.getOrElse(TrinoResponse(isEnd = true))
}

}
13 changes: 6 additions & 7 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@
<scopt.version>4.1.0</scopt.version>
<slf4j.version>1.7.36</slf4j.version>
<snakeyaml.version>1.33</snakeyaml.version>
<jackson.datatype.version>2.12.3</jackson.datatype.version>
<!--
DO NOT forget to change the following properties when change the minor version of Spark:
`delta.version`, `maven.plugin.scalatest.exclude.tags`
Expand Down Expand Up @@ -779,6 +778,12 @@
<version>${jackson.version}</version>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jdk8</artifactId>
<version>${jackson.version}</version>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.jaxrs</groupId>
<artifactId>jackson-jaxrs-base</artifactId>
Expand Down Expand Up @@ -1642,12 +1647,6 @@
<artifactId>py4j</artifactId>
<version>${py4j.version}</version>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jdk8</artifactId>
<version>${jackson.datatype.version}</version>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down

0 comments on commit 936ea1f

Please sign in to comment.