From 8f721d67a5f5cb3caaff20957241035f659719cb Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 23 Jul 2023 00:32:47 +0800 Subject: [PATCH] feat: support Google PaLM2 (close #105) --- README.md | 1 + common/model-ratio.go | 1 + controller/model.go | 9 ++ controller/relay-palm.go | 211 +++++++++++++++++++++---- controller/relay-text.go | 30 ++++ web/src/constants/channel.constants.js | 1 + 6 files changed, 221 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 2292311091d32..dc69b10ccd2fd 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用 + [x] OpenAI 官方通道(支持配置镜像) + [x] **Azure OpenAI API** + [x] [Anthropic Claude 系列模型](https://anthropic.com) + + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) + [x] [OpenAI-SB](https://openai-sb.com) diff --git a/common/model-ratio.go b/common/model-ratio.go index 8f034ec61459a..cc70fb22011d0 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -41,6 +41,7 @@ var ModelRatio = map[string]float64{ "claude-2": 30, "ERNIE-Bot": 1, // 0.012元/千tokens "ERNIE-Bot-turbo": 0.67, // 0.008元/千tokens + "PaLM-2": 1, } func ModelRatio2JSONString() string { diff --git a/controller/model.go b/controller/model.go index cfcb8d8744303..273b2c2209721 100644 --- a/controller/model.go +++ b/controller/model.go @@ -306,6 +306,15 @@ func init() { Root: "ERNIE-Bot-turbo", Parent: nil, }, + { + Id: "PaLM-2", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "PaLM-2", + Parent: nil, + }, } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { diff --git a/controller/relay-palm.go b/controller/relay-palm.go index ae739ca0c6398..d9a8249802572 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -1,10 +1,17 @@ package controller import ( + "encoding/json" "fmt" "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" ) +// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body +// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body + type PaLMChatMessage struct { Author string `json:"author"` Content string `json:"content"` @@ -15,45 +22,185 @@ type PaLMFilter struct { Message string `json:"message"` } -// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body +type PaLMPrompt struct { + Messages []PaLMChatMessage `json:"messages"` +} + type PaLMChatRequest struct { - Prompt []Message `json:"prompt"` - Temperature float64 `json:"temperature"` - CandidateCount int `json:"candidateCount"` - TopP float64 `json:"topP"` - TopK int `json:"topK"` + Prompt PaLMPrompt `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` +} + +type PaLMError struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` } -// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body type PaLMChatResponse struct { - Candidates []Message `json:"candidates"` - Messages []Message `json:"messages"` - Filters []PaLMFilter `json:"filters"` + Candidates []PaLMChatMessage `json:"candidates"` + Messages []Message `json:"messages"` + Filters []PaLMFilter `json:"filters"` + Error PaLMError `json:"error"` } -func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode { - // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage - messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages)) - for _, message := range openAIRequest.Messages { - var author string +func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { + palmRequest := PaLMChatRequest{ + Prompt: PaLMPrompt{ + Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), + }, + Temperature: textRequest.Temperature, + CandidateCount: textRequest.N, + TopP: textRequest.TopP, + TopK: textRequest.MaxTokens, + } + for _, message := range textRequest.Messages { + palmMessage := PaLMChatMessage{ + Content: message.Content, + } if message.Role == "user" { - author = "0" + palmMessage.Author = "0" } else { - author = "1" + palmMessage.Author = "1" } - messages = append(messages, PaLMChatMessage{ - Author: author, - Content: message.Content, - }) - } - request := PaLMChatRequest{ - Prompt: nil, - Temperature: openAIRequest.Temperature, - CandidateCount: openAIRequest.N, - TopP: openAIRequest.TopP, - TopK: openAIRequest.MaxTokens, - } - // TODO: forward request to PaLM & convert response - fmt.Print(request) - return nil + palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) + } + return &palmRequest +} + +func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := OpenAITextResponseChoice{ + Index: i, + Message: Message{ + Role: "assistant", + Content: candidate.Content, + }, + FinishReason: "stop", + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + if len(palmResponse.Candidates) > 0 { + choice.Delta.Content = palmResponse.Candidates[0].Content + } + choice.FinishReason = "stop" + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "palm2" + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysError("error reading stream response: " + err.Error()) + stopChan <- true + return + } + err = resp.Body.Close() + if err != nil { + common.SysError("error closing stream response: " + err.Error()) + stopChan <- true + return + } + var palmResponse PaLMChatResponse + err = json.Unmarshal(responseBody, &palmResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + stopChan <- true + return + } + fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + stopChan <- true + return + } + dataChan <- string(jsonResponse) + stopChan <- true + }() + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + c.Render(-1, common.CustomEvent{Data: "data: " + data}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var palmResponse PaLMChatResponse + err = json.Unmarshal(responseBody, &palmResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: palmResponse.Error.Message, + Type: palmResponse.Error.Status, + Param: "", + Code: palmResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responsePaLM2OpenAI(&palmResponse) + completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage } diff --git a/controller/relay-text.go b/controller/relay-text.go index 0e7893a658d77..1fea959e25a3c 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -82,6 +82,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeClaude } else if strings.HasPrefix(textRequest.Model, "ERNIE") { apiType = APITypeBaidu + } else if strings.HasPrefix(textRequest.Model, "PaLM") { + apiType = APITypePaLM } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -127,6 +129,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days + case APITypePaLM: + fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + fullRequestURL += "?key=" + apiKey } var promptTokens int var completionTokens int @@ -186,6 +193,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypePaLM: + palmRequest := requestOpenAI2PaLM(textRequest) + jsonStr, err := json.Marshal(palmRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -323,6 +337,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { textResponse.Usage = *usage return nil } + case APITypePaLM: + if textRequest.Stream { // PaLM2 API does not support stream + err, responseText := palmStreamHandler(c, resp) + if err != nil { + return err + } + streamResponseText = responseText + return nil + } else { + err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + textResponse.Usage = *usage + return nil + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 28d20405021f4..2ff60fedd6e03 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [ { key: 14, text: 'Anthropic', value: 14, color: 'black' }, { key: 8, text: '自定义', value: 8, color: 'pink' }, { key: 3, text: 'Azure', value: 3, color: 'olive' }, + { key: 11, text: 'PaLM', value: 11, color: 'orange' }, { key: 15, text: 'Baidu', value: 15, color: 'blue' }, { key: 2, text: 'API2D', value: 2, color: 'blue' }, { key: 4, text: 'CloseAI', value: 4, color: 'teal' },