From 067c6010a668c1ed9e83a991e589fb886304c959 Mon Sep 17 00:00:00 2001 From: liangbowen Date: Sat, 18 Mar 2023 15:34:16 +0800 Subject: [PATCH] [KYUUBI #4554] [CHAT] Code improvement in ChatGPTProvider ### _Why are the changes needed?_ - set authentication as default header in client construction instead of request construction - handle response's status code in scala style - transforming config's long value to int with `.intValue` instead of `asInstanceOf` casting - fix var name to `response` ### _How was this patch tested?_ - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [ ] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request Closes #4554 from bowenliang123/chatgpt-http. Closes #4554 114484a4d [liangbowen] httpclient improvement in ChatGPTProvider Authored-by: liangbowen Signed-off-by: Cheng Pan --- .../chat/provider/ChatGPTProvider.scala | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala index 28ef36bb4dd..a4cdb7c948c 100644 --- a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala +++ b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/provider/ChatGPTProvider.scala @@ -20,12 +20,15 @@ package org.apache.kyuubi.engine.chat.provider import java.util import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ + import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} -import org.apache.http.{HttpHost, HttpStatus} +import org.apache.http.{HttpHeaders, HttpHost, HttpStatus} import org.apache.http.client.config.RequestConfig import org.apache.http.client.methods.HttpPost import org.apache.http.entity.{ContentType, StringEntity} import org.apache.http.impl.client.{CloseableHttpClient, HttpClientBuilder} +import org.apache.http.message.BasicHeader import org.apache.http.util.EntityUtils import org.apache.kyuubi.config.KyuubiConf @@ -39,11 +42,16 @@ class ChatGPTProvider(conf: KyuubiConf) extends ChatProvider { s"which could be got at https://platform.openai.com/account/api-keys") } - private val httpClient: CloseableHttpClient = HttpClientBuilder.create().build() + private val httpClient: CloseableHttpClient = { + HttpClientBuilder.create() + .setDefaultHeaders(List( + new BasicHeader(HttpHeaders.AUTHORIZATION, s"Bearer $gptApiKey")).asJava) + .build() + } - private val requestConfig = { - val connectTimeout = conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT).asInstanceOf[Int] - val socketTimeout = conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT).asInstanceOf[Int] + private val requestConfig: RequestConfig = { + val connectTimeout = conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_CONNECT_TIMEOUT).intValue() + val socketTimeout = conf.get(KyuubiConf.ENGINE_CHAT_GPT_HTTP_SOCKET_TIMEOUT).intValue() val builder: RequestConfig.Builder = RequestConfig.custom() .setConnectTimeout(connectTimeout) .setSocketTimeout(socketTimeout) @@ -70,8 +78,6 @@ class ChatGPTProvider(conf: KyuubiConf) extends ChatProvider { messages.addLast(Message("user", q)) val request = new HttpPost("https://api.openai.com/v1/chat/completions") - request.addHeader("Authorization", "Bearer " + gptApiKey) - val req = Map( "messages" -> messages, "model" -> "gpt-3.5-turbo", @@ -81,17 +87,17 @@ class ChatGPTProvider(conf: KyuubiConf) extends ChatProvider { val entity = new StringEntity(mapper.writeValueAsString(req), ContentType.APPLICATION_JSON) request.setEntity(entity) request.setConfig(requestConfig) - val responseEntity = httpClient.execute(request) - val respJson = mapper.readTree(EntityUtils.toString(responseEntity.getEntity)) - val statusCode = responseEntity.getStatusLine.getStatusCode - if (responseEntity.getStatusLine.getStatusCode == HttpStatus.SC_OK) { - val replyMessage = mapper.treeToValue[Message]( - respJson.get("choices").get(0).get("message")) - messages.addLast(replyMessage) - replyMessage.content - } else { - messages.removeLast() - s"Chat failed. Status: $statusCode. ${respJson.get("error").get("message").asText}" + val response = httpClient.execute(request) + val respJson = mapper.readTree(EntityUtils.toString(response.getEntity)) + response.getStatusLine.getStatusCode match { + case HttpStatus.SC_OK => + val replyMessage = mapper.treeToValue[Message]( + respJson.get("choices").get(0).get("message")) + messages.addLast(replyMessage) + replyMessage.content + case errorStatusCode => + messages.removeLast() + s"Chat failed. Status: $errorStatusCode. ${respJson.get("error").get("message").asText}" } }