diff --git a/docs/docs.go b/docs/docs.go index 65fe3ef2..2630f3e4 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -98,7 +98,7 @@ const docTemplate = `{ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/schemas.UnifiedChatRequest" + "$ref": "#/definitions/schemas.ChatRequest" } } ], @@ -106,7 +106,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/schemas.UnifiedChatResponse" + "$ref": "#/definitions/schemas.ChatResponse" } }, "400": { @@ -676,49 +676,7 @@ const docTemplate = `{ } } }, - "schemas.OverrideChatRequest": { - "type": "object", - "properties": { - "message": { - "$ref": "#/definitions/schemas.ChatMessage" - }, - "model_id": { - "type": "string" - } - } - }, - "schemas.ProviderResponse": { - "type": "object", - "properties": { - "message": { - "$ref": "#/definitions/schemas.ChatMessage" - }, - "responseId": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "tokenCount": { - "$ref": "#/definitions/schemas.TokenUsage" - } - } - }, - "schemas.TokenUsage": { - "type": "object", - "properties": { - "promptTokens": { - "type": "number" - }, - "responseTokens": { - "type": "number" - }, - "totalTokens": { - "type": "number" - } - } - }, - "schemas.UnifiedChatRequest": { + "schemas.ChatRequest": { "type": "object", "properties": { "message": { @@ -735,7 +693,7 @@ const docTemplate = `{ } } }, - "schemas.UnifiedChatResponse": { + "schemas.ChatResponse": { "type": "object", "properties": { "cached": { @@ -763,6 +721,48 @@ const docTemplate = `{ "type": "string" } } + }, + "schemas.OverrideChatRequest": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "model_id": { + "type": "string" + } + } + }, + "schemas.ProviderResponse": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "responseId": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "tokenCount": { + "$ref": "#/definitions/schemas.TokenUsage" + } + } + }, + "schemas.TokenUsage": { + "type": "object", + "properties": { + "promptTokens": { + "type": "number" + }, + "responseTokens": { + "type": "number" + }, + "totalTokens": { + "type": "number" + } + } } }, "externalDocs": { diff --git a/docs/swagger.json b/docs/swagger.json index 9ee643b1..a90bce42 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -95,7 +95,7 @@ "in": "body", "required": true, "schema": { - "$ref": "#/definitions/schemas.UnifiedChatRequest" + "$ref": "#/definitions/schemas.ChatRequest" } } ], @@ -103,7 +103,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/schemas.UnifiedChatResponse" + "$ref": "#/definitions/schemas.ChatResponse" } }, "400": { @@ -673,49 +673,7 @@ } } }, - "schemas.OverrideChatRequest": { - "type": "object", - "properties": { - "message": { - "$ref": "#/definitions/schemas.ChatMessage" - }, - "model_id": { - "type": "string" - } - } - }, - "schemas.ProviderResponse": { - "type": "object", - "properties": { - "message": { - "$ref": "#/definitions/schemas.ChatMessage" - }, - "responseId": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, - "tokenCount": { - "$ref": "#/definitions/schemas.TokenUsage" - } - } - }, - "schemas.TokenUsage": { - "type": "object", - "properties": { - "promptTokens": { - "type": "number" - }, - "responseTokens": { - "type": "number" - }, - "totalTokens": { - "type": "number" - } - } - }, - "schemas.UnifiedChatRequest": { + "schemas.ChatRequest": { "type": "object", "properties": { "message": { @@ -732,7 +690,7 @@ } } }, - "schemas.UnifiedChatResponse": { + "schemas.ChatResponse": { "type": "object", "properties": { "cached": { @@ -760,6 +718,48 @@ "type": "string" } } + }, + "schemas.OverrideChatRequest": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "model_id": { + "type": "string" + } + } + }, + "schemas.ProviderResponse": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "responseId": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "tokenCount": { + "$ref": "#/definitions/schemas.TokenUsage" + } + } + }, + "schemas.TokenUsage": { + "type": "object", + "properties": { + "promptTokens": { + "type": "number" + }, + "responseTokens": { + "type": "number" + }, + "totalTokens": { + "type": "number" + } + } } }, "externalDocs": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 1baf24a6..6ff19b63 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -375,34 +375,7 @@ definitions: or assistant. type: string type: object - schemas.OverrideChatRequest: - properties: - message: - $ref: '#/definitions/schemas.ChatMessage' - model_id: - type: string - type: object - schemas.ProviderResponse: - properties: - message: - $ref: '#/definitions/schemas.ChatMessage' - responseId: - additionalProperties: - type: string - type: object - tokenCount: - $ref: '#/definitions/schemas.TokenUsage' - type: object - schemas.TokenUsage: - properties: - promptTokens: - type: number - responseTokens: - type: number - totalTokens: - type: number - type: object - schemas.UnifiedChatRequest: + schemas.ChatRequest: properties: message: $ref: '#/definitions/schemas.ChatMessage' @@ -413,7 +386,7 @@ definitions: override: $ref: '#/definitions/schemas.OverrideChatRequest' type: object - schemas.UnifiedChatResponse: + schemas.ChatResponse: properties: cached: type: boolean @@ -432,6 +405,33 @@ definitions: router: type: string type: object + schemas.OverrideChatRequest: + properties: + message: + $ref: '#/definitions/schemas.ChatMessage' + model_id: + type: string + type: object + schemas.ProviderResponse: + properties: + message: + $ref: '#/definitions/schemas.ChatMessage' + responseId: + additionalProperties: + type: string + type: object + tokenCount: + $ref: '#/definitions/schemas.TokenUsage' + type: object + schemas.TokenUsage: + properties: + promptTokens: + type: number + responseTokens: + type: number + totalTokens: + type: number + type: object externalDocs: description: Documentation url: https://glide.einstack.ai/ @@ -497,14 +497,14 @@ paths: name: payload required: true schema: - $ref: '#/definitions/schemas.UnifiedChatRequest' + $ref: '#/definitions/schemas.ChatRequest' produces: - application/json responses: "200": description: OK schema: - $ref: '#/definitions/schemas.UnifiedChatResponse' + $ref: '#/definitions/schemas.ChatResponse' "400": description: Bad Request schema: diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index 611bdedb..c97f542b 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -21,17 +21,17 @@ type Handler = func(c *fiber.Ctx) error // @Description Talk to different LLMs Chat API via unified endpoint // @tags Language // @Param router path string true "Router ID" -// @Param payload body schemas.UnifiedChatRequest true "Request Data" +// @Param payload body schemas.ChatRequest true "Request Data" // @Accept json // @Produce json -// @Success 200 {object} schemas.UnifiedChatResponse +// @Success 200 {object} schemas.ChatResponse // @Failure 400 {object} http.ErrorSchema // @Failure 404 {object} http.ErrorSchema // @Router /v1/language/{router}/chat [POST] func LangChatHandler(routerManager *routers.RouterManager) Handler { return func(c *fiber.Ctx) error { // Unmarshal request body - var req *schemas.UnifiedChatRequest + var req *schemas.ChatRequest err := c.BodyParser(&req) if err != nil { diff --git a/pkg/api/schemas/language.go b/pkg/api/schemas/language.go index 8dc1b7d3..7e2a2cdc 100644 --- a/pkg/api/schemas/language.go +++ b/pkg/api/schemas/language.go @@ -1,7 +1,7 @@ package schemas -// UnifiedChatRequest defines Glide's Chat Request Schema unified across all language models -type UnifiedChatRequest struct { +// ChatRequest defines Glide's Chat Request Schema unified across all language models +type ChatRequest struct { Message ChatMessage `json:"message"` MessageHistory []ChatMessage `json:"messageHistory"` Override OverrideChatRequest `json:"override,omitempty"` @@ -12,8 +12,8 @@ type OverrideChatRequest struct { Message ChatMessage `json:"message"` } -func NewChatFromStr(message string) *UnifiedChatRequest { - return &UnifiedChatRequest{ +func NewChatFromStr(message string) *ChatRequest { + return &ChatRequest{ Message: ChatMessage{ "human", message, @@ -22,8 +22,8 @@ func NewChatFromStr(message string) *UnifiedChatRequest { } } -// UnifiedChatResponse defines Glide's Chat Response Schema unified across all language models -type UnifiedChatResponse struct { +// ChatResponse defines Glide's Chat Response Schema unified across all language models +type ChatResponse struct { ID string `json:"id,omitempty"` Created int `json:"created,omitempty"` Provider string `json:"provider,omitempty"` @@ -58,120 +58,3 @@ type ChatMessage struct { // with a maximum length of 64 characters. Name string `json:"name,omitempty"` } - -// OpenAI Chat Response (also used by Azure OpenAI and OctoML) -// TODO: Should this live here? -type OpenAIChatCompletion struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []Choice `json:"choices"` - Usage Usage `json:"usage"` -} - -type Choice struct { - Index int `json:"index"` - Message ChatMessage `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` -} - -type Usage struct { - PromptTokens float64 `json:"prompt_tokens"` - CompletionTokens float64 `json:"completion_tokens"` - TotalTokens float64 `json:"total_tokens"` -} - -// Cohere Chat Response -type CohereChatCompletion struct { - Text string `json:"text"` - GenerationID string `json:"generation_id"` - ResponseID string `json:"response_id"` - TokenCount CohereTokenCount `json:"token_count"` - Citations []Citation `json:"citations"` - Documents []Documents `json:"documents"` - SearchQueries []SearchQuery `json:"search_queries"` - SearchResults []SearchResults `json:"search_results"` - Meta Meta `json:"meta"` - ToolInputs map[string]interface{} `json:"tool_inputs"` -} - -type CohereTokenCount struct { - PromptTokens float64 `json:"prompt_tokens"` - ResponseTokens float64 `json:"response_tokens"` - TotalTokens float64 `json:"total_tokens"` - BilledTokens float64 `json:"billed_tokens"` -} - -type Meta struct { - APIVersion struct { - Version string `json:"version"` - } `json:"api_version"` - BilledUnits struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - } `json:"billed_units"` -} - -type Citation struct { - Start int `json:"start"` - End int `json:"end"` - Text string `json:"text"` - DocumentID []string `json:"document_id"` -} - -type Documents struct { - ID string `json:"id"` - Data map[string]string `json:"data"` // TODO: This needs to be updated -} - -type SearchQuery struct { - Text string `json:"text"` - GenerationID string `json:"generation_id"` -} - -type SearchResults struct { - SearchQuery []SearchQueryObject `json:"search_query"` - Connectors []ConnectorsResponse `json:"connectors"` - DocumentID []string `json:"documentId"` -} - -type SearchQueryObject struct { - Text string `json:"text"` - GenerationID string `json:"generationId"` -} - -type ConnectorsResponse struct { - ID string `json:"id"` - UserAccessToken string `json:"user_access_token"` - ContOnFail string `json:"continue_on_failure"` - Options map[string]string `json:"options"` -} - -// Anthropic Chat Response -type AnthropicChatCompletion struct { - ID string `json:"id"` - Type string `json:"type"` - Model string `json:"model"` - Role string `json:"role"` - Content []Content `json:"content"` - StopReason string `json:"stop_reason"` - StopSequence string `json:"stop_sequence"` -} - -type Content struct { - Type string `json:"type"` - Text string `json:"text"` -} - -// Bedrock Chat Response -type BedrockChatCompletion struct { - InputTextTokenCount int `json:"inputTextTokenCount"` - Results []struct { - TokenCount int `json:"tokenCount"` - OutputText string `json:"outputText"` - CompletionReason string `json:"completionReason"` - } `json:"results"` -} diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go index b525bcb9..5a8d8ee3 100644 --- a/pkg/providers/anthropic/chat.go +++ b/pkg/providers/anthropic/chat.go @@ -49,7 +49,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) // Add items from messageHistory first and the new chat message last @@ -63,7 +63,7 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []Ch } // Chat sends a chat request to the specified anthropic model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -79,7 +79,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) @@ -87,7 +87,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -154,7 +154,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var anthropicCompletion schemas.AnthropicChatCompletion + var anthropicCompletion ChatCompletion err = json.Unmarshal(bodyBytes, &anthropicCompletion) if err != nil { @@ -162,8 +162,8 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, err } - // Map response to UnifiedChatResponse schema - response := schemas.UnifiedChatResponse{ + // Map response to ChatResponse schema + response := schemas.ChatResponse{ ID: anthropicCompletion.ID, Created: int(time.Now().UTC().Unix()), // not provided by anthropic Provider: providerName, diff --git a/pkg/providers/anthropic/client_test.go b/pkg/providers/anthropic/client_test.go index c8927a37..9c301365 100644 --- a/pkg/providers/anthropic/client_test.go +++ b/pkg/providers/anthropic/client_test.go @@ -56,7 +56,7 @@ func TestAnthropicClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "human", Content: "What's the biggest animal?", }} @@ -86,7 +86,7 @@ func TestAnthropicClient_BadChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "human", Content: "What's the biggest animal?", }} diff --git a/pkg/providers/anthropic/schamas.go b/pkg/providers/anthropic/schamas.go new file mode 100644 index 00000000..69b00248 --- /dev/null +++ b/pkg/providers/anthropic/schamas.go @@ -0,0 +1,17 @@ +package anthropic + +// Anthropic Chat Response +type ChatCompletion struct { + ID string `json:"id"` + Type string `json:"type"` + Model string `json:"model"` + Role string `json:"role"` + Content []Content `json:"content"` + StopReason string `json:"stop_reason"` + StopSequence string `json:"stop_sequence"` +} + +type Content struct { + Type string `json:"type"` + Text string `json:"text"` +} diff --git a/pkg/providers/azureopenai/chat.go b/pkg/providers/azureopenai/chat.go index 6fda0305..f961587c 100644 --- a/pkg/providers/azureopenai/chat.go +++ b/pkg/providers/azureopenai/chat.go @@ -9,9 +9,11 @@ import ( "net/http" "time" + "glide/pkg/api/schemas" + "glide/pkg/providers/openai" + "glide/pkg/providers/clients" - "glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -59,7 +61,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) // Add items from messageHistory first and the new chat message last @@ -73,7 +75,7 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []Ch } // Chat sends a chat request to the specified azure openai model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -89,7 +91,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) @@ -97,7 +99,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -164,7 +166,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var openAICompletion schemas.OpenAIChatCompletion + var openAICompletion openai.ChatCompletion err = json.Unmarshal(bodyBytes, &openAICompletion) if err != nil { @@ -175,7 +177,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche openAICompletion.SystemFingerprint = "" // Azure OpenAI doesn't return this // Map response to UnifiedChatResponse schema - response := schemas.UnifiedChatResponse{ + response := schemas.ChatResponse{ ID: openAICompletion.ID, Created: openAICompletion.Created, Provider: providerName, diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/providers/azureopenai/client_test.go index 62080029..8f5de037 100644 --- a/pkg/providers/azureopenai/client_test.go +++ b/pkg/providers/azureopenai/client_test.go @@ -55,7 +55,7 @@ func TestAzureOpenAIClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "user", Content: "What's the biggest animal?", }} @@ -88,7 +88,7 @@ func TestAzureOpenAIClient_ChatError(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "user", Content: "What's the biggest animal?", }} diff --git a/pkg/providers/bedrock/chat.go b/pkg/providers/bedrock/chat.go index 41eb604f..14feb9bc 100644 --- a/pkg/providers/bedrock/chat.go +++ b/pkg/providers/bedrock/chat.go @@ -46,7 +46,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) string { +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) string { // message history not yet supported for AWS models message := fmt.Sprintf("Role: %s, Content: %s", request.Message.Role, request.Message.Content) @@ -54,7 +54,7 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) stri } // Chat sends a chat request to the specified bedrock model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -70,7 +70,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) @@ -78,7 +78,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { rawPayload, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("unable to marshal chat request payload: %w", err) @@ -94,7 +94,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, err } - var bedrockCompletion schemas.BedrockChatCompletion + var bedrockCompletion ChatCompletion err = json.Unmarshal(result.Body, &bedrockCompletion) if err != nil { @@ -102,7 +102,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, err } - response := schemas.UnifiedChatResponse{ + response := schemas.ChatResponse{ ID: uuid.NewString(), Created: int(time.Now().Unix()), Provider: "aws-bedrock", diff --git a/pkg/providers/bedrock/client_test.go b/pkg/providers/bedrock/client_test.go index f261841b..bcbd0fa1 100644 --- a/pkg/providers/bedrock/client_test.go +++ b/pkg/providers/bedrock/client_test.go @@ -61,7 +61,7 @@ func TestBedrockClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "user", Content: "What's the biggest animal?", }} diff --git a/pkg/providers/bedrock/schemas.go b/pkg/providers/bedrock/schemas.go new file mode 100644 index 00000000..ac03de8e --- /dev/null +++ b/pkg/providers/bedrock/schemas.go @@ -0,0 +1,11 @@ +package bedrock + +// Bedrock Chat Response +type ChatCompletion struct { + InputTextTokenCount int `json:"inputTextTokenCount"` + Results []struct { + TokenCount int `json:"tokenCount"` + OutputText string `json:"outputText"` + CompletionReason string `json:"completionReason"` + } `json:"results"` +} diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index 28712887..165b67bd 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -65,7 +65,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified cohere model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -81,7 +81,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Message = request.Message.Content @@ -103,7 +103,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -170,7 +170,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var cohereCompletion schemas.CohereChatCompletion + var cohereCompletion ChatCompletion err = json.Unmarshal(bodyBytes, &cohereCompletion) if err != nil { @@ -178,8 +178,8 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, err } - // Map response to UnifiedChatResponse schema - response := schemas.UnifiedChatResponse{ + // Map response to ChatResponse schema + response := schemas.ChatResponse{ ID: cohereCompletion.ResponseID, Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this Provider: providerName, @@ -206,7 +206,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return &response, nil } -func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.UnifiedChatResponse, error) { +func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.ChatResponse, error) { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) diff --git a/pkg/providers/cohere/client_test.go b/pkg/providers/cohere/client_test.go index 7828aa37..439e44d6 100644 --- a/pkg/providers/cohere/client_test.go +++ b/pkg/providers/cohere/client_test.go @@ -55,7 +55,7 @@ func TestCohereClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "human", Content: "What's the biggest animal?", }} diff --git a/pkg/providers/cohere/schemas.go b/pkg/providers/cohere/schemas.go new file mode 100644 index 00000000..c807aa56 --- /dev/null +++ b/pkg/providers/cohere/schemas.go @@ -0,0 +1,67 @@ +package cohere + +// Cohere Chat Response +type ChatCompletion struct { + Text string `json:"text"` + GenerationID string `json:"generation_id"` + ResponseID string `json:"response_id"` + TokenCount TokenCount `json:"token_count"` + Citations []Citation `json:"citations"` + Documents []Documents `json:"documents"` + SearchQueries []SearchQuery `json:"search_queries"` + SearchResults []SearchResults `json:"search_results"` + Meta Meta `json:"meta"` + ToolInputs map[string]interface{} `json:"tool_inputs"` +} + +type TokenCount struct { + PromptTokens float64 `json:"prompt_tokens"` + ResponseTokens float64 `json:"response_tokens"` + TotalTokens float64 `json:"total_tokens"` + BilledTokens float64 `json:"billed_tokens"` +} + +type Meta struct { + APIVersion struct { + Version string `json:"version"` + } `json:"api_version"` + BilledUnits struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"billed_units"` +} + +type Citation struct { + Start int `json:"start"` + End int `json:"end"` + Text string `json:"text"` + DocumentID []string `json:"document_id"` +} + +type Documents struct { + ID string `json:"id"` + Data map[string]string `json:"data"` // TODO: This needs to be updated +} + +type SearchQuery struct { + Text string `json:"text"` + GenerationID string `json:"generation_id"` +} + +type SearchResults struct { + SearchQuery []SearchQueryObject `json:"search_query"` + Connectors []ConnectorsResponse `json:"connectors"` + DocumentID []string `json:"documentId"` +} + +type SearchQueryObject struct { + Text string `json:"text"` + GenerationID string `json:"generationId"` +} + +type ConnectorsResponse struct { + ID string `json:"id"` + UserAccessToken string `json:"user_access_token"` + ContOnFail string `json:"continue_on_failure"` + Options map[string]string `json:"options"` +} diff --git a/pkg/providers/octoml/chat.go b/pkg/providers/octoml/chat.go index 29ca6b7d..4860a0b9 100644 --- a/pkg/providers/octoml/chat.go +++ b/pkg/providers/octoml/chat.go @@ -9,6 +9,8 @@ import ( "net/http" "time" + "glide/pkg/providers/openai" + "glide/pkg/providers/clients" "glide/pkg/api/schemas" @@ -47,7 +49,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) // Add items from messageHistory first and the new chat message last @@ -61,7 +63,7 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []Ch } // Chat sends a chat request to the specified octoml model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -77,7 +79,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) @@ -85,7 +87,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -152,7 +154,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var openAICompletion schemas.OpenAIChatCompletion // Octo uses the same response schema as OpenAI + var openAICompletion openai.ChatCompletion // Octo uses the same response schema as OpenAI err = json.Unmarshal(bodyBytes, &openAICompletion) if err != nil { @@ -161,7 +163,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.UnifiedChatResponse{ + response := schemas.ChatResponse{ ID: openAICompletion.ID, Created: openAICompletion.Created, Provider: providerName, diff --git a/pkg/providers/octoml/client_test.go b/pkg/providers/octoml/client_test.go index 1c5c7e63..c8a438c1 100644 --- a/pkg/providers/octoml/client_test.go +++ b/pkg/providers/octoml/client_test.go @@ -55,7 +55,7 @@ func TestOctoMLClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "human", Content: "What's the biggest animal?", }} @@ -88,7 +88,7 @@ func TestOctoMLClient_Chat_Error(t *testing.T) { require.NoError(t, err) // Create a chat request - request := schemas.UnifiedChatRequest{ + request := schemas.ChatRequest{ Message: schemas.ChatMessage{ Role: "human", Content: "What's the biggest animal?", diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index c296c080..bbcc4ff4 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -61,7 +61,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { +func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) // Add items from messageHistory first and the new chat message last @@ -75,7 +75,7 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []Ch } // Chat sends a chat request to the specified OpenAI model. -func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -91,7 +91,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { +func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation chatRequest := c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) @@ -99,7 +99,7 @@ func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *C return chatRequest } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -166,7 +166,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var openAICompletion schemas.OpenAIChatCompletion + var openAICompletion ChatCompletion err = json.Unmarshal(bodyBytes, &openAICompletion) if err != nil { @@ -174,8 +174,8 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, err } - // Map response to UnifiedChatResponse schema - response := schemas.UnifiedChatResponse{ + // Map response to ChatResponse schema + response := schemas.ChatResponse{ ID: openAICompletion.ID, Created: openAICompletion.Created, Provider: providerName, diff --git a/pkg/providers/openai/client_test.go b/pkg/providers/openai/client_test.go index db080ce4..6bd8298d 100644 --- a/pkg/providers/openai/client_test.go +++ b/pkg/providers/openai/client_test.go @@ -56,7 +56,7 @@ func TestOpenAIClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + request := schemas.ChatRequest{Message: schemas.ChatMessage{ Role: "user", Content: "What's the biggest animal?", }} diff --git a/pkg/providers/openai/schemas.go b/pkg/providers/openai/schemas.go new file mode 100644 index 00000000..cf41aebf --- /dev/null +++ b/pkg/providers/openai/schemas.go @@ -0,0 +1,26 @@ +package openai + +// OpenAI Chat Response (also used by Azure OpenAI and OctoML) + +type ChatCompletion struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` +} + +type Choice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens float64 `json:"prompt_tokens"` + CompletionTokens float64 `json:"completion_tokens"` + TotalTokens float64 `json:"total_tokens"` +} diff --git a/pkg/providers/provider.go b/pkg/providers/provider.go index 4a3774b2..399d6ee7 100644 --- a/pkg/providers/provider.go +++ b/pkg/providers/provider.go @@ -15,7 +15,7 @@ import ( // LangModelProvider defines an interface a provider should fulfill to be able to serve language chat requests type LangModelProvider interface { Provider() string - Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) + Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) } type Model interface { @@ -78,7 +78,7 @@ func (m *LangModel) Weight() int { return m.weight } -func (m *LangModel) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (m *LangModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { startedAt := time.Now() resp, err := m.client.Chat(ctx, request) diff --git a/pkg/providers/testing.go b/pkg/providers/testing.go index f408380c..890421a0 100644 --- a/pkg/providers/testing.go +++ b/pkg/providers/testing.go @@ -14,8 +14,8 @@ type ResponseMock struct { Err *error } -func (m *ResponseMock) Resp() *schemas.UnifiedChatResponse { - return &schemas.UnifiedChatResponse{ +func (m *ResponseMock) Resp() *schemas.ChatResponse { + return &schemas.ChatResponse{ ID: "rsp0001", ModelResponse: schemas.ProviderResponse{ SystemID: map[string]string{ @@ -40,7 +40,7 @@ func NewProviderMock(responses []ResponseMock) *ProviderMock { } } -func (c *ProviderMock) Chat(_ context.Context, _ *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatRequest) (*schemas.ChatResponse, error) { response := c.responses[c.idx] c.idx++ diff --git a/pkg/routers/router.go b/pkg/routers/router.go index c2149c7a..13d89fa3 100644 --- a/pkg/routers/router.go +++ b/pkg/routers/router.go @@ -55,7 +55,7 @@ func (r *LangRouter) ID() string { return r.routerID } -func (r *LangRouter) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { +func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { if len(r.models) == 0 { return nil, ErrNoModels }