Skip to content

Commit a2243b4

Browse files
committed
[KYUUBI #3934] Compatiable with Trino rest dto
1 parent 534fc9f commit a2243b4

File tree

2 files changed

+321
-0
lines changed

2 files changed

+321
-0
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF AnyRef KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.server.trino.api
19+
20+
import java.net.URI
21+
import java.util
22+
23+
import scala.collection.JavaConverters._
24+
25+
import io.trino.client.{ClientTypeSignature, Column, QueryError, QueryResults, StatementStats, Warning}
26+
import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet}
27+
28+
import org.apache.kyuubi.operation.OperationStatus
29+
30+
31+
object KyuubiTrinoQueryResultAdapt {
32+
private val defaultWarning: util.List[Warning] = new util.ArrayList[Warning]()
33+
private val GENERIC_INTERNAL_ERROR_CODE = 65536
34+
private val GENERIC_INTERNAL_ERROR_NAME = "GENERIC_INTERNAL_ERROR_NAME"
35+
private val GENERIC_INTERNAL_ERROR_TYPE = "INTERNAL_ERROR"
36+
37+
def createQueryResults(
38+
queryId: String,
39+
nextUri: URI,
40+
queryHtmlUri: URI,
41+
queryStatus: OperationStatus,
42+
columns: Option[TGetResultSetMetadataResp] = None,
43+
data: Option[TRowSet] = None): QueryResults = {
44+
45+
// val queryHtmlUri = uriInfo.getRequestUriBuilder
46+
// .replacePath("ui/query.html").replaceQuery(queryId).build()
47+
48+
new QueryResults(queryId, queryHtmlUri, nextUri, nextUri,
49+
convertTColumn(columns),
50+
convertTRowSet(data),
51+
StatementStats.builder.setState(queryStatus.state.name()).build(),
52+
toQueryError(queryStatus), defaultWarning, null, 0L)
53+
}
54+
55+
def convertTColumn(columns: Option[TGetResultSetMetadataResp]): util.List[Column] = {
56+
if (columns.isEmpty) {
57+
return null
58+
}
59+
60+
columns.get.getSchema.getColumns.asScala.map(c => {
61+
val tp = c.getTypeDesc.getTypes.get(0).getPrimitiveEntry.getType.name()
62+
new Column(c.getColumnName, tp, new ClientTypeSignature(tp))
63+
}).toList.asJava
64+
}
65+
66+
def convertTRowSet(data: Option[TRowSet]): util.List[util.List[Object]] = {
67+
if (data.isEmpty) {
68+
return null
69+
}
70+
val rowSet = data.get
71+
var dataSet: Array[scala.List[Object]] = Array()
72+
73+
if (rowSet.getColumns == null) {
74+
return rowSet.getRows.asScala
75+
.map(t => t.getColVals.asScala.map(v => v.getFieldValue.asInstanceOf[Object]).asJava)
76+
.asJava
77+
}
78+
79+
rowSet.getColumns.asScala.foreach {
80+
case tColumn if tColumn.isSetBoolVal =>
81+
val nulls = util.BitSet.valueOf(tColumn.getBoolVal.getNulls)
82+
if (dataSet.isEmpty) {
83+
dataSet = tColumn.getBoolVal.getValues.asScala.zipWithIndex
84+
.foldLeft(Array[scala.List[Object]]()) {
85+
case (acc, x) if nulls.get(x._2) =>
86+
acc ++ List(List(None))
87+
case (acc, x) if !nulls.get(x._2) =>
88+
acc ++ List(List(x._1))
89+
}
90+
} else {
91+
tColumn.getBoolVal.getValues.asScala.zipWithIndex.foreach {
92+
case (_, rowIdx) if nulls.get(rowIdx) =>
93+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(None)
94+
case (v, rowIdx) =>
95+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(v)
96+
}
97+
}
98+
case tColumn if tColumn.isSetByteVal =>
99+
val nulls = util.BitSet.valueOf(tColumn.getByteVal.getNulls)
100+
if (dataSet.isEmpty) {
101+
dataSet = tColumn.getByteVal.getValues.asScala.zipWithIndex
102+
.foldLeft(Array[scala.List[Object]]()) {
103+
case (acc, x) if nulls.get(x._2) =>
104+
acc ++ List(scala.List(None))
105+
case (acc, x) if !nulls.get(x._2) =>
106+
acc ++ List(scala.List(x._1))
107+
}
108+
} else {
109+
tColumn.getByteVal.getValues.asScala.zipWithIndex.foreach {
110+
case (_, rowIdx) if nulls.get(rowIdx) =>
111+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(None)
112+
case (v, rowIdx) =>
113+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(v)
114+
}
115+
}
116+
case tColumn if tColumn.isSetI16Val =>
117+
val nulls = util.BitSet.valueOf(tColumn.getI16Val.getNulls)
118+
if (dataSet.isEmpty) {
119+
dataSet = tColumn.getI16Val.getValues.asScala.zipWithIndex
120+
.foldLeft(Array[scala.List[Object]]()) {
121+
case (acc, x) if nulls.get(x._2) =>
122+
acc ++ List(List(None))
123+
case (acc, x) if !nulls.get(x._2) =>
124+
acc ++ List(List(x._1))
125+
}
126+
} else {
127+
tColumn.getI16Val.getValues.asScala.zipWithIndex.foreach {
128+
case (_, rowIdx) if nulls.get(rowIdx) =>
129+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(None)
130+
case (v, rowIdx) =>
131+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(v)
132+
}
133+
}
134+
case tColumn if tColumn.isSetI32Val =>
135+
val nulls = util.BitSet.valueOf(tColumn.getI32Val.getNulls)
136+
if (dataSet.isEmpty) {
137+
dataSet = tColumn.getI32Val.getValues.asScala.zipWithIndex
138+
.foldLeft(Array[scala.List[Object]]()) {
139+
case (acc, x) if nulls.get(x._2) =>
140+
acc ++ List(List(None))
141+
case (acc, x) if !nulls.get(x._2) =>
142+
acc ++ List(List(x._1))
143+
}
144+
} else {
145+
tColumn.getI32Val.getValues.asScala.zipWithIndex.foreach {
146+
case (_, rowIdx) if nulls.get(rowIdx) =>
147+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(None)
148+
case (v, rowIdx) =>
149+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(v)
150+
}
151+
}
152+
case tColumn if tColumn.isSetI64Val =>
153+
val nulls = util.BitSet.valueOf(tColumn.getI64Val.getNulls)
154+
if (dataSet.isEmpty) {
155+
dataSet = tColumn.getI64Val.getValues.asScala.zipWithIndex
156+
.foldLeft(Array[scala.List[Object]]()) {
157+
case (acc, x) if nulls.get(x._2) =>
158+
acc ++ List(List(None))
159+
case (acc, x) if !nulls.get(x._2) =>
160+
acc ++ List(List(x._1))
161+
}
162+
} else {
163+
tColumn.getI64Val.getValues.asScala.zipWithIndex.foreach {
164+
case (_, rowIdx) if nulls.get(rowIdx) =>
165+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(None)
166+
case (v, rowIdx) =>
167+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(v)
168+
}
169+
}
170+
case tColumn if tColumn.isSetDoubleVal =>
171+
val nulls = util.BitSet.valueOf(tColumn.getDoubleVal.getNulls)
172+
if (dataSet.isEmpty) {
173+
dataSet = tColumn.getDoubleVal.getValues.asScala.zipWithIndex
174+
.foldLeft(Array[scala.List[Object]]()) {
175+
case (acc, x) if nulls.get(x._2) =>
176+
acc ++ List(List(None))
177+
case (acc, x) if !nulls.get(x._2) =>
178+
acc ++ List(List(x._1))
179+
}
180+
} else {
181+
tColumn.getDoubleVal.getValues.asScala.zipWithIndex.foreach {
182+
case (_, rowIdx) if nulls.get(rowIdx) =>
183+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(None)
184+
case (v, rowIdx) =>
185+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(v)
186+
}
187+
}
188+
case tColumn if tColumn.isSetBinaryVal =>
189+
val nulls = util.BitSet.valueOf(tColumn.getBinaryVal.getNulls)
190+
if (dataSet.isEmpty) {
191+
dataSet = tColumn.getBinaryVal.getValues.asScala.zipWithIndex
192+
.foldLeft(Array[scala.List[Object]]()) {
193+
case (acc, x) if nulls.get(x._2) =>
194+
acc ++ List(List(None))
195+
case (acc, x) if !nulls.get(x._2) =>
196+
acc ++ List(List(x._1))
197+
}
198+
} else {
199+
tColumn.getBinaryVal.getValues.asScala.zipWithIndex.foreach {
200+
case (_, rowIdx) if nulls.get(rowIdx) =>
201+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(None)
202+
case (v, rowIdx) =>
203+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(v)
204+
}
205+
}
206+
case tColumn =>
207+
val nulls = util.BitSet.valueOf(tColumn.getStringVal.getNulls)
208+
if (dataSet.isEmpty) {
209+
dataSet = tColumn.getStringVal.getValues.asScala.zipWithIndex
210+
.foldLeft(Array[scala.List[Object]]()) {
211+
case (acc, x) if nulls.get(x._2) =>
212+
acc ++ List(List(None))
213+
case (acc, x) if !nulls.get(x._2) =>
214+
acc ++ List(List(x._1))
215+
}
216+
} else {
217+
tColumn.getStringVal.getValues.asScala.zipWithIndex.foreach {
218+
case (_, rowIdx) if nulls.get(rowIdx) =>
219+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(None)
220+
case (v, rowIdx) =>
221+
dataSet(rowIdx) = dataSet(rowIdx) ++ List(v)
222+
}
223+
}
224+
}
225+
dataSet.toList.map(_.asJava).asJava
226+
}
227+
228+
def toQueryError(queryStatus: OperationStatus): QueryError = {
229+
val exception = queryStatus.exception
230+
if (exception.isEmpty) {
231+
null
232+
} else {
233+
new QueryError(exception.get.getMessage, queryStatus.state.name(),
234+
GENERIC_INTERNAL_ERROR_CODE, GENERIC_INTERNAL_ERROR_NAME,
235+
GENERIC_INTERNAL_ERROR_TYPE,
236+
null, null)
237+
}
238+
}
239+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.kyuubi.server.trino.api.v1
18+
19+
import java.net.URI
20+
import javax.ws.rs.core.MediaType
21+
22+
import scala.collection.JavaConverters._
23+
24+
import org.apache.hive.service.rpc.thrift.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V2
25+
import org.scalatest.concurrent.PatienceConfiguration.Timeout
26+
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
27+
28+
import org.apache.kyuubi.{KyuubiFunSuite, RestFrontendTestHelper}
29+
import org.apache.kyuubi.events.KyuubiOperationEvent
30+
import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle}
31+
import org.apache.kyuubi.operation.OperationState.{FINISHED, OperationState}
32+
import org.apache.kyuubi.server.trino.api.KyuubiTrinoQueryResultAdapt
33+
34+
class KyuubiTrinoQueryResultAdaptSuite extends KyuubiFunSuite with RestFrontendTestHelper {
35+
36+
test("test convert") {
37+
val opHandle = getOpHandle("select 1")
38+
val opHandleStr = opHandle.identifier.toString
39+
checkOpState(opHandleStr, FINISHED)
40+
41+
val metadataResp = fe.be.getResultSetMetadata(opHandle)
42+
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
43+
val status = fe.be.getOperationStatus(opHandle)
44+
45+
val uri = new URI("sfdsfsdfdsf")
46+
val results = KyuubiTrinoQueryResultAdapt
47+
.createQueryResults("/xdfd/xdf", uri, uri, status, Option(metadataResp), Option(tRowSet))
48+
49+
print(results.toString)
50+
assert(results.getColumns.get(0).getType.equals("INT_TYPE"))
51+
assert(results.getData.asScala.last.get(0).toString.equals("TI32Value(value:1)"))
52+
}
53+
54+
def getOpHandleStr(statement: String = "show tables"): String = {
55+
getOpHandle(statement).identifier.toString
56+
}
57+
58+
def getOpHandle(statement: String = "show tables"): OperationHandle = {
59+
val sessionHandle = fe.be.openSession(
60+
HIVE_CLI_SERVICE_PROTOCOL_V2,
61+
"admin",
62+
"123456",
63+
"localhost",
64+
Map("testConfig" -> "testValue"))
65+
66+
if (statement.nonEmpty) {
67+
fe.be.executeStatement(sessionHandle, statement, Map.empty, runAsync = true, 3000)
68+
} else {
69+
fe.be.getCatalogs(sessionHandle)
70+
}
71+
}
72+
73+
private def checkOpState(opHandleStr: String, state: OperationState): Unit = {
74+
eventually(Timeout(5.seconds)) {
75+
val response = webTarget.path(s"api/v1/operations/$opHandleStr/event")
76+
.request(MediaType.APPLICATION_JSON_TYPE).get()
77+
assert(response.getStatus === 200)
78+
val operationEvent = response.readEntity(classOf[KyuubiOperationEvent])
79+
assert(operationEvent.state === state.name())
80+
}
81+
}
82+
}

0 commit comments

Comments
 (0)