Skip to content

Commit

Permalink
[KYUUBI apache#4554] [CHAT] Code improvement in ChatGPTProvider
Browse files Browse the repository at this point in the history
### _Why are the changes needed?_

- set authentication as default header in client construction instead of  request construction
- handle response's status code in scala style
- transforming config's long value to int with `.intValue` instead of `asInstanceOf` casting
- fix var name to `response`

### _How was this patch tested?_
- [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [ ] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes apache#4554 from bowenliang123/chatgpt-http.

Closes apache#4554

114484a [liangbowen] httpclient improvement in ChatGPTProvider

Authored-by: liangbowen <liangbowen@gf.com.cn>
Signed-off-by: Cheng Pan <chengpan@apache.org>
  • Loading branch information
bowenliang123 authored and pan3793 committed Mar 18, 2023
1 parent 6ded079 commit 067c601
Showing 1 changed file with 24 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ package org.apache.kyuubi.engine.chat.provider
import java.util
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._

import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import org.apache.http.{HttpHost, HttpStatus}
import org.apache.http.{HttpHeaders, HttpHost, HttpStatus}
import org.apache.http.client.config.RequestConfig
import org.apache.http.client.methods.HttpPost
import org.apache.http.entity.{ContentType, StringEntity}
import org.apache.http.impl.client.{CloseableHttpClient, HttpClientBuilder}
import org.apache.http.message.BasicHeader
import org.apache.http.util.EntityUtils

import org.apache.kyuubi.config.KyuubiConf
Expand All @@ -39,11 +42,16 @@ class ChatGPTProvider(conf: KyuubiConf) extends ChatProvider {
s"which could be got at https://platform.openai.com/account/api-keys")
}

private val httpClient: CloseableHttpClient = HttpClientBuilder.create().build()
private val httpClient: CloseableHttpClient = {
HttpClientBuilder.create()
.setDefaultHeaders(List(
new BasicHeader(HttpHeaders.AUTHORIZATION, s"Bearer $gptApiKey")).asJava)
.build()
}

private val requestConfig = {
val connectTimeout = conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT).asInstanceOf[Int]
val socketTimeout = conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT).asInstanceOf[Int]
private val requestConfig: RequestConfig = {
val connectTimeout = conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT).intValue()
val socketTimeout = conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT).intValue()
val builder: RequestConfig.Builder = RequestConfig.custom()
.setConnectTimeout(connectTimeout)
.setSocketTimeout(socketTimeout)
Expand All @@ -70,8 +78,6 @@ class ChatGPTProvider(conf: KyuubiConf) extends ChatProvider {
messages.addLast(Message("user", q))

val request = new HttpPost("https://api.openai.com/v1/chat/completions")
request.addHeader("Authorization", "Bearer " + gptApiKey)

val req = Map(
"messages" -> messages,
"model" -> "gpt-3.5-turbo",
Expand All @@ -81,17 +87,17 @@ class ChatGPTProvider(conf: KyuubiConf) extends ChatProvider {
val entity = new StringEntity(mapper.writeValueAsString(req), ContentType.APPLICATION_JSON)
request.setEntity(entity)
request.setConfig(requestConfig)
val responseEntity = httpClient.execute(request)
val respJson = mapper.readTree(EntityUtils.toString(responseEntity.getEntity))
val statusCode = responseEntity.getStatusLine.getStatusCode
if (responseEntity.getStatusLine.getStatusCode == HttpStatus.SC_OK) {
val replyMessage = mapper.treeToValue[Message](
respJson.get("choices").get(0).get("message"))
messages.addLast(replyMessage)
replyMessage.content
} else {
messages.removeLast()
s"Chat failed. Status: $statusCode. ${respJson.get("error").get("message").asText}"
val response = httpClient.execute(request)
val respJson = mapper.readTree(EntityUtils.toString(response.getEntity))
response.getStatusLine.getStatusCode match {
case HttpStatus.SC_OK =>
val replyMessage = mapper.treeToValue[Message](
respJson.get("choices").get(0).get("message"))
messages.addLast(replyMessage)
replyMessage.content
case errorStatusCode =>
messages.removeLast()
s"Chat failed. Status: $errorStatusCode. ${respJson.get("error").get("message").asText}"
}
}

Expand Down

0 comments on commit 067c601

Please sign in to comment.