Skip to content

Commit

Permalink
#58: init client
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrueger12 committed Jan 4, 2024
1 parent ff3dc22 commit eb69f6b
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 0 deletions.
186 changes: 186 additions & 0 deletions pkg/providers/octoml/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package octoml

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"

"glide/pkg/providers/errs"

"glide/pkg/api/schemas"
"go.uber.org/zap"
)

type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}

// ChatRequest is an octoml-specific request schema
type ChatRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopWords []string `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
PresencePenalty int `json:"presence_penalty,omitempty"`
}

// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
return &ChatRequest{
Model: cfg.Model,
Temperature: cfg.DefaultParams.Temperature,
TopP: cfg.DefaultParams.TopP,
MaxTokens: cfg.DefaultParams.MaxTokens,
StopWords: cfg.DefaultParams.StopWords,
Stream: false, // unsupported right now
FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty,
PresencePenalty: cfg.DefaultParams.PresencePenalty,
}
}

func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage {
messages := make([]ChatMessage, 0, len(request.MessageHistory)+1)

// Add items from messageHistory first and the new chat message last
for _, message := range request.MessageHistory {
messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content})
}

messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})

return messages
}

// Chat sends a chat request to the specified octoml model.
func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) {
// Create a new chat request
chatRequest := c.createChatRequestSchema(request)

chatResponse, err := c.doChatRequest(ctx, chatRequest)
if err != nil {
return nil, err
}

if len(chatResponse.ProviderResponse.Message.Content) == 0 {
return nil, ErrEmptyResponse
}

return chatResponse, nil
}

func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request)

return chatRequest
}

func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) {
// Build request payload
rawPayload, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("unable to marshal octoml chat request payload: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload))
if err != nil {
return nil, fmt.Errorf("unable to create octoml chat request: %w", err)
}

req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey))
req.Header.Set("Content-Type", "application/json")

// TODO: this could leak information from messages which may not be a desired thing to have
c.telemetry.Logger.Debug(
"octoml chat request",
zap.String("chat_url", c.chatURL),
zap.Any("payload", payload),
)

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send octoml chat request: %w", err)
}

defer resp.Body.Close() // TODO: handle this error

if resp.StatusCode != http.StatusOK {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
c.telemetry.Logger.Error("failed to read octoml chat response", zap.Error(err))
}

// TODO: Handle failure conditions
// TODO: return errors
c.telemetry.Logger.Error(
"octoml chat request failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(bodyBytes)),
zap.Any("headers", resp.Header),
)

return nil, errs.ErrProviderUnavailable
}

// Read the response body into a byte slice
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
c.telemetry.Logger.Error("failed to read octoml chat response", zap.Error(err))
return nil, err
}

// Parse the response JSON
var responseJSON map[string]interface{}

err = json.Unmarshal(bodyBytes, &responseJSON)
if err != nil {
c.telemetry.Logger.Error("failed to parse octoml chat response", zap.Error(err))
return nil, err
}

// Parse response
var response schemas.UnifiedChatResponse

var responsePayload schemas.ProviderResponse

var tokenCount schemas.TokenCount

message := responseJSON["choices"].([]interface{})[0].(map[string]interface{})["message"].(map[string]interface{})
messageStruct := schemas.ChatMessage{
Role: message["role"].(string),
Content: message["content"].(string),
}

tokenCount = schemas.TokenCount{
PromptTokens: responseJSON["usage"].(map[string]interface{})["prompt_tokens"].(float64),
ResponseTokens: responseJSON["usage"].(map[string]interface{})["completion_tokens"].(float64),
TotalTokens: responseJSON["usage"].(map[string]interface{})["total_tokens"].(float64),
}

responsePayload = schemas.ProviderResponse{
ResponseID: map[string]string{"system_fingerprint": "none"},
Message: messageStruct,
TokenCount: tokenCount,
}

response = schemas.UnifiedChatResponse{
ID: responseJSON["id"].(string),
Created: responseJSON["created"].(float64),
Provider: "octoml",
Router: "chat", //TODO: Update this with actual router
Model: responseJSON["model"].(string),
Cached: false,
ProviderResponse: responsePayload,
}

return &response, nil
}
63 changes: 63 additions & 0 deletions pkg/providers/octoml/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package octoml

import (
"errors"
"net/http"
"net/url"
"time"

"glide/pkg/telemetry"
)

// TODO: Explore resource pooling
// TODO: Optimize Type use
// TODO: Explore Hertz TLS & resource pooling

const (
providerName = "octoml"
)

// ErrEmptyResponse is returned when the OctoML API returns an empty response.
var (
ErrEmptyResponse = errors.New("empty response")
)

// Client is a client for accessing OctoML API
type Client struct {
baseURL string
chatURL string
chatRequestTemplate *ChatRequest
config *Config
httpClient *http.Client
telemetry *telemetry.Telemetry
}

// NewClient creates a new OctoML client for the OctoML API.
func NewClient(cfg *Config, tel *telemetry.Telemetry) (*Client, error) {
chatURL, err := url.JoinPath(cfg.BaseURL, cfg.ChatEndpoint)
if err != nil {
return nil, err
}

c := &Client{
baseURL: cfg.BaseURL,
chatURL: chatURL,
config: cfg,
chatRequestTemplate: NewChatRequestFromConfig(cfg),
httpClient: &http.Client{
// TODO: use values from the config
Timeout: time.Second * 30,
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 2,
},
},
telemetry: tel,
}

return c, nil
}

func (c *Client) Provider() string {
return providerName
}
35 changes: 35 additions & 0 deletions pkg/providers/octoml/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package octoml

import (
"context"
"testing"
"fmt"

"glide/pkg/api/schemas"

"glide/pkg/telemetry"

"github.com/stretchr/testify/require"
)

func TestOpenAIClient_ChatRequest(t *testing.T) {


ctx := context.Background()
cfg := DefaultConfig()

client, err := NewClient(cfg, telemetry.NewTelemetryMock())
require.NoError(t, err)

request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{
Role: "user",
Content: "What's the biggest animal?",
}}

response, err := client.Chat(ctx, &request)
require.NoError(t, err)

fmt.Println(response)

//require.Equal(t, "chatcmpl-123", response.ID)
}
62 changes: 62 additions & 0 deletions pkg/providers/octoml/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package octoml

import (
"glide/pkg/config/fields"
)

// Params defines OctoML-specific model params with the specific validation of values
// TODO: Add validations
type Params struct {
Temperature float64 `yaml:"temperature,omitempty" json:"temperature"`
TopP float64 `yaml:"top_p,omitempty" json:"top_p"`
MaxTokens int `yaml:"max_tokens,omitempty" json:"max_tokens"`
StopWords []string `yaml:"stop,omitempty" json:"stop"`
FrequencyPenalty int `yaml:"frequency_penalty,omitempty" json:"frequency_penalty"`
PresencePenalty int `yaml:"presence_penalty,omitempty" json:"presence_penalty"`
// Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment
}

func DefaultParams() Params {
return Params{
Temperature: 1,
TopP: 1,
MaxTokens: 100,
StopWords: []string{},
}
}

func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error {
*p = DefaultParams()

type plain Params // to avoid recursion

return unmarshal((*plain)(p))
}

type Config struct {
BaseURL string `yaml:"base_url" json:"baseUrl" validate:"required"`
ChatEndpoint string `yaml:"chat_endpoint" json:"chatEndpoint" validate:"required"`
Model string `yaml:"model" json:"model" validate:"required"`
APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"`
DefaultParams *Params `yaml:"default_params,omitempty" json:"defaultParams"`
}

// DefaultConfig for OctoML models
func DefaultConfig() *Config {
defaultParams := DefaultParams()

return &Config{
BaseURL: "https://text.octoai.run/v1",
ChatEndpoint: "/chat/completions",
Model: "mistral-7b-instruct-fp16",
DefaultParams: &defaultParams,
}
}

func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error {
*c = *DefaultConfig()

type plain Config // to avoid recursion

return unmarshal((*plain)(c))
}
15 changes: 15 additions & 0 deletions pkg/providers/octoml/testdata/chat.req.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "human",
"content": "What's the biggest animal?"
}
],
"temperature": 0.8,
"top_p": 1,
"max_tokens": 100,
"n": 1,
"user": null,
"seed": null
}
21 changes: 21 additions & 0 deletions pkg/providers/octoml/testdata/chat.success.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0613",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "\n\nHello there, how may I assist you today?"
},
"logprobs": null,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}

0 comments on commit eb69f6b

Please sign in to comment.