-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ff3dc22
commit eb69f6b
Showing
6 changed files
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
package octoml | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
|
||
"glide/pkg/providers/errs" | ||
|
||
"glide/pkg/api/schemas" | ||
"go.uber.org/zap" | ||
) | ||
|
||
type ChatMessage struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
} | ||
|
||
// ChatRequest is an octoml-specific request schema | ||
type ChatRequest struct { | ||
Model string `json:"model"` | ||
Messages []ChatMessage `json:"messages"` | ||
Temperature float64 `json:"temperature,omitempty"` | ||
TopP float64 `json:"top_p,omitempty"` | ||
MaxTokens int `json:"max_tokens,omitempty"` | ||
StopWords []string `json:"stop,omitempty"` | ||
Stream bool `json:"stream,omitempty"` | ||
FrequencyPenalty int `json:"frequency_penalty,omitempty"` | ||
PresencePenalty int `json:"presence_penalty,omitempty"` | ||
} | ||
|
||
// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives | ||
func NewChatRequestFromConfig(cfg *Config) *ChatRequest { | ||
return &ChatRequest{ | ||
Model: cfg.Model, | ||
Temperature: cfg.DefaultParams.Temperature, | ||
TopP: cfg.DefaultParams.TopP, | ||
MaxTokens: cfg.DefaultParams.MaxTokens, | ||
StopWords: cfg.DefaultParams.StopWords, | ||
Stream: false, // unsupported right now | ||
FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty, | ||
PresencePenalty: cfg.DefaultParams.PresencePenalty, | ||
} | ||
} | ||
|
||
func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { | ||
messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) | ||
|
||
// Add items from messageHistory first and the new chat message last | ||
for _, message := range request.MessageHistory { | ||
messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content}) | ||
} | ||
|
||
messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) | ||
|
||
return messages | ||
} | ||
|
||
// Chat sends a chat request to the specified octoml model. | ||
func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { | ||
// Create a new chat request | ||
chatRequest := c.createChatRequestSchema(request) | ||
|
||
chatResponse, err := c.doChatRequest(ctx, chatRequest) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if len(chatResponse.ProviderResponse.Message.Content) == 0 { | ||
return nil, ErrEmptyResponse | ||
} | ||
|
||
return chatResponse, nil | ||
} | ||
|
||
func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { | ||
// TODO: consider using objectpool to optimize memory allocation | ||
chatRequest := c.chatRequestTemplate // hoping to get a copy of the template | ||
chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) | ||
|
||
return chatRequest | ||
} | ||
|
||
func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { | ||
// Build request payload | ||
rawPayload, err := json.Marshal(payload) | ||
if err != nil { | ||
return nil, fmt.Errorf("unable to marshal octoml chat request payload: %w", err) | ||
} | ||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) | ||
if err != nil { | ||
return nil, fmt.Errorf("unable to create octoml chat request: %w", err) | ||
} | ||
|
||
req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey)) | ||
req.Header.Set("Content-Type", "application/json") | ||
|
||
// TODO: this could leak information from messages which may not be a desired thing to have | ||
c.telemetry.Logger.Debug( | ||
"octoml chat request", | ||
zap.String("chat_url", c.chatURL), | ||
zap.Any("payload", payload), | ||
) | ||
|
||
resp, err := c.httpClient.Do(req) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to send octoml chat request: %w", err) | ||
} | ||
|
||
defer resp.Body.Close() // TODO: handle this error | ||
|
||
if resp.StatusCode != http.StatusOK { | ||
bodyBytes, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
c.telemetry.Logger.Error("failed to read octoml chat response", zap.Error(err)) | ||
} | ||
|
||
// TODO: Handle failure conditions | ||
// TODO: return errors | ||
c.telemetry.Logger.Error( | ||
"octoml chat request failed", | ||
zap.Int("status_code", resp.StatusCode), | ||
zap.String("response", string(bodyBytes)), | ||
zap.Any("headers", resp.Header), | ||
) | ||
|
||
return nil, errs.ErrProviderUnavailable | ||
} | ||
|
||
// Read the response body into a byte slice | ||
bodyBytes, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
c.telemetry.Logger.Error("failed to read octoml chat response", zap.Error(err)) | ||
return nil, err | ||
} | ||
|
||
// Parse the response JSON | ||
var responseJSON map[string]interface{} | ||
|
||
err = json.Unmarshal(bodyBytes, &responseJSON) | ||
if err != nil { | ||
c.telemetry.Logger.Error("failed to parse octoml chat response", zap.Error(err)) | ||
return nil, err | ||
} | ||
|
||
// Parse response | ||
var response schemas.UnifiedChatResponse | ||
|
||
var responsePayload schemas.ProviderResponse | ||
|
||
var tokenCount schemas.TokenCount | ||
|
||
message := responseJSON["choices"].([]interface{})[0].(map[string]interface{})["message"].(map[string]interface{}) | ||
messageStruct := schemas.ChatMessage{ | ||
Role: message["role"].(string), | ||
Content: message["content"].(string), | ||
} | ||
|
||
tokenCount = schemas.TokenCount{ | ||
PromptTokens: responseJSON["usage"].(map[string]interface{})["prompt_tokens"].(float64), | ||
ResponseTokens: responseJSON["usage"].(map[string]interface{})["completion_tokens"].(float64), | ||
TotalTokens: responseJSON["usage"].(map[string]interface{})["total_tokens"].(float64), | ||
} | ||
|
||
responsePayload = schemas.ProviderResponse{ | ||
ResponseID: map[string]string{"system_fingerprint": "none"}, | ||
Message: messageStruct, | ||
TokenCount: tokenCount, | ||
} | ||
|
||
response = schemas.UnifiedChatResponse{ | ||
ID: responseJSON["id"].(string), | ||
Created: responseJSON["created"].(float64), | ||
Provider: "octoml", | ||
Router: "chat", //TODO: Update this with actual router | ||
Model: responseJSON["model"].(string), | ||
Cached: false, | ||
ProviderResponse: responsePayload, | ||
} | ||
|
||
return &response, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
package octoml | ||
|
||
import ( | ||
"errors" | ||
"net/http" | ||
"net/url" | ||
"time" | ||
|
||
"glide/pkg/telemetry" | ||
) | ||
|
||
// TODO: Explore resource pooling | ||
// TODO: Optimize Type use | ||
// TODO: Explore Hertz TLS & resource pooling | ||
|
||
const ( | ||
providerName = "octoml" | ||
) | ||
|
||
// ErrEmptyResponse is returned when the OctoML API returns an empty response. | ||
var ( | ||
ErrEmptyResponse = errors.New("empty response") | ||
) | ||
|
||
// Client is a client for accessing OctoML API | ||
type Client struct { | ||
baseURL string | ||
chatURL string | ||
chatRequestTemplate *ChatRequest | ||
config *Config | ||
httpClient *http.Client | ||
telemetry *telemetry.Telemetry | ||
} | ||
|
||
// NewClient creates a new OctoML client for the OctoML API. | ||
func NewClient(cfg *Config, tel *telemetry.Telemetry) (*Client, error) { | ||
chatURL, err := url.JoinPath(cfg.BaseURL, cfg.ChatEndpoint) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
c := &Client{ | ||
baseURL: cfg.BaseURL, | ||
chatURL: chatURL, | ||
config: cfg, | ||
chatRequestTemplate: NewChatRequestFromConfig(cfg), | ||
httpClient: &http.Client{ | ||
// TODO: use values from the config | ||
Timeout: time.Second * 30, | ||
Transport: &http.Transport{ | ||
MaxIdleConns: 100, | ||
MaxIdleConnsPerHost: 2, | ||
}, | ||
}, | ||
telemetry: tel, | ||
} | ||
|
||
return c, nil | ||
} | ||
|
||
func (c *Client) Provider() string { | ||
return providerName | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package octoml | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
"fmt" | ||
|
||
"glide/pkg/api/schemas" | ||
|
||
"glide/pkg/telemetry" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestOpenAIClient_ChatRequest(t *testing.T) { | ||
|
||
|
||
ctx := context.Background() | ||
cfg := DefaultConfig() | ||
|
||
client, err := NewClient(cfg, telemetry.NewTelemetryMock()) | ||
require.NoError(t, err) | ||
|
||
request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ | ||
Role: "user", | ||
Content: "What's the biggest animal?", | ||
}} | ||
|
||
response, err := client.Chat(ctx, &request) | ||
require.NoError(t, err) | ||
|
||
fmt.Println(response) | ||
|
||
//require.Equal(t, "chatcmpl-123", response.ID) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package octoml | ||
|
||
import ( | ||
"glide/pkg/config/fields" | ||
) | ||
|
||
// Params defines OctoML-specific model params with the specific validation of values | ||
// TODO: Add validations | ||
type Params struct { | ||
Temperature float64 `yaml:"temperature,omitempty" json:"temperature"` | ||
TopP float64 `yaml:"top_p,omitempty" json:"top_p"` | ||
MaxTokens int `yaml:"max_tokens,omitempty" json:"max_tokens"` | ||
StopWords []string `yaml:"stop,omitempty" json:"stop"` | ||
FrequencyPenalty int `yaml:"frequency_penalty,omitempty" json:"frequency_penalty"` | ||
PresencePenalty int `yaml:"presence_penalty,omitempty" json:"presence_penalty"` | ||
// Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment | ||
} | ||
|
||
func DefaultParams() Params { | ||
return Params{ | ||
Temperature: 1, | ||
TopP: 1, | ||
MaxTokens: 100, | ||
StopWords: []string{}, | ||
} | ||
} | ||
|
||
func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error { | ||
*p = DefaultParams() | ||
|
||
type plain Params // to avoid recursion | ||
|
||
return unmarshal((*plain)(p)) | ||
} | ||
|
||
type Config struct { | ||
BaseURL string `yaml:"base_url" json:"baseUrl" validate:"required"` | ||
ChatEndpoint string `yaml:"chat_endpoint" json:"chatEndpoint" validate:"required"` | ||
Model string `yaml:"model" json:"model" validate:"required"` | ||
APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` | ||
DefaultParams *Params `yaml:"default_params,omitempty" json:"defaultParams"` | ||
} | ||
|
||
// DefaultConfig for OctoML models | ||
func DefaultConfig() *Config { | ||
defaultParams := DefaultParams() | ||
|
||
return &Config{ | ||
BaseURL: "https://text.octoai.run/v1", | ||
ChatEndpoint: "/chat/completions", | ||
Model: "mistral-7b-instruct-fp16", | ||
DefaultParams: &defaultParams, | ||
} | ||
} | ||
|
||
func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { | ||
*c = *DefaultConfig() | ||
|
||
type plain Config // to avoid recursion | ||
|
||
return unmarshal((*plain)(c)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"model": "gpt-3.5-turbo", | ||
"messages": [ | ||
{ | ||
"role": "human", | ||
"content": "What's the biggest animal?" | ||
} | ||
], | ||
"temperature": 0.8, | ||
"top_p": 1, | ||
"max_tokens": 100, | ||
"n": 1, | ||
"user": null, | ||
"seed": null | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
{ | ||
"id": "chatcmpl-123", | ||
"object": "chat.completion", | ||
"created": 1677652288, | ||
"model": "gpt-3.5-turbo-0613", | ||
"system_fingerprint": "fp_44709d6fcb", | ||
"choices": [{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": "\n\nHello there, how may I assist you today?" | ||
}, | ||
"logprobs": null, | ||
"finish_reason": "stop" | ||
}], | ||
"usage": { | ||
"prompt_tokens": 9, | ||
"completion_tokens": 12, | ||
"total_tokens": 21 | ||
} | ||
} |